summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-23 22:19:51 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-23 22:19:51 +0800
commit507e1d163aaa6ea4be23e7f08ff6ce0ef58c830b (patch)
tree1e3c722bb63e3873464296121ec290bd3e64ad14 /models/rgb_part_net.py
parent59ccf61fed4d95b7fe91bb9552f0deb2f2c75b76 (diff)
Remove the third term in canonical consistency loss
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 2cc0958..755d5dc 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -13,7 +13,6 @@ from utils.triplet_loss import BatchAllTripletLoss
class RGBPartNet(nn.Module):
def __init__(
self,
- num_class: int = 74,
ae_in_channels: int = 3,
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
@@ -31,7 +30,7 @@ class RGBPartNet(nn.Module):
):
super().__init__()
self.ae = AutoEncoder(
- num_class, ae_in_channels, ae_feature_channels, f_a_c_p_dims
+ ae_in_channels, ae_feature_channels, f_a_c_p_dims
)
self.pn = PartNet(
ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes,
@@ -60,7 +59,7 @@ class RGBPartNet(nn.Module):
# Step 1: Disentanglement
# t, n, c, h, w
- ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2, y)
+ ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2)
# Step 2.a: HPM & Static Gait Feature Aggregation
# t, n, c, h, w
@@ -85,7 +84,7 @@ class RGBPartNet(nn.Module):
else:
return x.unsqueeze(1).view(-1)
- def _disentangle(self, x_c1, x_c2=None, y=None):
+ def _disentangle(self, x_c1, x_c2=None):
t, n, c, h, w = x_c1.size()
if self.training:
# Decoded canonical features and Pose images
@@ -95,7 +94,7 @@ class RGBPartNet(nn.Module):
xrecon_loss, cano_cons_loss = [], []
for t2 in range(t):
t1 = random.randrange(t)
- output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y)
+ output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2])
(x_c1_t2, f_p_t2, losses) = output
# Decoded features or image