diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-07-14 09:13:05 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-07-14 09:13:05 +0800 |
commit | 80f132510161cd8ae75ba153d24030a27b772815 (patch) | |
tree | 4c065132280ccc7e3081a8f2aa90a07babe48de0 /libs/utils.py | |
parent | 819e69812f1cde332143ee968564c7d83abffe78 (diff) |
Use `dataclasses.fields()` instead of undocumented `__dataclass_fields__`
Diffstat (limited to 'libs/utils.py')
-rw-r--r-- | libs/utils.py | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/libs/utils.py b/libs/utils.py index caacb38..23964bc 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -1,3 +1,4 @@ +import dataclasses import os import random from abc import ABC, abstractmethod @@ -110,14 +111,20 @@ class Trainer(ABC): cudnn.deterministic = True def init_logger(self, log_dir): - csv_batch_log_name = os.path.join(log_dir, 'batch-log.csv') - csv_batch_logger = init_csv_logger(CSV_BATCH_LOGGER, csv_batch_log_name, - list(self.BatchLogRecord.__dataclass_fields__.keys())) + csv_batch_log_fname = os.path.join(log_dir, 'batch-log.csv') + csv_batch_logger = init_csv_logger( + name=CSV_BATCH_LOGGER, + filename=csv_batch_log_fname, + metric_names=[f.name for f in dataclasses.fields(self.BatchLogRecord)] + ) csv_epoch_logger = None if not self._inf_mode: - csv_epoch_log_name = os.path.join(log_dir, 'epoch-log.csv') - csv_epoch_logger = init_csv_logger(CSV_EPOCH_LOGGER, csv_epoch_log_name, - list(self.EpochLogRecord.__dataclass_fields__.keys())) + csv_epoch_log_fname = os.path.join(log_dir, 'epoch-log.csv') + csv_epoch_logger = init_csv_logger( + name=CSV_EPOCH_LOGGER, + filename=csv_epoch_log_fname, + metric_names=[f.name for f in dataclasses.fields(self.EpochLogRecord)] + ) tb_logger = SummaryWriter(os.path.join(log_dir, 'runs')) return Loggers(csv_batch_logger, csv_epoch_logger, tb_logger) |