diff options
-rw-r--r-- | supervised/utils.py | 29 |
1 files changed, 28 insertions, 1 deletions
diff --git a/supervised/utils.py b/supervised/utils.py index 95d865a..fde86eb 100644 --- a/supervised/utils.py +++ b/supervised/utils.py @@ -5,6 +5,25 @@ EPOCH_LOGGER = 'epoch_logger' BATCH_LOGGER = 'batch_logger' +class FileHandlerWithHeader(logging.FileHandler): + + def __init__(self, filename, header, mode='a', encoding=None, delay=0): + self.header = header + self.file_pre_exists = os.path.exists(filename) + + logging.FileHandler.__init__(self, filename, mode, encoding, delay) + if not delay and self.stream is not None: + self.stream.write(f'{header}\n') + + def emit(self, record): + if self.stream is None: + self.stream = self._open() + if not self.file_pre_exists: + self.stream.write(f'{self.header}\n') + + logging.FileHandler.emit(self, record) + + def setup_logging(name="log", filename=None, stream_log_level="INFO", @@ -19,8 +38,16 @@ def setup_logging(name="log", stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) if filename is not None: + header = 'time,logger,' + if name == BATCH_LOGGER: + header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr' + elif name == EPOCH_LOGGER: + header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy' + else: + raise NotImplementedError(f"Logger '{name}' is not implemented.") + os.makedirs(os.path.dirname(filename), exist_ok=True) - file_handler = logging.FileHandler(filename) + file_handler = FileHandlerWithHeader(filename, header) file_handler.setLevel(getattr(logging, file_log_level)) file_handler.setFormatter(formatter) logger.addHandler(file_handler) |