aboutsummaryrefslogtreecommitdiff
path: root/libs/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'libs/utils.py')
-rw-r--r--libs/utils.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/libs/utils.py b/libs/utils.py
new file mode 100644
index 0000000..fde86eb
--- /dev/null
+++ b/libs/utils.py
@@ -0,0 +1,68 @@
+import logging
+import os
+
+EPOCH_LOGGER = 'epoch_logger'
+BATCH_LOGGER = 'batch_logger'
+
+
+class FileHandlerWithHeader(logging.FileHandler):
+
+ def __init__(self, filename, header, mode='a', encoding=None, delay=0):
+ self.header = header
+ self.file_pre_exists = os.path.exists(filename)
+
+ logging.FileHandler.__init__(self, filename, mode, encoding, delay)
+ if not delay and self.stream is not None:
+ self.stream.write(f'{header}\n')
+
+ def emit(self, record):
+ if self.stream is None:
+ self.stream = self._open()
+ if not self.file_pre_exists:
+ self.stream.write(f'{self.header}\n')
+
+ logging.FileHandler.emit(self, record)
+
+
+def setup_logging(name="log",
+ filename=None,
+ stream_log_level="INFO",
+ file_log_level="INFO"):
+ logger = logging.getLogger(name)
+ logger.setLevel("INFO")
+ formatter = logging.Formatter(
+ '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
+ )
+ stream_handler = logging.StreamHandler()
+ stream_handler.setLevel(getattr(logging, stream_log_level))
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+ if filename is not None:
+ header = 'time,logger,'
+ if name == BATCH_LOGGER:
+ header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
+ elif name == EPOCH_LOGGER:
+ header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
+ else:
+ raise NotImplementedError(f"Logger '{name}' is not implemented.")
+
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ file_handler = FileHandlerWithHeader(filename, header)
+ file_handler.setLevel(getattr(logging, file_log_level))
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+ return logger
+
+
+def training_log(name):
+ def log_this(function):
+ logger = logging.getLogger(name)
+
+ def wrapper(*args, **kwargs):
+ output = function(*args, **kwargs)
+ logger.info(','.join(map(str, output.values())))
+ return output
+
+ return wrapper
+
+ return log_this