summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py37
1 files changed, 20 insertions, 17 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index b0169e3..06cbf28 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -39,26 +39,26 @@ class RGBPartNet(nn.Module):
self.num_parts = self.hpm.num_parts + tfa_num_parts
def forward(self, x_c1, x_c2=None):
- # Step 1: Disentanglement
- # n, t, c, h, w
- (f_a, f_c, f_p), ae_losses = self._disentangle(x_c1, x_c2)
+ if self.training:
+ # Step 1: Disentanglement
+ # n, t, c, h, w
+ (f_a, f_c, f_p), f_loss = self._disentangle(x_c1, x_c2)
- # Step 2.a: Static Gait Feature Aggregation & HPM
- # n, c, h, w
- f_c_mean = f_c.mean(1)
- x_c = self.hpm(f_c_mean)
- # p, n, d
+ # Step 2.a: Static Gait Feature Aggregation & HPM
+ # n, t, c, h, w
+ x_c, f_c_loss = self.hpm(f_c, *f_loss[1])
+ # p, n, d / p, n, t, c
- # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
- # n, t, c, h, w
- x_p = self.pn(f_p)
- # p, n, d
+ # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
+ # n, t, c, h, w
+ x_p, f_p_loss = self.pn(f_p, f_loss[2])
+ # p, n, d / p, n, t, c
- if self.training:
i_a, i_c, i_p = None, None, None
if self.image_log_on:
with torch.no_grad():
f_a_mean = f_a.mean(1)
+ f_c_mean = f_c.mean(1)
i_a = self.ae.decoder(
f_a_mean,
torch.zeros_like(f_c_mean),
@@ -77,15 +77,18 @@ class RGBPartNet(nn.Module):
device=f_c.device),
f_p.view(-1, *f_p_size[2:])
).view(x_c1.size())
- return x_c, x_p, ae_losses, (i_a, i_c, i_p)
- else:
+ return (x_c, x_p), (f_loss[0], f_c_loss, f_p_loss), (i_a, i_c, i_p)
+ else: # Evaluating
+ f_c, f_p = self._disentangle(x_c1, x_c2)
+ x_c = self.hpm(f_c)
+ x_p = self.pn(f_p)
return x_c, x_p
def _disentangle(self, x_c1_t2, x_c2_t2=None):
if self.training:
x_c1_t1 = x_c1_t2[:, torch.randperm(x_c1_t2.size(1)), :, :, :]
- features, losses = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
- return features, losses
+ features, f_loss = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
+ return features, f_loss
else: # evaluating
features = self.ae(x_c1_t2)
return features, None