diff options
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/models/model.py b/models/model.py index 46d7c4c..42064fe 100644 --- a/models/model.py +++ b/models/model.py @@ -163,7 +163,7 @@ class Model: ) else: # Different margins self.triplet_loss = JointBatchTripletLoss( - self.rgb_pn.hpm_num_parts, + self.rgb_pn.module.hpm_num_parts, triplet_is_hard, triplet_is_mean, triplet_margins ) else: # Soft margins @@ -226,13 +226,15 @@ class Model: embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part - y = y.repeat(self.rgb_pn.num_total_parts, 1) - trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) + y = y.repeat(self.rgb_pn.module.num_total_parts, 1) + trip_loss, dist, num_non_zero = self.triplet_loss( + embedding.contiguous(), y + ) losses = torch.cat(( ae_losses.mean(0), torch.stack(( - trip_loss[:self.rgb_pn.hpm_num_parts].mean(), - trip_loss[self.rgb_pn.hpm_num_parts:].mean() + trip_loss[:self.rgb_pn.module.hpm_num_parts].mean(), + trip_loss[self.rgb_pn.module.hpm_num_parts:].mean() )) )) loss = losses.sum() @@ -254,28 +256,32 @@ class Model: # None-zero losses in batch if num_non_zero is not None: self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), - 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() + 'HPM': num_non_zero[ + :self.rgb_pn.module.hpm_num_parts].mean(), + 'PartNet': num_non_zero[ + self.rgb_pn.module.hpm_num_parts:].mean() }, self.curr_iter) # Embedding distance - mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_dist = dist[:self.rgb_pn.module.hpm_num_parts].mean(0) self._add_ranked_scalars( 'Embedding/HPM distance', mean_hpm_dist, num_pos_pairs, num_pairs, self.curr_iter ) - mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_dist = dist[self.rgb_pn.module.hpm_num_parts:].mean(0) self._add_ranked_scalars( 'Embedding/ParNet distance', mean_pa_dist, num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm - mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_embedding = embedding[ + :self.rgb_pn.module.hpm_num_parts].mean(0) mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/HPM norm', mean_hpm_norm, self.k, self.pr * self.k, self.curr_iter ) - mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_embedding = embedding[ + self.rgb_pn.module.hpm_num_parts:].mean(0) mean_pa_norm = mean_pa_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/PartNet norm', mean_pa_norm, |