diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 103 | ||||
-rw-r--r-- | models/rgb_part_net.py | 104 |
2 files changed, 78 insertions, 129 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 7f0eb6c..1028767 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.layers import VGGConv2d, DCGANConvTranspose2d, BasicLinear +from models.layers import VGGConv2d, DCGANConvTranspose2d class Encoder(nn.Module): @@ -15,14 +15,12 @@ class Encoder(nn.Module): in_channels: int = 3, frame_size: Tuple[int, int] = (64, 48), feature_channels: int = 64, - output_dims: Tuple[int, int, int] = (128, 128, 64) + output_dims: Tuple[int, int, int] = (192, 192, 128) ): super().__init__() - self.feature_channels = feature_channels h_0, w_0 = frame_size h_1, w_1 = h_0 // 2, w_0 // 2 h_2, w_2 = h_1 // 2, w_1 // 2 - self.feature_size = self.h_3, self.w_3 = h_2 // 4, w_2 // 4 # Appearance features, canonical features, pose features (self.f_a_dim, self.f_c_dim, self.f_p_dim) = output_dims @@ -44,15 +42,6 @@ class Encoder(nn.Module): # Conv4 feature_map_size*8 x H//4 x W//4 # -> feature_map_size*8 x H//4 x W//4 (for large dataset) self.conv4 = VGGConv2d(feature_channels * 8, feature_channels * 8) - # MaxPool3 feature_map_size*8 x H//4 x W//4 - # -> feature_map_size*8 x H//16 x W//16 - self.max_pool3 = nn.AdaptiveMaxPool2d(self.feature_size) - - embedding_dim = sum(output_dims) - # FC feature_map_size*8 * H//16 * W//16 -> embedding_dim - self.fc = BasicLinear( - (feature_channels * 8) * self.h_3 * self.w_3, embedding_dim - ) def forward(self, x): x = self.conv1(x) @@ -61,11 +50,7 @@ class Encoder(nn.Module): x = self.max_pool2(x) x = self.conv3(x) x = self.conv4(x) - x = self.max_pool3(x) - x = x.view(-1, (self.feature_channels * 8) * self.h_3 * self.w_3) - embedding = self.fc(x) - - f_appearance, f_canonical, f_pose = embedding.split( + f_appearance, f_canonical, f_pose = x.split( (self.f_a_dim, self.f_c_dim, self.f_p_dim), dim=1 ) return f_appearance, f_canonical, f_pose @@ -76,47 +61,39 @@ class Decoder(nn.Module): def __init__( self, - input_dims: Tuple[int, int, int] = (128, 128, 64), feature_channels: int = 64, - feature_size: Tuple[int, int] = (4, 3), out_channels: int = 3, ): super().__init__() self.feature_channels = feature_channels - self.h_0, self.w_0 = feature_size - embedding_dim = sum(input_dims) - # FC 320 -> feature_map_size*8 * H * W - self.fc = BasicLinear( - embedding_dim, (feature_channels * 8) * self.h_0 * self.w_0 - ) - - # TransConv1 feature_map_size*8 x H x W - # -> feature_map_size*4 x H*2 x W*2 + # TransConv1 feature_map_size*8 x H x W + # -> feature_map_size*4 x H x W self.trans_conv1 = DCGANConvTranspose2d(feature_channels * 8, - feature_channels * 4) - # TransConv2 feature_map_size*4 x H*2 x W*2 - # -> feature_map_size*2 x H*4 x W*4 + feature_channels * 4, + kernel_size=3, + stride=1, + padding=1) + # TransConv2 feature_map_size*4 x H x W + # -> feature_map_size*2 x H*2 x W*2 self.trans_conv2 = DCGANConvTranspose2d(feature_channels * 4, feature_channels * 2) - # TransConv3 feature_map_size*2 x H*4 x W*4 - # -> feature_map_size x H*8 x W*8 + # TransConv3 feature_map_size*2 x H*2 x W*2 + # -> feature_map_size x H*2 x W*2 self.trans_conv3 = DCGANConvTranspose2d(feature_channels * 2, - feature_channels) - # TransConv4 feature_map_size x H*8 x W*8 - # -> in_channels x H*16 x W*16 + feature_channels, + kernel_size=3, + stride=1, + padding=1) + # TransConv4 feature_map_size x H*2 x W*2 + # -> in_channels x H*4 x W*4 self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose, is_feature_map=False): + def forward(self, f_appearance, f_canonical, f_pose): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) - x = self.fc(x) - x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0) - x = F.relu(x, inplace=True) x = self.trans_conv1(x) x = self.trans_conv2(x) - if is_feature_map: - return x x = self.trans_conv3(x) x = torch.sigmoid(self.trans_conv4(x)) @@ -129,13 +106,13 @@ class AutoEncoder(nn.Module): channels: int = 3, frame_size: Tuple[int, int] = (64, 48), feature_channels: int = 64, - embedding_dims: Tuple[int, int, int] = (128, 128, 64) + embedding_dims: Tuple[int, int, int] = (192, 192, 128) ): super().__init__() + self.embedding_dims = embedding_dims self.encoder = Encoder(channels, frame_size, feature_channels, embedding_dims) - self.decoder = Decoder(embedding_dims, feature_channels, - self.encoder.feature_size, channels) + self.decoder = Decoder(feature_channels, channels) def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() @@ -143,37 +120,41 @@ class AutoEncoder(nn.Module): x_c1_t2_ = x_c1_t2.view(n * t, c, h, w) (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_) = self.encoder(x_c1_t2_) + f_size = [torch.Size([n, t, embedding_dim, h // 4, w // 4]) + for embedding_dim in self.embedding_dims] + f_a_c1_t2 = f_a_c1_t2_.view(f_size[0]) + f_c_c1_t2 = f_c_c1_t2_.view(f_size[1]) + f_p_c1_t2 = f_p_c1_t2_.view(f_size[2]) + if self.training: # t1 is random time step, c2 is another condition - x_c1_t1 = x_c1_t1.view(n * t, c, h, w) - (f_a_c1_t1_, f_c_c1_t1_, _) = self.encoder(x_c1_t1) - x_c2_t2 = x_c2_t2.view(n * t, c, h, w) - (_, f_c_c2_t2_, f_p_c2_t2_) = self.encoder(x_c2_t2) + x_c1_t1_ = x_c1_t1.view(n * t, c, h, w) + (f_a_c1_t1_, f_c_c1_t1_, _) = self.encoder(x_c1_t1_) + x_c2_t2_ = x_c2_t2.view(n * t, c, h, w) + (_, f_c_c2_t2_, f_p_c2_t2_) = self.encoder(x_c2_t2_) x_c1_t2_pred_ = self.decoder(f_a_c1_t1_, f_c_c1_t1_, f_p_c1_t2_) x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w) xrecon_loss = torch.stack([ - F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :]) + F.mse_loss(x_c1_t2[:, i], x_c1_t2_pred[:, i]) for i in range(t) ]).sum() - f_c_c1_t1 = f_c_c1_t1_.view(n, t, -1) - f_c_c1_t2 = f_c_c1_t2_.view(n, t, -1) - f_c_c2_t2 = f_c_c2_t2_.view(n, t, -1) + f_c_c1_t1 = f_c_c1_t1_.view(f_size[1]) + f_c_c2_t2 = f_c_c2_t2_.view(f_size[1]) cano_cons_loss = torch.stack([ - F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :]) - + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :]) + F.mse_loss(f_c_c1_t1[:, i], f_c_c1_t2[:, i]) + + F.mse_loss(f_c_c1_t2[:, i], f_c_c2_t2[:, i]) for i in range(t) ]).mean() - f_p_c1_t2 = f_p_c1_t2_.view(n, t, -1) - f_p_c2_t2 = f_p_c2_t2_.view(n, t, -1) + f_p_c2_t2 = f_p_c2_t2_.view(f_size[2]) pose_sim_loss = F.mse_loss(f_p_c1_t2.mean(1), f_p_c2_t2.mean(1)) return ( - (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_), - (xrecon_loss, cano_cons_loss, pose_sim_loss * 10) + (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2), + (xrecon_loss / 10, cano_cons_loss, pose_sim_loss * 10) ) else: # evaluating - return f_c_c1_t2_, f_p_c1_t2_ + return f_a_c1_t2, f_c_c1_t2, f_p_c1_t2 diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index fcd8fbc..811a711 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -14,7 +14,7 @@ class RGBPartNet(nn.Module): ae_in_channels: int = 3, ae_in_size: Tuple[int, int] = (64, 48), ae_feature_channels: int = 64, - f_a_c_p_dims: Tuple[int, int, int] = (128, 128, 64), + f_a_c_p_dims: Tuple[int, int, int] = (192, 192, 128), hpm_scales: Tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, @@ -25,100 +25,68 @@ class RGBPartNet(nn.Module): ): super().__init__() self.h, self.w = ae_in_size - (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims self.image_log_on = image_log_on self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) - self.pn_in_channels = ae_feature_channels * 2 self.hpm = HorizontalPyramidMatching( - self.pn_in_channels, embedding_dims[0], hpm_scales, + f_a_c_p_dims[1], embedding_dims[0], hpm_scales, hpm_use_avg_pool, hpm_use_max_pool ) - self.pn = PartNet(self.pn_in_channels, embedding_dims[1], - tfa_num_parts, tfa_squeeze_ratio) + self.pn = PartNet( + f_a_c_p_dims[2], embedding_dims[1], tfa_num_parts, tfa_squeeze_ratio + ) self.num_parts = self.hpm.num_parts + tfa_num_parts def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement # n, t, c, h, w - ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2) + (f_a, f_c, f_p), ae_losses = self._disentangle(x_c1, x_c2) # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w - x_c = self.hpm(x_c) + f_c_mean = f_c.mean(1) + x_c = self.hpm(f_c_mean) # p, n, d # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) # n, t, c, h, w - x_p = self.pn(x_p) + x_p = self.pn(f_p) # p, n, d if self.training: - return x_c, x_p, ae_losses, images + i_a, i_c, i_p = None, None, None + if self.image_log_on: + f_a_mean = f_a.mean(1) + i_a = self.ae.decoder( + f_a_mean, + torch.zeros_like(f_c_mean), + torch.zeros_like(f_p[:, 0]) + ) + i_c = self.ae.decoder( + torch.zeros_like(f_a_mean), + f_c_mean, + torch.zeros_like(f_p[:, 0]) + ) + f_p_size = f_p.size() + i_p = self.ae.decoder( + torch.zeros(f_p_size[0] * f_p_size[1], *f_a.shape[2:], + device=f_a.device), + torch.zeros(f_p_size[0] * f_p_size[1], *f_c.shape[2:], + device=f_c.device), + f_p.view(-1, *f_p_size[2:]) + ).view(x_c1.size()) + return x_c, x_p, ae_losses, (i_a, i_c, i_p) else: return x_c, x_p def _disentangle(self, x_c1_t2, x_c2_t2=None): - n, t, c, h, w = x_c1_t2.size() - device = x_c1_t2.device if self.training: - x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] - ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) - # Decode features - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, device) - x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) - - i_a, i_c, i_p = None, None, None - if self.image_log_on: - with torch.no_grad(): - i_a = self._decode_appr_feature(f_a_, n, t, device) - # Continue decoding canonical features - i_c = self.ae.decoder.trans_conv3(x_c) - i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c)) - i_p_ = self.ae.decoder.trans_conv3(x_p_) - i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_)) - i_p = i_p_.view(n, t, c, h, w) - - return (x_c, x_p), losses, (i_a, i_c, i_p) - + x_c1_t1 = x_c1_t2[:, torch.randperm(x_c1_t2.size(1)), :, :, :] + features, losses = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) + return features, losses else: # evaluating - f_c_, f_p_ = self.ae(x_c1_t2) - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, device) - x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) - return (x_c, x_p), None, None - - def _decode_appr_feature(self, f_a_, n, t, device): - # Decode appearance features - f_a = f_a_.view(n, t, -1) - x_a = self.ae.decoder( - f_a.mean(1), - torch.zeros((n, self.f_c_dim), device=device), - torch.zeros((n, self.f_p_dim), device=device) - ) - return x_a - - def _decode_cano_feature(self, f_c_, n, t, device): - # Decode average canonical features to higher dimension - f_c = f_c_.view(n, t, -1) - x_c = self.ae.decoder( - torch.zeros((n, self.f_a_dim), device=device), - f_c.mean(1), - torch.zeros((n, self.f_p_dim), device=device), - is_feature_map=True - ) - return x_c - - def _decode_pose_feature(self, f_p_, n, t, device): - # Decode pose features to images - x_p_ = self.ae.decoder( - torch.zeros((n * t, self.f_a_dim), device=device), - torch.zeros((n * t, self.f_c_dim), device=device), - f_p_, - is_feature_map=True - ) - return x_p_ + features = self.ae(x_c1_t2) + return features, None |