aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--libs/utils.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/libs/utils.py b/libs/utils.py
index caacb38..23964bc 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -1,3 +1,4 @@
+import dataclasses
import os
import random
from abc import ABC, abstractmethod
@@ -110,14 +111,20 @@ class Trainer(ABC):
cudnn.deterministic = True
def init_logger(self, log_dir):
- csv_batch_log_name = os.path.join(log_dir, 'batch-log.csv')
- csv_batch_logger = init_csv_logger(CSV_BATCH_LOGGER, csv_batch_log_name,
- list(self.BatchLogRecord.__dataclass_fields__.keys()))
+ csv_batch_log_fname = os.path.join(log_dir, 'batch-log.csv')
+ csv_batch_logger = init_csv_logger(
+ name=CSV_BATCH_LOGGER,
+ filename=csv_batch_log_fname,
+ metric_names=[f.name for f in dataclasses.fields(self.BatchLogRecord)]
+ )
csv_epoch_logger = None
if not self._inf_mode:
- csv_epoch_log_name = os.path.join(log_dir, 'epoch-log.csv')
- csv_epoch_logger = init_csv_logger(CSV_EPOCH_LOGGER, csv_epoch_log_name,
- list(self.EpochLogRecord.__dataclass_fields__.keys()))
+ csv_epoch_log_fname = os.path.join(log_dir, 'epoch-log.csv')
+ csv_epoch_logger = init_csv_logger(
+ name=CSV_EPOCH_LOGGER,
+ filename=csv_epoch_log_fname,
+ metric_names=[f.name for f in dataclasses.fields(self.EpochLogRecord)]
+ )
tb_logger = SummaryWriter(os.path.join(log_dir, 'runs'))
return Loggers(csv_batch_logger, csv_epoch_logger, tb_logger)