diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-07 15:46:34 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-07 15:46:34 +0800 |
commit | bc82188377a8dc3068cd7b2ba5f02ed8bd4f2a54 (patch) | |
tree | c3265640aa0e37fe3c7a890ac5c415e32c287c77 /models/model.py | |
parent | 5e8947fbc90e1d67dadae36d32330a280d057267 (diff) |
Revert cross-reconstruction loss factor and make image log steps adjustable
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 34 |
1 files changed, 18 insertions, 16 deletions
diff --git a/models/model.py b/models/model.py index ceadb92..9dbf609 100644 --- a/models/model.py +++ b/models/model.py @@ -79,6 +79,7 @@ class Model: self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None self.image_log_on = system_config.get('image_log_on', False) + self.image_log_steps = system_config.get('image_log_steps', 100) self.val_size = system_config.get('val_size', 10) self.CASIAB_GALLERY_SELECTOR = { @@ -196,8 +197,8 @@ class Model: triplet_is_hard, triplet_is_mean, None ) - self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 - self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr + self.num_pairs = (self.pr * self.k - 1) * (self.pr * self.k) // 2 + self.num_pos_pairs = (self.k * (self.k - 1) // 2) * self.pr # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) @@ -278,24 +279,25 @@ class Model: 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses ) - if self.curr_iter % 100 == 99: - # Write disentangled images - if self.image_log_on: - i_a, i_c, i_p = images + # Write disentangled images + if self.image_log_on and self.curr_iter % self.image_log_steps \ + == self.image_log_steps - 1: + i_a, i_c, i_p = images + self.writer.add_images( + 'Appearance image', i_a, self.curr_iter + ) + self.writer.add_images( + 'Canonical image', i_c, self.curr_iter + ) + for i, (o, p) in enumerate(zip(x_c1, i_p)): self.writer.add_images( - 'Appearance image', i_a, self.curr_iter + f'Original image/batch {i}', o, self.curr_iter ) self.writer.add_images( - 'Canonical image', i_c, self.curr_iter + f'Pose image/batch {i}', p, self.curr_iter ) - for i, (o, p) in enumerate(zip(x_c1, i_p)): - self.writer.add_images( - f'Original image/batch {i}', o, self.curr_iter - ) - self.writer.add_images( - f'Pose image/batch {i}', p, self.curr_iter - ) + if self.curr_iter % 100 == 99: # Validation embed_c = self._flatten_embedding(embed_c) embed_p = self._flatten_embedding(embed_p) @@ -352,7 +354,7 @@ class Model: def _write_embedding(self, tag, embed, x, y): frame = x[:, 0, :, :, :].cpu() n, c, h, w = frame.size() - padding = torch.zeros(n, c, h, (h-w) // 2) + padding = torch.zeros(n, c, h, (h - w) // 2) padded_frame = torch.cat((padding, frame, padding), dim=-1) self.writer.add_embedding( embed, |