From cc807c751e2c14ef9a88e5c5be00b4eb082e705b Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 14 Jul 2022 14:54:11 +0800 Subject: Refactor baseline with trainer --- libs/utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'libs/utils.py') diff --git a/libs/utils.py b/libs/utils.py index 77e6cf1..bc45a12 100644 --- a/libs/utils.py +++ b/libs/utils.py @@ -39,6 +39,20 @@ class BaseConfig: optim_config: OptimConfig sched_config: SchedConfig + @staticmethod + def _config_from_args(args, dcls): + return dcls(**{f.name: getattr(args, f.name) + for f in dataclasses.fields(dcls)}) + + @classmethod + def from_args(cls, args): + dataset_config = cls._config_from_args(args, cls.DatasetConfig) + dataloader_config = cls._config_from_args(args, cls.DataLoaderConfig) + optim_config = cls._config_from_args(args, cls.OptimConfig) + sched_config = cls._config_from_args(args, cls.SchedConfig) + + return cls(dataset_config, dataloader_config, optim_config, sched_config) + class Trainer(ABC): def __init__( -- cgit v1.2.3