summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-23 20:59:44 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-23 20:59:44 +0800
commit46624a615429232cee01be670d925dd593ceb6a3 (patch)
treeaff484f6538147420063fdc6037b044266066ade
parentb7db891e0756fb490466246cf802358b1265a0c9 (diff)
Refactor and refine auto-encoder
1. Wrap Conv2d 3x3-padding-1 to VGGConv2d 2. Wrap ConvTranspose2d 4x4-stride-4-padding-1 to DCGANConvTranspose2d 3. Turn off bias in conv since the employment of batch normalization
-rw-r--r--models/auto_encoder.py22
-rw-r--r--models/layers.py68
2 files changed, 78 insertions, 12 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 1708bc9..bb4a377 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -2,6 +2,8 @@ import torch
from torch import nn as nn
from torch.nn import functional as F
+from models.layers import VGGConv2d, DCGANConvTranspose2d
+
class Encoder(nn.Module):
def __init__(self, in_channels: int, opt):
@@ -12,20 +14,20 @@ class Encoder(nn.Module):
# Cx[HxW]
# Conv1 3x64x32 -> 64x64x32
- self.conv1 = nn.Conv2d(in_channels, nf, kernel_size=3, padding=1)
+ self.conv1 = VGGConv2d(in_channels, nf)
self.batch_norm1 = nn.BatchNorm2d(nf)
# MaxPool1 64x64x32 -> 64x32x16
self.max_pool1 = nn.AdaptiveMaxPool2d((32, 16))
# Conv2 64x32x16 -> 256x32x16
- self.conv2 = nn.Conv2d(nf, nf * 4, kernel_size=3, padding=1)
+ self.conv2 = VGGConv2d(nf, nf * 4)
self.batch_norm2 = nn.BatchNorm2d(nf * 4)
# MaxPool2 256x32x16 -> 256x16x8
self.max_pool2 = nn.AdaptiveMaxPool2d((16, 8))
# Conv3 256x16x8 -> 512x16x8
- self.conv3 = nn.Conv2d(nf * 4, nf * 8, kernel_size=3, padding=1)
+ self.conv3 = VGGConv2d(nf * 4, nf * 8)
self.batch_norm3 = nn.BatchNorm2d(nf * 8)
# Conv4 512x16x8 -> 512x16x8 (for large dataset)
- self.conv4 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, padding=1)
+ self.conv4 = VGGConv2d(nf * 8, nf * 8)
self.batch_norm4 = nn.BatchNorm2d(nf * 8)
# MaxPool3 512x16x8 -> 512x4x2
self.max_pool3 = nn.AdaptiveMaxPool2d((4, 2))
@@ -61,20 +63,16 @@ class Decoder(nn.Module):
self.fc = nn.Linear(self.em_dim, nf * 8 * 2 * 4)
self.batch_norm_fc = nn.BatchNorm1d(nf * 8 * 2 * 4)
# TransConv1 512x4x2 -> 256x8x4
- self.trans_conv1 = nn.ConvTranspose2d(nf * 8, nf * 4, kernel_size=4,
- stride=2, padding=1)
+ self.trans_conv1 = DCGANConvTranspose2d(nf * 8, nf * 4)
self.batch_norm1 = nn.BatchNorm2d(nf * 4)
# TransConv2 256x8x4 -> 128x16x8
- self.trans_conv2 = nn.ConvTranspose2d(nf * 4, nf * 2, kernel_size=4,
- stride=2, padding=1)
+ self.trans_conv2 = DCGANConvTranspose2d(nf * 4, nf * 2)
self.batch_norm2 = nn.BatchNorm2d(nf * 2)
# TransConv3 128x16x8 -> 64x32x16
- self.trans_conv3 = nn.ConvTranspose2d(nf * 2, nf, kernel_size=4,
- stride=2, padding=1)
+ self.trans_conv3 = DCGANConvTranspose2d(nf * 2, nf)
self.batch_norm3 = nn.BatchNorm2d(nf)
# TransConv4 3x32x16
- self.trans_conv4 = nn.ConvTranspose2d(nf, out_channels, kernel_size=4,
- stride=2, padding=1)
+ self.trans_conv4 = DCGANConvTranspose2d(nf, out_channels)
def forward(self, fa, fgs, fgd):
x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim)
diff --git a/models/layers.py b/models/layers.py
index 2be93ad..f3ccbeb 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -4,6 +4,74 @@ import torch
import torch.nn as nn
+class BasicConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ **kwargs
+ ):
+ super(BasicConv2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
+ bias=False, **kwargs)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class VGGConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]] = 3,
+ padding: int = 1,
+ **kwargs
+ ):
+ super(VGGConv2d, self).__init__()
+ self.conv = BasicConv2d(in_channels, out_channels, kernel_size,
+ padding=padding, **kwargs)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class BasicConvTranspose2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ **kwargs
+ ):
+ super(BasicConvTranspose2d, self).__init__()
+ self.trans_conv = nn.ConvTranspose2d(in_channels, out_channels,
+ kernel_size, bias=False, **kwargs)
+
+ def forward(self, x):
+ return self.trans_conv(x)
+
+
+class DCGANConvTranspose2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]] = 4,
+ stride: int = 2,
+ padding: int = 1,
+ **kwargs
+ ):
+ super(DCGANConvTranspose2d).__init__()
+ self.trans_conv = BasicConvTranspose2d(in_channels, out_channels,
+ kernel_size, stride=stride,
+ padding=padding, **kwargs)
+
+ def forward(self, x):
+ return self.trans_conv(x)
+
+
class FocalConv2d(nn.Module):
def __init__(
self,