summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.py4
-rw-r--r--models/auto_encoder.py103
-rw-r--r--models/rgb_part_net.py104
3 files changed, 80 insertions, 131 deletions
diff --git a/config.py b/config.py
index 25846a2..b643c75 100644
--- a/config.py
+++ b/config.py
@@ -9,7 +9,7 @@ config = {
# Recorde disentangled image or not
'image_log_on': False,
# The number of subjects for validating (Part of testing set)
- 'val_size': 10,
+ 'val_size': 20,
},
# Dataset settings
'dataset': {
@@ -49,7 +49,7 @@ config = {
# Auto-encoder feature channels coefficient
'ae_feature_channels': 64,
# Appearance, canonical and pose feature dimensions
- 'f_a_c_p_dims': (192, 192, 96),
+ 'f_a_c_p_dims': (192, 192, 128),
# HPM pyramid scales, of which sum is number of parts
'hpm_scales': (1, 2, 4, 8),
# Global pooling method
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 7f0eb6c..1028767 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
-from models.layers import VGGConv2d, DCGANConvTranspose2d, BasicLinear
+from models.layers import VGGConv2d, DCGANConvTranspose2d
class Encoder(nn.Module):
@@ -15,14 +15,12 @@ class Encoder(nn.Module):
in_channels: int = 3,
frame_size: Tuple[int, int] = (64, 48),
feature_channels: int = 64,
- output_dims: Tuple[int, int, int] = (128, 128, 64)
+ output_dims: Tuple[int, int, int] = (192, 192, 128)
):
super().__init__()
- self.feature_channels = feature_channels
h_0, w_0 = frame_size
h_1, w_1 = h_0 // 2, w_0 // 2
h_2, w_2 = h_1 // 2, w_1 // 2
- self.feature_size = self.h_3, self.w_3 = h_2 // 4, w_2 // 4
# Appearance features, canonical features, pose features
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = output_dims
@@ -44,15 +42,6 @@ class Encoder(nn.Module):
# Conv4 feature_map_size*8 x H//4 x W//4
# -> feature_map_size*8 x H//4 x W//4 (for large dataset)
self.conv4 = VGGConv2d(feature_channels * 8, feature_channels * 8)
- # MaxPool3 feature_map_size*8 x H//4 x W//4
- # -> feature_map_size*8 x H//16 x W//16
- self.max_pool3 = nn.AdaptiveMaxPool2d(self.feature_size)
-
- embedding_dim = sum(output_dims)
- # FC feature_map_size*8 * H//16 * W//16 -> embedding_dim
- self.fc = BasicLinear(
- (feature_channels * 8) * self.h_3 * self.w_3, embedding_dim
- )
def forward(self, x):
x = self.conv1(x)
@@ -61,11 +50,7 @@ class Encoder(nn.Module):
x = self.max_pool2(x)
x = self.conv3(x)
x = self.conv4(x)
- x = self.max_pool3(x)
- x = x.view(-1, (self.feature_channels * 8) * self.h_3 * self.w_3)
- embedding = self.fc(x)
-
- f_appearance, f_canonical, f_pose = embedding.split(
+ f_appearance, f_canonical, f_pose = x.split(
(self.f_a_dim, self.f_c_dim, self.f_p_dim), dim=1
)
return f_appearance, f_canonical, f_pose
@@ -76,47 +61,39 @@ class Decoder(nn.Module):
def __init__(
self,
- input_dims: Tuple[int, int, int] = (128, 128, 64),
feature_channels: int = 64,
- feature_size: Tuple[int, int] = (4, 3),
out_channels: int = 3,
):
super().__init__()
self.feature_channels = feature_channels
- self.h_0, self.w_0 = feature_size
- embedding_dim = sum(input_dims)
- # FC 320 -> feature_map_size*8 * H * W
- self.fc = BasicLinear(
- embedding_dim, (feature_channels * 8) * self.h_0 * self.w_0
- )
-
- # TransConv1 feature_map_size*8 x H x W
- # -> feature_map_size*4 x H*2 x W*2
+ # TransConv1 feature_map_size*8 x H x W
+ # -> feature_map_size*4 x H x W
self.trans_conv1 = DCGANConvTranspose2d(feature_channels * 8,
- feature_channels * 4)
- # TransConv2 feature_map_size*4 x H*2 x W*2
- # -> feature_map_size*2 x H*4 x W*4
+ feature_channels * 4,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ # TransConv2 feature_map_size*4 x H x W
+ # -> feature_map_size*2 x H*2 x W*2
self.trans_conv2 = DCGANConvTranspose2d(feature_channels * 4,
feature_channels * 2)
- # TransConv3 feature_map_size*2 x H*4 x W*4
- # -> feature_map_size x H*8 x W*8
+ # TransConv3 feature_map_size*2 x H*2 x W*2
+ # -> feature_map_size x H*2 x W*2
self.trans_conv3 = DCGANConvTranspose2d(feature_channels * 2,
- feature_channels)
- # TransConv4 feature_map_size x H*8 x W*8
- # -> in_channels x H*16 x W*16
+ feature_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ # TransConv4 feature_map_size x H*2 x W*2
+ # -> in_channels x H*4 x W*4
self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
- def forward(self, f_appearance, f_canonical, f_pose, is_feature_map=False):
+ def forward(self, f_appearance, f_canonical, f_pose):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
- x = self.fc(x)
- x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0)
- x = F.relu(x, inplace=True)
x = self.trans_conv1(x)
x = self.trans_conv2(x)
- if is_feature_map:
- return x
x = self.trans_conv3(x)
x = torch.sigmoid(self.trans_conv4(x))
@@ -129,13 +106,13 @@ class AutoEncoder(nn.Module):
channels: int = 3,
frame_size: Tuple[int, int] = (64, 48),
feature_channels: int = 64,
- embedding_dims: Tuple[int, int, int] = (128, 128, 64)
+ embedding_dims: Tuple[int, int, int] = (192, 192, 128)
):
super().__init__()
+ self.embedding_dims = embedding_dims
self.encoder = Encoder(channels, frame_size,
feature_channels, embedding_dims)
- self.decoder = Decoder(embedding_dims, feature_channels,
- self.encoder.feature_size, channels)
+ self.decoder = Decoder(feature_channels, channels)
def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()
@@ -143,37 +120,41 @@ class AutoEncoder(nn.Module):
x_c1_t2_ = x_c1_t2.view(n * t, c, h, w)
(f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_) = self.encoder(x_c1_t2_)
+ f_size = [torch.Size([n, t, embedding_dim, h // 4, w // 4])
+ for embedding_dim in self.embedding_dims]
+ f_a_c1_t2 = f_a_c1_t2_.view(f_size[0])
+ f_c_c1_t2 = f_c_c1_t2_.view(f_size[1])
+ f_p_c1_t2 = f_p_c1_t2_.view(f_size[2])
+
if self.training:
# t1 is random time step, c2 is another condition
- x_c1_t1 = x_c1_t1.view(n * t, c, h, w)
- (f_a_c1_t1_, f_c_c1_t1_, _) = self.encoder(x_c1_t1)
- x_c2_t2 = x_c2_t2.view(n * t, c, h, w)
- (_, f_c_c2_t2_, f_p_c2_t2_) = self.encoder(x_c2_t2)
+ x_c1_t1_ = x_c1_t1.view(n * t, c, h, w)
+ (f_a_c1_t1_, f_c_c1_t1_, _) = self.encoder(x_c1_t1_)
+ x_c2_t2_ = x_c2_t2.view(n * t, c, h, w)
+ (_, f_c_c2_t2_, f_p_c2_t2_) = self.encoder(x_c2_t2_)
x_c1_t2_pred_ = self.decoder(f_a_c1_t1_, f_c_c1_t1_, f_p_c1_t2_)
x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w)
xrecon_loss = torch.stack([
- F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
+ F.mse_loss(x_c1_t2[:, i], x_c1_t2_pred[:, i])
for i in range(t)
]).sum()
- f_c_c1_t1 = f_c_c1_t1_.view(n, t, -1)
- f_c_c1_t2 = f_c_c1_t2_.view(n, t, -1)
- f_c_c2_t2 = f_c_c2_t2_.view(n, t, -1)
+ f_c_c1_t1 = f_c_c1_t1_.view(f_size[1])
+ f_c_c2_t2 = f_c_c2_t2_.view(f_size[1])
cano_cons_loss = torch.stack([
- F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
- + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :])
+ F.mse_loss(f_c_c1_t1[:, i], f_c_c1_t2[:, i])
+ + F.mse_loss(f_c_c1_t2[:, i], f_c_c2_t2[:, i])
for i in range(t)
]).mean()
- f_p_c1_t2 = f_p_c1_t2_.view(n, t, -1)
- f_p_c2_t2 = f_p_c2_t2_.view(n, t, -1)
+ f_p_c2_t2 = f_p_c2_t2_.view(f_size[2])
pose_sim_loss = F.mse_loss(f_p_c1_t2.mean(1), f_p_c2_t2.mean(1))
return (
- (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_),
- (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2),
+ (xrecon_loss / 10, cano_cons_loss, pose_sim_loss * 10)
)
else: # evaluating
- return f_c_c1_t2_, f_p_c1_t2_
+ return f_a_c1_t2, f_c_c1_t2, f_p_c1_t2
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index fcd8fbc..811a711 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -14,7 +14,7 @@ class RGBPartNet(nn.Module):
ae_in_channels: int = 3,
ae_in_size: Tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
- f_a_c_p_dims: Tuple[int, int, int] = (128, 128, 64),
+ f_a_c_p_dims: Tuple[int, int, int] = (192, 192, 128),
hpm_scales: Tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
@@ -25,100 +25,68 @@ class RGBPartNet(nn.Module):
):
super().__init__()
self.h, self.w = ae_in_size
- (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
self.image_log_on = image_log_on
self.ae = AutoEncoder(
ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
- self.pn_in_channels = ae_feature_channels * 2
self.hpm = HorizontalPyramidMatching(
- self.pn_in_channels, embedding_dims[0], hpm_scales,
+ f_a_c_p_dims[1], embedding_dims[0], hpm_scales,
hpm_use_avg_pool, hpm_use_max_pool
)
- self.pn = PartNet(self.pn_in_channels, embedding_dims[1],
- tfa_num_parts, tfa_squeeze_ratio)
+ self.pn = PartNet(
+ f_a_c_p_dims[2], embedding_dims[1], tfa_num_parts, tfa_squeeze_ratio
+ )
self.num_parts = self.hpm.num_parts + tfa_num_parts
def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
# n, t, c, h, w
- ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2)
+ (f_a, f_c, f_p), ae_losses = self._disentangle(x_c1, x_c2)
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
- x_c = self.hpm(x_c)
+ f_c_mean = f_c.mean(1)
+ x_c = self.hpm(f_c_mean)
# p, n, d
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
# n, t, c, h, w
- x_p = self.pn(x_p)
+ x_p = self.pn(f_p)
# p, n, d
if self.training:
- return x_c, x_p, ae_losses, images
+ i_a, i_c, i_p = None, None, None
+ if self.image_log_on:
+ f_a_mean = f_a.mean(1)
+ i_a = self.ae.decoder(
+ f_a_mean,
+ torch.zeros_like(f_c_mean),
+ torch.zeros_like(f_p[:, 0])
+ )
+ i_c = self.ae.decoder(
+ torch.zeros_like(f_a_mean),
+ f_c_mean,
+ torch.zeros_like(f_p[:, 0])
+ )
+ f_p_size = f_p.size()
+ i_p = self.ae.decoder(
+ torch.zeros(f_p_size[0] * f_p_size[1], *f_a.shape[2:],
+ device=f_a.device),
+ torch.zeros(f_p_size[0] * f_p_size[1], *f_c.shape[2:],
+ device=f_c.device),
+ f_p.view(-1, *f_p_size[2:])
+ ).view(x_c1.size())
+ return x_c, x_p, ae_losses, (i_a, i_c, i_p)
else:
return x_c, x_p
def _disentangle(self, x_c1_t2, x_c2_t2=None):
- n, t, c, h, w = x_c1_t2.size()
- device = x_c1_t2.device
if self.training:
- x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
- ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
- # Decode features
- x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p_ = self._decode_pose_feature(f_p_, n, t, device)
- x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4)
-
- i_a, i_c, i_p = None, None, None
- if self.image_log_on:
- with torch.no_grad():
- i_a = self._decode_appr_feature(f_a_, n, t, device)
- # Continue decoding canonical features
- i_c = self.ae.decoder.trans_conv3(x_c)
- i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
- i_p_ = self.ae.decoder.trans_conv3(x_p_)
- i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_))
- i_p = i_p_.view(n, t, c, h, w)
-
- return (x_c, x_p), losses, (i_a, i_c, i_p)
-
+ x_c1_t1 = x_c1_t2[:, torch.randperm(x_c1_t2.size(1)), :, :, :]
+ features, losses = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
+ return features, losses
else: # evaluating
- f_c_, f_p_ = self.ae(x_c1_t2)
- x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p_ = self._decode_pose_feature(f_p_, n, t, device)
- x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4)
- return (x_c, x_p), None, None
-
- def _decode_appr_feature(self, f_a_, n, t, device):
- # Decode appearance features
- f_a = f_a_.view(n, t, -1)
- x_a = self.ae.decoder(
- f_a.mean(1),
- torch.zeros((n, self.f_c_dim), device=device),
- torch.zeros((n, self.f_p_dim), device=device)
- )
- return x_a
-
- def _decode_cano_feature(self, f_c_, n, t, device):
- # Decode average canonical features to higher dimension
- f_c = f_c_.view(n, t, -1)
- x_c = self.ae.decoder(
- torch.zeros((n, self.f_a_dim), device=device),
- f_c.mean(1),
- torch.zeros((n, self.f_p_dim), device=device),
- is_feature_map=True
- )
- return x_c
-
- def _decode_pose_feature(self, f_p_, n, t, device):
- # Decode pose features to images
- x_p_ = self.ae.decoder(
- torch.zeros((n * t, self.f_a_dim), device=device),
- torch.zeros((n * t, self.f_c_dim), device=device),
- f_p_,
- is_feature_map=True
- )
- return x_p_
+ features = self.ae(x_c1_t2)
+ return features, None