summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.py18
-rw-r--r--models/model.py27
-rw-r--r--utils/configuration.py5
3 files changed, 24 insertions, 26 deletions
diff --git a/config.py b/config.py
index 4c108e2..e70c2bd 100644
--- a/config.py
+++ b/config.py
@@ -72,8 +72,6 @@ config: Configuration = {
},
'optimizer': {
# Global parameters
- # Iteration start to optimize non-disentangling parts
- # 'start_iter': 0,
# Initial learning rate of Adam Optimizer
'lr': 1e-4,
# Coefficients used for computing running averages of
@@ -87,15 +85,15 @@ config: Configuration = {
# 'amsgrad': False,
# Local parameters (override global ones)
- # 'auto_encoder': {
- # 'weight_decay': 0.001
- # },
+ 'auto_encoder': {
+ 'weight_decay': 0.001
+ },
},
'scheduler': {
- # Period of learning rate decay
- 'step_size': 500,
- # Multiplicative factor of decay
- 'gamma': 1,
+ # Step start to decay
+ 'start_step': 15_000,
+ # Multiplicative factor of decay in the end
+ 'final_gamma': 0.001,
}
},
# Model metadata
@@ -109,6 +107,6 @@ config: Configuration = {
# Restoration iteration (multiple models, e.g. nm, bg and cl)
'restore_iters': (0, 0, 0),
# Total iteration for training (multiple models)
- 'total_iters': (80_000, 80_000, 80_000),
+ 'total_iters': (25_000, 25_000, 25_000),
},
}
diff --git a/models/model.py b/models/model.py
index b942eb8..497a0ea 100644
--- a/models/model.py
+++ b/models/model.py
@@ -147,7 +147,6 @@ class Model:
triplet_is_mean = model_hp.pop('triplet_is_mean', True)
triplet_margins = model_hp.pop('triplet_margins', None)
optim_hp: dict = self.hp.get('optimizer', {}).copy()
- start_iter = optim_hp.pop('start_iter', 0)
ae_optim_hp = optim_hp.pop('auto_encoder', {})
pn_optim_hp = optim_hp.pop('part_net', {})
hpm_optim_hp = optim_hp.pop('hpm', {})
@@ -184,14 +183,17 @@ class Model:
{'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
{'params': self.rgb_pn.fc_mat, **fc_optim_hp}
], **optim_hp)
- sched_gamma = sched_hp.get('gamma', 0.9)
- sched_step_size = sched_hp.get('step_size', 500)
+ sched_final_gamma = sched_hp.get('final_gamma', 0.001)
+ sched_start_step = sched_hp.get('start_step', 15_000)
+
+ def lr_lambda(epoch):
+ passed_step = epoch - sched_start_step
+ all_step = self.total_iter - sched_start_step
+ return sched_final_gamma ** (passed_step / all_step)
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
- lambda epoch: sched_gamma ** (epoch // sched_step_size),
- lambda epoch: 0 if epoch < start_iter else 1,
- lambda epoch: 0 if epoch < start_iter else 1,
- lambda epoch: 0 if epoch < start_iter else 1,
+ lr_lambda, lr_lambda, lr_lambda, lr_lambda
])
+
self.writer = SummaryWriter(self._log_name)
self.rgb_pn.train()
@@ -211,7 +213,7 @@ class Model:
running_loss = torch.zeros(5, device=self.device)
print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
+ f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}")
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -282,10 +284,7 @@ class Model:
lrs = self.scheduler.get_last_lr()
# Write learning rates
self.writer.add_scalar(
- 'Learning rate/Auto-encoder', lrs[0], self.curr_iter
- )
- self.writer.add_scalar(
- 'Learning rate/Others', lrs[1], self.curr_iter
+ 'Learning rate', lrs[0], self.curr_iter
)
# Write disentangled images
if self.image_log_on:
@@ -309,7 +308,7 @@ class Model:
print(f'{hour:02}:{minute:02}:{second:02}',
f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- '{:.3e} {:.3e}'.format(lrs[0], lrs[1]))
+ f'{lrs[0]:.3e}')
running_loss.zero_()
# Step scheduler
@@ -385,6 +384,8 @@ class Model:
# Init models
model_hp: dict = self.hp.get('model', {}).copy()
+ model_hp.pop('triplet_is_hard', True)
+ model_hp.pop('triplet_is_mean', True)
model_hp.pop('triplet_margins', None)
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
# Try to accelerate computation using CUDA or others
diff --git a/utils/configuration.py b/utils/configuration.py
index 31eb243..0f8d9ff 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -57,7 +57,6 @@ class SubOptimizerHPConfiguration(TypedDict):
class OptimizerHPConfiguration(TypedDict):
- start_iter: int
lr: int
betas: tuple[float, float]
eps: float
@@ -70,8 +69,8 @@ class OptimizerHPConfiguration(TypedDict):
class SchedulerHPConfiguration(TypedDict):
- step_size: int
- gamma: float
+ start_step: int
+ final_gamma: float
class HyperparameterConfiguration(TypedDict):