summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-09 16:51:08 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-09 16:51:08 +0800
commit9f3d4cc14ad36e515b56e86fb8e26f519bde831e (patch)
tree3d35bfdf799ec1811880696376e8f0969a07b561
parent99ddd7c142a4ec97cb8bd14b204651790b3cf4ee (diff)
Some optimizations
1. Scheduler will decay the learning rate of auto-encoder only 2. Write learning rate history to tensorboard 3. Reduce image log frequency
-rw-r--r--models/model.py76
1 files changed, 44 insertions, 32 deletions
diff --git a/models/model.py b/models/model.py
index 0418070..ee07615 100644
--- a/models/model.py
+++ b/models/model.py
@@ -153,8 +153,18 @@ class Model:
self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam([
{'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
+ {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
+ {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
+ {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
], **optim_hp)
- self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, **sched_hp)
+ sched_gamma = sched_hp.get('gamma', 0.9)
+ sched_step_size = sched_hp.get('step_size', 500)
+ self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
+ lambda epoch: sched_gamma ** (epoch // sched_step_size),
+ lambda epoch: 0 if epoch < start_iter else 1,
+ lambda epoch: 0 if epoch < start_iter else 1,
+ lambda epoch: 0 if epoch < start_iter else 1,
+ ])
self.writer = SummaryWriter(self._log_name)
self.rgb_pn.train()
@@ -172,18 +182,8 @@ class Model:
start_time = datetime.now()
running_loss = torch.zeros(5, device=self.device)
print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
- f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} LR(s)")
+ f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
for (batch_c1, batch_c2) in dataloader:
- if self.curr_iter == start_iter:
- self.optimizer.add_param_group(
- {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}
- )
- self.optimizer.add_param_group(
- {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}
- )
- self.optimizer.add_param_group(
- {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
- )
self.curr_iter += 1
# Zero the parameter gradients
self.optimizer.zero_grad()
@@ -205,31 +205,39 @@ class Model:
'Canonical consistency loss', 'Batch All triplet loss (HPM)',
'Batch All triplet loss (PartNet)'
], losses)), self.curr_iter)
- if self.image_log_on:
- (appearance_image, canonical_image, pose_image) = images
- self.writer.add_images(
- 'Canonical image', canonical_image, self.curr_iter
- )
- for i in range(self.pr * self.k):
- self.writer.add_images(
- f'Original image/batch {i}', x_c1[i], self.curr_iter
- )
- self.writer.add_images(
- f'Appearance image/batch {i}',
- appearance_image[:, i, :, :, :],
- self.curr_iter
- )
- self.writer.add_images(
- f'Pose image/batch {i}',
- pose_image[:, i, :, :, :],
- self.curr_iter
- )
if self.curr_iter % 100 == 0:
lrs = self.scheduler.get_last_lr()
+ # Write learning rates
+ self.writer.add_scalar(
+ 'Learning rate/Auto-encoder', lrs[0], self.curr_iter
+ )
+ self.writer.add_scalar(
+ 'Learning rate/Others', lrs[1], self.curr_iter
+ )
+ # Write disentangled images
+ if self.image_log_on:
+ (appearance_image, canonical_image, pose_image) = images
+ self.writer.add_images(
+ 'Canonical image', canonical_image, self.curr_iter
+ )
+ for i in range(self.pr * self.k):
+ self.writer.add_images(
+ f'Original image/batch {i}', x_c1[i], self.curr_iter
+ )
+ self.writer.add_images(
+ f'Appearance image/batch {i}',
+ appearance_image[:, i, :, :, :],
+ self.curr_iter
+ )
+ self.writer.add_images(
+ f'Pose image/batch {i}',
+ pose_image[:, i, :, :, :],
+ self.curr_iter
+ )
print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- ' '.join(('{:.3e}'.format(lr) for lr in lrs)))
+ '{:.3e} {:.3e}'.format(lrs[0], lrs[1]))
running_loss.zero_()
# Step scheduler
@@ -244,6 +252,10 @@ class Model:
}, self._checkpoint_name)
print(datetime.now() - start_time, 'used')
start_time = datetime.now()
+ print(
+ f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
+ f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} {'LRs':^19}"
+ )
if self.curr_iter == self.total_iter:
self.writer.close()