diff options
Diffstat (limited to 'supervised/utils.py')
-rw-r--r-- | supervised/utils.py | 46 |
1 files changed, 40 insertions, 6 deletions
diff --git a/supervised/utils.py b/supervised/utils.py index c477544..95d865a 100644 --- a/supervised/utils.py +++ b/supervised/utils.py @@ -1,7 +1,41 @@ -def training_log(func): - def wrapper(*args, **kwargs): - result = func(*args, **kwargs) - print(result) - return result +import logging +import os - return wrapper +EPOCH_LOGGER = 'epoch_logger' +BATCH_LOGGER = 'batch_logger' + + +def setup_logging(name="log", + filename=None, + stream_log_level="INFO", + file_log_level="INFO"): + logger = logging.getLogger(name) + logger.setLevel("INFO") + formatter = logging.Formatter( + '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S' + ) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(getattr(logging, stream_log_level)) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + if filename is not None: + os.makedirs(os.path.dirname(filename), exist_ok=True) + file_handler = logging.FileHandler(filename) + file_handler.setLevel(getattr(logging, file_log_level)) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + return logger + + +def training_log(name): + def log_this(function): + logger = logging.getLogger(name) + + def wrapper(*args, **kwargs): + output = function(*args, **kwargs) + logger.info(','.join(map(str, output.values()))) + return output + + return wrapper + + return log_this |