summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-20 14:48:16 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-20 14:48:16 +0800
commit9b1828be1db7fd1be8731a7cec66162de9145285 (patch)
tree9efb5a37856f34e333457e9d7ab2aaa8ba811cf6 /models/rgb_part_net.py
parente33c22e556ed64e1c1fdb011d78a124d1489ad15 (diff)
parentc538919cb69e35a46811aef0b23baefe6a4c499c (diff)
Merge branch 'python3.8' into data_parallel_py3.8
# Conflicts: # models/model.py
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py20
1 files changed, 3 insertions, 17 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 2af990e..15b69f9 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -6,7 +6,6 @@ import torch.nn as nn
from models.auto_encoder import AutoEncoder
from models.hpm import HorizontalPyramidMatching
from models.part_net import PartNet
-from utils.triplet_loss import BatchAllTripletLoss
class RGBPartNet(nn.Module):
@@ -27,7 +26,6 @@ class RGBPartNet(nn.Module):
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
embedding_dims: int = 256,
- triplet_margins: Tuple[float, float] = (0.2, 0.2),
image_log_on: bool = False
):
super().__init__()
@@ -52,17 +50,13 @@ class RGBPartNet(nn.Module):
out_channels, embedding_dims)
self.fc_mat = nn.Parameter(empty_fc)
- (hpm_margin, pn_margin) = triplet_margins
- self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin)
- self.pn_ba_trip = BatchAllTripletLoss(pn_margin)
-
def fc(self, x):
return x @ self.fc_mat
- def forward(self, x_c1, x_c2=None, y=None):
+ def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
# n, t, c, h, w
- ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2)
+ ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2)
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
@@ -79,15 +73,7 @@ class RGBPartNet(nn.Module):
x = self.fc(x)
if self.training:
- y = y.T
- hpm_ba_trip = self.hpm_ba_trip(
- x[:self.hpm_num_parts], y[:self.hpm_num_parts]
- )
- pn_ba_trip = self.pn_ba_trip(
- x[self.hpm_num_parts:], y[self.hpm_num_parts:]
- )
- losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip))
- return losses, images
+ return x, ae_losses, images
else:
return x.unsqueeze(1).view(-1)