summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-02 20:10:14 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-02 20:10:14 +0800
commit09c2af6f4881bb4a2ed5e3685b8ce65ad41695cb (patch)
tree3e7e9fba6d1fc04f3d976a2aec9a9cc71e712845 /models/model.py
parentab09d9bd0fe92d97f340feef1e2bbbcd33468953 (diff)
Fix DataParallel specific bugs
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/models/model.py b/models/model.py
index 70596ad..a8e0316 100644
--- a/models/model.py
+++ b/models/model.py
@@ -227,11 +227,10 @@ class Model:
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
y = y.repeat(self.rgb_pn.module.num_total_parts, 1)
- trip_loss, dist, num_non_zero = self.triplet_loss(
- embedding.contiguous(), y
- )
+ embedding = embedding.transpose(0, 1)
+ trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
losses = torch.cat((
- ae_losses.mean(0),
+ ae_losses.view(-1, 3).mean(0),
torch.stack((
trip_loss[:self.rgb_pn.module.hpm_num_parts].mean(),
trip_loss[self.rgb_pn.module.hpm_num_parts:].mean()