aboutsummaryrefslogtreecommitdiff
path: root/simclr/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'simclr/main.py')
-rw-r--r--simclr/main.py67
1 files changed, 41 insertions, 26 deletions
diff --git a/simclr/main.py b/simclr/main.py
index 69e2ab2..3456d41 100644
--- a/simclr/main.py
+++ b/simclr/main.py
@@ -1,11 +1,12 @@
import argparse
import os
-import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Callable
+import sys
import torch
+import torch.distributed as dist
import yaml
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10, CIFAR100, ImageNet
@@ -17,8 +18,7 @@ sys.path.insert(0, path)
from libs.criteria import InfoNCELoss
from libs.datautils import color_distortion, Clip, RandomGaussianBlur, TwinTransform
from libs.optimizers import LARS
-from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR
-from libs.utils import Trainer, BaseConfig
+from libs.utils import Trainer, BaseConfig, elastic_launch
from libs.logging import BaseBatchLogRecord, Loggers
from simclr.models import CIFARSimCLRResNet50, ImageNetSimCLRResNet50
@@ -254,11 +254,12 @@ class SimCLRTrainer(Trainer):
self.optims = {n: LARS(o) if config.optim_config.optim == 'lars' else o
for n, o in self.optims.items()}
- def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device):
+ def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int):
model = self.models['model']
optim = self.optims['model_optim']
sched = self.scheds['model_optim_sched']
train_loader = iter(self.train_loader)
+
model.train()
for iter_ in range(self.restore_iter, num_iters):
input_, _ = next(train_loader)
@@ -268,28 +269,37 @@ class SimCLRTrainer(Trainer):
train_loss, train_accuracy = loss_fn(output)
train_loss.backward()
optim.step()
- self.log(logger, self.BatchLogRecord(
- iter_, num_iters, iter_, iter_, num_iters,
- optim.param_groups[0]['lr'],
- train_loss.item(), train_accuracy.item(),
- eval_loss=None, eval_accuracy=None,
- ))
- if (iter_ + 1) % (num_iters // 100) == 0:
- metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0)
- eval_loss = metrics[0].item()
- eval_accuracy = metrics[1].item()
- eval_log = self.BatchLogRecord(
+
+ if logger is not None:
+ self.log(logger, self.BatchLogRecord(
iter_, num_iters, iter_, iter_, num_iters,
- lr=None, train_loss=None, train_accuracy=None,
- eval_loss=eval_loss, eval_accuracy=eval_accuracy,
- )
- self.log(logger, eval_log)
- self.save_checkpoint(eval_log)
+ optim.param_groups[0]['lr'],
+ train_loss.item(), train_accuracy.item(),
+ eval_loss=None, eval_accuracy=None,
+ ))
+ dist.barrier()
+
+ if (iter_ + 1) % (num_iters // 100) == 0:
+ # TODO Gather results from other workers
+ metrics = torch.Tensor(list(self.eval(loss_fn, device)))
+ if logger is not None:
+ metrics_mean = metrics.mean(0)
+ eval_loss = metrics_mean[0].item()
+ eval_accuracy = metrics_mean[1].item()
+ eval_log = self.BatchLogRecord(
+ iter_, num_iters, iter_, iter_, num_iters,
+ lr=None, train_loss=None, train_accuracy=None,
+ eval_loss=eval_loss, eval_accuracy=eval_accuracy,
+ )
+ self.log(logger, eval_log)
+ self.save_checkpoint(eval_log)
model.train()
+ dist.barrier()
+
if sched is not None:
sched.step()
- def eval(self, loss_fn: Callable, device: torch.device):
+ def eval(self, loss_fn: Callable, device: int):
model = self.models['model']
model.eval()
with torch.no_grad():
@@ -300,14 +310,13 @@ class SimCLRTrainer(Trainer):
yield loss.item(), accuracy.item()
-if __name__ == '__main__':
+def main(local_rank, global_rank):
args = parse_args_and_config()
config = SimCLRConfig.from_args(args)
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = SimCLRTrainer(
seed=args.seed,
checkpoint_dir=args.checkpoint_dir,
- device=device,
+ device=local_rank,
inf_mode=True,
num_iters=args.num_iters,
config=config,
@@ -315,5 +324,11 @@ if __name__ == '__main__':
out_dim=args.out_dim,
)
- loggers = trainer.init_logger(args.log_dir)
- trainer.train(args.num_iters, InfoNCELoss(args.temp), loggers, device)
+ loggers = None
+ if global_rank == 0:
+ loggers = trainer.init_logger(args.log_dir)
+ trainer.train(args.num_iters, InfoNCELoss(args.temp), loggers, local_rank)
+
+
+if __name__ == '__main__':
+ elastic_launch(main)