summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py83
1 files changed, 45 insertions, 38 deletions
diff --git a/models/model.py b/models/model.py
index 912d0b9..5f079b8 100644
--- a/models/model.py
+++ b/models/model.py
@@ -150,8 +150,18 @@ class Model:
self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam([
{'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
+ {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
+ {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
+ {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
], **optim_hp)
- self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, **sched_hp)
+ sched_gamma = sched_hp.get('gamma', 0.9)
+ sched_step_size = sched_hp.get('step_size', 500)
+ self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
+ lambda epoch: sched_gamma ** (epoch // sched_step_size),
+ lambda epoch: 0 if epoch < start_iter else 1,
+ lambda epoch: 0 if epoch < start_iter else 1,
+ lambda epoch: 0 if epoch < start_iter else 1,
+ ])
self.writer = SummaryWriter(self._log_name)
self.rgb_pn.train()
@@ -168,19 +178,10 @@ class Model:
# Training start
start_time = datetime.now()
running_loss = torch.zeros(5, device=self.device)
- print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
- f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} LR(s)")
+ print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
+ f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
+ f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
for (batch_c1, batch_c2) in dataloader:
- if self.curr_iter == start_iter:
- self.optimizer.add_param_group(
- {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}
- )
- self.optimizer.add_param_group(
- {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}
- )
- self.optimizer.add_param_group(
- {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
- )
self.curr_iter += 1
# Zero the parameter gradients
self.optimizer.zero_grad()
@@ -198,35 +199,43 @@ class Model:
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
self.writer.add_scalars('Loss/details', dict(zip([
- 'Cross reconstruction loss', 'Pose similarity loss',
- 'Canonical consistency loss', 'Batch All triplet loss (HPM)',
+ 'Cross reconstruction loss', 'Canonical consistency loss',
+ 'Pose similarity loss', 'Batch All triplet loss (HPM)',
'Batch All triplet loss (PartNet)'
], losses)), self.curr_iter)
- if self.image_log_on:
- (appearance_image, canonical_image, pose_image) = images
- self.writer.add_images(
- 'Canonical image', canonical_image, self.curr_iter
- )
- for i in range(self.pr * self.k):
- self.writer.add_images(
- f'Original image/batch {i}', x_c1[i], self.curr_iter
- )
- self.writer.add_images(
- f'Appearance image/batch {i}',
- appearance_image[:, i, :, :, :],
- self.curr_iter
- )
- self.writer.add_images(
- f'Pose image/batch {i}',
- pose_image[:, i, :, :, :],
- self.curr_iter
- )
if self.curr_iter % 100 == 0:
lrs = self.scheduler.get_last_lr()
- print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
+ # Write learning rates
+ self.writer.add_scalar(
+ 'Learning rate/Auto-encoder', lrs[0], self.curr_iter
+ )
+ self.writer.add_scalar(
+ 'Learning rate/Others', lrs[1], self.curr_iter
+ )
+ # Write disentangled images
+ if self.image_log_on:
+ i_a, i_c, i_p = images
+ self.writer.add_images(
+ 'Canonical image', i_c, self.curr_iter
+ )
+ for (i, (o, a, p)) in enumerate(zip(x_c1, i_a, i_p)):
+ self.writer.add_images(
+ f'Original image/batch {i}', o, self.curr_iter
+ )
+ self.writer.add_images(
+ f'Appearance image/batch {i}', a, self.curr_iter
+ )
+ self.writer.add_images(
+ f'Pose image/batch {i}', p, self.curr_iter
+ )
+ time_used = datetime.now() - start_time
+ remaining_minute, second = divmod(time_used.seconds, 60)
+ hour, minute = divmod(remaining_minute, 60)
+ print(f'{hour:02}:{minute:02}:{second:02}',
+ f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- ' '.join(('{:.3e}'.format(lr) for lr in lrs)))
+ '{:.3e} {:.3e}'.format(lrs[0], lrs[1]))
running_loss.zero_()
# Step scheduler
@@ -239,8 +248,6 @@ class Model:
'optim_state_dict': self.optimizer.state_dict(),
'loss': loss,
}, self._checkpoint_name)
- print(datetime.now() - start_time, 'used')
- start_time = datetime.now()
if self.curr_iter == self.total_iter:
self.writer.close()