From 5a95c94e9f250001d0007b5ac238505d0a5f23b5 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 27 Dec 2020 20:20:12 +0800 Subject: Refine auto-encoder 1. Wrap fully connected layers 2. Introduce hyperparameter tuning in constructor --- models/auto_encoder.py | 119 ++++++++++++++++++++++++++++++------------------- models/layers.py | 16 +++++++ 2 files changed, 89 insertions(+), 46 deletions(-) (limited to 'models') diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 4a8aaa5..701f299 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -2,34 +2,48 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.layers import VGGConv2d, DCGANConvTranspose2d +from models.layers import VGGConv2d, DCGANConvTranspose2d, BasicLinear class Encoder(nn.Module): - def __init__(self, in_channels: int, opt): - super(Encoder, self).__init__() - self.opt = opt - self.em_dim = opt.em_dim - nf = 64 + """Squeeze input feature map to lower dimension""" - # Cx[HxW] - # Conv1 3x64x32 -> 64x64x32 - self.conv1 = VGGConv2d(in_channels, nf) - # MaxPool1 64x64x32 -> 64x32x16 + def __init__( + self, + in_channels: int = 3, + feature_channels: int = 64, + output_dims: tuple[int, int, int] = (128, 128, 64) + ): + super().__init__() + self.feature_channels = feature_channels + # Appearance features, canonical features, pose features + (self.f_a_dim, self.f_c_dim, self.f_p_dim) = output_dims + + # Conv1 in_channels x 64 x 32 + # -> feature_map_size x 64 x 32 + self.conv1 = VGGConv2d(in_channels, feature_channels) + # MaxPool1 feature_map_size x 64 x 32 + # -> feature_map_size x 32 x 16 self.max_pool1 = nn.AdaptiveMaxPool2d((32, 16)) - # Conv2 64x32x16 -> 256x32x16 - self.conv2 = VGGConv2d(nf, nf * 4) - # MaxPool2 256x32x16 -> 256x16x8 + # Conv2 feature_map_size x 32 x 16 + # -> (feature_map_size*4) x 32 x 16 + self.conv2 = VGGConv2d(feature_channels, feature_channels * 4) + # MaxPool2 (feature_map_size*4) x 32 x 16 + # -> (feature_map_size*4) x 16 x 8 self.max_pool2 = nn.AdaptiveMaxPool2d((16, 8)) - # Conv3 256x16x8 -> 512x16x8 - self.conv3 = VGGConv2d(nf * 4, nf * 8) - # Conv4 512x16x8 -> 512x16x8 (for large dataset) - self.conv4 = VGGConv2d(nf * 8, nf * 8) - # MaxPool3 512x16x8 -> 512x4x2 + # Conv3 (feature_map_size*4) x 16 x 8 + # -> (feature_map_size*8) x 16 x 8 + self.conv3 = VGGConv2d(feature_channels * 4, feature_channels * 8) + # Conv4 (feature_map_size*8) x 16 x 8 + # -> (feature_map_size*8) x 16 x 8 (for large dataset) + self.conv4 = VGGConv2d(feature_channels * 8, feature_channels * 8) + # MaxPool3 (feature_map_size*8) x 16 x 8 + # -> (feature_map_size*8) x 4 x 2 self.max_pool3 = nn.AdaptiveMaxPool2d((4, 2)) - # FC 512*4*2 -> 320 - self.fc = nn.Linear(nf * 8 * 2 * 4, self.em_dim, bias=False) - self.bn_fc = nn.BatchNorm1d(self.em_dim) + + embedding_dim = sum(output_dims) + # FC (feature_map_size*8) * 4 * 2 -> 320 + self.fc = BasicLinear(feature_channels * 8 * 2 * 4, embedding_dim) def forward(self, x): x = self.conv1(x) @@ -39,39 +53,52 @@ class Encoder(nn.Module): x = self.conv3(x) x = self.conv4(x) x = self.max_pool3(x) - x = x.view(-1, (64 * 8) * 2 * 4) - embedding = self.bn_fc(self.fc(x)) + x = x.view(-1, (self.feature_channels * 8) * 2 * 4) + embedding = self.fc(x) - fa, fgs, fgd = embedding.split( - (self.opt.fa_dim, self.opt.fg_dim / 2, self.opt.fg_dim / 2), dim=1 + f_appearance, f_canonical, f_pose = embedding.split( + (self.f_a_dim, self.f_c_dim, self.f_p_dim), dim=1 ) - return fa, fgs, fgd + return f_appearance, f_canonical, f_pose class Decoder(nn.Module): - def __init__(self, out_channels: int, opt): - super(Decoder, self).__init__() - self.em_dim = opt.em_dim - nf = 64 + """Upscale embedding to original image""" + + def __init__( + self, + out_channels: int, + feature_channels: int = 64, + input_dims: tuple[int, int, int] = (128, 128, 64) + ): + super().__init__() + self.feature_channels = feature_channels + + embedding_dim = sum(input_dims) + # FC 320 -> (feature_map_size*8) * 4 * 2 + self.fc = BasicLinear(embedding_dim, feature_channels * 8 * 2 * 4) - # Cx[HxW] - # FC 320 -> 512*4*2 - self.fc = nn.Linear(self.em_dim, nf * 8 * 2 * 4) - self.bn_fc = nn.BatchNorm1d(nf * 8 * 2 * 4) - # TransConv1 512x4x2 -> 256x8x4 - self.trans_conv1 = DCGANConvTranspose2d(nf * 8, nf * 4) - # TransConv2 256x8x4 -> 128x16x8 - self.trans_conv2 = DCGANConvTranspose2d(nf * 4, nf * 2) - # TransConv3 128x16x8 -> 64x32x16 - self.trans_conv3 = DCGANConvTranspose2d(nf * 2, nf) - # TransConv4 3x32x16 - self.trans_conv4 = DCGANConvTranspose2d(nf, out_channels, + # TransConv1 (feature_map_size*8) x 4 x 2 + # -> (feature_map_size*4) x 8 x 4 + self.trans_conv1 = DCGANConvTranspose2d(feature_channels * 8, + feature_channels * 4) + # TransConv2 (feature_map_size*4) x 8 x 4 + # -> (feature_map_size*2) x 16 x 8 + self.trans_conv2 = DCGANConvTranspose2d(feature_channels * 4, + feature_channels * 2) + # TransConv3 (feature_map_size*2) x 16 x 8 + # -> feature_map_size x 32 x 16 + self.trans_conv3 = DCGANConvTranspose2d(feature_channels * 2, + feature_channels) + # TransConv4 feature_map_size x 32 x 16 + # -> in_channels x 64 x 32 + self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, fa, fgs, fgd): - x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim) - x = self.bn_fc(self.fc(x)) - x = F.relu(x.view(-1, 64 * 8, 4, 2), True) + def forward(self, f_appearance, f_canonical, f_pose): + x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) + x = self.fc(x) + x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) x = self.trans_conv1(x) x = self.trans_conv2(x) x = self.trans_conv3(x) diff --git a/models/layers.py b/models/layers.py index cba6e47..f824078 100644 --- a/models/layers.py +++ b/models/layers.py @@ -83,6 +83,22 @@ class DCGANConvTranspose2d(BasicConvTranspose2d): return super().forward(x) +class BasicLinear(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + ): + super().__init__() + self.fc = nn.Linear(in_features, out_features, bias=False) + self.bn = nn.BatchNorm1d(out_features) + + def forward(self, x): + x = self.fc(x) + x = self.bn(x) + return x + + class FocalConv2d(BasicConv2d): def __init__( self, -- cgit v1.2.3