aboutsummaryrefslogtreecommitdiff
path: root/libs/utils.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-10 15:17:06 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-10 15:17:06 +0800
commitf9b51f152e94896a122325fa2c48b0ce10c881c5 (patch)
treef777f3f9451318c0ed16428b8302fc3f02a08bb6 /libs/utils.py
parentebb2f93ac01f40d00968daaf9a2ad96c24ce7ab3 (diff)
Add default schedulers
Diffstat (limited to 'libs/utils.py')
-rw-r--r--libs/utils.py24
1 files changed, 21 insertions, 3 deletions
diff --git a/libs/utils.py b/libs/utils.py
index 5e8529b..c019ba9 100644
--- a/libs/utils.py
+++ b/libs/utils.py
@@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
from libs.logging import CSV_EPOCH_LOGGER, CSV_BATCH_LOGGER, BaseBatchLogRecord, BaseEpochLogRecord, Loggers, \
init_csv_logger, csv_logger, tensorboard_logger
+from libs.schedulers import LinearWarmupAndCosineAnneal, LinearLR
@dataclass
@@ -33,6 +34,7 @@ class BaseConfig:
@dataclass
class SchedConfig:
sched: None
+ warmup_iters: int
dataset_config: DatasetConfig
dataloader_config: DataLoaderConfig
@@ -221,15 +223,31 @@ class Trainer(ABC):
return last_metrics
@staticmethod
- @abstractmethod
def _configure_scheduler(
optims: Iterable[tuple[str, torch.optim.Optimizer]],
last_iter: int, num_iters: int, sched_config: BaseConfig.SchedConfig,
) -> Iterable[tuple[str, torch.optim.lr_scheduler._LRScheduler]
| tuple[str, None]]:
for optim_name, optim in optims:
- sched = torch.optim.lr_scheduler._LRScheduler(optim, -1)
- yield f"{optim_name}_sched", sched
+ if sched_config.sched == 'warmup-anneal':
+ scheduler = LinearWarmupAndCosineAnneal(
+ optim,
+ warm_up=sched_config.warmup_iters / num_iters,
+ T_max=num_iters,
+ last_epoch=last_iter,
+ )
+ elif sched_config.sched == 'linear':
+ scheduler = LinearLR(
+ optim,
+ num_epochs=num_iters,
+ last_epoch=last_iter
+ )
+ elif sched_config.sched in {None, '', 'const'}:
+ scheduler = None
+ else:
+ raise NotImplementedError(f"Unimplemented scheduler: '{sched_config.sched}'")
+
+ yield f"{optim_name}_sched", scheduler
def _custom_init_fn(self, config: BaseConfig):
pass