aboutsummaryrefslogtreecommitdiff
path: root/libs/utils.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-07-14 14:54:11 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-07-14 14:55:00 +0800
commitcc807c751e2c14ef9a88e5c5be00b4eb082e705b (patch)
tree04af8607e2906df68d3e77edbd2658353fb044a0 /libs/utils.py
parentb9d83e80b946437bb8dc0b586488fa756f52d732 (diff)
Refactor baseline with trainer
Diffstat (limited to 'libs/utils.py')
-rw-r--r--libs/utils.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/libs/utils.py b/libs/utils.py
index 77e6cf1..bc45a12 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -39,6 +39,20 @@ class BaseConfig:
optim_config: OptimConfig
sched_config: SchedConfig
+ @staticmethod
+ def _config_from_args(args, dcls):
+ return dcls(**{f.name: getattr(args, f.name)
+ for f in dataclasses.fields(dcls)})
+
+ @classmethod
+ def from_args(cls, args):
+ dataset_config = cls._config_from_args(args, cls.DatasetConfig)
+ dataloader_config = cls._config_from_args(args, cls.DataLoaderConfig)
+ optim_config = cls._config_from_args(args, cls.OptimConfig)
+ sched_config = cls._config_from_args(args, cls.SchedConfig)
+
+ return cls(dataset_config, dataloader_config, optim_config, sched_config)
+
class Trainer(ABC):
def __init__(