aboutsummaryrefslogtreecommitdiff
path: root/libs
diff options
context:
space:
mode:
Diffstat (limited to 'libs')
-rw-r--r--libs/logging.py44
-rw-r--r--libs/utils.py14
2 files changed, 14 insertions, 44 deletions
diff --git a/libs/logging.py b/libs/logging.py
index 38e523d..3969ffa 100644
--- a/libs/logging.py
+++ b/libs/logging.py
@@ -30,50 +30,6 @@ class FileHandlerWithHeader(logging.FileHandler):
logging.FileHandler.emit(self, record)
-def setup_logging(name="log",
- filename=None,
- stream_log_level="INFO",
- file_log_level="INFO"):
- logger = logging.getLogger(name)
- logger.setLevel("INFO")
- formatter = logging.Formatter(
- '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
- )
- stream_handler = logging.StreamHandler()
- stream_handler.setLevel(getattr(logging, stream_log_level))
- stream_handler.setFormatter(formatter)
- logger.addHandler(stream_handler)
- if filename is not None:
- header = 'time,logger,'
- if name == CSV_BATCH_LOGGER:
- header += 'batch,n_batches,global_batch,epoch,n_epochs,train_loss,lr'
- elif name == CSV_EPOCH_LOGGER:
- header += 'epoch,n_epochs,train_loss,test_loss,test_accuracy'
- else:
- raise NotImplementedError(f"Logger '{name}' is not implemented.")
-
- os.makedirs(os.path.dirname(filename), exist_ok=True)
- file_handler = FileHandlerWithHeader(filename, header)
- file_handler.setLevel(getattr(logging, file_log_level))
- file_handler.setFormatter(formatter)
- logger.addHandler(file_handler)
- return logger
-
-
-def training_log(name):
- def log_this(function):
- logger = logging.getLogger(name)
-
- def wrapper(*args, **kwargs):
- output = function(*args, **kwargs)
- logger.info(','.join(map(str, output.values())))
- return output
-
- return wrapper
-
- return log_this
-
-
@dataclass
class BaseBatchLogRecord:
batch: int
diff --git a/libs/utils.py b/libs/utils.py
index 77e6cf1..bc45a12 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -39,6 +39,20 @@ class BaseConfig:
optim_config: OptimConfig
sched_config: SchedConfig
+ @staticmethod
+ def _config_from_args(args, dcls):
+ return dcls(**{f.name: getattr(args, f.name)
+ for f in dataclasses.fields(dcls)})
+
+ @classmethod
+ def from_args(cls, args):
+ dataset_config = cls._config_from_args(args, cls.DatasetConfig)
+ dataloader_config = cls._config_from_args(args, cls.DataLoaderConfig)
+ optim_config = cls._config_from_args(args, cls.OptimConfig)
+ sched_config = cls._config_from_args(args, cls.SchedConfig)
+
+ return cls(dataset_config, dataloader_config, optim_config, sched_config)
+
class Trainer(ABC):
def __init__(