diff options
-rw-r--r-- | models/model.py | 7 | ||||
-rw-r--r-- | models/rgb_part_net.py | 2 |
2 files changed, 4 insertions, 5 deletions
diff --git a/models/model.py b/models/model.py index 70596ad..a8e0316 100644 --- a/models/model.py +++ b/models/model.py @@ -227,11 +227,10 @@ class Model: y = batch_c1['label'].to(self.device) # Duplicate labels for each part y = y.repeat(self.rgb_pn.module.num_total_parts, 1) - trip_loss, dist, num_non_zero = self.triplet_loss( - embedding.contiguous(), y - ) + embedding = embedding.transpose(0, 1) + trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) losses = torch.cat(( - ae_losses.mean(0), + ae_losses.view(-1, 3).mean(0), torch.stack(( trip_loss[:self.rgb_pn.module.hpm_num_parts].mean(), trip_loss[self.rgb_pn.module.hpm_num_parts:].mean() diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 8a0f3a7..cdf579b 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -67,7 +67,7 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - return x, ae_losses, images + return x.transpose(0, 1), ae_losses, images else: return x.unsqueeze(1).view(-1) |