aboutsummaryrefslogtreecommitdiff
path: root/libs/utils.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-07-13 23:29:39 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-07-13 23:29:39 +0800
commit819e69812f1cde332143ee968564c7d83abffe78 (patch)
tree73661f7c4e81d4b56e8b79876c0909b021361b83 /libs/utils.py
parentf48a03fe9e76efe1c79ced0e45a986862682d036 (diff)
Implement abstract `Trainer` with static config checking
Diffstat (limited to 'libs/utils.py')
-rw-r--r--libs/utils.py284
1 files changed, 227 insertions, 57 deletions
diff --git a/libs/utils.py b/libs/utils.py
index 3df1574..caacb38 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -1,71 +1,241 @@
-import logging
import os
+import random
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Iterable, Callable
-EPOCH_LOGGER = 'epoch_logger'
-BATCH_LOGGER = 'batch_logger'
+import torch
+from torch.backends import cudnn
+from torch.utils.data import Dataset, DataLoader, RandomSampler
+from torch.utils.tensorboard import SummaryWriter
+from libs.logging import CSV_EPOCH_LOGGER, CSV_BATCH_LOGGER, BaseBatchLogRecord, BaseEpochLogRecord, Loggers, \
+ init_csv_logger, csv_logger, tensorboard_logger
-class FileHandlerWithHeader(logging.FileHandler):
- def __init__(self, filename, header, mode='a',
- encoding=None, delay=False, errors=None):
- self.header = header
- self.file_pre_exists = os.path.exists(filename)
+@dataclass
+class BaseConfig:
+ @dataclass
+ class DatasetConfig:
+ dataset: str
- super(FileHandlerWithHeader, self).__init__(
- filename, mode, encoding, delay, errors
+ @dataclass
+ class DataLoaderConfig:
+ batch_size: int
+ num_worker: int
+
+ @dataclass
+ class OptimConfig:
+ optim: str
+ lr: float
+
+ @dataclass
+ class SchedConfig:
+ sched: None
+
+ dataset_config: DatasetConfig
+ dataloader_config: DataLoaderConfig
+ optim_config: OptimConfig
+ sched_config: SchedConfig
+
+
+class Trainer(ABC):
+ def __init__(
+ self,
+ seed: int,
+ checkpoint_dir: str,
+ device: torch.device,
+ inf_mode: bool,
+ num_iters: int,
+ config: BaseConfig,
+ ):
+ self._args = locals()
+ self._set_seed(seed)
+
+ train_set, test_set = self._prepare_dataset(config.dataset_config)
+ train_loader, test_loader = self._create_dataloader(
+ train_set, test_set, inf_mode, config.dataloader_config
+ )
+
+ models = self._init_models(config.dataset_config.dataset)
+ models = {n: m.to(device) for n, m in models}
+ optims = dict(self._configure_optimizers(models.items(), config.optim_config))
+ last_metrics = self._auto_load_checkpoint(
+ checkpoint_dir, inf_mode, **(models | optims)
)
- if not delay and self.stream is not None and not self.file_pre_exists:
- self.stream.write(f'{header}\n')
-
- def emit(self, record):
- if self.stream is None:
- self.stream = self._open()
- if not self.file_pre_exists:
- self.stream.write(f'{self.header}\n')
-
- logging.FileHandler.emit(self, record)
-
-
-def setup_logging(name="log",
- filename=None,
- stream_log_level="INFO",
- file_log_level="INFO"):
- logger = logging.getLogger(name)
- logger.setLevel("INFO")
- formatter = logging.Formatter(
- '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
- )
- stream_handler = logging.StreamHandler()
- stream_handler.setLevel(getattr(logging, stream_log_level))
- stream_handler.setFormatter(formatter)
- logger.addHandler(stream_handler)
- if filename is not None:
- header = 'time,logger,'
- if name == BATCH_LOGGER:
- header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
- elif name == EPOCH_LOGGER:
- header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
+
+ if last_metrics is None:
+ last_iter = -1
+ restore_iter = 0
+ elif isinstance(last_metrics, BaseEpochLogRecord):
+ last_iter = last_metrics.epoch * len(train_loader) - 1
+ restore_iter = last_metrics.epoch
+ elif isinstance(last_metrics, BaseBatchLogRecord):
+ last_iter = last_metrics.global_batch - 1
+ restore_iter = last_metrics.global_batch
+ else:
+ raise NotImplementedError(f"Unknown log type: '{type(last_metrics)}'")
+ if not inf_mode:
+ num_iters *= len(train_loader)
+ scheds = dict(self._configure_scheduler(
+ optims.items(), last_iter, num_iters, config.sched_config,
+ ))
+
+ self._custom_init_fn(config)
+
+ self.restore_iter = restore_iter
+ self.train_loader = train_loader
+ self.test_loader = test_loader
+ self.models = models
+ self.optims = optims
+ self.scheds = scheds
+ self._inf_mode = inf_mode
+ self._checkpoint_dir = checkpoint_dir
+
+ @dataclass
+ class BatchLogRecord(BaseBatchLogRecord):
+ pass
+
+ @dataclass
+ class EpochLogRecord(BaseEpochLogRecord):
+ pass
+
+ @staticmethod
+ def _set_seed(seed):
+ if seed in {-1, None, ''}:
+ cudnn.benchmark = True
+ else:
+ random.seed(seed)
+ torch.manual_seed(seed)
+ cudnn.deterministic = True
+
+ def init_logger(self, log_dir):
+ csv_batch_log_name = os.path.join(log_dir, 'batch-log.csv')
+ csv_batch_logger = init_csv_logger(CSV_BATCH_LOGGER, csv_batch_log_name,
+ list(self.BatchLogRecord.__dataclass_fields__.keys()))
+ csv_epoch_logger = None
+ if not self._inf_mode:
+ csv_epoch_log_name = os.path.join(log_dir, 'epoch-log.csv')
+ csv_epoch_logger = init_csv_logger(CSV_EPOCH_LOGGER, csv_epoch_log_name,
+ list(self.EpochLogRecord.__dataclass_fields__.keys()))
+ tb_logger = SummaryWriter(os.path.join(log_dir, 'runs'))
+
+ return Loggers(csv_batch_logger, csv_epoch_logger, tb_logger)
+
+ def dump_args(self, exclude=frozenset()) -> dict:
+ return {k: v for k, v in self._args.items() if k not in {'self'} | exclude}
+
+ @staticmethod
+ @abstractmethod
+ def _prepare_dataset(dataset_config: BaseConfig.DatasetConfig) -> tuple[Dataset, Dataset]:
+ train_set = Dataset()
+ test_set = Dataset()
+ return train_set, test_set
+
+ @staticmethod
+ def _create_dataloader(
+ train_set: Dataset, test_set: Dataset,
+ inf_mode: bool, dataloader_config: BaseConfig.DataLoaderConfig
+ ) -> tuple[DataLoader, DataLoader]:
+ if inf_mode:
+ inf_sampler = RandomSampler(train_set,
+ replacement=True,
+ num_samples=int(1e20))
+ train_loader = DataLoader(train_set,
+ sampler=inf_sampler,
+ batch_size=dataloader_config.batch_size,
+ num_workers=dataloader_config.num_worker)
else:
- raise NotImplementedError(f"Logger '{name}' is not implemented.")
+ train_loader = DataLoader(train_set,
+ shuffle=True,
+ batch_size=dataloader_config.batch_size,
+ num_workers=dataloader_config.num_worker)
+ test_loader = DataLoader(test_set,
+ shuffle=False,
+ batch_size=dataloader_config.batch_size,
+ num_workers=dataloader_config.num_worker)
+
+ return train_loader, test_loader
+
+ @staticmethod
+ @abstractmethod
+ def _init_models(dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
+ model = torch.nn.Module()
+ yield 'model_name', model
+
+ @staticmethod
+ @abstractmethod
+ def _configure_optimizers(
+ models: Iterable[tuple[str, torch.nn.Module]],
+ optim_config: BaseConfig.OptimConfig
+ ) -> Iterable[tuple[str, torch.optim.Optimizer]]:
+ for model_name, model in models:
+ optim = torch.optim.Optimizer([model.state_dict()], {})
+ yield f"{model_name}_optim", optim
+
+ def _auto_load_checkpoint(
+ self,
+ checkpoint_dir: str,
+ inf_mode: bool,
+ **modules
+ ) -> None | BaseEpochLogRecord | BaseEpochLogRecord:
+ if not os.path.exists(checkpoint_dir):
+ return None
+ checkpoint_files = os.listdir(checkpoint_dir)
+ if not checkpoint_files:
+ return None
+ iter2checkpoint = {int(os.path.splitext(checkpoint_file)[0]): checkpoint_file
+ for checkpoint_file in checkpoint_files}
+ restore_iter = max(iter2checkpoint.keys())
+ latest_checkpoint = iter2checkpoint[restore_iter]
+ checkpoint = torch.load(os.path.join(checkpoint_dir, latest_checkpoint))
+ for module_name in modules.keys():
+ module_state_dict = checkpoint[f"{module_name}_state_dict"]
+ modules[module_name].load_state_dict(module_state_dict)
+
+ last_metrics = {k: v for k, v in checkpoint.items()
+ if not k.endswith('state_dict')}
+ if inf_mode:
+ last_metrics = self.BatchLogRecord(**last_metrics)
+ else:
+ last_metrics = self.EpochLogRecord(**last_metrics)
+
+ return last_metrics
- os.makedirs(os.path.dirname(filename), exist_ok=True)
- file_handler = FileHandlerWithHeader(filename, header)
- file_handler.setLevel(getattr(logging, file_log_level))
- file_handler.setFormatter(formatter)
- logger.addHandler(file_handler)
- return logger
+ @staticmethod
+ @abstractmethod
+ def _configure_scheduler(
+ optims: Iterable[tuple[str, torch.optim.Optimizer]],
+ last_iter: int, num_iters: int, sched_config: BaseConfig.SchedConfig,
+ ) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler]
+ | tuple[str, None]]:
+ for optim_name, optim in optims:
+ sched = torch.optim.lr_scheduler._LRScheduler(optim, -1)
+ yield f"{optim_name}_sched", sched
+ def _custom_init_fn(self, config: BaseConfig):
+ pass
-def training_log(name):
- def log_this(function):
- logger = logging.getLogger(name)
+ @staticmethod
+ @csv_logger
+ @tensorboard_logger
+ def log(loggers: Loggers, metrics: BaseBatchLogRecord | BaseEpochLogRecord):
+ return loggers, metrics
- def wrapper(*args, **kwargs):
- output = function(*args, **kwargs)
- logger.info(','.join(map(str, output.values())))
- return output
+ def save_checkpoint(self, metrics: BaseEpochLogRecord | BaseBatchLogRecord):
+ os.makedirs(self._checkpoint_dir, exist_ok=True)
+ checkpoint_name = os.path.join(self._checkpoint_dir, f"{metrics.epoch:06d}.pt")
+ models_state_dict = {f"{model_name}_state_dict": model.state_dict()
+ for model_name, model in self.models.items()}
+ optims_state_dict = {f"{optim_name}_state_dict": optim.state_dict()
+ for optim_name, optim in self.optims.items()}
+ checkpoint = metrics.__dict__ | models_state_dict | optims_state_dict
+ torch.save(checkpoint, checkpoint_name)
- return wrapper
+ @abstractmethod
+ def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device):
+ pass
- return log_this
+ @abstractmethod
+ def eval(self, loss_fn: Callable, device: torch.device):
+ pass