diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-23 22:19:51 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-23 22:19:51 +0800 |
commit | 507e1d163aaa6ea4be23e7f08ff6ce0ef58c830b (patch) | |
tree | 1e3c722bb63e3873464296121ec290bd3e64ad14 /models/rgb_part_net.py | |
parent | 59ccf61fed4d95b7fe91bb9552f0deb2f2c75b76 (diff) |
Remove the third term in canonical consistency loss
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 2cc0958..755d5dc 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,7 +13,6 @@ from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): def __init__( self, - num_class: int = 74, ae_in_channels: int = 3, ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), @@ -31,7 +30,7 @@ class RGBPartNet(nn.Module): ): super().__init__() self.ae = AutoEncoder( - num_class, ae_in_channels, ae_feature_channels, f_a_c_p_dims + ae_in_channels, ae_feature_channels, f_a_c_p_dims ) self.pn = PartNet( ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, @@ -60,7 +59,7 @@ class RGBPartNet(nn.Module): # Step 1: Disentanglement # t, n, c, h, w - ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2, y) + ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2) # Step 2.a: HPM & Static Gait Feature Aggregation # t, n, c, h, w @@ -85,7 +84,7 @@ class RGBPartNet(nn.Module): else: return x.unsqueeze(1).view(-1) - def _disentangle(self, x_c1, x_c2=None, y=None): + def _disentangle(self, x_c1, x_c2=None): t, n, c, h, w = x_c1.size() if self.training: # Decoded canonical features and Pose images @@ -95,7 +94,7 @@ class RGBPartNet(nn.Module): xrecon_loss, cano_cons_loss = [], [] for t2 in range(t): t1 = random.randrange(t) - output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y) + output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2]) (x_c1_t2, f_p_t2, losses) = output # Decoded features or image |