aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--supervised/utils.py29
1 files changed, 28 insertions, 1 deletions
diff --git a/supervised/utils.py b/supervised/utils.py
index 95d865a..fde86eb 100644
--- a/supervised/utils.py
+++ b/supervised/utils.py
@@ -5,6 +5,25 @@ EPOCH_LOGGER = 'epoch_logger'
BATCH_LOGGER = 'batch_logger'
+class FileHandlerWithHeader(logging.FileHandler):
+
+ def __init__(self, filename, header, mode='a', encoding=None, delay=0):
+ self.header = header
+ self.file_pre_exists = os.path.exists(filename)
+
+ logging.FileHandler.__init__(self, filename, mode, encoding, delay)
+ if not delay and self.stream is not None:
+ self.stream.write(f'{header}\n')
+
+ def emit(self, record):
+ if self.stream is None:
+ self.stream = self._open()
+ if not self.file_pre_exists:
+ self.stream.write(f'{self.header}\n')
+
+ logging.FileHandler.emit(self, record)
+
+
def setup_logging(name="log",
filename=None,
stream_log_level="INFO",
@@ -19,8 +38,16 @@ def setup_logging(name="log",
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if filename is not None:
+ header = 'time,logger,'
+ if name == BATCH_LOGGER:
+ header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
+ elif name == EPOCH_LOGGER:
+ header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
+ else:
+ raise NotImplementedError(f"Logger '{name}' is not implemented.")
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
- file_handler = logging.FileHandler(filename)
+ file_handler = FileHandlerWithHeader(filename, header)
file_handler.setLevel(getattr(logging, file_log_level))
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)