From 80f132510161cd8ae75ba153d24030a27b772815 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 14 Jul 2022 09:13:05 +0800 Subject: Use `dataclasses.fields()` instead of undocumented `__dataclass_fields__` --- libs/utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/libs/utils.py b/libs/utils.py index caacb38..23964bc 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -1,3 +1,4 @@ +import dataclasses import os import random from abc import ABC, abstractmethod @@ -110,14 +111,20 @@ class Trainer(ABC): cudnn.deterministic = True def init_logger(self, log_dir): - csv_batch_log_name = os.path.join(log_dir, 'batch-log.csv') - csv_batch_logger = init_csv_logger(CSV_BATCH_LOGGER, csv_batch_log_name, - list(self.BatchLogRecord.__dataclass_fields__.keys())) + csv_batch_log_fname = os.path.join(log_dir, 'batch-log.csv') + csv_batch_logger = init_csv_logger( + name=CSV_BATCH_LOGGER, + filename=csv_batch_log_fname, + metric_names=[f.name for f in dataclasses.fields(self.BatchLogRecord)] + ) csv_epoch_logger = None if not self._inf_mode: - csv_epoch_log_name = os.path.join(log_dir, 'epoch-log.csv') - csv_epoch_logger = init_csv_logger(CSV_EPOCH_LOGGER, csv_epoch_log_name, - list(self.EpochLogRecord.__dataclass_fields__.keys())) + csv_epoch_log_fname = os.path.join(log_dir, 'epoch-log.csv') + csv_epoch_logger = init_csv_logger( + name=CSV_EPOCH_LOGGER, + filename=csv_epoch_log_fname, + metric_names=[f.name for f in dataclasses.fields(self.EpochLogRecord)] + ) tb_logger = SummaryWriter(os.path.join(log_dir, 'runs')) return Loggers(csv_batch_logger, csv_epoch_logger, tb_logger) -- cgit v1.2.3