aboutsummaryrefslogtreecommitdiff
path: root/libs/logging.py
blob: 6dfe5f92b3e948a36fd90f1f76a09f0d64c42474 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import logging
import os
from dataclasses import dataclass

from torch.utils.tensorboard import SummaryWriter

CSV_EPOCH_LOGGER = 'csv_epoch_logger'
CSV_BATCH_LOGGER = 'csv_batch_logger'


class FileHandlerWithHeader(logging.FileHandler):

    def __init__(self, filename, header, mode='a',
                 encoding=None, delay=False, errors=None):
        self.header = header
        self.file_pre_exists = os.path.exists(filename)

        super(FileHandlerWithHeader, self).__init__(
            filename, mode, encoding, delay, errors
        )
        if not delay and self.stream is not None and not self.file_pre_exists:
            self.stream.write(f'{header}\n')

    def emit(self, record):
        if self.stream is None:
            self.stream = self._open()
            if not self.file_pre_exists:
                self.stream.write(f'{self.header}\n')

        logging.FileHandler.emit(self, record)


@dataclass
class BaseBatchLogRecord:
    batch: int
    num_batches: int
    global_batch: int
    epoch: int
    num_epochs: int


@dataclass
class BaseEpochLogRecord:
    epoch: int
    num_epochs: int


@dataclass
class Loggers:
    csv_batch: logging.Logger
    csv_epoch: logging.Logger | None
    tensorboard: SummaryWriter


def init_csv_logger(name="log",
                    filename="log.csv",
                    metric_names=None,
                    stream_log_level="INFO",
                    file_log_level="INFO"):
    logger = logging.getLogger(name)
    logger.setLevel("INFO")
    formatter = logging.Formatter(
        '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
    )
    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(getattr(logging, stream_log_level))
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    header = ['time', 'logger']
    if metric_names:
        header += metric_names

    header = ','.join(header)
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    file_handler = FileHandlerWithHeader(filename, header)
    file_handler.setLevel(getattr(logging, file_log_level))
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    return logger


def csv_logger(function):
    def wrapper(*args, **kwargs):
        loggers, metrics = function(*args, **kwargs)
        if isinstance(metrics, BaseEpochLogRecord):
            logger = loggers.csv_epoch
        elif isinstance(metrics, BaseBatchLogRecord):
            logger = loggers.csv_batch
        else:
            raise NotImplementedError(f"Unknown log type: '{type(metrics)}'")

        logger.info(','.join(map(str, metrics.__dict__.values())))
        return loggers, metrics

    return wrapper


def tensorboard_logger(function):
    def wrapper(*args, **kwargs):
        loggers, metrics = function(*args, **kwargs)
        if isinstance(metrics, BaseBatchLogRecord):
            metrics_exclude = BaseBatchLogRecord.__annotations__.keys()
            global_step = metrics.global_batch
        elif isinstance(metrics, BaseEpochLogRecord):
            metrics_exclude = BaseEpochLogRecord.__annotations__.keys()
            global_step = metrics.epoch
        else:
            raise NotImplementedError(f"Unknown log type: '{type(metrics)}'")

        logger = loggers.tensorboard
        for metric_name, metric_value in metrics.__dict__.items():
            if metric_name not in metrics_exclude:
                if isinstance(metric_value, float):
                    logger.add_scalar(metric_name.replace('_', '/', 1),
                                      metric_value, global_step + 1)
                else:
                    NotImplementedError(f"Unsupported type: '{type(metric_value)}'")
        return loggers, metrics

    return wrapper