diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/rgb_part_net.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 5012765..a58be39 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from models import AutoEncoder, HorizontalPyramidMatching, PartNet +from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): @@ -24,7 +25,7 @@ class RGBPartNet(nn.Module): tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, embedding_dims: int = 256, - triplet_margin: int = 0.2 + triplet_margin: float = 0.2 ): super().__init__() self.ae = AutoEncoder( @@ -43,8 +44,10 @@ class RGBPartNet(nn.Module): empty_fc = torch.empty(total_parts, out_channels, embedding_dims) self.fc_mat = nn.Parameter(empty_fc) + self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin) + def fc(self, x): - return torch.matmul(x, self.fc_mat) + return x @ self.fc_mat def forward(self, x_c1, x_c2, y=None): # Step 0: Swap batch_size and time dimensions for next step @@ -72,10 +75,10 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - # TODO Implement Batch All triplet loss function - batch_all_triplet_loss = torch.tensor(0.) - loss = torch.sum(torch.stack((*losses, batch_all_triplet_loss))) - return loss + batch_all_triplet_loss = self.ba_triplet_loss(x, y) + losses = (*losses, batch_all_triplet_loss) + loss = torch.sum(torch.stack(losses)) + return loss, (loss.item() for loss in losses) else: return x |