summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-20 14:19:30 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-20 14:19:30 +0800
commit50eedae4f320c446544772fb2b0abcbce1be7590 (patch)
tree93a5b252a6d11b0c677d0700522d928d59916a9e /models/rgb_part_net.py
parent4aa9044122878a8e2b887a8b170c036983431559 (diff)
Separate triplet loss from model
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py20
1 files changed, 3 insertions, 17 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 67acac3..408bca0 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -4,7 +4,6 @@ import torch.nn as nn
from models.auto_encoder import AutoEncoder
from models.hpm import HorizontalPyramidMatching
from models.part_net import PartNet
-from utils.triplet_loss import BatchAllTripletLoss
class RGBPartNet(nn.Module):
@@ -25,7 +24,6 @@ class RGBPartNet(nn.Module):
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
embedding_dims: int = 256,
- triplet_margins: tuple[float, float] = (0.2, 0.2),
image_log_on: bool = False
):
super().__init__()
@@ -50,17 +48,13 @@ class RGBPartNet(nn.Module):
out_channels, embedding_dims)
self.fc_mat = nn.Parameter(empty_fc)
- (hpm_margin, pn_margin) = triplet_margins
- self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin)
- self.pn_ba_trip = BatchAllTripletLoss(pn_margin)
-
def fc(self, x):
return x @ self.fc_mat
- def forward(self, x_c1, x_c2=None, y=None):
+ def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
# n, t, c, h, w
- ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2)
+ ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2)
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
@@ -77,15 +71,7 @@ class RGBPartNet(nn.Module):
x = self.fc(x)
if self.training:
- y = y.T
- hpm_ba_trip = self.hpm_ba_trip(
- x[:self.hpm_num_parts], y[:self.hpm_num_parts]
- )
- pn_ba_trip = self.pn_ba_trip(
- x[self.hpm_num_parts:], y[self.hpm_num_parts:]
- )
- losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip))
- return losses, images
+ return x, ae_losses, images
else:
return x.unsqueeze(1).view(-1)