diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 81 | ||||
-rw-r--r-- | models/layers.py | 8 | ||||
-rw-r--r-- | models/model.py | 14 | ||||
-rw-r--r-- | models/rgb_part_net.py | 17 |
4 files changed, 70 insertions, 50 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 7b9b29f..e17caed 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -13,39 +13,46 @@ class Encoder(nn.Module): def __init__( self, in_channels: int = 3, + frame_size: Tuple[int, int] = (64, 48), feature_channels: int = 64, output_dims: Tuple[int, int, int] = (128, 128, 64) ): super().__init__() self.feature_channels = feature_channels + h_0, w_0 = frame_size + h_1, w_1 = h_0 // 2, w_0 // 2 + h_2, w_2 = h_1 // 2, w_1 // 2 + self.feature_size = self.h_3, self.w_3 = h_2 // 4, w_2 // 4 # Appearance features, canonical features, pose features (self.f_a_dim, self.f_c_dim, self.f_p_dim) = output_dims - # Conv1 in_channels x 64 x 32 - # -> feature_map_size x 64 x 32 + # Conv1 in_channels x H x W + # -> feature_map_size x H x W self.conv1 = VGGConv2d(in_channels, feature_channels) - # MaxPool1 feature_map_size x 64 x 32 - # -> feature_map_size x 32 x 16 - self.max_pool1 = nn.AdaptiveMaxPool2d((32, 16)) - # Conv2 feature_map_size x 32 x 16 - # -> (feature_map_size*4) x 32 x 16 + # MaxPool1 feature_map_size x H x W + # -> feature_map_size x H//2 x W//2 + self.max_pool1 = nn.AdaptiveMaxPool2d((h_1, w_1)) + # Conv2 feature_map_size x H//2 x W//2 + # -> feature_map_size*4 x H//2 x W//2 self.conv2 = VGGConv2d(feature_channels, feature_channels * 4) - # MaxPool2 (feature_map_size*4) x 32 x 16 - # -> (feature_map_size*4) x 16 x 8 - self.max_pool2 = nn.AdaptiveMaxPool2d((16, 8)) - # Conv3 (feature_map_size*4) x 16 x 8 - # -> (feature_map_size*8) x 16 x 8 + # MaxPool2 feature_map_size*4 x H//2 x W//2 + # -> feature_map_size*4 x H//4 x W//4 + self.max_pool2 = nn.AdaptiveMaxPool2d((h_2, w_2)) + # Conv3 feature_map_size*4 x H//4 x W//4 + # -> feature_map_size*8 x H//4 x W//4 self.conv3 = VGGConv2d(feature_channels * 4, feature_channels * 8) - # Conv4 (feature_map_size*8) x 16 x 8 - # -> (feature_map_size*8) x 16 x 8 (for large dataset) + # Conv4 feature_map_size*8 x H//4 x W//4 + # -> feature_map_size*8 x H//4 x W//4 (for large dataset) self.conv4 = VGGConv2d(feature_channels * 8, feature_channels * 8) - # MaxPool3 (feature_map_size*8) x 16 x 8 - # -> (feature_map_size*8) x 4 x 2 - self.max_pool3 = nn.AdaptiveMaxPool2d((4, 2)) + # MaxPool3 feature_map_size*8 x H//4 x W//4 + # -> feature_map_size*8 x H//16 x W//16 + self.max_pool3 = nn.AdaptiveMaxPool2d(self.feature_size) embedding_dim = sum(output_dims) - # FC (feature_map_size*8) * 4 * 2 -> 320 - self.fc = BasicLinear(feature_channels * 8 * 2 * 4, embedding_dim) + # FC feature_map_size*8 * H//16 * W//16 -> embedding_dim + self.fc = BasicLinear( + (feature_channels * 8) * self.h_3 * self.w_3, embedding_dim + ) def forward(self, x): x = self.conv1(x) @@ -55,7 +62,7 @@ class Encoder(nn.Module): x = self.conv3(x) x = self.conv4(x) x = self.max_pool3(x) - x = x.view(-1, (self.feature_channels * 8) * 2 * 4) + x = x.view(-1, (self.feature_channels * 8) * self.h_3 * self.w_3) embedding = self.fc(x) f_appearance, f_canonical, f_pose = embedding.split( @@ -71,36 +78,41 @@ class Decoder(nn.Module): self, input_dims: Tuple[int, int, int] = (128, 128, 64), feature_channels: int = 64, + feature_size: Tuple[int, int] = (4, 3), out_channels: int = 3, ): super().__init__() self.feature_channels = feature_channels + self.h_0, self.w_0 = feature_size embedding_dim = sum(input_dims) - # FC 320 -> (feature_map_size*8) * 4 * 2 - self.fc = BasicLinear(embedding_dim, feature_channels * 8 * 2 * 4) + # FC 320 -> feature_map_size*8 * H * W + self.fc = BasicLinear( + embedding_dim, (feature_channels * 8) * self.h_0 * self.w_0 + ) - # TransConv1 (feature_map_size*8) x 4 x 2 - # -> (feature_map_size*4) x 8 x 4 + # TransConv1 feature_map_size*8 x H x W + # -> feature_map_size*4 x H*2 x W*2 self.trans_conv1 = DCGANConvTranspose2d(feature_channels * 8, feature_channels * 4) - # TransConv2 (feature_map_size*4) x 8 x 4 - # -> (feature_map_size*2) x 16 x 8 + # TransConv2 feature_map_size*4 x H*2 x W*2 + # -> feature_map_size*2 x H*4 x W*4 self.trans_conv2 = DCGANConvTranspose2d(feature_channels * 4, feature_channels * 2) - # TransConv3 (feature_map_size*2) x 16 x 8 - # -> feature_map_size x 32 x 16 + # TransConv3 feature_map_size*2 x H*4 x W*4 + # -> feature_map_size x H*8 x W*8 self.trans_conv3 = DCGANConvTranspose2d(feature_channels * 2, feature_channels) - # TransConv4 feature_map_size x 32 x 16 - # -> in_channels x 64 x 32 + # TransConv4 feature_map_size x H*8 x W*8 + # -> in_channels x H*16 x W*16 self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) def forward(self, f_appearance, f_canonical, f_pose, cano_only=False): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) - x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) + x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0) + x = F.relu(x, inplace=True) x = self.trans_conv1(x) x = self.trans_conv2(x) if cano_only: @@ -115,12 +127,15 @@ class AutoEncoder(nn.Module): def __init__( self, channels: int = 3, + frame_size: Tuple[int, int] = (64, 48), feature_channels: int = 64, embedding_dims: Tuple[int, int, int] = (128, 128, 64) ): super().__init__() - self.encoder = Encoder(channels, feature_channels, embedding_dims) - self.decoder = Decoder(embedding_dims, feature_channels, channels) + self.encoder = Encoder(channels, frame_size, + feature_channels, embedding_dims) + self.decoder = Decoder(embedding_dims, feature_channels, + self.encoder.feature_size, channels) def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() diff --git a/models/layers.py b/models/layers.py index 98e4c10..ae61583 100644 --- a/models/layers.py +++ b/models/layers.py @@ -162,7 +162,7 @@ class BasicConv1d(nn.Module): return self.conv(x) -class HorizontalPyramidPooling(BasicConv2d): +class HorizontalPyramidPooling(nn.Module): def __init__( self, in_channels: int, @@ -172,8 +172,10 @@ class HorizontalPyramidPooling(BasicConv2d): use_max_pool: bool = False, **kwargs ): - super().__init__(in_channels, out_channels, kernel_size=1, **kwargs) + super().__init__() self.use_1x1conv = use_1x1conv + if use_1x1conv: + self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs) self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.' @@ -188,5 +190,5 @@ class HorizontalPyramidPooling(BasicConv2d): elif not self.use_avg_pool and self.use_max_pool: x = self.max_pool(x) if self.use_1x1conv: - x = super().forward(x) + x = self.conv(x) return x diff --git a/models/model.py b/models/model.py index 27c648c..a3b6d3a 100644 --- a/models/model.py +++ b/models/model.py @@ -52,6 +52,7 @@ class Model: self.is_train: bool = True self.in_channels: int = 3 + self.in_size: Tuple[int, int] = (64, 48) self.pr: Optional[int] = None self.k: Optional[int] = None @@ -144,7 +145,7 @@ class Model: hpm_optim_hp = optim_hp.pop('hpm', {}) fc_optim_hp = optim_hp.pop('fc', {}) sched_hp = self.hp.get('scheduler', {}) - self.rgb_pn = RGBPartNet(self.in_channels, **model_hp, + self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) @@ -220,16 +221,16 @@ class Model: if self.image_log_on: i_a, i_c, i_p = images self.writer.add_images( + 'Appearance image', i_a, self.curr_iter + ) + self.writer.add_images( 'Canonical image', i_c, self.curr_iter ) - for (i, (o, a, p)) in enumerate(zip(x_c1, i_a, i_p)): + for i, (o, p) in enumerate(zip(x_c1, i_p)): self.writer.add_images( f'Original image/batch {i}', o, self.curr_iter ) self.writer.add_images( - f'Appearance image/batch {i}', a, self.curr_iter - ) - self.writer.add_images( f'Pose image/batch {i}', p, self.curr_iter ) time_used = datetime.now() - start_time @@ -296,7 +297,7 @@ class Model: # Init models model_hp = self.hp.get('model', {}) - self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **model_hp) + self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.rgb_pn.eval() @@ -456,6 +457,7 @@ class Model: dataset_config: Dict ) -> Union[CASIAB]: self.in_channels = dataset_config.get('num_input_channels', 3) + self.in_size = dataset_config.get('frame_size', (64, 48)) self._dataset_sig = self._make_signature( dataset_config, popped_keys=['root_dir', 'cache_on'] diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 260eabd..2af990e 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,6 +13,7 @@ class RGBPartNet(nn.Module): def __init__( self, ae_in_channels: int = 3, + ae_in_size: Tuple[int, int] = (64, 48), ae_feature_channels: int = 64, f_a_c_p_dims: Tuple[int, int, int] = (128, 128, 64), hpm_use_1x1conv: bool = False, @@ -35,7 +36,7 @@ class RGBPartNet(nn.Module): self.image_log_on = image_log_on self.ae = AutoEncoder( - ae_in_channels, ae_feature_channels, f_a_c_p_dims + ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) self.pn = PartNet( ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, @@ -103,7 +104,7 @@ class RGBPartNet(nn.Module): i_a, i_c, i_p = None, None, None if self.image_log_on: - i_a = self._decode_appr_feature(f_a_, n, t, c, h, w, device) + i_a = self._decode_appr_feature(f_a_, n, t, device) # Continue decoding canonical features i_c = self.ae.decoder.trans_conv3(x_c) i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c)) @@ -117,14 +118,14 @@ class RGBPartNet(nn.Module): x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device) return (x_c, x_p), None, None - def _decode_appr_feature(self, f_a_, n, t, c, h, w, device): + def _decode_appr_feature(self, f_a_, n, t, device): # Decode appearance features - x_a_ = self.ae.decoder( - f_a_, - torch.zeros((n * t, self.f_c_dim), device=device), - torch.zeros((n * t, self.f_p_dim), device=device) + f_a = f_a_.view(n, t, -1) + x_a = self.ae.decoder( + f_a.mean(1), + torch.zeros((n, self.f_c_dim), device=device), + torch.zeros((n, self.f_p_dim), device=device) ) - x_a = x_a_.view(n, t, c, h, w) return x_a def _decode_cano_feature(self, f_c_, n, t, device): |