summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.py12
-rw-r--r--models/model.py27
-rw-r--r--utils/configuration.py2
3 files changed, 27 insertions, 14 deletions
diff --git a/config.py b/config.py
index df7c64c..726d06e 100644
--- a/config.py
+++ b/config.py
@@ -41,7 +41,7 @@ config: Configuration = {
# Batch size (pr, k)
# `pr` denotes number of persons
# `k` denotes number of sequences per person
- 'batch_size': (4, 6),
+ 'batch_size': (4, 5),
# Number of workers of Dataloader
'num_workers': 4,
# Faster data transfer from RAM to GPU if enabled
@@ -88,18 +88,22 @@ config: Configuration = {
# Local parameters (override global ones)
# 'auto_encoder': {
- # 'weight_decay': 0.001
+ # 'lr': 1e-3
# },
},
'scheduler': {
# Step start to decay
'start_step': 500,
+ # Step stop decaying
+ # 'stop_step': 30_000,
# Multiplicative factor of decay in the end
'final_gamma': 0.01,
# Local parameters (override global ones)
- # 'hpm': {
- # 'final_gamma': 0.001
+ # 'auto_encoder': {
+ # 'start_step': 0,
+ # 'stop_step': 500,
+ # 'final_gamma': 0.5
# }
}
},
diff --git a/models/model.py b/models/model.py
index 9dbf609..45067e6 100644
--- a/models/model.py
+++ b/models/model.py
@@ -212,24 +212,31 @@ class Model:
], **optim_hp)
# Scheduler
- start_step = sched_hp.get('start_step', 15_000)
+ start_step = sched_hp.get('start_step', 0)
+ stop_step = sched_hp.get('stop_step', self.total_iter)
final_gamma = sched_hp.get('final_gamma', 0.001)
ae_start_step = ae_sched_hp.get('start_step', start_step)
+ ae_stop_step = ae_sched_hp.get('stop_step', stop_step)
ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma)
- ae_all_step = self.total_iter - ae_start_step
+ ae_all_step = ae_stop_step - ae_start_step
hpm_start_step = hpm_sched_hp.get('start_step', start_step)
+ hpm_stop_step = hpm_sched_hp.get('stop_step', stop_step)
hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma)
- hpm_all_step = self.total_iter - hpm_start_step
+ hpm_all_step = hpm_stop_step - hpm_start_step
pn_start_step = pn_sched_hp.get('start_step', start_step)
+ pn_stop_step = pn_sched_hp.get('stop_step', stop_step)
pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma)
- pn_all_step = self.total_iter - pn_start_step
+ pn_all_step = pn_stop_step - pn_start_step
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
- lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step)
- if t > ae_start_step else 1,
- lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step)
- if t > hpm_start_step else 1,
- lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step)
- if t > pn_start_step else 1,
+ lambda t: 1 if t <= ae_start_step
+ else ae_final_gamma ** ((t - ae_start_step) / ae_all_step)
+ if ae_start_step < t <= ae_stop_step else ae_final_gamma,
+ lambda t: 1 if t <= hpm_start_step
+ else hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step)
+ if hpm_start_step < t <= hpm_stop_step else hpm_final_gamma,
+ lambda t: 1 if t <= pn_start_step
+ else pn_final_gamma ** ((t - pn_start_step) / pn_all_step)
+ if pn_start_step < t <= pn_stop_step else pn_final_gamma,
])
self.writer = SummaryWriter(self._log_name)
diff --git a/utils/configuration.py b/utils/configuration.py
index 608d413..579d2f2 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -66,11 +66,13 @@ class OptimizerHPConfiguration(TypedDict):
class SubSchedulerHPConfiguration(TypedDict):
start_step: int
+ stop_step: int
final_gamma: float
class SchedulerHPConfiguration(TypedDict):
start_step: int
+ stop_step: int
final_gamma: float
auto_encoder: SubSchedulerHPConfiguration
hpm: SubSchedulerHPConfiguration