diff options
Diffstat (limited to 'simclr/main.py')
-rw-r--r-- | simclr/main.py | 67 |
1 files changed, 41 insertions, 26 deletions
diff --git a/simclr/main.py b/simclr/main.py index 69e2ab2..3456d41 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -1,11 +1,12 @@ import argparse import os -import sys from dataclasses import dataclass from pathlib import Path from typing import Iterable, Callable +import sys import torch +import torch.distributed as dist import yaml from torch.utils.data import Dataset from torchvision.datasets import CIFAR10, CIFAR100, ImageNet @@ -17,8 +18,7 @@ sys.path.insert(0, path) from libs.criteria import InfoNCELoss from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTransform from libs.optimizers import LARS -from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR -from libs.utils import Trainer, BaseConfig +from libs.utils import Trainer, BaseConfig, elastic_launch from libs.logging import BaseBatchLogRecord, Loggers from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50 @@ -254,11 +254,12 @@ class SimCLRTrainer(Trainer): self.optims = {n: LARS(o) if config.optim_config.optim == 'lars' else o for n, o in self.optims.items()} - def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): + def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int): model = self.models['model'] optim = self.optims['model_optim'] sched = self.scheds['model_optim_sched'] train_loader = iter(self.train_loader) + model.train() for iter_ in range(self.restore_iter, num_iters): input_, _ = next(train_loader) @@ -268,28 +269,37 @@ class SimCLRTrainer(Trainer): train_loss, train_accuracy = loss_fn(output) train_loss.backward() optim.step() - self.log(logger, self.BatchLogRecord( - iter_, num_iters, iter_, iter_, num_iters, - optim.param_groups[0]['lr'], - train_loss.item(), train_accuracy.item(), - eval_loss=None, eval_accuracy=None, - )) - if (iter_ + 1) % (num_iters // 100) == 0: - metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) - eval_loss = metrics[0].item() - eval_accuracy = metrics[1].item() - eval_log = self.BatchLogRecord( + + if logger is not None: + self.log(logger, self.BatchLogRecord( iter_, num_iters, iter_, iter_, num_iters, - lr=None, train_loss=None, train_accuracy=None, - eval_loss=eval_loss, eval_accuracy=eval_accuracy, - ) - self.log(logger, eval_log) - self.save_checkpoint(eval_log) + optim.param_groups[0]['lr'], + train_loss.item(), train_accuracy.item(), + eval_loss=None, eval_accuracy=None, + )) + dist.barrier() + + if (iter_ + 1) % (num_iters // 100) == 0: + # TODO Gather results from other workers + metrics = torch.Tensor(list(self.eval(loss_fn, device))) + if logger is not None: + metrics_mean = metrics.mean(0) + eval_loss = metrics_mean[0].item() + eval_accuracy = metrics_mean[1].item() + eval_log = self.BatchLogRecord( + iter_, num_iters, iter_, iter_, num_iters, + lr=None, train_loss=None, train_accuracy=None, + eval_loss=eval_loss, eval_accuracy=eval_accuracy, + ) + self.log(logger, eval_log) + self.save_checkpoint(eval_log) model.train() + dist.barrier() + if sched is not None: sched.step() - def eval(self, loss_fn: Callable, device: torch.device): + def eval(self, loss_fn: Callable, device: int): model = self.models['model'] model.eval() with torch.no_grad(): @@ -300,14 +310,13 @@ class SimCLRTrainer(Trainer): yield loss.item(), accuracy.item() -if __name__ == '__main__': +def main(local_rank, global_rank): args = parse_args_and_config() config = SimCLRConfig.from_args(args) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') trainer = SimCLRTrainer( seed=args.seed, checkpoint_dir=args.checkpoint_dir, - device=device, + device=local_rank, inf_mode=True, num_iters=args.num_iters, config=config, @@ -315,5 +324,11 @@ if __name__ == '__main__': out_dim=args.out_dim, ) - loggers = trainer.init_logger(args.log_dir) - trainer.train(args.num_iters, InfoNCELoss(args.temp), loggers, device) + loggers = None + if global_rank == 0: + loggers = trainer.init_logger(args.log_dir) + trainer.train(args.num_iters, InfoNCELoss(args.temp), loggers, local_rank) + + +if __name__ == '__main__': + elastic_launch(main) |