diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-27 20:20:12 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-27 20:20:12 +0800 |
commit | 5a95c94e9f250001d0007b5ac238505d0a5f23b5 (patch) | |
tree | 3e997b6d45963ee7a035204d51b22a0c0a100a86 /models/auto_encoder.py | |
parent | 343e2311b48b900c07072849570f82a5b90d2aa7 (diff) |
Refine auto-encoder
1. Wrap fully connected layers
2. Introduce hyperparameter tuning in constructor
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r-- | models/auto_encoder.py | 119 |
1 files changed, 73 insertions, 46 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 4a8aaa5..701f299 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -2,34 +2,48 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.layers import VGGConv2d, DCGANConvTranspose2d +from models.layers import VGGConv2d, DCGANConvTranspose2d, BasicLinear class Encoder(nn.Module): - def __init__(self, in_channels: int, opt): - super(Encoder, self).__init__() - self.opt = opt - self.em_dim = opt.em_dim - nf = 64 + """Squeeze input feature map to lower dimension""" - # Cx[HxW] - # Conv1 3x64x32 -> 64x64x32 - self.conv1 = VGGConv2d(in_channels, nf) - # MaxPool1 64x64x32 -> 64x32x16 + def __init__( + self, + in_channels: int = 3, + feature_channels: int = 64, + output_dims: tuple[int, int, int] = (128, 128, 64) + ): + super().__init__() + self.feature_channels = feature_channels + # Appearance features, canonical features, pose features + (self.f_a_dim, self.f_c_dim, self.f_p_dim) = output_dims + + # Conv1 in_channels x 64 x 32 + # -> feature_map_size x 64 x 32 + self.conv1 = VGGConv2d(in_channels, feature_channels) + # MaxPool1 feature_map_size x 64 x 32 + # -> feature_map_size x 32 x 16 self.max_pool1 = nn.AdaptiveMaxPool2d((32, 16)) - # Conv2 64x32x16 -> 256x32x16 - self.conv2 = VGGConv2d(nf, nf * 4) - # MaxPool2 256x32x16 -> 256x16x8 + # Conv2 feature_map_size x 32 x 16 + # -> (feature_map_size*4) x 32 x 16 + self.conv2 = VGGConv2d(feature_channels, feature_channels * 4) + # MaxPool2 (feature_map_size*4) x 32 x 16 + # -> (feature_map_size*4) x 16 x 8 self.max_pool2 = nn.AdaptiveMaxPool2d((16, 8)) - # Conv3 256x16x8 -> 512x16x8 - self.conv3 = VGGConv2d(nf * 4, nf * 8) - # Conv4 512x16x8 -> 512x16x8 (for large dataset) - self.conv4 = VGGConv2d(nf * 8, nf * 8) - # MaxPool3 512x16x8 -> 512x4x2 + # Conv3 (feature_map_size*4) x 16 x 8 + # -> (feature_map_size*8) x 16 x 8 + self.conv3 = VGGConv2d(feature_channels * 4, feature_channels * 8) + # Conv4 (feature_map_size*8) x 16 x 8 + # -> (feature_map_size*8) x 16 x 8 (for large dataset) + self.conv4 = VGGConv2d(feature_channels * 8, feature_channels * 8) + # MaxPool3 (feature_map_size*8) x 16 x 8 + # -> (feature_map_size*8) x 4 x 2 self.max_pool3 = nn.AdaptiveMaxPool2d((4, 2)) - # FC 512*4*2 -> 320 - self.fc = nn.Linear(nf * 8 * 2 * 4, self.em_dim, bias=False) - self.bn_fc = nn.BatchNorm1d(self.em_dim) + + embedding_dim = sum(output_dims) + # FC (feature_map_size*8) * 4 * 2 -> 320 + self.fc = BasicLinear(feature_channels * 8 * 2 * 4, embedding_dim) def forward(self, x): x = self.conv1(x) @@ -39,39 +53,52 @@ class Encoder(nn.Module): x = self.conv3(x) x = self.conv4(x) x = self.max_pool3(x) - x = x.view(-1, (64 * 8) * 2 * 4) - embedding = self.bn_fc(self.fc(x)) + x = x.view(-1, (self.feature_channels * 8) * 2 * 4) + embedding = self.fc(x) - fa, fgs, fgd = embedding.split( - (self.opt.fa_dim, self.opt.fg_dim / 2, self.opt.fg_dim / 2), dim=1 + f_appearance, f_canonical, f_pose = embedding.split( + (self.f_a_dim, self.f_c_dim, self.f_p_dim), dim=1 ) - return fa, fgs, fgd + return f_appearance, f_canonical, f_pose class Decoder(nn.Module): - def __init__(self, out_channels: int, opt): - super(Decoder, self).__init__() - self.em_dim = opt.em_dim - nf = 64 + """Upscale embedding to original image""" + + def __init__( + self, + out_channels: int, + feature_channels: int = 64, + input_dims: tuple[int, int, int] = (128, 128, 64) + ): + super().__init__() + self.feature_channels = feature_channels + + embedding_dim = sum(input_dims) + # FC 320 -> (feature_map_size*8) * 4 * 2 + self.fc = BasicLinear(embedding_dim, feature_channels * 8 * 2 * 4) - # Cx[HxW] - # FC 320 -> 512*4*2 - self.fc = nn.Linear(self.em_dim, nf * 8 * 2 * 4) - self.bn_fc = nn.BatchNorm1d(nf * 8 * 2 * 4) - # TransConv1 512x4x2 -> 256x8x4 - self.trans_conv1 = DCGANConvTranspose2d(nf * 8, nf * 4) - # TransConv2 256x8x4 -> 128x16x8 - self.trans_conv2 = DCGANConvTranspose2d(nf * 4, nf * 2) - # TransConv3 128x16x8 -> 64x32x16 - self.trans_conv3 = DCGANConvTranspose2d(nf * 2, nf) - # TransConv4 3x32x16 - self.trans_conv4 = DCGANConvTranspose2d(nf, out_channels, + # TransConv1 (feature_map_size*8) x 4 x 2 + # -> (feature_map_size*4) x 8 x 4 + self.trans_conv1 = DCGANConvTranspose2d(feature_channels * 8, + feature_channels * 4) + # TransConv2 (feature_map_size*4) x 8 x 4 + # -> (feature_map_size*2) x 16 x 8 + self.trans_conv2 = DCGANConvTranspose2d(feature_channels * 4, + feature_channels * 2) + # TransConv3 (feature_map_size*2) x 16 x 8 + # -> feature_map_size x 32 x 16 + self.trans_conv3 = DCGANConvTranspose2d(feature_channels * 2, + feature_channels) + # TransConv4 feature_map_size x 32 x 16 + # -> in_channels x 64 x 32 + self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, fa, fgs, fgd): - x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim) - x = self.bn_fc(self.fc(x)) - x = F.relu(x.view(-1, 64 * 8, 4, 2), True) + def forward(self, f_appearance, f_canonical, f_pose): + x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) + x = self.fc(x) + x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) x = self.trans_conv1(x) x = self.trans_conv2(x) x = self.trans_conv3(x) |