diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-13 23:38:43 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-13 23:38:43 +0800 |
commit | 957a2a46e7725184776c3c72860e8215164cc4ef (patch) | |
tree | 43e098595db4ee332bca5f6caecfbd02369debbe /simclr/evaluate.py | |
parent | 1b8f01ce9706905c36c6f11ed9deac8548ad7341 (diff) |
Implement distributed data parallel via torch elastic launcher
Diffstat (limited to 'simclr/evaluate.py')
-rw-r--r-- | simclr/evaluate.py | 63 |
1 files changed, 39 insertions, 24 deletions
diff --git a/simclr/evaluate.py b/simclr/evaluate.py index 5c41b84..23bd299 100644 --- a/simclr/evaluate.py +++ b/simclr/evaluate.py @@ -6,6 +6,7 @@ from typing import Iterable, Callable import sys import torch +import torch.distributed as dist import yaml from torch.utils.data import Dataset from torchvision.datasets import CIFAR10, CIFAR100, ImageNet @@ -16,7 +17,7 @@ sys.path.insert(0, path) from libs.optimizers import LARS from libs.logging import Loggers, BaseBatchLogRecord, BaseEpochLogRecord -from libs.utils import BaseConfig +from libs.utils import BaseConfig, elastic_launch from simclr.main import SimCLRTrainer, SimCLRConfig from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50 @@ -190,7 +191,7 @@ class SimCLREvalTrainer(SimCLRTrainer): if k in self.models['backbone'].state_dict()} self.models['backbone'].load_state_dict(backbone_state_dict) - def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): + def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int): backbone, classifier = self.models.values() optim_b, optim_c = self.optims.values() sched_b, sched_c = self.scheds.values() @@ -218,24 +219,33 @@ class SimCLREvalTrainer(SimCLRTrainer): if self.finetune: optim_b.step() optim_c.step() - self.log(logger, self.BatchLogRecord( - batch, num_batches, global_batch, iter_, num_iters, - optim_c.param_groups[0]['lr'], train_loss.item() - )) + + if logger is not None: + self.log(logger, self.BatchLogRecord( + batch, num_batches, global_batch, iter_, num_iters, + optim_c.param_groups[0]['lr'], train_loss.item() + )) + dist.barrier() + if (iter_ + 1) % (num_iters // 10) == 0: - metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) - eval_loss = metrics[0].item() - eval_accuracy = metrics[1].item() - epoch_log = self.EpochLogRecord(iter_, num_iters, - eval_loss, eval_accuracy) - self.log(logger, epoch_log) - self.save_checkpoint(epoch_log) - if sched_b is not None and self.finetune: - sched_b.step() - if sched_c is not None: - sched_c.step() - - def eval(self, loss_fn: Callable, device: torch.device): + # TODO Gather results from other workers + metrics = torch.Tensor(list(self.eval(loss_fn, device))) + if logger is not None: + metrics_mean = metrics.mean(0) + eval_loss = metrics_mean[0].item() + eval_accuracy = metrics_mean[1].item() + epoch_log = self.EpochLogRecord(iter_, num_iters, + eval_loss, eval_accuracy) + self.log(logger, epoch_log) + self.save_checkpoint(epoch_log) + dist.barrier() + + if sched_b is not None and self.finetune: + sched_b.step() + if sched_c is not None: + sched_c.step() + + def eval(self, loss_fn: Callable, device: int): backbone, classifier = self.models.values() backbone.eval() classifier.eval() @@ -250,14 +260,13 @@ class SimCLREvalTrainer(SimCLRTrainer): yield loss.item(), accuracy.item() -if __name__ == '__main__': +def main(local_rank, global_rank): args = parse_args_and_config() config = SimCLREvalConfig.from_args(args) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') trainer = SimCLREvalTrainer( seed=args.seed, checkpoint_dir=args.checkpoint_dir, - device=device, + device=local_rank, inf_mode=False, num_iters=args.num_iters, config=config, @@ -267,5 +276,11 @@ if __name__ == '__main__': finetune=args.finetune, ) - loggers = trainer.init_logger(args.log_dir) - trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, device) + loggers = None + if global_rank == 0: + loggers = trainer.init_logger(args.log_dir) + trainer.train(args.num_iters, torch.nn.CrossEntropyLoss(), loggers, local_rank) + + +if __name__ == '__main__': + elastic_launch(main) |