aboutsummaryrefslogtreecommitdiff
path: root/supervised/utils.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-17 20:10:42 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-17 20:10:42 +0800
commit9b2c25d3c927b7533e5d7d9665b67962e4c6934b (patch)
treeb4a80004ffdfe8165f5f8d077afb8b054d4472ce /supervised/utils.py
parent568569c764ffdd73cd660434df50d30d26203f63 (diff)
Move some utils to libs directory
Diffstat (limited to 'supervised/utils.py')
-rw-r--r--supervised/utils.py68
1 files changed, 0 insertions, 68 deletions
diff --git a/supervised/utils.py b/supervised/utils.py
deleted file mode 100644
index fde86eb..0000000
--- a/supervised/utils.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import logging
-import os
-
-EPOCH_LOGGER = 'epoch_logger'
-BATCH_LOGGER = 'batch_logger'
-
-
-class FileHandlerWithHeader(logging.FileHandler):
-
- def __init__(self, filename, header, mode='a', encoding=None, delay=0):
- self.header = header
- self.file_pre_exists = os.path.exists(filename)
-
- logging.FileHandler.__init__(self, filename, mode, encoding, delay)
- if not delay and self.stream is not None:
- self.stream.write(f'{header}\n')
-
- def emit(self, record):
- if self.stream is None:
- self.stream = self._open()
- if not self.file_pre_exists:
- self.stream.write(f'{self.header}\n')
-
- logging.FileHandler.emit(self, record)
-
-
-def setup_logging(name="log",
- filename=None,
- stream_log_level="INFO",
- file_log_level="INFO"):
- logger = logging.getLogger(name)
- logger.setLevel("INFO")
- formatter = logging.Formatter(
- '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
- )
- stream_handler = logging.StreamHandler()
- stream_handler.setLevel(getattr(logging, stream_log_level))
- stream_handler.setFormatter(formatter)
- logger.addHandler(stream_handler)
- if filename is not None:
- header = 'time,logger,'
- if name == BATCH_LOGGER:
- header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
- elif name == EPOCH_LOGGER:
- header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
- else:
- raise NotImplementedError(f"Logger '{name}' is not implemented.")
-
- os.makedirs(os.path.dirname(filename), exist_ok=True)
- file_handler = FileHandlerWithHeader(filename, header)
- file_handler.setLevel(getattr(logging, file_log_level))
- file_handler.setFormatter(formatter)
- logger.addHandler(file_handler)
- return logger
-
-
-def training_log(name):
- def log_this(function):
- logger = logging.getLogger(name)
-
- def wrapper(*args, **kwargs):
- output = function(*args, **kwargs)
- logger.info(','.join(map(str, output.values())))
- return output
-
- return wrapper
-
- return log_this