diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-07-13 23:29:39 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-07-13 23:29:39 +0800 |
commit | 819e69812f1cde332143ee968564c7d83abffe78 (patch) | |
tree | 73661f7c4e81d4b56e8b79876c0909b021361b83 | |
parent | f48a03fe9e76efe1c79ced0e45a986862682d036 (diff) |
Implement abstract `Trainer` with static config checking
-rw-r--r-- | libs/utils.py | 284 |
1 files changed, 227 insertions, 57 deletions
diff --git a/libs/utils.py b/libs/utils.py index 3df1574..caacb38 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -1,71 +1,241 @@ -import logging import os +import random +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Iterable, Callable -EPOCH_LOGGER = 'epoch_logger' -BATCH_LOGGER = 'batch_logger' +import torch +from torch.backends import cudnn +from torch.utils.data import Dataset, DataLoader, RandomSampler +from torch.utils.tensorboard import SummaryWriter +from libs.logging import CSV_EPOCH_LOGGER, CSV_BATCH_LOGGER, BaseBatchLogRecord, BaseEpochLogRecord, Loggers, \ + init_csv_logger, csv_logger, tensorboard_logger -class FileHandlerWithHeader(logging.FileHandler): - def __init__(self, filename, header, mode='a', - encoding=None, delay=False, errors=None): - self.header = header - self.file_pre_exists = os.path.exists(filename) +@dataclass +class BaseConfig: + @dataclass + class DatasetConfig: + dataset: str - super(FileHandlerWithHeader, self).__init__( - filename, mode, encoding, delay, errors + @dataclass + class DataLoaderConfig: + batch_size: int + num_worker: int + + @dataclass + class OptimConfig: + optim: str + lr: float + + @dataclass + class SchedConfig: + sched: None + + dataset_config: DatasetConfig + dataloader_config: DataLoaderConfig + optim_config: OptimConfig + sched_config: SchedConfig + + +class Trainer(ABC): + def __init__( + self, + seed: int, + checkpoint_dir: str, + device: torch.device, + inf_mode: bool, + num_iters: int, + config: BaseConfig, + ): + self._args = locals() + self._set_seed(seed) + + train_set, test_set = self._prepare_dataset(config.dataset_config) + train_loader, test_loader = self._create_dataloader( + train_set, test_set, inf_mode, config.dataloader_config + ) + + models = self._init_models(config.dataset_config.dataset) + models = {n: m.to(device) for n, m in models} + optims = dict(self._configure_optimizers(models.items(), config.optim_config)) + last_metrics = self._auto_load_checkpoint( + checkpoint_dir, inf_mode, **(models | optims) ) - if not delay and self.stream is not None and not self.file_pre_exists: - self.stream.write(f'{header}\n') - - def emit(self, record): - if self.stream is None: - self.stream = self._open() - if not self.file_pre_exists: - self.stream.write(f'{self.header}\n') - - logging.FileHandler.emit(self, record) - - -def setup_logging(name="log", - filename=None, - stream_log_level="INFO", - file_log_level="INFO"): - logger = logging.getLogger(name) - logger.setLevel("INFO") - formatter = logging.Formatter( - '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S' - ) - stream_handler = logging.StreamHandler() - stream_handler.setLevel(getattr(logging, stream_log_level)) - stream_handler.setFormatter(formatter) - logger.addHandler(stream_handler) - if filename is not None: - header = 'time,logger,' - if name == BATCH_LOGGER: - header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr' - elif name == EPOCH_LOGGER: - header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy' + + if last_metrics is None: + last_iter = -1 + restore_iter = 0 + elif isinstance(last_metrics, BaseEpochLogRecord): + last_iter = last_metrics.epoch * len(train_loader) - 1 + restore_iter = last_metrics.epoch + elif isinstance(last_metrics, BaseBatchLogRecord): + last_iter = last_metrics.global_batch - 1 + restore_iter = last_metrics.global_batch + else: + raise NotImplementedError(f"Unknown log type: '{type(last_metrics)}'") + if not inf_mode: + num_iters *= len(train_loader) + scheds = dict(self._configure_scheduler( + optims.items(), last_iter, num_iters, config.sched_config, + )) + + self._custom_init_fn(config) + + self.restore_iter = restore_iter + self.train_loader = train_loader + self.test_loader = test_loader + self.models = models + self.optims = optims + self.scheds = scheds + self._inf_mode = inf_mode + self._checkpoint_dir = checkpoint_dir + + @dataclass + class BatchLogRecord(BaseBatchLogRecord): + pass + + @dataclass + class EpochLogRecord(BaseEpochLogRecord): + pass + + @staticmethod + def _set_seed(seed): + if seed in {-1, None, ''}: + cudnn.benchmark = True + else: + random.seed(seed) + torch.manual_seed(seed) + cudnn.deterministic = True + + def init_logger(self, log_dir): + csv_batch_log_name = os.path.join(log_dir, 'batch-log.csv') + csv_batch_logger = init_csv_logger(CSV_BATCH_LOGGER, csv_batch_log_name, + list(self.BatchLogRecord.__dataclass_fields__.keys())) + csv_epoch_logger = None + if not self._inf_mode: + csv_epoch_log_name = os.path.join(log_dir, 'epoch-log.csv') + csv_epoch_logger = init_csv_logger(CSV_EPOCH_LOGGER, csv_epoch_log_name, + list(self.EpochLogRecord.__dataclass_fields__.keys())) + tb_logger = SummaryWriter(os.path.join(log_dir, 'runs')) + + return Loggers(csv_batch_logger, csv_epoch_logger, tb_logger) + + def dump_args(self, exclude=frozenset()) -> dict: + return {k: v for k, v in self._args.items() if k not in {'self'} | exclude} + + @staticmethod + @abstractmethod + def _prepare_dataset(dataset_config: BaseConfig.DatasetConfig) -> tuple[Dataset, Dataset]: + train_set = Dataset() + test_set = Dataset() + return train_set, test_set + + @staticmethod + def _create_dataloader( + train_set: Dataset, test_set: Dataset, + inf_mode: bool, dataloader_config: BaseConfig.DataLoaderConfig + ) -> tuple[DataLoader, DataLoader]: + if inf_mode: + inf_sampler = RandomSampler(train_set, + replacement=True, + num_samples=int(1e20)) + train_loader = DataLoader(train_set, + sampler=inf_sampler, + batch_size=dataloader_config.batch_size, + num_workers=dataloader_config.num_worker) else: - raise NotImplementedError(f"Logger '{name}' is not implemented.") + train_loader = DataLoader(train_set, + shuffle=True, + batch_size=dataloader_config.batch_size, + num_workers=dataloader_config.num_worker) + test_loader = DataLoader(test_set, + shuffle=False, + batch_size=dataloader_config.batch_size, + num_workers=dataloader_config.num_worker) + + return train_loader, test_loader + + @staticmethod + @abstractmethod + def _init_models(dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: + model = torch.nn.Module() + yield 'model_name', model + + @staticmethod + @abstractmethod + def _configure_optimizers( + models: Iterable[tuple[str, torch.nn.Module]], + optim_config: BaseConfig.OptimConfig + ) -> Iterable[tuple[str, torch.optim.Optimizer]]: + for model_name, model in models: + optim = torch.optim.Optimizer([model.state_dict()], {}) + yield f"{model_name}_optim", optim + + def _auto_load_checkpoint( + self, + checkpoint_dir: str, + inf_mode: bool, + **modules + ) -> None | BaseEpochLogRecord | BaseEpochLogRecord: + if not os.path.exists(checkpoint_dir): + return None + checkpoint_files = os.listdir(checkpoint_dir) + if not checkpoint_files: + return None + iter2checkpoint = {int(os.path.splitext(checkpoint_file)[0]): checkpoint_file + for checkpoint_file in checkpoint_files} + restore_iter = max(iter2checkpoint.keys()) + latest_checkpoint = iter2checkpoint[restore_iter] + checkpoint = torch.load(os.path.join(checkpoint_dir, latest_checkpoint)) + for module_name in modules.keys(): + module_state_dict = checkpoint[f"{module_name}_state_dict"] + modules[module_name].load_state_dict(module_state_dict) + + last_metrics = {k: v for k, v in checkpoint.items() + if not k.endswith('state_dict')} + if inf_mode: + last_metrics = self.BatchLogRecord(**last_metrics) + else: + last_metrics = self.EpochLogRecord(**last_metrics) + + return last_metrics - os.makedirs(os.path.dirname(filename), exist_ok=True) - file_handler = FileHandlerWithHeader(filename, header) - file_handler.setLevel(getattr(logging, file_log_level)) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - return logger + @staticmethod + @abstractmethod + def _configure_scheduler( + optims: Iterable[tuple[str, torch.optim.Optimizer]], + last_iter: int, num_iters: int, sched_config: BaseConfig.SchedConfig, + ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler] + | tuple[str, None]]: + for optim_name, optim in optims: + sched = torch.optim.lr_scheduler._LRScheduler(optim, -1) + yield f"{optim_name}_sched", sched + def _custom_init_fn(self, config: BaseConfig): + pass -def training_log(name): - def log_this(function): - logger = logging.getLogger(name) + @staticmethod + @csv_logger + @tensorboard_logger + def log(loggers: Loggers, metrics: BaseBatchLogRecord | BaseEpochLogRecord): + return loggers, metrics - def wrapper(*args, **kwargs): - output = function(*args, **kwargs) - logger.info(','.join(map(str, output.values()))) - return output + def save_checkpoint(self, metrics: BaseEpochLogRecord | BaseBatchLogRecord): + os.makedirs(self._checkpoint_dir, exist_ok=True) + checkpoint_name = os.path.join(self._checkpoint_dir, f"{metrics.epoch:06d}.pt") + models_state_dict = {f"{model_name}_state_dict": model.state_dict() + for model_name, model in self.models.items()} + optims_state_dict = {f"{optim_name}_state_dict": optim.state_dict() + for optim_name, optim in self.optims.items()} + checkpoint = metrics.__dict__ | models_state_dict | optims_state_dict + torch.save(checkpoint, checkpoint_name) - return wrapper + @abstractmethod + def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): + pass - return log_this + @abstractmethod + def eval(self, loss_fn: Callable, device: torch.device): + pass |