1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
|
import logging
import os
from dataclasses import dataclass
from torch.utils.tensorboard import SummaryWriter
CSV_EPOCH_LOGGER = 'csv_epoch_logger'
CSV_BATCH_LOGGER = 'csv_batch_logger'
class FileHandlerWithHeader(logging.FileHandler):
def __init__(self, filename, header, mode='a',
encoding=None, delay=False, errors=None):
self.header = header
self.file_pre_exists = os.path.exists(filename)
super(FileHandlerWithHeader, self).__init__(
filename, mode, encoding, delay, errors
)
if not delay and self.stream is not None and not self.file_pre_exists:
self.stream.write(f'{header}\n')
def emit(self, record):
if self.stream is None:
self.stream = self._open()
if not self.file_pre_exists:
self.stream.write(f'{self.header}\n')
logging.FileHandler.emit(self, record)
@dataclass
class BaseBatchLogRecord:
batch: int
num_batches: int
global_batch: int
epoch: int
num_epochs: int
@dataclass
class BaseEpochLogRecord:
epoch: int
num_epochs: int
@dataclass
class Loggers:
csv_batch: logging.Logger
csv_epoch: logging.Logger | None
tensorboard: SummaryWriter
def init_csv_logger(name="log",
filename="log.csv",
metric_names=None,
stream_log_level="INFO",
file_log_level="INFO"):
logger = logging.getLogger(name)
logger.setLevel("INFO")
formatter = logging.Formatter(
'%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(getattr(logging, stream_log_level))
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
header = ['time', 'logger']
if metric_names:
header += metric_names
header = ','.join(header)
os.makedirs(os.path.dirname(filename), exist_ok=True)
file_handler = FileHandlerWithHeader(filename, header)
file_handler.setLevel(getattr(logging, file_log_level))
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def csv_logger(function):
def wrapper(*args, **kwargs):
loggers, metrics = function(*args, **kwargs)
if isinstance(metrics, BaseEpochLogRecord):
logger = loggers.csv_epoch
elif isinstance(metrics, BaseBatchLogRecord):
logger = loggers.csv_batch
else:
raise NotImplementedError(f"Unknown log type: '{type(metrics)}'")
logger.info(','.join(map(str, metrics.__dict__.values())))
return loggers, metrics
return wrapper
def tensorboard_logger(function):
def wrapper(*args, **kwargs):
loggers, metrics = function(*args, **kwargs)
if isinstance(metrics, BaseBatchLogRecord):
metrics_exclude = BaseBatchLogRecord.__annotations__.keys()
global_step = metrics.global_batch
elif isinstance(metrics, BaseEpochLogRecord):
metrics_exclude = BaseEpochLogRecord.__annotations__.keys()
global_step = metrics.epoch
else:
raise NotImplementedError(f"Unknown log type: '{type(metrics)}'")
logger = loggers.tensorboard
for metric_name, metric_value in metrics.__dict__.items():
if metric_name not in metrics_exclude:
if isinstance(metric_value, float):
logger.add_scalar(metric_name, metric_value, global_step + 1)
else:
NotImplementedError(f"Unsupported type: '{type(metric_value)}'")
return loggers, metrics
return wrapper
|