aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--libs/logging.py165
1 files changed, 165 insertions, 0 deletions
diff --git a/libs/logging.py b/libs/logging.py
new file mode 100644
index 0000000..38e523d
--- /dev/null
+++ b/libs/logging.py
@@ -0,0 +1,165 @@
+import logging
+import os
+from dataclasses import dataclass
+
+from torch.utils.tensorboard import SummaryWriter
+
+CSV_EPOCH_LOGGER = 'csv_epoch_logger'
+CSV_BATCH_LOGGER = 'csv_batch_logger'
+
+
+class FileHandlerWithHeader(logging.FileHandler):
+
+ def __init__(self, filename, header, mode='a',
+ encoding=None, delay=False, errors=None):
+ self.header = header
+ self.file_pre_exists = os.path.exists(filename)
+
+ super(FileHandlerWithHeader, self).__init__(
+ filename, mode, encoding, delay, errors
+ )
+ if not delay and self.stream is not None and not self.file_pre_exists:
+ self.stream.write(f'{header}\n')
+
+ def emit(self, record):
+ if self.stream is None:
+ self.stream = self._open()
+ if not self.file_pre_exists:
+ self.stream.write(f'{self.header}\n')
+
+ logging.FileHandler.emit(self, record)
+
+
+def setup_logging(name="log",
+ filename=None,
+ stream_log_level="INFO",
+ file_log_level="INFO"):
+ logger = logging.getLogger(name)
+ logger.setLevel("INFO")
+ formatter = logging.Formatter(
+ '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
+ )
+ stream_handler = logging.StreamHandler()
+ stream_handler.setLevel(getattr(logging, stream_log_level))
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+ if filename is not None:
+ header = 'time,logger,'
+ if name == CSV_BATCH_LOGGER:
+ header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
+ elif name == CSV_EPOCH_LOGGER:
+ header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
+ else:
+ raise NotImplementedError(f"Logger '{name}' is not implemented.")
+
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ file_handler = FileHandlerWithHeader(filename, header)
+ file_handler.setLevel(getattr(logging, file_log_level))
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+ return logger
+
+
+def training_log(name):
+ def log_this(function):
+ logger = logging.getLogger(name)
+
+ def wrapper(*args, **kwargs):
+ output = function(*args, **kwargs)
+ logger.info(','.join(map(str, output.values())))
+ return output
+
+ return wrapper
+
+ return log_this
+
+
+@dataclass
+class BaseBatchLogRecord:
+ batch: int
+ num_batches: int
+ global_batch: int
+ epoch: int
+ num_epochs: int
+
+
+@dataclass
+class BaseEpochLogRecord:
+ epoch: int
+ num_epochs: int
+
+
+@dataclass
+class Loggers:
+ csv_batch: logging.Logger
+ csv_epoch: logging.Logger | None
+ tensorboard: SummaryWriter
+
+
+def init_csv_logger(name="log",
+ filename="log.csv",
+ metric_names=None,
+ stream_log_level="INFO",
+ file_log_level="INFO"):
+ logger = logging.getLogger(name)
+ logger.setLevel("INFO")
+ formatter = logging.Formatter(
+ '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
+ )
+ stream_handler = logging.StreamHandler()
+ stream_handler.setLevel(getattr(logging, stream_log_level))
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+
+ header = ['time', 'logger']
+ if metric_names:
+ header += metric_names
+
+ header = ','.join(header)
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ file_handler = FileHandlerWithHeader(filename, header)
+ file_handler.setLevel(getattr(logging, file_log_level))
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+
+ return logger
+
+
+def csv_logger(function):
+ def wrapper(*args, **kwargs):
+ loggers, metrics = function(*args, **kwargs)
+ if isinstance(metrics, BaseEpochLogRecord):
+ logger = loggers.csv_epoch
+ elif isinstance(metrics, BaseBatchLogRecord):
+ logger = loggers.csv_batch
+ else:
+ raise NotImplementedError(f"Unknown log type: '{type(metrics)}'")
+
+ logger.info(','.join(map(str, metrics.__dict__.values())))
+ return loggers, metrics
+
+ return wrapper
+
+
+def tensorboard_logger(function):
+ def wrapper(*args, **kwargs):
+ loggers, metrics = function(*args, **kwargs)
+ if isinstance(metrics, BaseBatchLogRecord):
+ metrics_exclude = BaseBatchLogRecord.__annotations__.keys()
+ global_step = metrics.global_batch
+ elif isinstance(metrics, BaseEpochLogRecord):
+ metrics_exclude = BaseEpochLogRecord.__annotations__.keys()
+ global_step = metrics.epoch
+ else:
+ raise NotImplementedError(f"Unknown log type: '{type(metrics)}'")
+
+ logger = loggers.tensorboard
+ for metric_name, metric_value in metrics.__dict__.items():
+ if metric_name not in metrics_exclude:
+ if isinstance(metric_value, float):
+ logger.add_scalar(metric_name, metric_value, global_step + 1)
+ else:
+ NotImplementedError(f"Unsupported type: '{type(metric_value)}'")
+ return loggers, metrics
+
+ return wrapper