diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-20 14:26:31 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-20 14:29:31 +0800 | 
| commit | 850c525772969823eef6083e8018ac43a1e87e4c (patch) | |
| tree | 16f1d48561dead30d31ab2b72918fb1e9fef5665 /models/rgb_part_net.py | |
| parent | a31eb135c8cc3a8737fabca54fe3d5791f293753 (diff) | |
| parent | 50eedae4f320c446544772fb2b0abcbce1be7590 (diff) | |
Merge branch 'master' into data_parallel
# Conflicts:
#	models/model.py
Diffstat (limited to 'models/rgb_part_net.py')
| -rw-r--r-- | models/rgb_part_net.py | 20 | 
1 files changed, 3 insertions, 17 deletions
| diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 67acac3..408bca0 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -4,7 +4,6 @@ import torch.nn as nn  from models.auto_encoder import AutoEncoder  from models.hpm import HorizontalPyramidMatching  from models.part_net import PartNet -from utils.triplet_loss import BatchAllTripletLoss  class RGBPartNet(nn.Module): @@ -25,7 +24,6 @@ class RGBPartNet(nn.Module):              tfa_squeeze_ratio: int = 4,              tfa_num_parts: int = 16,              embedding_dims: int = 256, -            triplet_margins: tuple[float, float] = (0.2, 0.2),              image_log_on: bool = False      ):          super().__init__() @@ -50,17 +48,13 @@ class RGBPartNet(nn.Module):                                 out_channels, embedding_dims)          self.fc_mat = nn.Parameter(empty_fc) -        (hpm_margin, pn_margin) = triplet_margins -        self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) -        self.pn_ba_trip = BatchAllTripletLoss(pn_margin) -      def fc(self, x):          return x @ self.fc_mat -    def forward(self, x_c1, x_c2=None, y=None): +    def forward(self, x_c1, x_c2=None):          # Step 1: Disentanglement          # n, t, c, h, w -        ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2) +        ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2)          # Step 2.a: Static Gait Feature Aggregation & HPM          # n, c, h, w @@ -77,15 +71,7 @@ class RGBPartNet(nn.Module):          x = self.fc(x)          if self.training: -            y = y.T -            hpm_ba_trip = self.hpm_ba_trip( -                x[:self.hpm_num_parts], y[:self.hpm_num_parts] -            ) -            pn_ba_trip = self.pn_ba_trip( -                x[self.hpm_num_parts:], y[self.hpm_num_parts:] -            ) -            losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) -            return losses, images +            return x, ae_losses, images          else:              return x.unsqueeze(1).view(-1) | 
