diff options
Diffstat (limited to 'supervised')
-rw-r--r-- | supervised/baseline.py | 53 |
1 files changed, 34 insertions, 19 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 4be1c97..9ce3cf0 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -6,6 +6,7 @@ from typing import Iterable, Callable import sys import torch +import torch.distributed as dist import yaml from torch.utils.data import Dataset from torchvision.datasets import CIFAR10, CIFAR100 @@ -15,7 +16,7 @@ path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) sys.path.insert(0, path) from libs.datautils import Clip -from libs.utils import Trainer, BaseConfig +from libs.utils import Trainer, BaseConfig, elastic_launch from libs.logging import BaseBatchLogRecord, BaseEpochLogRecord, Loggers from models import CIFARResNet50, CIFARViTTiny @@ -211,7 +212,7 @@ class SupBaselineTrainer(Trainer): yield f"{model_name}_optim", optimizer - def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): + def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int): model = self.models['model'] optim = self.optims['model_optim'] sched = self.scheds['model_optim_sched'] @@ -227,23 +228,32 @@ class SupBaselineTrainer(Trainer): train_loss = loss_fn(output, targets) train_loss.backward() optim.step() - self.log(logger, self.BatchLogRecord( - batch, num_batches, global_batch, iter_, num_iters, - optim.param_groups[0]['lr'], train_loss.item() - )) - metrics = torch.Tensor(list(self.eval(loss_fn, device))).mean(0) - eval_loss = metrics[0].item() - eval_accuracy = metrics[1].item() - epoch_log = self.EpochLogRecord(iter_, num_iters, - eval_loss, eval_accuracy) - self.log(logger, epoch_log) - self.save_checkpoint(epoch_log) + + if logger is not None: + self.log(logger, self.BatchLogRecord( + batch, num_batches, global_batch, iter_, num_iters, + optim.param_groups[0]['lr'], train_loss.item() + )) + dist.barrier() + + # TODO Gather results from other workers + metrics = torch.Tensor(list(self.eval(loss_fn, device))) + if logger is not None: + metrics_mean = metrics.mean(0) + eval_loss = metrics_mean[0].item() + eval_accuracy = metrics_mean[1].item() + epoch_log = self.EpochLogRecord(iter_, num_iters, + eval_loss, eval_accuracy) + self.log(logger, epoch_log) + self.save_checkpoint(epoch_log) + dist.barrier() + # Step after save checkpoint, otherwise the schedular will # one iter ahead after restore if sched is not None: sched.step() - def eval(self, loss_fn: Callable, device: torch.device): + def eval(self, loss_fn: Callable, device: int): model = self.models['model'] model.eval() with torch.no_grad(): @@ -256,20 +266,25 @@ class SupBaselineTrainer(Trainer): yield loss.item(), accuracy.item() -if __name__ == '__main__': +def main(local_rank, global_rank): args = parse_args_and_config() config = SupBaselineConfig.from_args(args) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') trainer = SupBaselineTrainer( seed=args.seed, checkpoint_dir=args.checkpoint_dir, - device=device, + device=local_rank, inf_mode=False, num_iters=args.num_iters, config=config, backbone=args.backbone, ) - loggers = trainer.init_logger(args.log_dir) + loggers = None + if global_rank == 0: + loggers = trainer.init_logger(args.log_dir) loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=args.label_smooth) - trainer.train(args.num_iters, loss_fn, loggers, device) + trainer.train(args.num_iters, loss_fn, loggers, local_rank) + + +if __name__ == '__main__': + elastic_launch(main) |