aboutsummaryrefslogtreecommitdiff
path: root/libs
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-13 23:38:43 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-13 23:38:43 +0800
commit957a2a46e7725184776c3c72860e8215164cc4ef (patch)
tree43e098595db4ee332bca5f6caecfbd02369debbe /libs
parent1b8f01ce9706905c36c6f11ed9deac8548ad7341 (diff)
Implement distributed data parallel via torch elastic launcher
Diffstat (limited to 'libs')
-rw-r--r--libs/criteria.py19
-rw-r--r--libs/datautils.py11
-rw-r--r--libs/utils.py71
3 files changed, 86 insertions, 15 deletions
diff --git a/libs/criteria.py b/libs/criteria.py
index baa36ce..7d367c1 100644
--- a/libs/criteria.py
+++ b/libs/criteria.py
@@ -1,4 +1,6 @@
import torch
+import torch.distributed as dist
+import torch.distributed.rpc as rpc
from torch import nn, Tensor
from torch.nn import functional as F
@@ -8,10 +10,21 @@ class InfoNCELoss(nn.Module):
super().__init__()
self.temp = temp
+ @staticmethod
+ def _norm_and_stack(feat: Tensor) -> Tensor:
+ local_feat_norm = F.normalize(feat)
+ local_feat_norm_stack = torch.stack(local_feat_norm.chunk(2))
+
+ return local_feat_norm_stack
+
def forward(self, feature: Tensor) -> tuple[Tensor, Tensor]:
- bz = feature.size(0) // 2
- feat_norm = F.normalize(feature)
- feat1_norm, feat2_norm = feat_norm.split(bz)
+ feat_norm = torch.cat([
+ rpc.rpc_sync(f"worker{i}", self._norm_and_stack, (feature,))
+ for i in range(dist.get_world_size())
+ ], dim=1)
+ bz = feat_norm.size(1)
+
+ feat1_norm, feat2_norm = feat_norm[0], feat_norm[1]
logits = feat1_norm @ feat2_norm.T
pos_logits_mask = torch.eye(bz, dtype=torch.bool)
pos_logits = logits[pos_logits_mask].unsqueeze(-1)
diff --git a/libs/datautils.py b/libs/datautils.py
index 6a7c506..53222a8 100644
--- a/libs/datautils.py
+++ b/libs/datautils.py
@@ -125,3 +125,14 @@ class TwinTransform:
v1 = self.transform(x)
v2 = self.transform(x)
return v1, v2
+
+
+class ContinuousSampler(torch.utils.data.sampler.Sampler):
+ def __init__(self, sampler):
+ super(ContinuousSampler, self).__init__(sampler)
+ self.base_sampler = sampler
+
+ def __iter__(self):
+ while True:
+ for batch in self.base_sampler:
+ yield batch
diff --git a/libs/utils.py b/libs/utils.py
index 63ea116..90ae48f 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -6,10 +6,15 @@ from dataclasses import dataclass
from typing import Iterable, Callable
import torch
+from torch import distributed as dist
from torch.backends import cudnn
-from torch.utils.data import Dataset, DataLoader, RandomSampler
+from torch.distributed import rpc as rpc
+from torch.utils.data import Dataset, DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data.sampler import BatchSampler
from torch.utils.tensorboard import SummaryWriter
+from libs.datautils import ContinuousSampler
from libs.logging import CSV_EPOCH_LOGGER, CSV_BATCH_LOGGER, BaseBatchLogRecord, BaseEpochLogRecord, Loggers, \
init_csv_logger, csv_logger, tensorboard_logger
from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR
@@ -61,7 +66,7 @@ class Trainer(ABC):
self,
seed: int,
checkpoint_dir: str,
- device: torch.device,
+ device: int,
inf_mode: bool,
num_iters: int,
config: BaseConfig,
@@ -75,7 +80,7 @@ class Trainer(ABC):
)
models = self._init_models(config.dataset_config.dataset)
- models = {n: m.to(device) for n, m in models}
+ models = dict(self._models_to_devices(models, device))
optims = dict(self._configure_optimizers(models.items(), config.optim_config))
last_metrics = self._auto_load_checkpoint(
checkpoint_dir, inf_mode, **(models | optims)
@@ -155,22 +160,25 @@ class Trainer(ABC):
train_set: Dataset, test_set: Dataset,
inf_mode: bool, dataloader_config: BaseConfig.DataLoaderConfig
) -> tuple[DataLoader, DataLoader]:
+ train_sampler = DistributedSampler(train_set, shuffle=True)
+ test_sampler = DistributedSampler(test_set, shuffle=False)
if inf_mode:
- inf_sampler = RandomSampler(train_set,
- replacement=True,
- num_samples=int(1e20))
+ inf_train_sampler = ContinuousSampler(
+ BatchSampler(train_sampler,
+ dataloader_config.batch_size,
+ drop_last=True)
+ )
train_loader = DataLoader(train_set,
- sampler=inf_sampler,
- batch_size=dataloader_config.batch_size,
+ batch_sampler=inf_train_sampler,
num_workers=dataloader_config.num_workers)
else:
train_loader = DataLoader(train_set,
- shuffle=True,
batch_size=dataloader_config.batch_size,
+ sampler=train_sampler,
num_workers=dataloader_config.num_workers)
test_loader = DataLoader(test_set,
- shuffle=False,
batch_size=dataloader_config.batch_size,
+ sampler=test_sampler,
num_workers=dataloader_config.num_workers)
return train_loader, test_loader
@@ -182,6 +190,18 @@ class Trainer(ABC):
yield 'model_name', model
@staticmethod
+ def _models_to_devices(
+ models: Iterable[tuple[str, torch.nn.Module]],
+ device: int,
+ ) -> Iterable[tuple[str, torch.nn.Module]]:
+ for name, model in models:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = model.to(device)
+ model = torch.nn.parallel.DistributedDataParallel(model, [device])
+
+ yield name, model
+
+ @staticmethod
@abstractmethod
def _configure_optimizers(
models: Iterable[tuple[str, torch.nn.Module]],
@@ -267,9 +287,36 @@ class Trainer(ABC):
torch.save(checkpoint, checkpoint_name)
@abstractmethod
- def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device):
+ def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int):
pass
@abstractmethod
- def eval(self, loss_fn: Callable, device: torch.device):
+ def eval(self, loss_fn: Callable, device: int):
pass
+
+
+def elastic_launch(func):
+ dist.init_process_group(backend='nccl')
+ local_rank = int(os.environ['LOCAL_RANK'])
+ local_world_size = int(os.environ['LOCAL_WORLD_SIZE'])
+ ngpu_per_proc = torch.cuda.device_count() // local_world_size
+
+ assert ngpu_per_proc == 1
+
+ global_rank = dist.get_rank()
+ global_world_size = dist.get_world_size()
+ rpc.init_rpc(
+ f"worker{global_rank}",
+ rank=global_rank,
+ world_size=global_world_size,
+ rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
+ device_maps={f"worker{callee}": {
+ caller: callee for caller in range(global_world_size)
+ } for callee in range(global_world_size)}
+ )
+ )
+
+ func(local_rank, global_rank)
+
+ rpc.shutdown()
+ dist.destroy_process_group()