diff options
-rw-r--r-- | libs/criteria.py | 19 | ||||
-rw-r--r-- | libs/datautils.py | 11 | ||||
-rw-r--r-- | libs/utils.py | 71 | ||||
-rw-r--r-- | simclr/evaluate.py | 63 | ||||
-rw-r--r-- | simclr/main.py | 67 | ||||
-rw-r--r-- | supervised/baseline.py | 53 |
6 files changed, 200 insertions, 84 deletions
diff --git a/libs/criteria.py b/libs/criteria.py index baa36ce..7d367c1 100644 --- a/libs/criteria.py +++ b/libs/criteria.py @@ -1,4 +1,6 @@ import torch +import torch.distributed as dist +import torch.distributed.rpc as rpc from torch import nn, Tensor from torch.nn import functional as F @@ -8,10 +10,21 @@ class InfoNCELoss(nn.Module): super().__init__() self.temp = temp + @staticmethod + def _norm_and_stack(feat: Tensor) -> Tensor: + local_feat_norm = F.normalize(feat) + local_feat_norm_stack = torch.stack(local_feat_norm.chunk(2)) + + return local_feat_norm_stack + def forward(self, feature: Tensor) -> tuple[Tensor, Tensor]: - bz = feature.size(0) // 2 - feat_norm = F.normalize(feature) - feat1_norm, feat2_norm = feat_norm.split(bz) + feat_norm = torch.cat([ + rpc.rpc_sync(f"worker{i}", self._norm_and_stack, (feature,)) + for i in range(dist.get_world_size()) + ], dim=1) + bz = feat_norm.size(1) + + feat1_norm, feat2_norm = feat_norm[0], feat_norm[1] logits = feat1_norm @ feat2_norm.T pos_logits_mask = torch.eye(bz, dtype=torch.bool) pos_logits = logits[pos_logits_mask].unsqueeze(-1) diff --git a/libs/datautils.py b/libs/datautils.py index 6a7c506..53222a8 100644 --- a/libs/datautils.py +++ b/libs/datautils.py @@ -125,3 +125,14 @@ class TwinTransform: v1 = self.transform(x) v2 = self.transform(x) return v1, v2 + + +class ContinuousSampler(torch.utils.data.sampler.Sampler): + def __init__(self, sampler): + super(ContinuousSampler, self).__init__(sampler) + self.base_sampler = sampler + + def __iter__(self): + while True: + for batch in self.base_sampler: + yield batch diff --git a/libs/utils.py b/libs/utils.py index 63ea116..90ae48f 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -6,10 +6,15 @@ from dataclasses import dataclass from typing import Iterable, Callable import torch +from torch import distributed as dist from torch.backends import cudnn -from torch.utils.data import Dataset, DataLoader, RandomSampler +from torch.distributed import rpc as rpc +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import BatchSampler from torch.utils.tensorboard import SummaryWriter +from libs.datautils import ContinuousSampler from libs.logging import CSV_EPOCH_LOGGER, CSV_BATCH_LOGGER, BaseBatchLogRecord, BaseEpochLogRecord, Loggers, \ init_csv_logger, csv_logger, tensorboard_logger from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR @@ -61,7 +66,7 @@ class Trainer(ABC): self, seed: int, checkpoint_dir: str, - device: torch.device, + device: int, inf_mode: bool, num_iters: int, config: BaseConfig, @@ -75,7 +80,7 @@ class Trainer(ABC): ) models = self._init_models(config.dataset_config.dataset) - models = {n: m.to(device) for n, m in models} + models = dict(self._models_to_devices(models, device)) optims = dict(self._configure_optimizers(models.items(), config.optim_config)) last_metrics = self._auto_load_checkpoint( checkpoint_dir, inf_mode, **(models | optims) @@ -155,22 +160,25 @@ class Trainer(ABC): train_set: Dataset, test_set: Dataset, inf_mode: bool, dataloader_config: BaseConfig.DataLoaderConfig ) -> tuple[DataLoader, DataLoader]: + train_sampler = DistributedSampler(train_set, shuffle=True) + test_sampler = DistributedSampler(test_set, shuffle=False) if inf_mode: - inf_sampler = RandomSampler(train_set, - replacement=True, - num_samples=int(1e20)) + inf_train_sampler = ContinuousSampler( + BatchSampler(train_sampler, + dataloader_config.batch_size, + drop_last=True) + ) train_loader = DataLoader(train_set, - sampler=inf_sampler, - batch_size=dataloader_config.batch_size, + batch_sampler=inf_train_sampler, num_workers=dataloader_config.num_workers) else: train_loader = DataLoader(train_set, - shuffle=True, batch_size=dataloader_config.batch_size, + sampler=train_sampler, num_workers=dataloader_config.num_workers) test_loader = DataLoader(test_set, - shuffle=False, batch_size=dataloader_config.batch_size, + sampler=test_sampler, num_workers=dataloader_config.num_workers) return train_loader, test_loader @@ -182,6 +190,18 @@ class Trainer(ABC): yield 'model_name', model @staticmethod + def _models_to_devices( + models: Iterable[tuple[str, torch.nn.Module]], + device: int, + ) -> Iterable[tuple[str, torch.nn.Module]]: + for name, model in models: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = model.to(device) + model = torch.nn.parallel.DistributedDataParallel(model, [device]) + + yield name, model + + @staticmethod @abstractmethod def _configure_optimizers( models: Iterable[tuple[str, torch.nn.Module]], @@ -267,9 +287,36 @@ class Trainer(ABC): torch.save(checkpoint, checkpoint_name) @abstractmethod - def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): + def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int): pass @abstractmethod - def eval(self, loss_fn: Callable, device: torch.device): + def eval(self, loss_fn: Callable, device: int): pass + + +def elastic_launch(func): + dist.init_process_group(backend='nccl') + local_rank = int(os.environ['LOCAL_RANK']) + local_world_size = int(os.environ['LOCAL_WORLD_SIZE']) + ngpu_per_proc = torch.cuda.device_count() // local_world_size + + assert ngpu_per_proc == 1 + + global_rank = dist.get_rank() + global_world_size = dist.get_world_size() + rpc.init_rpc( + f"worker{global_rank}", + rank=global_rank, + world_size=global_world_size, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + device_maps={f"worker{callee}": { + caller: callee for caller in range(global_world_size) + } for callee in range(global_world_size)} + ) + ) + + func(local_rank, global_rank) + + rpc.shutdown() + dist.destroy_process_group() diff --git a/simclr/evaluate.py b/simclr/evaluate.py index 5c41b84..23bd299 100644 --- a/simclr/evaluate.py +++ b/simclr/evaluate.py @@ -6,6 +6,7 @@ from typing import Iterable, Callable import sys import torch +import torch.distributed as dist import yaml from torch.utils.data import Dataset from torchvision.datasets import CIFAR10, CIFAR100, ImageNet @@ -16,7 +17,7 @@ sys.path.insert(0, path) from libs.optimizers import LARS from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord -from libs.utils import BaseConfig +from libs.utils import BaseConfig, elastic_launch from simclr.main import SimCLRTrainer, SimCLRConfig from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50 @@ -190,7 +191,7 @@ class SimCLREvalTrainer(SimCLRTrainer): if k in self.models['backbone'].state_dict()} self.models['backbone'].load_state_dict(backbone_state_dict) - def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): + def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int): backbone, classifier = self.models.values() optim_b, optim_c = self.optims.values() sched_b, sched_c = self.scheds.values() @@ -218,24 +219,33 @@ class SimCLREvalTrainer(SimCLRTrainer): if self.finetune: optim_b.step() optim_c.step() - self.log(logger, self.BatchLogRecord( - batch, num_batches, global_batch, iter_, num_iters, - optim_c.param_groups[0]['lr'], train_loss.item() - )) + + if logger is not None: + self.log(logger, self.BatchLogRecord( + batch, num_batches, global_batch, iter_, num_iters, + optim_c.param_groups[0]['lr'], train_loss.item() + )) + dist.barrier() + if (iter_ + 1) % (num_iters // 10) == 0: - metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) - eval_loss = metrics[0].item() - eval_accuracy = metrics[1].item() - epoch_log = self.EpochLogRecord(iter_, num_iters, - eval_loss, eval_accuracy) - self.log(logger, epoch_log) - self.save_checkpoint(epoch_log) - if sched_b is not None and self.finetune: - sched_b.step() - if sched_c is not None: - sched_c.step() - - def eval(self, loss_fn: Callable, device: torch.device): + # TODO Gather results from other workers + metrics = torch.Tensor(list(self.eval(loss_fn, device))) + if logger is not None: + metrics_mean = metrics.mean(0) + eval_loss = metrics_mean[0].item() + eval_accuracy = metrics_mean[1].item() + epoch_log = self.EpochLogRecord(iter_, num_iters, + eval_loss, eval_accuracy) + self.log(logger, epoch_log) + self.save_checkpoint(epoch_log) + dist.barrier() + + if sched_b is not None and self.finetune: + sched_b.step() + if sched_c is not None: + sched_c.step() + + def eval(self, loss_fn: Callable, device: int): backbone, classifier = self.models.values() backbone.eval() classifier.eval() @@ -250,14 +260,13 @@ class SimCLREvalTrainer(SimCLRTrainer): yield loss.item(), accuracy.item() -if __name__ == '__main__': +def main(local_rank, global_rank): args = parse_args_and_config() config = SimCLREvalConfig.from_args(args) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') trainer = SimCLREvalTrainer( seed=args.seed, checkpoint_dir=args.checkpoint_dir, - device=device, + device=local_rank, inf_mode=False, num_iters=args.num_iters, config=config, @@ -267,5 +276,11 @@ if __name__ == '__main__': finetune=args.finetune, ) - loggers = trainer.init_logger(args.log_dir) - trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device) + loggers = None + if global_rank == 0: + loggers = trainer.init_logger(args.log_dir) + trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, local_rank) + + +if __name__ == '__main__': + elastic_launch(main) diff --git a/simclr/main.py b/simclr/main.py index 69e2ab2..3456d41 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -1,11 +1,12 @@ import argparse import os -import sys from dataclasses import dataclass from pathlib import Path from typing import Iterable, Callable +import sys import torch +import torch.distributed as dist import yaml from torch.utils.data import Dataset from torchvision.datasets import CIFAR10, CIFAR100, ImageNet @@ -17,8 +18,7 @@ sys.path.insert(0, path) from libs.criteria import InfoNCELoss from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTransform from libs.optimizers import LARS -from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR -from libs.utils import Trainer, BaseConfig +from libs.utils import Trainer, BaseConfig, elastic_launch from libs.logging import BaseBatchLogRecord, Loggers from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50 @@ -254,11 +254,12 @@ class SimCLRTrainer(Trainer): self.optims = {n: LARS(o) if config.optim_config.optim == 'lars' else o for n, o in self.optims.items()} - def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): + def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int): model = self.models['model'] optim = self.optims['model_optim'] sched = self.scheds['model_optim_sched'] train_loader = iter(self.train_loader) + model.train() for iter_ in range(self.restore_iter, num_iters): input_, _ = next(train_loader) @@ -268,28 +269,37 @@ class SimCLRTrainer(Trainer): train_loss, train_accuracy = loss_fn(output) train_loss.backward() optim.step() - self.log(logger, self.BatchLogRecord( - iter_, num_iters, iter_, iter_, num_iters, - optim.param_groups[0]['lr'], - train_loss.item(), train_accuracy.item(), - eval_loss=None, eval_accuracy=None, - )) - if (iter_ + 1) % (num_iters // 100) == 0: - metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) - eval_loss = metrics[0].item() - eval_accuracy = metrics[1].item() - eval_log = self.BatchLogRecord( + + if logger is not None: + self.log(logger, self.BatchLogRecord( iter_, num_iters, iter_, iter_, num_iters, - lr=None, train_loss=None, train_accuracy=None, - eval_loss=eval_loss, eval_accuracy=eval_accuracy, - ) - self.log(logger, eval_log) - self.save_checkpoint(eval_log) + optim.param_groups[0]['lr'], + train_loss.item(), train_accuracy.item(), + eval_loss=None, eval_accuracy=None, + )) + dist.barrier() + + if (iter_ + 1) % (num_iters // 100) == 0: + # TODO Gather results from other workers + metrics = torch.Tensor(list(self.eval(loss_fn, device))) + if logger is not None: + metrics_mean = metrics.mean(0) + eval_loss = metrics_mean[0].item() + eval_accuracy = metrics_mean[1].item() + eval_log = self.BatchLogRecord( + iter_, num_iters, iter_, iter_, num_iters, + lr=None, train_loss=None, train_accuracy=None, + eval_loss=eval_loss, eval_accuracy=eval_accuracy, + ) + self.log(logger, eval_log) + self.save_checkpoint(eval_log) model.train() + dist.barrier() + if sched is not None: sched.step() - def eval(self, loss_fn: Callable, device: torch.device): + def eval(self, loss_fn: Callable, device: int): model = self.models['model'] model.eval() with torch.no_grad(): @@ -300,14 +310,13 @@ class SimCLRTrainer(Trainer): yield loss.item(), accuracy.item() -if __name__ == '__main__': +def main(local_rank, global_rank): args = parse_args_and_config() config = SimCLRConfig.from_args(args) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') trainer = SimCLRTrainer( seed=args.seed, checkpoint_dir=args.checkpoint_dir, - device=device, + device=local_rank, inf_mode=True, num_iters=args.num_iters, config=config, @@ -315,5 +324,11 @@ if __name__ == '__main__': out_dim=args.out_dim, ) - loggers = trainer.init_logger(args.log_dir) - trainer.train(args.num_iters, InfoNCELoss(args.temp), loggers, device) + loggers = None + if global_rank == 0: + loggers = trainer.init_logger(args.log_dir) + trainer.train(args.num_iters, InfoNCELoss(args.temp), loggers, local_rank) + + +if __name__ == '__main__': + elastic_launch(main) diff --git a/supervised/baseline.py b/supervised/baseline.py index 4be1c97..9ce3cf0 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -6,6 +6,7 @@ from typing import Iterable, Callable import sys import torch +import torch.distributed as dist import yaml from torch.utils.data import Dataset from torchvision.datasets import CIFAR10, CIFAR100 @@ -15,7 +16,7 @@ path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) sys.path.insert(0, path) from libs.datautils import Clip -from libs.utils import Trainer, BaseConfig +from libs.utils import Trainer, BaseConfig, elastic_launch from libs.logging import BaseBatchLogRecord, BaseEpochLogRecord, Loggers from models import CIFARResNet50, CIFARViTTiny @@ -211,7 +212,7 @@ class SupBaselineTrainer(Trainer): yield f"{model_name}_optim", optimizer - def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): + def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int): model = self.models['model'] optim = self.optims['model_optim'] sched = self.scheds['model_optim_sched'] @@ -227,23 +228,32 @@ class SupBaselineTrainer(Trainer): train_loss = loss_fn(output, targets) train_loss.backward() optim.step() - self.log(logger, self.BatchLogRecord( - batch, num_batches, global_batch, iter_, num_iters, - optim.param_groups[0]['lr'], train_loss.item() - )) - metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) - eval_loss = metrics[0].item() - eval_accuracy = metrics[1].item() - epoch_log = self.EpochLogRecord(iter_, num_iters, - eval_loss, eval_accuracy) - self.log(logger, epoch_log) - self.save_checkpoint(epoch_log) + + if logger is not None: + self.log(logger, self.BatchLogRecord( + batch, num_batches, global_batch, iter_, num_iters, + optim.param_groups[0]['lr'], train_loss.item() + )) + dist.barrier() + + # TODO Gather results from other workers + metrics = torch.Tensor(list(self.eval(loss_fn, device))) + if logger is not None: + metrics_mean = metrics.mean(0) + eval_loss = metrics_mean[0].item() + eval_accuracy = metrics_mean[1].item() + epoch_log = self.EpochLogRecord(iter_, num_iters, + eval_loss, eval_accuracy) + self.log(logger, epoch_log) + self.save_checkpoint(epoch_log) + dist.barrier() + # Step after save checkpoint, otherwise the schedular will # one iter ahead after restore if sched is not None: sched.step() - def eval(self, loss_fn: Callable, device: torch.device): + def eval(self, loss_fn: Callable, device: int): model = self.models['model'] model.eval() with torch.no_grad(): @@ -256,20 +266,25 @@ class SupBaselineTrainer(Trainer): yield loss.item(), accuracy.item() -if __name__ == '__main__': +def main(local_rank, global_rank): args = parse_args_and_config() config = SupBaselineConfig.from_args(args) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') trainer = SupBaselineTrainer( seed=args.seed, checkpoint_dir=args.checkpoint_dir, - device=device, + device=local_rank, inf_mode=False, num_iters=args.num_iters, config=config, backbone=args.backbone, ) - loggers = trainer.init_logger(args.log_dir) + loggers = None + if global_rank == 0: + loggers = trainer.init_logger(args.log_dir) loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smooth) - trainer.train(args.num_iters, loss_fn, loggers, device) + trainer.train(args.num_iters, loss_fn, loggers, local_rank) + + +if __name__ == '__main__': + elastic_launch(main) |