diff options
Diffstat (limited to 'libs')
-rw-r--r-- | libs/logging.py | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/libs/logging.py b/libs/logging.py new file mode 100644 index 0000000..38e523d --- /dev/null +++ b/libs/logging.py @@ -0,0 +1,165 @@ +import logging +import os +from dataclasses import dataclass + +from torch.utils.tensorboard import SummaryWriter + +CSV_EPOCH_LOGGER = 'csv_epoch_logger' +CSV_BATCH_LOGGER = 'csv_batch_logger' + + +class FileHandlerWithHeader(logging.FileHandler): + + def __init__(self, filename, header, mode='a', + encoding=None, delay=False, errors=None): + self.header = header + self.file_pre_exists = os.path.exists(filename) + + super(FileHandlerWithHeader, self).__init__( + filename, mode, encoding, delay, errors + ) + if not delay and self.stream is not None and not self.file_pre_exists: + self.stream.write(f'{header}\n') + + def emit(self, record): + if self.stream is None: + self.stream = self._open() + if not self.file_pre_exists: + self.stream.write(f'{self.header}\n') + + logging.FileHandler.emit(self, record) + + +def setup_logging(name="log", + filename=None, + stream_log_level="INFO", + file_log_level="INFO"): + logger = logging.getLogger(name) + logger.setLevel("INFO") + formatter = logging.Formatter( + '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S' + ) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(getattr(logging, stream_log_level)) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + if filename is not None: + header = 'time,logger,' + if name == CSV_BATCH_LOGGER: + header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr' + elif name == CSV_EPOCH_LOGGER: + header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy' + else: + raise NotImplementedError(f"Logger '{name}' is not implemented.") + + os.makedirs(os.path.dirname(filename), exist_ok=True) + file_handler = FileHandlerWithHeader(filename, header) + file_handler.setLevel(getattr(logging, file_log_level)) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + return logger + + +def training_log(name): + def log_this(function): + logger = logging.getLogger(name) + + def wrapper(*args, **kwargs): + output = function(*args, **kwargs) + logger.info(','.join(map(str, output.values()))) + return output + + return wrapper + + return log_this + + +@dataclass +class BaseBatchLogRecord: + batch: int + num_batches: int + global_batch: int + epoch: int + num_epochs: int + + +@dataclass +class BaseEpochLogRecord: + epoch: int + num_epochs: int + + +@dataclass +class Loggers: + csv_batch: logging.Logger + csv_epoch: logging.Logger | None + tensorboard: SummaryWriter + + +def init_csv_logger(name="log", + filename="log.csv", + metric_names=None, + stream_log_level="INFO", + file_log_level="INFO"): + logger = logging.getLogger(name) + logger.setLevel("INFO") + formatter = logging.Formatter( + '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S' + ) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(getattr(logging, stream_log_level)) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + + header = ['time', 'logger'] + if metric_names: + header += metric_names + + header = ','.join(header) + os.makedirs(os.path.dirname(filename), exist_ok=True) + file_handler = FileHandlerWithHeader(filename, header) + file_handler.setLevel(getattr(logging, file_log_level)) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +def csv_logger(function): + def wrapper(*args, **kwargs): + loggers, metrics = function(*args, **kwargs) + if isinstance(metrics, BaseEpochLogRecord): + logger = loggers.csv_epoch + elif isinstance(metrics, BaseBatchLogRecord): + logger = loggers.csv_batch + else: + raise NotImplementedError(f"Unknown log type: '{type(metrics)}'") + + logger.info(','.join(map(str, metrics.__dict__.values()))) + return loggers, metrics + + return wrapper + + +def tensorboard_logger(function): + def wrapper(*args, **kwargs): + loggers, metrics = function(*args, **kwargs) + if isinstance(metrics, BaseBatchLogRecord): + metrics_exclude = BaseBatchLogRecord.__annotations__.keys() + global_step = metrics.global_batch + elif isinstance(metrics, BaseEpochLogRecord): + metrics_exclude = BaseEpochLogRecord.__annotations__.keys() + global_step = metrics.epoch + else: + raise NotImplementedError(f"Unknown log type: '{type(metrics)}'") + + logger = loggers.tensorboard + for metric_name, metric_value in metrics.__dict__.items(): + if metric_name not in metrics_exclude: + if isinstance(metric_value, float): + logger.add_scalar(metric_name, metric_value, global_step + 1) + else: + NotImplementedError(f"Unsupported type: '{type(metric_value)}'") + return loggers, metrics + + return wrapper |