aboutsummaryrefslogtreecommitdiff
path: root/supervised/utils.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-17 17:26:29 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-17 17:26:29 +0800
commit48e0405bfbffa63d9e2a610c1cf982892a861e4a (patch)
tree57eadefa6f0a46ecaec22063c3114fa9f2e774c3 /supervised/utils.py
parent9723bf4b92b8c3fb779da6f4990ac702a68c7cc6 (diff)
Add CSV logger header
Diffstat (limited to 'supervised/utils.py')
-rw-r--r--supervised/utils.py29
1 files changed, 28 insertions, 1 deletions
diff --git a/supervised/utils.py b/supervised/utils.py
index 95d865a..fde86eb 100644
--- a/supervised/utils.py
+++ b/supervised/utils.py
@@ -5,6 +5,25 @@ EPOCH_LOGGER = 'epoch_logger'
BATCH_LOGGER = 'batch_logger'
+class FileHandlerWithHeader(logging.FileHandler):
+
+ def __init__(self, filename, header, mode='a', encoding=None, delay=0):
+ self.header = header
+ self.file_pre_exists = os.path.exists(filename)
+
+ logging.FileHandler.__init__(self, filename, mode, encoding, delay)
+ if not delay and self.stream is not None:
+ self.stream.write(f'{header}\n')
+
+ def emit(self, record):
+ if self.stream is None:
+ self.stream = self._open()
+ if not self.file_pre_exists:
+ self.stream.write(f'{self.header}\n')
+
+ logging.FileHandler.emit(self, record)
+
+
def setup_logging(name="log",
filename=None,
stream_log_level="INFO",
@@ -19,8 +38,16 @@ def setup_logging(name="log",
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if filename is not None:
+ header = 'time,logger,'
+ if name == BATCH_LOGGER:
+ header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
+ elif name == EPOCH_LOGGER:
+ header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
+ else:
+ raise NotImplementedError(f"Logger '{name}' is not implemented.")
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
- file_handler = logging.FileHandler(filename)
+ file_handler = FileHandlerWithHeader(filename, header)
file_handler.setLevel(getattr(logging, file_log_level))
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)