aboutsummaryrefslogtreecommitdiff
path: root/libs/utils.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-07-14 09:13:05 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-07-14 09:13:05 +0800
commit80f132510161cd8ae75ba153d24030a27b772815 (patch)
tree4c065132280ccc7e3081a8f2aa90a07babe48de0 /libs/utils.py
parent819e69812f1cde332143ee968564c7d83abffe78 (diff)
Use `dataclasses.fields()` instead of undocumented `__dataclass_fields__`
Diffstat (limited to 'libs/utils.py')
-rw-r--r--libs/utils.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/libs/utils.py b/libs/utils.py
index caacb38..23964bc 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -1,3 +1,4 @@
+import dataclasses
import os
import random
from abc import ABC, abstractmethod
@@ -110,14 +111,20 @@ class Trainer(ABC):
cudnn.deterministic = True
def init_logger(self, log_dir):
- csv_batch_log_name = os.path.join(log_dir, 'batch-log.csv')
- csv_batch_logger = init_csv_logger(CSV_BATCH_LOGGER, csv_batch_log_name,
- list(self.BatchLogRecord.__dataclass_fields__.keys()))
+ csv_batch_log_fname = os.path.join(log_dir, 'batch-log.csv')
+ csv_batch_logger = init_csv_logger(
+ name=CSV_BATCH_LOGGER,
+ filename=csv_batch_log_fname,
+ metric_names=[f.name for f in dataclasses.fields(self.BatchLogRecord)]
+ )
csv_epoch_logger = None
if not self._inf_mode:
- csv_epoch_log_name = os.path.join(log_dir, 'epoch-log.csv')
- csv_epoch_logger = init_csv_logger(CSV_EPOCH_LOGGER, csv_epoch_log_name,
- list(self.EpochLogRecord.__dataclass_fields__.keys()))
+ csv_epoch_log_fname = os.path.join(log_dir, 'epoch-log.csv')
+ csv_epoch_logger = init_csv_logger(
+ name=CSV_EPOCH_LOGGER,
+ filename=csv_epoch_log_fname,
+ metric_names=[f.name for f in dataclasses.fields(self.EpochLogRecord)]
+ )
tb_logger = SummaryWriter(os.path.join(log_dir, 'runs'))
return Loggers(csv_batch_logger, csv_epoch_logger, tb_logger)