summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py37
1 files changed, 16 insertions, 21 deletions
diff --git a/models/model.py b/models/model.py
index ee07615..70c43b3 100644
--- a/models/model.py
+++ b/models/model.py
@@ -181,8 +181,9 @@ class Model:
# Training start
start_time = datetime.now()
running_loss = torch.zeros(5, device=self.device)
- print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
- f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
+ print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
+ f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
+ f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -201,8 +202,8 @@ class Model:
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
self.writer.add_scalars('Loss/details', dict(zip([
- 'Cross reconstruction loss', 'Pose similarity loss',
- 'Canonical consistency loss', 'Batch All triplet loss (HPM)',
+ 'Cross reconstruction loss', 'Canonical consistency loss',
+ 'Pose similarity loss', 'Batch All triplet loss (HPM)',
'Batch All triplet loss (PartNet)'
], losses)), self.curr_iter)
@@ -217,25 +218,25 @@ class Model:
)
# Write disentangled images
if self.image_log_on:
- (appearance_image, canonical_image, pose_image) = images
+ i_a, i_c, i_p = images
self.writer.add_images(
- 'Canonical image', canonical_image, self.curr_iter
+ 'Canonical image', i_c, self.curr_iter
)
- for i in range(self.pr * self.k):
+ for (i, (o, a, p)) in enumerate(zip(x_c1, i_a, i_p)):
self.writer.add_images(
- f'Original image/batch {i}', x_c1[i], self.curr_iter
+ f'Original image/batch {i}', o, self.curr_iter
)
self.writer.add_images(
- f'Appearance image/batch {i}',
- appearance_image[:, i, :, :, :],
- self.curr_iter
+ f'Appearance image/batch {i}', a, self.curr_iter
)
self.writer.add_images(
- f'Pose image/batch {i}',
- pose_image[:, i, :, :, :],
- self.curr_iter
+ f'Pose image/batch {i}', p, self.curr_iter
)
- print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
+ time_used = datetime.now() - start_time
+ remaining_minute, second = divmod(time_used.seconds, 60)
+ hour, minute = divmod(remaining_minute, 60)
+ print(f'{hour:02}:{minute:02}:{second:02}',
+ f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
'{:.3e} {:.3e}'.format(lrs[0], lrs[1]))
running_loss.zero_()
@@ -250,12 +251,6 @@ class Model:
'optim_state_dict': self.optimizer.state_dict(),
'loss': loss,
}, self._checkpoint_name)
- print(datetime.now() - start_time, 'used')
- start_time = datetime.now()
- print(
- f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
- f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} {'LRs':^19}"
- )
if self.curr_iter == self.total_iter:
self.writer.close()