diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 4 | ||||
-rw-r--r-- | models/part_net.py | 18 | ||||
-rw-r--r-- | models/rgb_part_net.py | 37 |
3 files changed, 24 insertions, 35 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 1ef7494..e6a3e60 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -106,14 +106,14 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose, cano_only=False): + def forward(self, f_appearance, f_canonical, f_pose, is_feature_map=False): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0) x = F.relu(x, inplace=True) x = self.trans_conv1(x) x = self.trans_conv2(x) - if cano_only: + if is_feature_map: return x x = self.trans_conv3(x) x = torch.sigmoid(self.trans_conv4(x)) diff --git a/models/part_net.py b/models/part_net.py index 62a2bac..29cf9cd 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -110,32 +110,22 @@ class TemporalFeatureAggregator(nn.Module): class PartNet(nn.Module): def __init__( self, - in_channels: int = 3, - feature_channels: int = 32, - kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), - paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), - halving: tuple[int, ...] = (0, 2, 3), + in_channels: int = 128, squeeze_ratio: int = 4, num_part: int = 16 ): super().__init__() self.num_part = num_part - self.fpfe = FrameLevelPartFeatureExtractor( - in_channels, feature_channels, kernel_sizes, paddings, halving - ) - - num_fconv_blocks = len(self.fpfe.fconv_blocks) - self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) self.tfa = TemporalFeatureAggregator( - self.tfa_in_channels, squeeze_ratio, self.num_part + in_channels, squeeze_ratio, self.num_part ) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) def forward(self, x): - n, t, _, _, _ = x.size() - x = self.fpfe(x) + n, t, c, h, w = x.size() + x = x.view(n * t, c, h, w) # n * t x c x h x w # Horizontal Pooling diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 408bca0..936ec46 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -17,16 +17,13 @@ class RGBPartNet(nn.Module): hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, - fpfe_feature_channels: int = 32, - fpfe_kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), - fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), - fpfe_halving: tuple[int, ...] = (0, 2, 3), tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, embedding_dims: int = 256, image_log_on: bool = False ): super().__init__() + self.h, self.w = ae_in_size (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims self.hpm_num_parts = sum(hpm_scales) self.image_log_on = image_log_on @@ -34,18 +31,17 @@ class RGBPartNet(nn.Module): self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) + self.pn_in_channels = ae_feature_channels * 2 self.pn = PartNet( - ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, - fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_parts + self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts ) - out_channels = self.pn.tfa_in_channels self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 2, out_channels, hpm_use_1x1conv, + ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv, hpm_scales, hpm_use_avg_pool, hpm_use_max_pool ) self.num_total_parts = self.hpm_num_parts + tfa_num_parts empty_fc = torch.empty(self.num_total_parts, - out_channels, embedding_dims) + self.pn_in_channels, embedding_dims) self.fc_mat = nn.Parameter(empty_fc) def fc(self, x): @@ -82,17 +78,20 @@ class RGBPartNet(nn.Module): if self.training: ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) # Decode features - with torch.no_grad(): - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + x_c = self._decode_cano_feature(f_c_, n, t, device) + x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) - i_a, i_c, i_p = None, None, None - if self.image_log_on: + i_a, i_c, i_p = None, None, None + if self.image_log_on: + with torch.no_grad(): i_a = self._decode_appr_feature(f_a_, n, t, device) # Continue decoding canonical features i_c = self.ae.decoder.trans_conv3(x_c) i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c)) - i_p = x_p + i_p_ = self.ae.decoder.trans_conv3(x_p_) + i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_)) + i_p = i_p_.view(n, t, c, h, w) return (x_c, x_p), losses, (i_a, i_c, i_p) @@ -119,7 +118,7 @@ class RGBPartNet(nn.Module): torch.zeros((n, self.f_a_dim), device=device), f_c.mean(1), torch.zeros((n, self.f_p_dim), device=device), - cano_only=True + is_feature_map=True ) return x_c @@ -128,7 +127,7 @@ class RGBPartNet(nn.Module): x_p_ = self.ae.decoder( torch.zeros((n * t, self.f_a_dim), device=device), torch.zeros((n * t, self.f_c_dim), device=device), - f_p_ + f_p_, + is_feature_map=True ) - x_p = x_p_.view(n, t, c, h, w) - return x_p + return x_p_ |