aboutsummaryrefslogtreecommitdiff
path: root/libs/logging.py
blob: 38e523dcd6a83b0904bd97b464b3490d3b2cd544 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import logging
import os
from dataclasses import dataclass

from torch.utils.tensorboard import SummaryWriter

CSV_EPOCH_LOGGER = 'csv_epoch_logger'
CSV_BATCH_LOGGER = 'csv_batch_logger'


class FileHandlerWithHeader(logging.FileHandler):

    def __init__(self, filename, header, mode='a',
                 encoding=None, delay=False, errors=None):
        self.header = header
        self.file_pre_exists = os.path.exists(filename)

        super(FileHandlerWithHeader, self).__init__(
            filename, mode, encoding, delay, errors
        )
        if not delay and self.stream is not None and not self.file_pre_exists:
            self.stream.write(f'{header}\n')

    def emit(self, record):
        if self.stream is None:
            self.stream = self._open()
            if not self.file_pre_exists:
                self.stream.write(f'{self.header}\n')

        logging.FileHandler.emit(self, record)


def setup_logging(name="log",
                  filename=None,
                  stream_log_level="INFO",
                  file_log_level="INFO"):
    logger = logging.getLogger(name)
    logger.setLevel("INFO")
    formatter = logging.Formatter(
        '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
    )
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(getattr(logging, stream_log_level))
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)
    if filename is not None:
        header = 'time,logger,'
        if name == CSV_BATCH_LOGGER:
            header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
        elif name == CSV_EPOCH_LOGGER:
            header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
        else:
            raise NotImplementedError(f"Logger '{name}' is not implemented.")

        os.makedirs(os.path.dirname(filename), exist_ok=True)
        file_handler = FileHandlerWithHeader(filename, header)
        file_handler.setLevel(getattr(logging, file_log_level))
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    return logger


def training_log(name):
    def log_this(function):
        logger = logging.getLogger(name)

        def wrapper(*args, **kwargs):
            output = function(*args, **kwargs)
            logger.info(','.join(map(str, output.values())))
            return output

        return wrapper

    return log_this


@dataclass
class BaseBatchLogRecord:
    batch: int
    num_batches: int
    global_batch: int
    epoch: int
    num_epochs: int


@dataclass
class BaseEpochLogRecord:
    epoch: int
    num_epochs: int


@dataclass
class Loggers:
    csv_batch: logging.Logger
    csv_epoch: logging.Logger | None
    tensorboard: SummaryWriter


def init_csv_logger(name="log",
                    filename="log.csv",
                    metric_names=None,
                    stream_log_level="INFO",
                    file_log_level="INFO"):
    logger = logging.getLogger(name)
    logger.setLevel("INFO")
    formatter = logging.Formatter(
        '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
    )
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(getattr(logging, stream_log_level))
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    header = ['time', 'logger']
    if metric_names:
        header += metric_names

    header = ','.join(header)
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    file_handler = FileHandlerWithHeader(filename, header)
    file_handler.setLevel(getattr(logging, file_log_level))
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    return logger


def csv_logger(function):
    def wrapper(*args, **kwargs):
        loggers, metrics = function(*args, **kwargs)
        if isinstance(metrics, BaseEpochLogRecord):
            logger = loggers.csv_epoch
        elif isinstance(metrics, BaseBatchLogRecord):
            logger = loggers.csv_batch
        else:
            raise NotImplementedError(f"Unknown log type: '{type(metrics)}'")

        logger.info(','.join(map(str, metrics.__dict__.values())))
        return loggers, metrics

    return wrapper


def tensorboard_logger(function):
    def wrapper(*args, **kwargs):
        loggers, metrics = function(*args, **kwargs)
        if isinstance(metrics, BaseBatchLogRecord):
            metrics_exclude = BaseBatchLogRecord.__annotations__.keys()
            global_step = metrics.global_batch
        elif isinstance(metrics, BaseEpochLogRecord):
            metrics_exclude = BaseEpochLogRecord.__annotations__.keys()
            global_step = metrics.epoch
        else:
            raise NotImplementedError(f"Unknown log type: '{type(metrics)}'")

        logger = loggers.tensorboard
        for metric_name, metric_value in metrics.__dict__.items():
            if metric_name not in metrics_exclude:
                if isinstance(metric_value, float):
                    logger.add_scalar(metric_name, metric_value, global_step + 1)
                else:
                    NotImplementedError(f"Unsupported type: '{type(metric_value)}'")
        return loggers, metrics

    return wrapper