diff options
Diffstat (limited to 'libs/utils.py')
-rw-r--r-- | libs/utils.py | 71 |
1 files changed, 59 insertions, 12 deletions
diff --git a/libs/utils.py b/libs/utils.py index 63ea116..90ae48f 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -6,10 +6,15 @@ from dataclasses import dataclass from typing import Iterable, Callable import torch +from torch import distributed as dist from torch.backends import cudnn -from torch.utils.data import Dataset, DataLoader, RandomSampler +from torch.distributed import rpc as rpc +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import BatchSampler from torch.utils.tensorboard import SummaryWriter +from libs.datautils import ContinuousSampler from libs.logging import CSV_EPOCH_LOGGER, CSV_BATCH_LOGGER, BaseBatchLogRecord, BaseEpochLogRecord, Loggers, \ init_csv_logger, csv_logger, tensorboard_logger from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR @@ -61,7 +66,7 @@ class Trainer(ABC): self, seed: int, checkpoint_dir: str, - device: torch.device, + device: int, inf_mode: bool, num_iters: int, config: BaseConfig, @@ -75,7 +80,7 @@ class Trainer(ABC): ) models = self._init_models(config.dataset_config.dataset) - models = {n: m.to(device) for n, m in models} + models = dict(self._models_to_devices(models, device)) optims = dict(self._configure_optimizers(models.items(), config.optim_config)) last_metrics = self._auto_load_checkpoint( checkpoint_dir, inf_mode, **(models | optims) @@ -155,22 +160,25 @@ class Trainer(ABC): train_set: Dataset, test_set: Dataset, inf_mode: bool, dataloader_config: BaseConfig.DataLoaderConfig ) -> tuple[DataLoader, DataLoader]: + train_sampler = DistributedSampler(train_set, shuffle=True) + test_sampler = DistributedSampler(test_set, shuffle=False) if inf_mode: - inf_sampler = RandomSampler(train_set, - replacement=True, - num_samples=int(1e20)) + inf_train_sampler = ContinuousSampler( + BatchSampler(train_sampler, + dataloader_config.batch_size, + drop_last=True) + ) train_loader = DataLoader(train_set, - sampler=inf_sampler, - batch_size=dataloader_config.batch_size, + batch_sampler=inf_train_sampler, num_workers=dataloader_config.num_workers) else: train_loader = DataLoader(train_set, - shuffle=True, batch_size=dataloader_config.batch_size, + sampler=train_sampler, num_workers=dataloader_config.num_workers) test_loader = DataLoader(test_set, - shuffle=False, batch_size=dataloader_config.batch_size, + sampler=test_sampler, num_workers=dataloader_config.num_workers) return train_loader, test_loader @@ -182,6 +190,18 @@ class Trainer(ABC): yield 'model_name', model @staticmethod + def _models_to_devices( + models: Iterable[tuple[str, torch.nn.Module]], + device: int, + ) -> Iterable[tuple[str, torch.nn.Module]]: + for name, model in models: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = model.to(device) + model = torch.nn.parallel.DistributedDataParallel(model, [device]) + + yield name, model + + @staticmethod @abstractmethod def _configure_optimizers( models: Iterable[tuple[str, torch.nn.Module]], @@ -267,9 +287,36 @@ class Trainer(ABC): torch.save(checkpoint, checkpoint_name) @abstractmethod - def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): + def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: int): pass @abstractmethod - def eval(self, loss_fn: Callable, device: torch.device): + def eval(self, loss_fn: Callable, device: int): pass + + +def elastic_launch(func): + dist.init_process_group(backend='nccl') + local_rank = int(os.environ['LOCAL_RANK']) + local_world_size = int(os.environ['LOCAL_WORLD_SIZE']) + ngpu_per_proc = torch.cuda.device_count() // local_world_size + + assert ngpu_per_proc == 1 + + global_rank = dist.get_rank() + global_world_size = dist.get_world_size() + rpc.init_rpc( + f"worker{global_rank}", + rank=global_rank, + world_size=global_world_size, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + device_maps={f"worker{callee}": { + caller: callee for caller in range(global_world_size) + } for callee in range(global_world_size)} + ) + ) + + func(local_rank, global_rank) + + rpc.shutdown() + dist.destroy_process_group() |