aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--libs/criteria.py19
-rw-r--r--libs/datautils.py11
-rw-r--r--libs/utils.py71
-rw-r--r--simclr/evaluate.py63
-rw-r--r--simclr/main.py67
-rw-r--r--supervised/baseline.py53
6 files changed, 200 insertions, 84 deletions
diff --git a/libs/criteria.py b/libs/criteria.py
index baa36ce..7d367c1 100644
--- a/libs/criteria.py
+++ b/libs/criteria.py
@@ -1,4 +1,6 @@
import torch
+import torch.distributed as dist
+import torch.distributed.rpc as rpc
from torch import nn, Tensor
from torch.nn import functional as F
@@ -8,10 +10,21 @@ class InfoNCELoss(nn.Module):
super().__init__()
self.temp = temp
+ @staticmethod
+ def _norm_and_stack(feat: Tensor) -> Tensor:
+ local_feat_norm = F.normalize(feat)
+ local_feat_norm_stack = torch.stack(local_feat_norm.chunk(2))
+
+ return local_feat_norm_stack
+
def forward(self, feature: Tensor) -> tuple[Tensor, Tensor]:
- bz = feature.size(0) // 2
- feat_norm = F.normalize(feature)
- feat1_norm, feat2_norm = feat_norm.split(bz)
+ feat_norm = torch.cat([
+ rpc.rpc_sync(f"worker{i}", self._norm_and_stack, (feature,))
+ for i in range(dist.get_world_size())
+ ], dim=1)
+ bz = feat_norm.size(1)
+
+ feat1_norm, feat2_norm = feat_norm[0], feat_norm[1]
logits = feat1_norm @ feat2_norm.T
pos_logits_mask = torch.eye(bz, dtype=torch.bool)
pos_logits = logits[pos_logits_mask].unsqueeze(-1)
diff --git a/libs/datautils.py b/libs/datautils.py
index 6a7c506..53222a8 100644
--- a/libs/datautils.py
+++ b/libs/datautils.py
@@ -125,3 +125,14 @@ class TwinTransform:
v1 = self.transform(x)
v2 = self.transform(x)
return v1, v2
+
+
+class ContinuousSampler(torch.utils.data.sampler.Sampler):
+ def __init__(self, sampler):
+ super(ContinuousSampler, self).__init__(sampler)
+ self.base_sampler = sampler
+
+ def __iter__(self):
+ while True:
+ for batch in self.base_sampler:
+ yield batch
diff --git a/libs/utils.py b/libs/utils.py
index 63ea116..90ae48f 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -6,10 +6,15 @@ from dataclasses import dataclass
from typing import Iterable, Callable
import torch
+from torch import distributed as dist
from torch.backends import cudnn
-from torch.utils.data import Dataset, DataLoader, RandomSampler
+from torch.distributed import rpc as rpc
+from torch.utils.data import Dataset, DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data.sampler import BatchSampler
from torch.utils.tensorboard import SummaryWriter
+from libs.datautils import ContinuousSampler
from libs.logging import CSV_EPOCH_LOGGER, CSV_BATCH_LOGGER, BaseBatchLogRecord, BaseEpochLogRecord, Loggers, \
init_csv_logger, csv_logger, tensorboard_logger
from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR
@@ -61,7 +66,7 @@ class Trainer(ABC):
self,
seed: int,
checkpoint_dir: str,
- device: torch.device,
+ device: int,
inf_mode: bool,
num_iters: int,
config: BaseConfig,
@@ -75,7 +80,7 @@ class Trainer(ABC):
)
models = self._init_models(config.dataset_config.dataset)
- models = {n: m.to(device) for n, m in models}
+ models = dict(self._models_to_devices(models, device))
optims = dict(self._configure_optimizers(models.items(), config.optim_config))
last_metrics = self._auto_load_checkpoint(
checkpoint_dir, inf_mode, **(models | optims)
@@ -155,22 +160,25 @@ class Trainer(ABC):
train_set: Dataset, test_set: Dataset,
inf_mode: bool, dataloader_config: BaseConfig.DataLoaderConfig
) -> tuple[DataLoader, DataLoader]:
+ train_sampler = DistributedSampler(train_set, shuffle=True)
+ test_sampler = DistributedSampler(test_set, shuffle=False)
if inf_mode:
- inf_sampler = RandomSampler(train_set,
- replacement=True,
- num_samples=int(1e20))
+ inf_train_sampler = ContinuousSampler(
+ BatchSampler(train_sampler,
+ dataloader_config.batch_size,
+ drop_last=True)
+ )
train_loader = DataLoader(train_set,
- sampler=inf_sampler,
- batch_size=dataloader_config.batch_size,
+ batch_sampler=inf_train_sampler,
num_workers=dataloader_config.num_workers)
else:
train_loader = DataLoader(train_set,
- shuffle=True,
batch_size=dataloader_config.batch_size,
+ sampler=train_sampler,
num_workers=dataloader_config.num_workers)
test_loader = DataLoader(test_set,
- shuffle=False,
batch_size=dataloader_config.batch_size,
+ sampler=test_sampler,
num_workers=dataloader_config.num_workers)
return train_loader, test_loader
@@ -182,6 +190,18 @@ class Trainer(ABC):
yield 'model_name', model
@staticmethod
+ def _models_to_devices(
+ models: Iterable[tuple[str, torch.nn.Module]],
+ device: int,
+ ) -> Iterable[tuple[str, torch.nn.Module]]:
+ for name, model in models:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = model.to(device)
+ model = torch.nn.parallel.DistributedDataParallel(model, [device])
+
+ yield name, model
+
+ @staticmethod
@abstractmethod
def _configure_optimizers(
models: Iterable[tuple[str, torch.nn.Module]],
@@ -267,9 +287,36 @@ class Trainer(ABC):
torch.save(checkpoint, checkpoint_name)
@abstractmethod
- def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device):
+ def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int):
pass
@abstractmethod
- def eval(self, loss_fn: Callable, device: torch.device):
+ def eval(self, loss_fn: Callable, device: int):
pass
+
+
+def elastic_launch(func):
+ dist.init_process_group(backend='nccl')
+ local_rank = int(os.environ['LOCAL_RANK'])
+ local_world_size = int(os.environ['LOCAL_WORLD_SIZE'])
+ ngpu_per_proc = torch.cuda.device_count() // local_world_size
+
+ assert ngpu_per_proc == 1
+
+ global_rank = dist.get_rank()
+ global_world_size = dist.get_world_size()
+ rpc.init_rpc(
+ f"worker{global_rank}",
+ rank=global_rank,
+ world_size=global_world_size,
+ rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
+ device_maps={f"worker{callee}": {
+ caller: callee for caller in range(global_world_size)
+ } for callee in range(global_world_size)}
+ )
+ )
+
+ func(local_rank, global_rank)
+
+ rpc.shutdown()
+ dist.destroy_process_group()
diff --git a/simclr/evaluate.py b/simclr/evaluate.py
index 5c41b84..23bd299 100644
--- a/simclr/evaluate.py
+++ b/simclr/evaluate.py
@@ -6,6 +6,7 @@ from typing import Iterable, Callable
import sys
import torch
+import torch.distributed as dist
import yaml
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10, CIFAR100, ImageNet
@@ -16,7 +17,7 @@ sys.path.insert(0, path)
from libs.optimizers import LARS
from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord
-from libs.utils import BaseConfig
+from libs.utils import BaseConfig, elastic_launch
from simclr.main import SimCLRTrainer, SimCLRConfig
from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50
@@ -190,7 +191,7 @@ class SimCLREvalTrainer(SimCLRTrainer):
if k in self.models['backbone'].state_dict()}
self.models['backbone'].load_state_dict(backbone_state_dict)
- def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device):
+ def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int):
backbone, classifier = self.models.values()
optim_b, optim_c = self.optims.values()
sched_b, sched_c = self.scheds.values()
@@ -218,24 +219,33 @@ class SimCLREvalTrainer(SimCLRTrainer):
if self.finetune:
optim_b.step()
optim_c.step()
- self.log(logger, self.BatchLogRecord(
- batch, num_batches, global_batch, iter_, num_iters,
- optim_c.param_groups[0]['lr'], train_loss.item()
- ))
+
+ if logger is not None:
+ self.log(logger, self.BatchLogRecord(
+ batch, num_batches, global_batch, iter_, num_iters,
+ optim_c.param_groups[0]['lr'], train_loss.item()
+ ))
+ dist.barrier()
+
if (iter_ + 1) % (num_iters // 10) == 0:
- metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
- eval_loss = metrics[0].item()
- eval_accuracy = metrics[1].item()
- epoch_log = self.EpochLogRecord(iter_, num_iters,
- eval_loss, eval_accuracy)
- self.log(logger, epoch_log)
- self.save_checkpoint(epoch_log)
- if sched_b is not None and self.finetune:
- sched_b.step()
- if sched_c is not None:
- sched_c.step()
-
- def eval(self, loss_fn: Callable, device: torch.device):
+ # TODO Gather results from other workers
+ metrics = torch.Tensor(list(self.eval(loss_fn, device)))
+ if logger is not None:
+ metrics_mean = metrics.mean(0)
+ eval_loss = metrics_mean[0].item()
+ eval_accuracy = metrics_mean[1].item()
+ epoch_log = self.EpochLogRecord(iter_, num_iters,
+ eval_loss, eval_accuracy)
+ self.log(logger, epoch_log)
+ self.save_checkpoint(epoch_log)
+ dist.barrier()
+
+ if sched_b is not None and self.finetune:
+ sched_b.step()
+ if sched_c is not None:
+ sched_c.step()
+
+ def eval(self, loss_fn: Callable, device: int):
backbone, classifier = self.models.values()
backbone.eval()
classifier.eval()
@@ -250,14 +260,13 @@ class SimCLREvalTrainer(SimCLRTrainer):
yield loss.item(), accuracy.item()
-if __name__ == '__main__':
+def main(local_rank, global_rank):
args = parse_args_and_config()
config = SimCLREvalConfig.from_args(args)
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = SimCLREvalTrainer(
seed=args.seed,
checkpoint_dir=args.checkpoint_dir,
- device=device,
+ device=local_rank,
inf_mode=False,
num_iters=args.num_iters,
config=config,
@@ -267,5 +276,11 @@ if __name__ == '__main__':
finetune=args.finetune,
)
- loggers = trainer.init_logger(args.log_dir)
- trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device)
+ loggers = None
+ if global_rank == 0:
+ loggers = trainer.init_logger(args.log_dir)
+ trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, local_rank)
+
+
+if __name__ == '__main__':
+ elastic_launch(main)
diff --git a/simclr/main.py b/simclr/main.py
index 69e2ab2..3456d41 100644
--- a/simclr/main.py
+++ b/simclr/main.py
@@ -1,11 +1,12 @@
import argparse
import os
-import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Callable
+import sys
import torch
+import torch.distributed as dist
import yaml
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10, CIFAR100, ImageNet
@@ -17,8 +18,7 @@ sys.path.insert(0, path)
from libs.criteria import InfoNCELoss
from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTransform
from libs.optimizers import LARS
-from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR
-from libs.utils import Trainer, BaseConfig
+from libs.utils import Trainer, BaseConfig, elastic_launch
from libs.logging import BaseBatchLogRecord, Loggers
from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50
@@ -254,11 +254,12 @@ class SimCLRTrainer(Trainer):
self.optims = {n: LARS(o) if config.optim_config.optim == 'lars' else o
for n, o in self.optims.items()}
- def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device):
+ def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int):
model = self.models['model']
optim = self.optims['model_optim']
sched = self.scheds['model_optim_sched']
train_loader = iter(self.train_loader)
+
model.train()
for iter_ in range(self.restore_iter, num_iters):
input_, _ = next(train_loader)
@@ -268,28 +269,37 @@ class SimCLRTrainer(Trainer):
train_loss, train_accuracy = loss_fn(output)
train_loss.backward()
optim.step()
- self.log(logger, self.BatchLogRecord(
- iter_, num_iters, iter_, iter_, num_iters,
- optim.param_groups[0]['lr'],
- train_loss.item(), train_accuracy.item(),
- eval_loss=None, eval_accuracy=None,
- ))
- if (iter_ + 1) % (num_iters // 100) == 0:
- metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
- eval_loss = metrics[0].item()
- eval_accuracy = metrics[1].item()
- eval_log = self.BatchLogRecord(
+
+ if logger is not None:
+ self.log(logger, self.BatchLogRecord(
iter_, num_iters, iter_, iter_, num_iters,
- lr=None, train_loss=None, train_accuracy=None,
- eval_loss=eval_loss, eval_accuracy=eval_accuracy,
- )
- self.log(logger, eval_log)
- self.save_checkpoint(eval_log)
+ optim.param_groups[0]['lr'],
+ train_loss.item(), train_accuracy.item(),
+ eval_loss=None, eval_accuracy=None,
+ ))
+ dist.barrier()
+
+ if (iter_ + 1) % (num_iters // 100) == 0:
+ # TODO Gather results from other workers
+ metrics = torch.Tensor(list(self.eval(loss_fn, device)))
+ if logger is not None:
+ metrics_mean = metrics.mean(0)
+ eval_loss = metrics_mean[0].item()
+ eval_accuracy = metrics_mean[1].item()
+ eval_log = self.BatchLogRecord(
+ iter_, num_iters, iter_, iter_, num_iters,
+ lr=None, train_loss=None, train_accuracy=None,
+ eval_loss=eval_loss, eval_accuracy=eval_accuracy,
+ )
+ self.log(logger, eval_log)
+ self.save_checkpoint(eval_log)
model.train()
+ dist.barrier()
+
if sched is not None:
sched.step()
- def eval(self, loss_fn: Callable, device: torch.device):
+ def eval(self, loss_fn: Callable, device: int):
model = self.models['model']
model.eval()
with torch.no_grad():
@@ -300,14 +310,13 @@ class SimCLRTrainer(Trainer):
yield loss.item(), accuracy.item()
-if __name__ == '__main__':
+def main(local_rank, global_rank):
args = parse_args_and_config()
config = SimCLRConfig.from_args(args)
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = SimCLRTrainer(
seed=args.seed,
checkpoint_dir=args.checkpoint_dir,
- device=device,
+ device=local_rank,
inf_mode=True,
num_iters=args.num_iters,
config=config,
@@ -315,5 +324,11 @@ if __name__ == '__main__':
out_dim=args.out_dim,
)
- loggers = trainer.init_logger(args.log_dir)
- trainer.train(args.num_iters, InfoNCELoss(args.temp), loggers, device)
+ loggers = None
+ if global_rank == 0:
+ loggers = trainer.init_logger(args.log_dir)
+ trainer.train(args.num_iters, InfoNCELoss(args.temp), loggers, local_rank)
+
+
+if __name__ == '__main__':
+ elastic_launch(main)
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 4be1c97..9ce3cf0 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -6,6 +6,7 @@ from typing import Iterable, Callable
import sys
import torch
+import torch.distributed as dist
import yaml
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10, CIFAR100
@@ -15,7 +16,7 @@ path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
sys.path.insert(0, path)
from libs.datautils import Clip
-from libs.utils import Trainer, BaseConfig
+from libs.utils import Trainer, BaseConfig, elastic_launch
from libs.logging import BaseBatchLogRecord, BaseEpochLogRecord, Loggers
from models import CIFARResNet50, CIFARViTTiny
@@ -211,7 +212,7 @@ class SupBaselineTrainer(Trainer):
yield f"{model_name}_optim", optimizer
- def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device):
+ def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int):
model = self.models['model']
optim = self.optims['model_optim']
sched = self.scheds['model_optim_sched']
@@ -227,23 +228,32 @@ class SupBaselineTrainer(Trainer):
train_loss = loss_fn(output, targets)
train_loss.backward()
optim.step()
- self.log(logger, self.BatchLogRecord(
- batch, num_batches, global_batch, iter_, num_iters,
- optim.param_groups[0]['lr'], train_loss.item()
- ))
- metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
- eval_loss = metrics[0].item()
- eval_accuracy = metrics[1].item()
- epoch_log = self.EpochLogRecord(iter_, num_iters,
- eval_loss, eval_accuracy)
- self.log(logger, epoch_log)
- self.save_checkpoint(epoch_log)
+
+ if logger is not None:
+ self.log(logger, self.BatchLogRecord(
+ batch, num_batches, global_batch, iter_, num_iters,
+ optim.param_groups[0]['lr'], train_loss.item()
+ ))
+ dist.barrier()
+
+ # TODO Gather results from other workers
+ metrics = torch.Tensor(list(self.eval(loss_fn, device)))
+ if logger is not None:
+ metrics_mean = metrics.mean(0)
+ eval_loss = metrics_mean[0].item()
+ eval_accuracy = metrics_mean[1].item()
+ epoch_log = self.EpochLogRecord(iter_, num_iters,
+ eval_loss, eval_accuracy)
+ self.log(logger, epoch_log)
+ self.save_checkpoint(epoch_log)
+ dist.barrier()
+
# Step after save checkpoint, otherwise the schedular will
# one iter ahead after restore
if sched is not None:
sched.step()
- def eval(self, loss_fn: Callable, device: torch.device):
+ def eval(self, loss_fn: Callable, device: int):
model = self.models['model']
model.eval()
with torch.no_grad():
@@ -256,20 +266,25 @@ class SupBaselineTrainer(Trainer):
yield loss.item(), accuracy.item()
-if __name__ == '__main__':
+def main(local_rank, global_rank):
args = parse_args_and_config()
config = SupBaselineConfig.from_args(args)
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = SupBaselineTrainer(
seed=args.seed,
checkpoint_dir=args.checkpoint_dir,
- device=device,
+ device=local_rank,
inf_mode=False,
num_iters=args.num_iters,
config=config,
backbone=args.backbone,
)
- loggers = trainer.init_logger(args.log_dir)
+ loggers = None
+ if global_rank == 0:
+ loggers = trainer.init_logger(args.log_dir)
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smooth)
- trainer.train(args.num_iters, loss_fn, loggers, device)
+ trainer.train(args.num_iters, loss_fn, loggers, local_rank)
+
+
+if __name__ == '__main__':
+ elastic_launch(main)