summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py168
1 files changed, 53 insertions, 115 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 0e7d8b3..8ebcfd3 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -1,8 +1,5 @@
-import random
-
import torch
import torch.nn as nn
-import torch.nn.functional as F
from models.auto_encoder import AutoEncoder
from models.hpm import HorizontalPyramidMatching
@@ -59,24 +56,18 @@ class RGBPartNet(nn.Module):
return x @ self.fc_mat
def forward(self, x_c1, x_c2=None, y=None):
- # Step 0: Swap batch_size and time dimensions for next step
- # n, t, c, h, w
- x_c1 = x_c1.transpose(0, 1)
- if self.training:
- x_c2 = x_c2.transpose(0, 1)
-
# Step 1: Disentanglement
- # t, n, c, h, w
- ((x_c_c1, x_p_c1), images, losses) = self._disentangle(x_c1, x_c2)
+ # n, t, c, h, w
+ ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2)
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
- x_c = self.hpm(x_c_c1)
+ x_c = self.hpm(x_c)
# p, n, c
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
- # t, n, c, h, w
- x_p = self.pn(x_p_c1)
+ # n, t, c, h, w
+ x_p = self.pn(x_p)
# p, n, c
# Step 3: Cat feature map together and fc
@@ -91,113 +82,60 @@ class RGBPartNet(nn.Module):
else:
return x.unsqueeze(1).view(-1)
- def _disentangle(self, x_c1, x_c2=None):
- t, n, c, h, w = x_c1.size()
- device = x_c1.device
+ def _disentangle(self, x_c1_t2, x_c2_t2=None):
+ n, t, c, h, w = x_c1_t2.size()
+ device = x_c1_t2.device
+ x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
if self.training:
- # Encoded appearance, canonical and pose features
- f_a_c1, f_c_c1, f_p_c1 = [], [], []
- # Features required to calculate losses
- f_p_c2 = []
- xrecon_loss, cano_cons_loss = [], []
- for t2 in range(t):
- t1 = random.randrange(t)
- output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2])
- (f_c1_t2, f_p_t2, losses) = output
-
- (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = f_c1_t2
- if self.image_log_on:
- f_a_c1.append(f_a_c1_t2)
- # Save canonical features and pose features
- f_c_c1.append(f_c_c1_t2)
- f_p_c1.append(f_p_c1_t2)
-
- # Losses per time step
- # Used in pose similarity loss
- (_, f_p_c2_t2) = f_p_t2
- f_p_c2.append(f_p_c2_t2)
-
- # Cross reconstruction loss and canonical loss
- (xrecon_loss_t2, cano_cons_loss_t2) = losses
- xrecon_loss.append(xrecon_loss_t2)
- cano_cons_loss.append(cano_cons_loss_t2)
- if self.image_log_on:
- f_a_c1 = torch.stack(f_a_c1)
- f_c_c1_mean = torch.stack(f_c_c1).mean(0)
- f_p_c1 = torch.stack(f_p_c1)
- f_p_c2 = torch.stack(f_p_c2)
-
+ ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
# Decode features
- appearance_image, canonical_image, pose_image = None, None, None
with torch.no_grad():
- # Decode average canonical features to higher dimension
- x_c_c1 = self.ae.decoder(
- torch.zeros((n, self.f_a_dim), device=device),
- f_c_c1_mean,
- torch.zeros((n, self.f_p_dim), device=device),
- cano_only=True
- )
- # Decode pose features to images
- f_p_c1_ = f_p_c1.view(t * n, -1)
- x_p_c1_ = self.ae.decoder(
- torch.zeros((t * n, self.f_a_dim), device=device),
- torch.zeros((t * n, self.f_c_dim), device=device),
- f_p_c1_
- )
- x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+ x_c = self._decode_cano_feature(f_c_, n, t, device)
+ x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ i_a, i_c, i_p = None, None, None
if self.image_log_on:
- # Decode appearance features
- f_a_c1_ = f_a_c1.view(t * n, -1)
- appearance_image_ = self.ae.decoder(
- f_a_c1_,
- torch.zeros((t * n, self.f_c_dim), device=device),
- torch.zeros((t * n, self.f_p_dim), device=device)
- )
- appearance_image = appearance_image_.view(t, n, c, h, w)
+ i_a = self._decode_appr_feature(f_a_, n, t, c, h, w, device)
# Continue decoding canonical features
- canonical_image = self.ae.decoder.trans_conv3(x_c_c1)
- canonical_image = torch.sigmoid(
- self.ae.decoder.trans_conv4(canonical_image)
- )
- pose_image = x_p_c1
-
- # Losses
- xrecon_loss = torch.sum(torch.stack(xrecon_loss))
- pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) * 10
- cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
+ i_c = self.ae.decoder.trans_conv3(x_c)
+ i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
+ i_p = x_p
- return ((x_c_c1, x_p_c1),
- (appearance_image, canonical_image, pose_image),
- (xrecon_loss, pose_sim_loss, cano_cons_loss))
+ return (x_c, x_p), losses, (i_a, i_c, i_p)
else: # evaluating
- x_c1_ = x_c1.view(t * n, c, h, w)
- (f_c_c1_, f_p_c1_) = self.ae(x_c1_)
-
- # Canonical features
- f_c_c1 = f_c_c1_.view(t, n, -1)
- f_c_c1_mean = f_c_c1.mean(0)
- x_c_c1 = self.ae.decoder(
- torch.zeros((n, self.f_a_dim)),
- f_c_c1_mean,
- torch.zeros((n, self.f_p_dim)),
- cano_only=True
- )
-
- # Pose features
- x_p_c1_ = self.ae.decoder(
- torch.zeros((t * n, self.f_a_dim)),
- torch.zeros((t * n, self.f_c_dim)),
- f_p_c1_
- )
- x_p_c1 = x_p_c1_.view(t, n, c, h, w)
-
- return (x_c_c1, x_p_c1), None, None
-
- @staticmethod
- def _pose_sim_loss(f_p_c1: torch.Tensor,
- f_p_c2: torch.Tensor) -> torch.Tensor:
- f_p_c1_mean = f_p_c1.mean(dim=0)
- f_p_c2_mean = f_p_c2.mean(dim=0)
- return F.mse_loss(f_p_c1_mean, f_p_c2_mean)
+ f_c_, f_p_ = self.ae(x_c1_t2)
+ x_c = self._decode_cano_feature(f_c_, n, t, device)
+ x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ return (x_c, x_p), None, None
+
+ def _decode_appr_feature(self, f_a_, n, t, c, h, w, device):
+ # Decode appearance features
+ x_a_ = self.ae.decoder(
+ f_a_,
+ torch.zeros((n * t, self.f_c_dim), device=device),
+ torch.zeros((n * t, self.f_p_dim), device=device)
+ )
+ x_a = x_a_.view(n, t, c, h, w)
+ return x_a
+
+ def _decode_cano_feature(self, f_c_, n, t, device):
+ # Decode average canonical features to higher dimension
+ f_c = f_c_.view(n, t, -1)
+ x_c = self.ae.decoder(
+ torch.zeros((n, self.f_a_dim), device=device),
+ f_c.mean(1),
+ torch.zeros((n, self.f_p_dim), device=device),
+ cano_only=True
+ )
+ return x_c
+
+ def _decode_pose_feature(self, f_p_, n, t, c, h, w, device):
+ # Decode pose features to images
+ x_p_ = self.ae.decoder(
+ torch.zeros((n * t, self.f_a_dim), device=device),
+ torch.zeros((n * t, self.f_c_dim), device=device),
+ f_p_
+ )
+ x_p = x_p_.view(n, t, c, h, w)
+ return x_p