summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py31
-rw-r--r--models/hpm.py22
-rw-r--r--models/rgb_part_net.py40
3 files changed, 57 insertions, 36 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index c84061c..234111a 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -95,10 +95,13 @@ class Decoder(nn.Module):
self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
- def forward(self, f_appearance, f_canonical, f_pose):
+ def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.fc(x)
x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True)
+ # Decode canonical features without transpose convolutions
+ if no_trans_conv:
+ return x
x = self.trans_conv1(x)
x = self.trans_conv2(x)
x = self.trans_conv3(x)
@@ -131,16 +134,32 @@ class AutoEncoder(nn.Module):
def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y):
# t1 is random time step
(f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
- (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
(_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)
x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)
xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_)
- y_ = self.classifier(f_c_c1_t2)
+ y_ = self.classifier(f_c_c1_t2.contiguous())
cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2)
+ self.mse_loss(f_c_c1_t2, f_c_c2_t2)
- + self.xent_loss(y, y_))
+ + self.xent_loss(y_, y))
- return ((f_c_c1_t2, f_p_c1_t2, f_p_c2_t2),
- xrecon_loss_t2, cano_cons_loss_t2)
+ f_a_size, f_c_size, f_p_size = (
+ f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size()
+ )
+ # Decode canonical features for HPM
+ x_c_c1_t2 = self.decoder(
+ torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size),
+ no_trans_conv=True
+ )
+ # Decode pose features for Part Net
+ x_p_c1_t2 = self.decoder(
+ torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2
+ )
+
+ return (
+ (x_c_c1_t2, x_p_c1_t2),
+ (f_p_c1_t2, f_p_c2_t2),
+ (xrecon_loss_t2, cano_cons_loss_t2)
+ )
diff --git a/models/hpm.py b/models/hpm.py
index 5553094..66503e3 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
-from torchvision.models import resnet50
from models.layers import HorizontalPyramidPooling
@@ -8,12 +7,11 @@ from models.layers import HorizontalPyramidPooling
class HorizontalPyramidMatching(nn.Module):
def __init__(
self,
- in_channels: int = 3,
+ in_channels: int,
out_channels: int = 128,
- scales: tuple[int, ...] = (1, 2, 4, 8),
+ scales: tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
use_max_pool: bool = True,
- use_backbone: bool = False,
**kwargs
):
super().__init__()
@@ -22,11 +20,6 @@ class HorizontalPyramidMatching(nn.Module):
self.scales = scales
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
- self.use_backbone = use_backbone
-
- if self.use_backbone:
- self.backbone = resnet50(pretrained=True)
- self.in_channels = self.backbone.layer4[-1].conv1.in_channels
self.pyramids = nn.ModuleList([
self._make_pyramid(scale, **kwargs) for scale in self.scales
@@ -44,15 +37,10 @@ class HorizontalPyramidMatching(nn.Module):
return pyramid
def forward(self, x):
- # Flatten frames in all batches
+ # Flatten canonical features in all batches
t, n, c, h, w = x.size()
- x = x.view(-1, c, h, w)
-
- if self.use_backbone:
- # FIXME Inconsistent dimensions
- x = self.backbone(x)
+ x = x.view(t * n, c, h, w)
- t_n, _, h, _ = x.size()
feature = []
for pyramid_index, pyramid in enumerate(self.pyramids):
h_per_hpp = h // self.scales[pyramid_index]
@@ -61,7 +49,7 @@ class HorizontalPyramidMatching(nn.Module):
(hpp_index + 1) * h_per_hpp)
x_slice = x[:, :, h_filter, :]
x_slice = hpp(x_slice)
- x_slice = x_slice.view(t_n, -1)
+ x_slice = x_slice.view(t * n, -1)
feature.append(x_slice)
x = torch.stack(feature)
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 377c108..0ff8251 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -13,7 +13,7 @@ class RGBPartNet(nn.Module):
ae_in_channels: int = 3,
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
- hpm_scales: tuple[int, ...] = (1, 2, 4, 8),
+ hpm_scales: tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
fpfe_feature_channels: int = 32,
@@ -32,7 +32,7 @@ class RGBPartNet(nn.Module):
fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_part
)
self.hpm = HorizontalPyramidMatching(
- ae_in_channels, self.pn.tfa_in_channels, hpm_scales,
+ ae_feature_channels * 8, self.pn.tfa_in_channels, hpm_scales,
hpm_use_avg_pool, hpm_use_max_pool
)
@@ -54,38 +54,52 @@ class RGBPartNet(nn.Module):
# Step 1: Disentanglement
# t, n, c, h, w
num_frames = len(x_c1)
- f_c_c1, f_p_c1, f_p_c2 = [], [], []
+ # Decoded canonical features and Pose images
+ x_c_c1, x_p_c1 = [], []
+ # Features required to calculate losses
+ f_p_c1, f_p_c2 = [], []
xrecon_loss, cano_cons_loss = torch.zeros(1), torch.zeros(1)
for t2 in range(num_frames):
t1 = random.randrange(num_frames)
output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y)
- (feature_t2, xrecon_loss_t2, cano_cons_loss_t2) = output
- (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2) = feature_t2
- # Features for next step
- f_c_c1.append(f_c_c1_t2)
- f_p_c1.append(f_p_c1_t2)
+ (x_c1_t2, f_p_t2, losses) = output
+
+ # Decoded features or image
+ (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
+ # Canonical Features for HPM
+ x_c_c1.append(x_c_c1_t2)
+ # Pose image for Part Net
+ x_p_c1.append(x_p_c1_t2)
+
# Losses per time step
+ # Used in pose similarity loss
+ (f_p_c1_t2, f_p_c2_t2) = f_p_t2
+ f_p_c1.append(f_p_c1_t2)
f_p_c2.append(f_p_c2_t2)
+ # Cross reconstruction loss and canonical loss
+ (xrecon_loss_t2, cano_cons_loss_t2) = losses
xrecon_loss += xrecon_loss_t2
cano_cons_loss += cano_cons_loss_t2
- f_c_c1 = torch.stack(f_c_c1)
- f_p_c1 = torch.stack(f_p_c1)
+
+ x_c_c1 = torch.stack(x_c_c1)
+ x_p_c1 = torch.stack(x_p_c1)
# Step 2.a: HPM & Static Gait Feature Aggregation
# t, n, c, h, w
- x_c = self.hpm(f_c_c1)
+ x_c = self.hpm(x_c_c1)
# p, t, n, c
x_c = x_c.mean(dim=1)
# p, n, c
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
# t, n, c, h, w
- x_p = self.pn(f_p_c1)
+ x_p = self.pn(x_p_c1)
# p, n, c
# Step 3: Cat feature map together and calculate losses
- x = torch.cat(x_c, x_p)
+ x = torch.cat([x_c, x_p])
# Losses
+ f_p_c1 = torch.stack(f_p_c1)
f_p_c2 = torch.stack(f_p_c2)
pose_sim_loss = self.pose_sim_loss(f_p_c1, f_p_c2)
cano_cons_loss /= num_frames