1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
|
import logging
import os
from dataclasses import dataclass
from torch.utils.tensorboard import SummaryWriter
CSV_EPOCH_LOGGER = 'csv_epoch_logger'
CSV_BATCH_LOGGER = 'csv_batch_logger'
class FileHandlerWithHeader(logging.FileHandler):
def __init__(self, filename, header, mode='a',
encoding=None, delay=False, errors=None):
self.header = header
self.file_pre_exists = os.path.exists(filename)
super(FileHandlerWithHeader, self).__init__(
filename, mode, encoding, delay, errors
)
if not delay and self.stream is not None and not self.file_pre_exists:
self.stream.write(f'{header}\n')
def emit(self, record):
if self.stream is None:
self.stream = self._open()
if not self.file_pre_exists:
self.stream.write(f'{self.header}\n')
logging.FileHandler.emit(self, record)
def setup_logging(name="log",
filename=None,
stream_log_level="INFO",
file_log_level="INFO"):
logger = logging.getLogger(name)
logger.setLevel("INFO")
formatter = logging.Formatter(
'%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(getattr(logging, stream_log_level))
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if filename is not None:
header = 'time,logger,'
if name == CSV_BATCH_LOGGER:
header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
elif name == CSV_EPOCH_LOGGER:
header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
else:
raise NotImplementedError(f"Logger '{name}' is not implemented.")
os.makedirs(os.path.dirname(filename), exist_ok=True)
file_handler = FileHandlerWithHeader(filename, header)
file_handler.setLevel(getattr(logging, file_log_level))
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def training_log(name):
def log_this(function):
logger = logging.getLogger(name)
def wrapper(*args, **kwargs):
output = function(*args, **kwargs)
logger.info(','.join(map(str, output.values())))
return output
return wrapper
return log_this
@dataclass
class BaseBatchLogRecord:
batch: int
num_batches: int
global_batch: int
epoch: int
num_epochs: int
@dataclass
class BaseEpochLogRecord:
epoch: int
num_epochs: int
@dataclass
class Loggers:
csv_batch: logging.Logger
csv_epoch: logging.Logger | None
tensorboard: SummaryWriter
def init_csv_logger(name="log",
filename="log.csv",
metric_names=None,
stream_log_level="INFO",
file_log_level="INFO"):
logger = logging.getLogger(name)
logger.setLevel("INFO")
formatter = logging.Formatter(
'%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(getattr(logging, stream_log_level))
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
header = ['time', 'logger']
if metric_names:
header += metric_names
header = ','.join(header)
os.makedirs(os.path.dirname(filename), exist_ok=True)
file_handler = FileHandlerWithHeader(filename, header)
file_handler.setLevel(getattr(logging, file_log_level))
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def csv_logger(function):
def wrapper(*args, **kwargs):
loggers, metrics = function(*args, **kwargs)
if isinstance(metrics, BaseEpochLogRecord):
logger = loggers.csv_epoch
elif isinstance(metrics, BaseBatchLogRecord):
logger = loggers.csv_batch
else:
raise NotImplementedError(f"Unknown log type: '{type(metrics)}'")
logger.info(','.join(map(str, metrics.__dict__.values())))
return loggers, metrics
return wrapper
def tensorboard_logger(function):
def wrapper(*args, **kwargs):
loggers, metrics = function(*args, **kwargs)
if isinstance(metrics, BaseBatchLogRecord):
metrics_exclude = BaseBatchLogRecord.__annotations__.keys()
global_step = metrics.global_batch
elif isinstance(metrics, BaseEpochLogRecord):
metrics_exclude = BaseEpochLogRecord.__annotations__.keys()
global_step = metrics.epoch
else:
raise NotImplementedError(f"Unknown log type: '{type(metrics)}'")
logger = loggers.tensorboard
for metric_name, metric_value in metrics.__dict__.items():
if metric_name not in metrics_exclude:
if isinstance(metric_value, float):
logger.add_scalar(metric_name, metric_value, global_step + 1)
else:
NotImplementedError(f"Unsupported type: '{type(metric_value)}'")
return loggers, metrics
return wrapper
|