From 46624a615429232cee01be670d925dd593ceb6a3 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 23 Dec 2020 20:59:44 +0800 Subject: Refactor and refine auto-encoder 1. Wrap Conv2d 3x3-padding-1 to VGGConv2d 2. Wrap ConvTranspose2d 4x4-stride-4-padding-1 to DCGANConvTranspose2d 3. Turn off bias in conv since the employment of batch normalization --- models/auto_encoder.py | 22 ++++++++-------- models/layers.py | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 1708bc9..bb4a377 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -2,6 +2,8 @@ import torch from torch import nn as nn from torch.nn import functional as F +from models.layers import VGGConv2d, DCGANConvTranspose2d + class Encoder(nn.Module): def __init__(self, in_channels: int, opt): @@ -12,20 +14,20 @@ class Encoder(nn.Module): # Cx[HxW] # Conv1 3x64x32 -> 64x64x32 - self.conv1 = nn.Conv2d(in_channels, nf, kernel_size=3, padding=1) + self.conv1 = VGGConv2d(in_channels, nf) self.batch_norm1 = nn.BatchNorm2d(nf) # MaxPool1 64x64x32 -> 64x32x16 self.max_pool1 = nn.AdaptiveMaxPool2d((32, 16)) # Conv2 64x32x16 -> 256x32x16 - self.conv2 = nn.Conv2d(nf, nf * 4, kernel_size=3, padding=1) + self.conv2 = VGGConv2d(nf, nf * 4) self.batch_norm2 = nn.BatchNorm2d(nf * 4) # MaxPool2 256x32x16 -> 256x16x8 self.max_pool2 = nn.AdaptiveMaxPool2d((16, 8)) # Conv3 256x16x8 -> 512x16x8 - self.conv3 = nn.Conv2d(nf * 4, nf * 8, kernel_size=3, padding=1) + self.conv3 = VGGConv2d(nf * 4, nf * 8) self.batch_norm3 = nn.BatchNorm2d(nf * 8) # Conv4 512x16x8 -> 512x16x8 (for large dataset) - self.conv4 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, padding=1) + self.conv4 = VGGConv2d(nf * 8, nf * 8) self.batch_norm4 = nn.BatchNorm2d(nf * 8) # MaxPool3 512x16x8 -> 512x4x2 self.max_pool3 = nn.AdaptiveMaxPool2d((4, 2)) @@ -61,20 +63,16 @@ class Decoder(nn.Module): self.fc = nn.Linear(self.em_dim, nf * 8 * 2 * 4) self.batch_norm_fc = nn.BatchNorm1d(nf * 8 * 2 * 4) # TransConv1 512x4x2 -> 256x8x4 - self.trans_conv1 = nn.ConvTranspose2d(nf * 8, nf * 4, kernel_size=4, - stride=2, padding=1) + self.trans_conv1 = DCGANConvTranspose2d(nf * 8, nf * 4) self.batch_norm1 = nn.BatchNorm2d(nf * 4) # TransConv2 256x8x4 -> 128x16x8 - self.trans_conv2 = nn.ConvTranspose2d(nf * 4, nf * 2, kernel_size=4, - stride=2, padding=1) + self.trans_conv2 = DCGANConvTranspose2d(nf * 4, nf * 2) self.batch_norm2 = nn.BatchNorm2d(nf * 2) # TransConv3 128x16x8 -> 64x32x16 - self.trans_conv3 = nn.ConvTranspose2d(nf * 2, nf, kernel_size=4, - stride=2, padding=1) + self.trans_conv3 = DCGANConvTranspose2d(nf * 2, nf) self.batch_norm3 = nn.BatchNorm2d(nf) # TransConv4 3x32x16 - self.trans_conv4 = nn.ConvTranspose2d(nf, out_channels, kernel_size=4, - stride=2, padding=1) + self.trans_conv4 = DCGANConvTranspose2d(nf, out_channels) def forward(self, fa, fgs, fgd): x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim) diff --git a/models/layers.py b/models/layers.py index 2be93ad..f3ccbeb 100644 --- a/models/layers.py +++ b/models/layers.py @@ -4,6 +4,74 @@ import torch import torch.nn as nn +class BasicConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + **kwargs + ): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, + bias=False, **kwargs) + + def forward(self, x): + return self.conv(x) + + +class VGGConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]] = 3, + padding: int = 1, + **kwargs + ): + super(VGGConv2d, self).__init__() + self.conv = BasicConv2d(in_channels, out_channels, kernel_size, + padding=padding, **kwargs) + + def forward(self, x): + return self.conv(x) + + +class BasicConvTranspose2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + **kwargs + ): + super(BasicConvTranspose2d, self).__init__() + self.trans_conv = nn.ConvTranspose2d(in_channels, out_channels, + kernel_size, bias=False, **kwargs) + + def forward(self, x): + return self.trans_conv(x) + + +class DCGANConvTranspose2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]] = 4, + stride: int = 2, + padding: int = 1, + **kwargs + ): + super(DCGANConvTranspose2d).__init__() + self.trans_conv = BasicConvTranspose2d(in_channels, out_channels, + kernel_size, stride=stride, + padding=padding, **kwargs) + + def forward(self, x): + return self.trans_conv(x) + + class FocalConv2d(nn.Module): def __init__( self, -- cgit v1.2.3