diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-20 14:48:16 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-20 14:48:16 +0800 |
commit | 9b1828be1db7fd1be8731a7cec66162de9145285 (patch) | |
tree | 9efb5a37856f34e333457e9d7ab2aaa8ba811cf6 /models/rgb_part_net.py | |
parent | e33c22e556ed64e1c1fdb011d78a124d1489ad15 (diff) | |
parent | c538919cb69e35a46811aef0b23baefe6a4c499c (diff) |
Merge branch 'python3.8' into data_parallel_py3.8
# Conflicts:
# models/model.py
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 20 |
1 files changed, 3 insertions, 17 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 2af990e..15b69f9 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -6,7 +6,6 @@ import torch.nn as nn from models.auto_encoder import AutoEncoder from models.hpm import HorizontalPyramidMatching from models.part_net import PartNet -from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): @@ -27,7 +26,6 @@ class RGBPartNet(nn.Module): tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, embedding_dims: int = 256, - triplet_margins: Tuple[float, float] = (0.2, 0.2), image_log_on: bool = False ): super().__init__() @@ -52,17 +50,13 @@ class RGBPartNet(nn.Module): out_channels, embedding_dims) self.fc_mat = nn.Parameter(empty_fc) - (hpm_margin, pn_margin) = triplet_margins - self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) - self.pn_ba_trip = BatchAllTripletLoss(pn_margin) - def fc(self, x): return x @ self.fc_mat - def forward(self, x_c1, x_c2=None, y=None): + def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement # n, t, c, h, w - ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2) + ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2) # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w @@ -79,15 +73,7 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - y = y.T - hpm_ba_trip = self.hpm_ba_trip( - x[:self.hpm_num_parts], y[:self.hpm_num_parts] - ) - pn_ba_trip = self.pn_ba_trip( - x[self.hpm_num_parts:], y[self.hpm_num_parts:] - ) - losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) - return losses, images + return x, ae_losses, images else: return x.unsqueeze(1).view(-1) |