summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-04-07 15:46:34 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-04-07 15:46:34 +0800
commitbc82188377a8dc3068cd7b2ba5f02ed8bd4f2a54 (patch)
treec3265640aa0e37fe3c7a890ac5c415e32c287c77 /models
parent5e8947fbc90e1d67dadae36d32330a280d057267 (diff)
Revert cross-reconstruction loss factor and make image log steps adjustable
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py2
-rw-r--r--models/model.py34
2 files changed, 19 insertions, 17 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 49cb78b..96dfdb3 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -152,7 +152,7 @@ class AutoEncoder(nn.Module):
return (
(f_a_c1_t2, f_c_c1_t2, f_p_c1_t2),
- (xrecon_loss / 10, cano_cons_loss, pose_sim_loss * 10)
+ (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
)
else: # evaluating
return f_a_c1_t2, f_c_c1_t2, f_p_c1_t2
diff --git a/models/model.py b/models/model.py
index ceadb92..9dbf609 100644
--- a/models/model.py
+++ b/models/model.py
@@ -79,6 +79,7 @@ class Model:
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
self.image_log_on = system_config.get('image_log_on', False)
+ self.image_log_steps = system_config.get('image_log_steps', 100)
self.val_size = system_config.get('val_size', 10)
self.CASIAB_GALLERY_SELECTOR = {
@@ -196,8 +197,8 @@ class Model:
triplet_is_hard, triplet_is_mean, None
)
- self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
- self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
+ self.num_pairs = (self.pr * self.k - 1) * (self.pr * self.k) // 2
+ self.num_pos_pairs = (self.k * (self.k - 1) // 2) * self.pr
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
@@ -278,24 +279,25 @@ class Model:
'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses
)
- if self.curr_iter % 100 == 99:
- # Write disentangled images
- if self.image_log_on:
- i_a, i_c, i_p = images
+ # Write disentangled images
+ if self.image_log_on and self.curr_iter % self.image_log_steps \
+ == self.image_log_steps - 1:
+ i_a, i_c, i_p = images
+ self.writer.add_images(
+ 'Appearance image', i_a, self.curr_iter
+ )
+ self.writer.add_images(
+ 'Canonical image', i_c, self.curr_iter
+ )
+ for i, (o, p) in enumerate(zip(x_c1, i_p)):
self.writer.add_images(
- 'Appearance image', i_a, self.curr_iter
+ f'Original image/batch {i}', o, self.curr_iter
)
self.writer.add_images(
- 'Canonical image', i_c, self.curr_iter
+ f'Pose image/batch {i}', p, self.curr_iter
)
- for i, (o, p) in enumerate(zip(x_c1, i_p)):
- self.writer.add_images(
- f'Original image/batch {i}', o, self.curr_iter
- )
- self.writer.add_images(
- f'Pose image/batch {i}', p, self.curr_iter
- )
+ if self.curr_iter % 100 == 99:
# Validation
embed_c = self._flatten_embedding(embed_c)
embed_p = self._flatten_embedding(embed_p)
@@ -352,7 +354,7 @@ class Model:
def _write_embedding(self, tag, embed, x, y):
frame = x[:, 0, :, :, :].cpu()
n, c, h, w = frame.size()
- padding = torch.zeros(n, c, h, (h-w) // 2)
+ padding = torch.zeros(n, c, h, (h - w) // 2)
padded_frame = torch.cat((padding, frame, padding), dim=-1)
self.writer.add_embedding(
embed,