summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/rgb_part_net.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 5012765..a58be39 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -5,6 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from models import AutoEncoder, HorizontalPyramidMatching, PartNet
+from utils.triplet_loss import BatchAllTripletLoss
class RGBPartNet(nn.Module):
@@ -24,7 +25,7 @@ class RGBPartNet(nn.Module):
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
embedding_dims: int = 256,
- triplet_margin: int = 0.2
+ triplet_margin: float = 0.2
):
super().__init__()
self.ae = AutoEncoder(
@@ -43,8 +44,10 @@ class RGBPartNet(nn.Module):
empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
self.fc_mat = nn.Parameter(empty_fc)
+ self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin)
+
def fc(self, x):
- return torch.matmul(x, self.fc_mat)
+ return x @ self.fc_mat
def forward(self, x_c1, x_c2, y=None):
# Step 0: Swap batch_size and time dimensions for next step
@@ -72,10 +75,10 @@ class RGBPartNet(nn.Module):
x = self.fc(x)
if self.training:
- # TODO Implement Batch All triplet loss function
- batch_all_triplet_loss = torch.tensor(0.)
- loss = torch.sum(torch.stack((*losses, batch_all_triplet_loss)))
- return loss
+ batch_all_triplet_loss = self.ba_triplet_loss(x, y)
+ losses = (*losses, batch_all_triplet_loss)
+ loss = torch.sum(torch.stack(losses))
+ return loss, (loss.item() for loss in losses)
else:
return x