summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/model.py28
1 files changed, 17 insertions, 11 deletions
diff --git a/models/model.py b/models/model.py
index 2eeaf5e..70596ad 100644
--- a/models/model.py
+++ b/models/model.py
@@ -163,7 +163,7 @@ class Model:
)
else: # Different margins
self.triplet_loss = JointBatchTripletLoss(
- self.rgb_pn.hpm_num_parts,
+ self.rgb_pn.module.hpm_num_parts,
triplet_is_hard, triplet_is_mean, triplet_margins
)
else: # Soft margins
@@ -226,13 +226,15 @@ class Model:
embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2)
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
- y = y.repeat(self.rgb_pn.num_total_parts, 1)
- trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
+ y = y.repeat(self.rgb_pn.module.num_total_parts, 1)
+ trip_loss, dist, num_non_zero = self.triplet_loss(
+ embedding.contiguous(), y
+ )
losses = torch.cat((
ae_losses.mean(0),
torch.stack((
- trip_loss[:self.rgb_pn.hpm_num_parts].mean(),
- trip_loss[self.rgb_pn.hpm_num_parts:].mean()
+ trip_loss[:self.rgb_pn.module.hpm_num_parts].mean(),
+ trip_loss[self.rgb_pn.module.hpm_num_parts:].mean()
))
))
loss = losses.sum()
@@ -254,28 +256,32 @@ class Model:
# None-zero losses in batch
if num_non_zero is not None:
self.writer.add_scalars('Loss/non-zero counts', {
- 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(),
- 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean()
+ 'HPM': num_non_zero[
+ :self.rgb_pn.module.hpm_num_parts].mean(),
+ 'PartNet': num_non_zero[
+ self.rgb_pn.module.hpm_num_parts:].mean()
}, self.curr_iter)
# Embedding distance
- mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0)
+ mean_hpm_dist = dist[:self.rgb_pn.module.hpm_num_parts].mean(0)
self._add_ranked_scalars(
'Embedding/HPM distance', mean_hpm_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
- mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0)
+ mean_pa_dist = dist[self.rgb_pn.module.hpm_num_parts:].mean(0)
self._add_ranked_scalars(
'Embedding/ParNet distance', mean_pa_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
# Embedding norm
- mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0)
+ mean_hpm_embedding = embedding[
+ :self.rgb_pn.module.hpm_num_parts].mean(0)
mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/HPM norm', mean_hpm_norm,
self.k, self.pr * self.k, self.curr_iter
)
- mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0)
+ mean_pa_embedding = embedding[
+ self.rgb_pn.module.hpm_num_parts:].mean(0)
mean_pa_norm = mean_pa_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/PartNet norm', mean_pa_norm,