summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py121
1 files changed, 43 insertions, 78 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 4a82da3..d3f8ade 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -1,9 +1,8 @@
import torch
import torch.nn as nn
+import torch.nn.functional as F
from models.auto_encoder import AutoEncoder
-from models.hpm import HorizontalPyramidMatching
-from models.part_net import PartNet
class RGBPartNet(nn.Module):
@@ -13,12 +12,6 @@ class RGBPartNet(nn.Module):
ae_in_size: tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
- hpm_scales: tuple[int, ...] = (1, 2, 4),
- hpm_use_avg_pool: bool = True,
- hpm_use_max_pool: bool = True,
- tfa_squeeze_ratio: int = 4,
- tfa_num_parts: int = 16,
- embedding_dims: tuple[int] = (256, 256),
image_log_on: bool = False
):
super().__init__()
@@ -29,94 +22,66 @@ class RGBPartNet(nn.Module):
self.ae = AutoEncoder(
ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
- self.pn_in_channels = ae_feature_channels * 2
- self.hpm = HorizontalPyramidMatching(
- self.pn_in_channels, embedding_dims[0], hpm_scales,
- hpm_use_avg_pool, hpm_use_max_pool
- )
- self.pn = PartNet(self.pn_in_channels, embedding_dims[1],
- tfa_num_parts, tfa_squeeze_ratio)
-
- self.num_parts = self.hpm.num_parts + tfa_num_parts
def forward(self, x_c1, x_c2=None):
- # Step 1: Disentanglement
- # n, t, c, h, w
- ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2)
-
- # Step 2.a: Static Gait Feature Aggregation & HPM
- # n, c, h, w
- x_c = self.hpm(x_c)
- # p, n, d
-
- # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
- # n, t, c, h, w
- x_p = self.pn(x_p)
- # p, n, d
+ losses, features, images = self._disentangle(x_c1, x_c2)
if self.training:
- return x_c, x_p, ae_losses, images
+ losses = torch.stack(losses)
+ return losses, features, images
else:
- return x_c, x_p
+ return features
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()
- device = x_c1_t2.device
if self.training:
x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
- # Decode features
- x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p_ = self._decode_pose_feature(f_p_, n, t, device)
- x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4)
+ f_a = f_a_.view(n, t, -1)
+ f_c = f_c_.view(n, t, -1)
+ f_p = f_p_.view(n, t, -1)
i_a, i_c, i_p = None, None, None
if self.image_log_on:
with torch.no_grad():
- i_a = self._decode_appr_feature(f_a_, n, t, device)
- # Continue decoding canonical features
- i_c = self.ae.decoder.trans_conv3(x_c)
- i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
- i_p_ = self.ae.decoder.trans_conv3(x_p_)
- i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_))
+ x_a, i_a = self._separate_decode(
+ f_a.mean(1),
+ torch.zeros_like(f_c[:, 0, :]),
+ torch.zeros_like(f_p[:, 0, :])
+ )
+ x_c, i_c = self._separate_decode(
+ torch.zeros_like(f_a[:, 0, :]),
+ f_c.mean(1),
+ torch.zeros_like(f_p[:, 0, :]),
+ )
+ x_p_, i_p_ = self._separate_decode(
+ torch.zeros_like(f_a_),
+ torch.zeros_like(f_c_),
+ f_p_
+ )
+ x_p = tuple(_x_p.view(n, t, *_x_p.size()[1:]) for _x_p in x_p_)
i_p = i_p_.view(n, t, c, h, w)
- return (x_c, x_p), losses, (i_a, i_c, i_p)
+ return losses, (x_a, x_c, x_p), (i_a, i_c, i_p)
else: # evaluating
f_c_, f_p_ = self.ae(x_c1_t2)
- x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p_ = self._decode_pose_feature(f_p_, n, t, device)
- x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4)
- return (x_c, x_p), None, None
-
- def _decode_appr_feature(self, f_a_, n, t, device):
- # Decode appearance features
- f_a = f_a_.view(n, t, -1)
- x_a = self.ae.decoder(
- f_a.mean(1),
- torch.zeros((n, self.f_c_dim), device=device),
- torch.zeros((n, self.f_p_dim), device=device)
- )
- return x_a
-
- def _decode_cano_feature(self, f_c_, n, t, device):
- # Decode average canonical features to higher dimension
- f_c = f_c_.view(n, t, -1)
- x_c = self.ae.decoder(
- torch.zeros((n, self.f_a_dim), device=device),
- f_c.mean(1),
- torch.zeros((n, self.f_p_dim), device=device),
- is_feature_map=True
- )
- return x_c
-
- def _decode_pose_feature(self, f_p_, n, t, device):
- # Decode pose features to images
- x_p_ = self.ae.decoder(
- torch.zeros((n * t, self.f_a_dim), device=device),
- torch.zeros((n * t, self.f_c_dim), device=device),
- f_p_,
- is_feature_map=True
+ f_c = f_c_.view(n, t, -1)
+ f_p = f_p_.view(n, t, -1)
+ return (f_c, f_p), None, None
+
+ def _separate_decode(self, f_a, f_c, f_p):
+ x_1 = torch.cat((f_a, f_c, f_p), dim=1)
+ x_1 = self.ae.decoder.fc(x_1).view(
+ -1,
+ self.ae.decoder.feature_channels * 8,
+ self.ae.decoder.h_0,
+ self.ae.decoder.w_0
)
- return x_p_
+ x_1 = F.relu(x_1, inplace=True)
+ x_2 = self.ae.decoder.trans_conv1(x_1)
+ x_3 = self.ae.decoder.trans_conv2(x_2)
+ x_4 = self.ae.decoder.trans_conv3(x_3)
+ image = torch.sigmoid(self.ae.decoder.trans_conv4(x_4))
+ x = (x_1, x_2, x_3, x_4)
+ return x, image