1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
import logging
import os
EPOCH_LOGGER = 'epoch_logger'
BATCH_LOGGER = 'batch_logger'
class FileHandlerWithHeader(logging.FileHandler):
def __init__(self, filename, header, mode='a', encoding=None, delay=0):
self.header = header
self.file_pre_exists = os.path.exists(filename)
logging.FileHandler.__init__(self, filename, mode, encoding, delay)
if not delay and self.stream is not None:
self.stream.write(f'{header}\n')
def emit(self, record):
if self.stream is None:
self.stream = self._open()
if not self.file_pre_exists:
self.stream.write(f'{self.header}\n')
logging.FileHandler.emit(self, record)
def setup_logging(name="log",
filename=None,
stream_log_level="INFO",
file_log_level="INFO"):
logger = logging.getLogger(name)
logger.setLevel("INFO")
formatter = logging.Formatter(
'%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(getattr(logging, stream_log_level))
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if filename is not None:
header = 'time,logger,'
if name == BATCH_LOGGER:
header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
elif name == EPOCH_LOGGER:
header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
else:
raise NotImplementedError(f"Logger '{name}' is not implemented.")
os.makedirs(os.path.dirname(filename), exist_ok=True)
file_handler = FileHandlerWithHeader(filename, header)
file_handler.setLevel(getattr(logging, file_log_level))
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def training_log(name):
def log_this(function):
logger = logging.getLogger(name)
def wrapper(*args, **kwargs):
output = function(*args, **kwargs)
logger.info(','.join(map(str, output.values())))
return output
return wrapper
return log_this
|