diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-20 14:42:45 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-20 14:43:06 +0800 |
commit | c538919cb69e35a46811aef0b23baefe6a4c499c (patch) | |
tree | bee9a9582dfbb60053a6dd53f1a958abaa9dd8d5 /models/rgb_part_net.py | |
parent | 969030864495e7c2b419400fd81ee0fad83de41e (diff) | |
parent | 820d3dec284f38e6a3089dad5277bc3f6c5123bf (diff) |
Merge branch 'master' into python3.8
# Conflicts:
# models/model.py
# models/rgb_part_net.py
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 20 |
1 files changed, 3 insertions, 17 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 2af990e..15b69f9 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -6,7 +6,6 @@ import torch.nn as nn from models.auto_encoder import AutoEncoder from models.hpm import HorizontalPyramidMatching from models.part_net import PartNet -from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): @@ -27,7 +26,6 @@ class RGBPartNet(nn.Module): tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, embedding_dims: int = 256, - triplet_margins: Tuple[float, float] = (0.2, 0.2), image_log_on: bool = False ): super().__init__() @@ -52,17 +50,13 @@ class RGBPartNet(nn.Module): out_channels, embedding_dims) self.fc_mat = nn.Parameter(empty_fc) - (hpm_margin, pn_margin) = triplet_margins - self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) - self.pn_ba_trip = BatchAllTripletLoss(pn_margin) - def fc(self, x): return x @ self.fc_mat - def forward(self, x_c1, x_c2=None, y=None): + def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement # n, t, c, h, w - ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2) + ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2) # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w @@ -79,15 +73,7 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - y = y.T - hpm_ba_trip = self.hpm_ba_trip( - x[:self.hpm_num_parts], y[:self.hpm_num_parts] - ) - pn_ba_trip = self.pn_ba_trip( - x[self.hpm_num_parts:], y[self.hpm_num_parts:] - ) - losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) - return losses, images + return x, ae_losses, images else: return x.unsqueeze(1).view(-1) |