aboutsummaryrefslogtreecommitdiff
path: root/libs/utils.py
blob: 3df157488a4162f9bd6cece9b901ba87ec3372eb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import logging
import os

EPOCH_LOGGER = 'epoch_logger'
BATCH_LOGGER = 'batch_logger'


class FileHandlerWithHeader(logging.FileHandler):

    def __init__(self, filename, header, mode='a',
                 encoding=None, delay=False, errors=None):
        self.header = header
        self.file_pre_exists = os.path.exists(filename)

        super(FileHandlerWithHeader, self).__init__(
            filename, mode, encoding, delay, errors
        )
        if not delay and self.stream is not None and not self.file_pre_exists:
            self.stream.write(f'{header}\n')

    def emit(self, record):
        if self.stream is None:
            self.stream = self._open()
            if not self.file_pre_exists:
                self.stream.write(f'{self.header}\n')

        logging.FileHandler.emit(self, record)


def setup_logging(name="log",
                  filename=None,
                  stream_log_level="INFO",
                  file_log_level="INFO"):
    logger = logging.getLogger(name)
    logger.setLevel("INFO")
    formatter = logging.Formatter(
        '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
    )
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(getattr(logging, stream_log_level))
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)
    if filename is not None:
        header = 'time,logger,'
        if name == BATCH_LOGGER:
            header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
        elif name == EPOCH_LOGGER:
            header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
        else:
            raise NotImplementedError(f"Logger '{name}' is not implemented.")

        os.makedirs(os.path.dirname(filename), exist_ok=True)
        file_handler = FileHandlerWithHeader(filename, header)
        file_handler.setLevel(getattr(logging, file_log_level))
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    return logger


def training_log(name):
    def log_this(function):
        logger = logging.getLogger(name)

        def wrapper(*args, **kwargs):
            output = function(*args, **kwargs)
            logger.info(','.join(map(str, output.values())))
            return output

        return wrapper

    return log_this