summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-05 20:08:22 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-05 20:08:22 +0800
commit06e9a53673fb193f8287d9d9b95463a5d1b044bb (patch)
treea9ba1f37ee3688beee21864616d9b672f1e76d9d /models/rgb_part_net.py
parentecb8d8d750cd4a81494feb5dcb582641f73d67ff (diff)
Calculate losses outside modules
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index cdf579b..7785bb7 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -50,7 +50,7 @@ class RGBPartNet(nn.Module):
def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
# n, t, c, h, w
- ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2)
+ ((x_c, x_p), images, f_loss) = self._disentangle(x_c1, x_c2)
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
@@ -67,7 +67,7 @@ class RGBPartNet(nn.Module):
x = self.fc(x)
if self.training:
- return x.transpose(0, 1), ae_losses, images
+ return x.transpose(0, 1), images, f_loss
else:
return x.unsqueeze(1).view(-1)
@@ -76,7 +76,7 @@ class RGBPartNet(nn.Module):
device = x_c1_t2.device
if self.training:
x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
- ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
+ (f_a_, f_c_, f_p_), f_loss = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
# Decode features
x_c = self._decode_cano_feature(f_c_, n, t, device)
x_p_ = self._decode_pose_feature(f_p_, n, t, device)
@@ -93,7 +93,7 @@ class RGBPartNet(nn.Module):
i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_))
i_p = i_p_.view(n, t, c, h, w)
- return (x_c, x_p), losses, (i_a, i_c, i_p)
+ return (x_c, x_p), (i_a, i_c, i_p), f_loss
else: # evaluating
f_c_, f_p_ = self.ae(x_c1_t2)