diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-05-17 17:19:20 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-05-17 17:19:20 +0800 |
commit | 2b5bc2350df8ebb8e4e5722f5563f2e44823ac34 (patch) | |
tree | 94816e528ed7ac1732c44877181540fd5a8ad810 /models | |
parent | fde8d8fc0e4d986be37b26b468b0ac9d53d7bcd9 (diff) |
Reduce channels in the auto-encoder and add more layers
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 86 | ||||
-rw-r--r-- | models/model.py | 2 |
2 files changed, 57 insertions, 31 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index dc7843a..aa55a42 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -12,42 +12,53 @@ class Encoder(nn.Module): self, in_channels: int = 3, frame_size: tuple[int, int] = (64, 48), - feature_channels: int = 64, - output_dims: tuple[int, int, int] = (192, 192, 128) + feature_channels: int = 32, + output_dims: tuple[int, int, int] = (48, 48, 32) ): super().__init__() h_0, w_0 = frame_size h_1, w_1 = h_0 // 2, w_0 // 2 - h_2, w_2 = h_1 // 2, w_1 // 2 + h_2, w_2 = h_1 // 4, w_1 // 4 # Appearance features, canonical features, pose features (self.f_a_dim, self.f_c_dim, self.f_p_dim) = output_dims # Conv1 in_channels x H x W # -> feature_map_size x H x W - self.conv1 = VGGConv2d(in_channels, feature_channels) + self.conv1 = VGGConv2d(in_channels, + feature_channels, + kernel_size=5, + padding=2) + # Conv2 feature_map_size x H x W + # -> feature_map_size x H x W + self.conv2 = VGGConv2d(feature_channels, feature_channels) # MaxPool1 feature_map_size x H x W # -> feature_map_size x H//2 x W//2 self.max_pool1 = nn.AdaptiveMaxPool2d((h_1, w_1)) - # Conv2 feature_map_size x H//2 x W//2 - # -> feature_map_size*4 x H//2 x W//2 - self.conv2 = VGGConv2d(feature_channels, feature_channels * 4) - # MaxPool2 feature_map_size*4 x H//2 x W//2 - # -> feature_map_size*4 x H//4 x W//4 + # Conv3 feature_map_size x H//2 x W//2 + # -> feature_map_size*2 x H//2 x W//2 + self.conv3 = VGGConv2d(feature_channels, feature_channels * 2) + # Conv4 feature_map_size*2 x H//2 x W//2 + # -> feature_map_size*2 x H//2 x W//2 + self.conv4 = VGGConv2d(feature_channels * 2, feature_channels * 2) + # MaxPool2 feature_map_size*2 x H//2 x W//2 + # -> feature_map_size*2 x H//8 x W//8 self.max_pool2 = nn.AdaptiveMaxPool2d((h_2, w_2)) - # Conv3 feature_map_size*4 x H//4 x W//4 - # -> feature_map_size*8 x H//4 x W//4 - self.conv3 = VGGConv2d(feature_channels * 4, feature_channels * 8) - # Conv4 feature_map_size*8 x H//4 x W//4 - # -> feature_map_size*8 x H//4 x W//4 (for large dataset) - self.conv4 = VGGConv2d(feature_channels * 8, feature_channels * 8) + # Conv5 feature_map_size*2 x H//8 x W//8 + # -> feature_map_size*4 x H//8 x W//8 + self.conv5 = VGGConv2d(feature_channels * 2, feature_channels * 4) + # Conv6 feature_map_size*4 x H//8 x W//8 + # -> feature_map_size*4 x H//8 x W//8 + self.conv6 = VGGConv2d(feature_channels * 4, feature_channels * 4) def forward(self, x): x = self.conv1(x) - x = self.max_pool1(x) x = self.conv2(x) - x = self.max_pool2(x) + x = self.max_pool1(x) x = self.conv3(x) x = self.conv4(x) + x = self.max_pool2(x) + x = self.conv5(x) + x = self.conv6(x) f_appearance, f_canonical, f_pose = x.split( (self.f_a_dim, self.f_c_dim, self.f_p_dim), dim=1 ) @@ -59,41 +70,56 @@ class Decoder(nn.Module): def __init__( self, - feature_channels: int = 64, + feature_channels: int = 32, out_channels: int = 3, ): super().__init__() self.feature_channels = feature_channels - # TransConv1 feature_map_size*8 x H x W + # TransConv1 feature_map_size*4 x H x W # -> feature_map_size*4 x H x W - self.trans_conv1 = DCGANConvTranspose2d(feature_channels * 8, + self.trans_conv1 = DCGANConvTranspose2d(feature_channels * 4, feature_channels * 4, kernel_size=3, stride=1, padding=1) - # TransConv2 feature_map_size*4 x H x W + # TransConv2 feature_map_size*4 x H x W # -> feature_map_size*2 x H*2 x W*2 self.trans_conv2 = DCGANConvTranspose2d(feature_channels * 4, feature_channels * 2) # TransConv3 feature_map_size*2 x H*2 x W*2 - # -> feature_map_size x H*2 x W*2 + # -> feature_map_size*2 x H*2 x W*2 self.trans_conv3 = DCGANConvTranspose2d(feature_channels * 2, + feature_channels * 2, + kernel_size=3, + stride=1, + padding=1) + # TransConv4 feature_map_size*2 x H*2 x W*2 + # -> feature_map_size x H*4 x W*4 + self.trans_conv4 = DCGANConvTranspose2d(feature_channels * 2, + feature_channels) + + # TransConv5 feature_map_size x H*4 x W*4 + # -> feature_map_size x H*4 x W*4 + self.trans_conv5 = DCGANConvTranspose2d(feature_channels, feature_channels, kernel_size=3, stride=1, padding=1) - # TransConv4 feature_map_size x H*2 x W*2 - # -> in_channels x H*4 x W*4 - self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, - is_last_layer=True) + + # TransConv6 feature_map_size x H*4 x W*4 + # -> out_channels x H*8 x W*8 + self.trans_conv6 = DCGANConvTranspose2d(feature_channels, + out_channels) def forward(self, f_appearance, f_canonical, f_pose): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.trans_conv1(x) x = self.trans_conv2(x) x = self.trans_conv3(x) - x = torch.sigmoid(self.trans_conv4(x)) + x = self.trans_conv4(x) + x = self.trans_conv5(x) + x = torch.sigmoid(self.trans_conv6(x)) return x @@ -103,8 +129,8 @@ class AutoEncoder(nn.Module): self, channels: int = 3, frame_size: tuple[int, int] = (64, 48), - feature_channels: int = 64, - embedding_dims: tuple[int, int, int] = (192, 192, 128) + feature_channels: int = 32, + embedding_dims: tuple[int, int, int] = (48, 48, 32) ): super().__init__() self.embedding_dims = embedding_dims @@ -118,7 +144,7 @@ class AutoEncoder(nn.Module): x_c1_t2_ = x_c1_t2.view(n * t, c, h, w) (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_) = self.encoder(x_c1_t2_) - f_size = [torch.Size([n, t, embedding_dim, h // 4, w // 4]) + f_size = [torch.Size([n, t, embedding_dim, h // 8, w // 8]) for embedding_dim in self.embedding_dims] f_a_c1_t2 = f_a_c1_t2_.view(f_size[0]) f_c_c1_t2 = f_c_c1_t2_.view(f_size[1]) diff --git a/models/model.py b/models/model.py index 6118bdf..bf23428 100644 --- a/models/model.py +++ b/models/model.py @@ -366,7 +366,7 @@ class Model: for _f_p_c1_t2, _f_p_c2_t2 in zip(*f_loss[2]) ]).sum() - return xrecon_loss, cano_cons_loss * 10, pose_sim_loss * 100 + return xrecon_loss / 10, cano_cons_loss, pose_sim_loss * 10 def _classification_loss(self, embedding, y): # Duplicate labels for each part |