From 4ad6d0af2f4e6868576582e4b86b096a2e95e6fa Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 24 Dec 2020 15:42:42 +0800 Subject: Change the usage of layers and reorganize relations of layers 1. Add batch normalization and activation to layers 2. VGGConv2d and FocalConv2d inherits to BasicConv2d; DCGANConvTranspose2d inherits to BasicConvTranspose2d --- models/auto_encoder.py | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) (limited to 'models/auto_encoder.py') diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 1be878f..6483bd9 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -15,36 +15,32 @@ class Encoder(nn.Module): # Cx[HxW] # Conv1 3x64x32 -> 64x64x32 self.conv1 = VGGConv2d(in_channels, nf) - self.batch_norm1 = nn.BatchNorm2d(nf) # MaxPool1 64x64x32 -> 64x32x16 self.max_pool1 = nn.AdaptiveMaxPool2d((32, 16)) # Conv2 64x32x16 -> 256x32x16 self.conv2 = VGGConv2d(nf, nf * 4) - self.batch_norm2 = nn.BatchNorm2d(nf * 4) # MaxPool2 256x32x16 -> 256x16x8 self.max_pool2 = nn.AdaptiveMaxPool2d((16, 8)) # Conv3 256x16x8 -> 512x16x8 self.conv3 = VGGConv2d(nf * 4, nf * 8) - self.batch_norm3 = nn.BatchNorm2d(nf * 8) # Conv4 512x16x8 -> 512x16x8 (for large dataset) self.conv4 = VGGConv2d(nf * 8, nf * 8) - self.batch_norm4 = nn.BatchNorm2d(nf * 8) # MaxPool3 512x16x8 -> 512x4x2 self.max_pool3 = nn.AdaptiveMaxPool2d((4, 2)) # FC 512*4*2 -> 320 - self.fc = nn.Linear(nf * 8 * 2 * 4, self.em_dim) - self.batch_norm_fc = nn.BatchNorm1d(self.em_dim) + self.fc = nn.Linear(nf * 8 * 2 * 4, self.em_dim, bias=False) + self.bn_fc = nn.BatchNorm1d(self.em_dim) def forward(self, x): - x = F.leaky_relu(self.batch_norm1(self.conv1(x)), 0.2, inplace=True) + x = self.conv1(x) x = self.max_pool1(x) - x = F.leaky_relu(self.batch_norm2(self.conv2(x)), 0.2, inplace=True) + x = self.conv2(x) x = self.max_pool2(x) - x = F.leaky_relu(self.batch_norm3(self.conv3(x)), 0.2, inplace=True) - x = F.leaky_relu(self.batch_norm4(self.conv4(x)), 0.2, inplace=True) + x = self.conv3(x) + x = self.conv4(x) x = self.max_pool3(x) x = x.view(-1, (64 * 8) * 2 * 4) - embedding = self.batch_norm_fc(self.fc(x)) + embedding = self.bn_fc(self.fc(x)) fa, fgs, fgd = embedding.split( (self.opt.fa_dim, self.opt.fg_dim / 2, self.opt.fg_dim / 2), dim=1 @@ -61,26 +57,24 @@ class Decoder(nn.Module): # Cx[HxW] # FC 320 -> 512*4*2 self.fc = nn.Linear(self.em_dim, nf * 8 * 2 * 4) - self.batch_norm_fc = nn.BatchNorm1d(nf * 8 * 2 * 4) + self.bn_fc = nn.BatchNorm1d(nf * 8 * 2 * 4) # TransConv1 512x4x2 -> 256x8x4 self.trans_conv1 = DCGANConvTranspose2d(nf * 8, nf * 4) - self.batch_norm1 = nn.BatchNorm2d(nf * 4) # TransConv2 256x8x4 -> 128x16x8 self.trans_conv2 = DCGANConvTranspose2d(nf * 4, nf * 2) - self.batch_norm2 = nn.BatchNorm2d(nf * 2) # TransConv3 128x16x8 -> 64x32x16 self.trans_conv3 = DCGANConvTranspose2d(nf * 2, nf) - self.batch_norm3 = nn.BatchNorm2d(nf) # TransConv4 3x32x16 - self.trans_conv4 = DCGANConvTranspose2d(nf, out_channels) + self.trans_conv4 = DCGANConvTranspose2d(nf, out_channels, + is_last_layer=True) def forward(self, fa, fgs, fgd): x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim) - x = F.relu(self.batch_norm_fc(self.fc(x)), True) - x = x.view(-1, 64 * 8, 4, 2) - x = F.relu(self.batch_norm1(self.trans_conv1(x)), True) - x = F.relu(self.batch_norm2(self.trans_conv2(x)), True) - x = F.relu(self.batch_norm3(self.trans_conv3(x)), True) + x = self.bn_fc(self.fc(x)) + x = F.relu(x.view(-1, 64 * 8, 4, 2), True) + x = self.trans_conv1(x) + x = self.trans_conv2(x) + x = self.trans_conv3(x) x = F.sigmoid(self.trans_conv4(x)) return x -- cgit v1.2.3