From cc807c751e2c14ef9a88e5c5be00b4eb082e705b Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 14 Jul 2022 14:54:11 +0800 Subject: Refactor baseline with trainer --- libs/logging.py | 44 -------------------------------------------- libs/utils.py | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 44 deletions(-) (limited to 'libs') diff --git a/libs/logging.py b/libs/logging.py index 38e523d..3969ffa 100644 --- a/libs/logging.py +++ b/libs/logging.py @@ -30,50 +30,6 @@ class FileHandlerWithHeader(logging.FileHandler): logging.FileHandler.emit(self, record) -def setup_logging(name="log", - filename=None, - stream_log_level="INFO", - file_log_level="INFO"): - logger = logging.getLogger(name) - logger.setLevel("INFO") - formatter = logging.Formatter( - '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S' - ) - stream_handler = logging.StreamHandler() - stream_handler.setLevel(getattr(logging, stream_log_level)) - stream_handler.setFormatter(formatter) - logger.addHandler(stream_handler) - if filename is not None: - header = 'time,logger,' - if name == CSV_BATCH_LOGGER: - header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr' - elif name == CSV_EPOCH_LOGGER: - header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy' - else: - raise NotImplementedError(f"Logger '{name}' is not implemented.") - - os.makedirs(os.path.dirname(filename), exist_ok=True) - file_handler = FileHandlerWithHeader(filename, header) - file_handler.setLevel(getattr(logging, file_log_level)) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - return logger - - -def training_log(name): - def log_this(function): - logger = logging.getLogger(name) - - def wrapper(*args, **kwargs): - output = function(*args, **kwargs) - logger.info(','.join(map(str, output.values()))) - return output - - return wrapper - - return log_this - - @dataclass class BaseBatchLogRecord: batch: int diff --git a/libs/utils.py b/libs/utils.py index 77e6cf1..bc45a12 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -39,6 +39,20 @@ class BaseConfig: optim_config: OptimConfig sched_config: SchedConfig + @staticmethod + def _config_from_args(args, dcls): + return dcls(**{f.name: getattr(args, f.name) + for f in dataclasses.fields(dcls)}) + + @classmethod + def from_args(cls, args): + dataset_config = cls._config_from_args(args, cls.DatasetConfig) + dataloader_config = cls._config_from_args(args, cls.DataLoaderConfig) + optim_config = cls._config_from_args(args, cls.OptimConfig) + sched_config = cls._config_from_args(args, cls.SchedConfig) + + return cls(dataset_config, dataloader_config, optim_config, sched_config) + class Trainer(ABC): def __init__( -- cgit v1.2.3