summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-02 19:10:08 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-02 19:10:08 +0800
commit02aaefaba26b6842d2feb403edfd71aaa75904da (patch)
tree18744854e8d80e0239c0b2f3e7eaf39bc0a7974e /models
parentde8561d1d053730c5af03e1d06850efb60865d3c (diff)
Correct feature dims after disentanglement and HPM backbone removal
1. Features used in HPM is decoded canonical embedding without transpose convolution 2. Decode pose embedding to image for Part Net 3. Backbone seems to be redundant, we can use feature map given by auto-decoder
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py31
-rw-r--r--models/hpm.py22
-rw-r--r--models/rgb_part_net.py40
3 files changed, 57 insertions, 36 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index c84061c..234111a 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -95,10 +95,13 @@ class Decoder(nn.Module):
self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
- def forward(self, f_appearance, f_canonical, f_pose):
+ def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.fc(x)
x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True)
+ # Decode canonical features without transpose convolutions
+ if no_trans_conv:
+ return x
x = self.trans_conv1(x)
x = self.trans_conv2(x)
x = self.trans_conv3(x)
@@ -131,16 +134,32 @@ class AutoEncoder(nn.Module):
def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y):
# t1 is random time step
(f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
- (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
(_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)
x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)
xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_)
- y_ = self.classifier(f_c_c1_t2)
+ y_ = self.classifier(f_c_c1_t2.contiguous())
cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2)
+ self.mse_loss(f_c_c1_t2, f_c_c2_t2)
- + self.xent_loss(y, y_))
+ + self.xent_loss(y_, y))
- return ((f_c_c1_t2, f_p_c1_t2, f_p_c2_t2),
- xrecon_loss_t2, cano_cons_loss_t2)
+ f_a_size, f_c_size, f_p_size = (
+ f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size()
+ )
+ # Decode canonical features for HPM
+ x_c_c1_t2 = self.decoder(
+ torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size),
+ no_trans_conv=True
+ )
+ # Decode pose features for Part Net
+ x_p_c1_t2 = self.decoder(
+ torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2
+ )
+
+ return (
+ (x_c_c1_t2, x_p_c1_t2),
+ (f_p_c1_t2, f_p_c2_t2),
+ (xrecon_loss_t2, cano_cons_loss_t2)
+ )
diff --git a/models/hpm.py b/models/hpm.py
index 5553094..66503e3 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
-from torchvision.models import resnet50
from models.layers import HorizontalPyramidPooling
@@ -8,12 +7,11 @@ from models.layers import HorizontalPyramidPooling
class HorizontalPyramidMatching(nn.Module):
def __init__(
self,
- in_channels: int = 3,
+ in_channels: int,
out_channels: int = 128,
- scales: tuple[int, ...] = (1, 2, 4, 8),
+ scales: tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
use_max_pool: bool = True,
- use_backbone: bool = False,
**kwargs
):
super().__init__()
@@ -22,11 +20,6 @@ class HorizontalPyramidMatching(nn.Module):
self.scales = scales
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
- self.use_backbone = use_backbone
-
- if self.use_backbone:
- self.backbone = resnet50(pretrained=True)
- self.in_channels = self.backbone.layer4[-1].conv1.in_channels
self.pyramids = nn.ModuleList([
self._make_pyramid(scale, **kwargs) for scale in self.scales
@@ -44,15 +37,10 @@ class HorizontalPyramidMatching(nn.Module):
return pyramid
def forward(self, x):
- # Flatten frames in all batches
+ # Flatten canonical features in all batches
t, n, c, h, w = x.size()
- x = x.view(-1, c, h, w)
-
- if self.use_backbone:
- # FIXME Inconsistent dimensions
- x = self.backbone(x)
+ x = x.view(t * n, c, h, w)
- t_n, _, h, _ = x.size()
feature = []
for pyramid_index, pyramid in enumerate(self.pyramids):
h_per_hpp = h // self.scales[pyramid_index]
@@ -61,7 +49,7 @@ class HorizontalPyramidMatching(nn.Module):
(hpp_index + 1) * h_per_hpp)
x_slice = x[:, :, h_filter, :]
x_slice = hpp(x_slice)
- x_slice = x_slice.view(t_n, -1)
+ x_slice = x_slice.view(t * n, -1)
feature.append(x_slice)
x = torch.stack(feature)
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 377c108..0ff8251 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -13,7 +13,7 @@ class RGBPartNet(nn.Module):
ae_in_channels: int = 3,
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
- hpm_scales: tuple[int, ...] = (1, 2, 4, 8),
+ hpm_scales: tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
fpfe_feature_channels: int = 32,
@@ -32,7 +32,7 @@ class RGBPartNet(nn.Module):
fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_part
)
self.hpm = HorizontalPyramidMatching(
- ae_in_channels, self.pn.tfa_in_channels, hpm_scales,
+ ae_feature_channels * 8, self.pn.tfa_in_channels, hpm_scales,
hpm_use_avg_pool, hpm_use_max_pool
)
@@ -54,38 +54,52 @@ class RGBPartNet(nn.Module):
# Step 1: Disentanglement
# t, n, c, h, w
num_frames = len(x_c1)
- f_c_c1, f_p_c1, f_p_c2 = [], [], []
+ # Decoded canonical features and Pose images
+ x_c_c1, x_p_c1 = [], []
+ # Features required to calculate losses
+ f_p_c1, f_p_c2 = [], []
xrecon_loss, cano_cons_loss = torch.zeros(1), torch.zeros(1)
for t2 in range(num_frames):
t1 = random.randrange(num_frames)
output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y)
- (feature_t2, xrecon_loss_t2, cano_cons_loss_t2) = output
- (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2) = feature_t2
- # Features for next step
- f_c_c1.append(f_c_c1_t2)
- f_p_c1.append(f_p_c1_t2)
+ (x_c1_t2, f_p_t2, losses) = output
+
+ # Decoded features or image
+ (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
+ # Canonical Features for HPM
+ x_c_c1.append(x_c_c1_t2)
+ # Pose image for Part Net
+ x_p_c1.append(x_p_c1_t2)
+
# Losses per time step
+ # Used in pose similarity loss
+ (f_p_c1_t2, f_p_c2_t2) = f_p_t2
+ f_p_c1.append(f_p_c1_t2)
f_p_c2.append(f_p_c2_t2)
+ # Cross reconstruction loss and canonical loss
+ (xrecon_loss_t2, cano_cons_loss_t2) = losses
xrecon_loss += xrecon_loss_t2
cano_cons_loss += cano_cons_loss_t2
- f_c_c1 = torch.stack(f_c_c1)
- f_p_c1 = torch.stack(f_p_c1)
+
+ x_c_c1 = torch.stack(x_c_c1)
+ x_p_c1 = torch.stack(x_p_c1)
# Step 2.a: HPM & Static Gait Feature Aggregation
# t, n, c, h, w
- x_c = self.hpm(f_c_c1)
+ x_c = self.hpm(x_c_c1)
# p, t, n, c
x_c = x_c.mean(dim=1)
# p, n, c
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
# t, n, c, h, w
- x_p = self.pn(f_p_c1)
+ x_p = self.pn(x_p_c1)
# p, n, c
# Step 3: Cat feature map together and calculate losses
- x = torch.cat(x_c, x_p)
+ x = torch.cat([x_c, x_p])
# Losses
+ f_p_c1 = torch.stack(f_p_c1)
f_p_c2 = torch.stack(f_p_c2)
pose_sim_loss = self.pose_sim_loss(f_p_c1, f_p_c2)
cano_cons_loss /= num_frames