diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-23 20:59:44 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-23 20:59:44 +0800 |
commit | 46624a615429232cee01be670d925dd593ceb6a3 (patch) | |
tree | aff484f6538147420063fdc6037b044266066ade /models/layers.py | |
parent | b7db891e0756fb490466246cf802358b1265a0c9 (diff) |
Refactor and refine auto-encoder
1. Wrap Conv2d 3x3-padding-1 to VGGConv2d
2. Wrap ConvTranspose2d 4x4-stride-4-padding-1 to DCGANConvTranspose2d
3. Turn off bias in conv since the employment of batch normalization
Diffstat (limited to 'models/layers.py')
-rw-r--r-- | models/layers.py | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/models/layers.py b/models/layers.py index 2be93ad..f3ccbeb 100644 --- a/models/layers.py +++ b/models/layers.py @@ -4,6 +4,74 @@ import torch import torch.nn as nn +class BasicConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + **kwargs + ): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, + bias=False, **kwargs) + + def forward(self, x): + return self.conv(x) + + +class VGGConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]] = 3, + padding: int = 1, + **kwargs + ): + super(VGGConv2d, self).__init__() + self.conv = BasicConv2d(in_channels, out_channels, kernel_size, + padding=padding, **kwargs) + + def forward(self, x): + return self.conv(x) + + +class BasicConvTranspose2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + **kwargs + ): + super(BasicConvTranspose2d, self).__init__() + self.trans_conv = nn.ConvTranspose2d(in_channels, out_channels, + kernel_size, bias=False, **kwargs) + + def forward(self, x): + return self.trans_conv(x) + + +class DCGANConvTranspose2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]] = 4, + stride: int = 2, + padding: int = 1, + **kwargs + ): + super(DCGANConvTranspose2d).__init__() + self.trans_conv = BasicConvTranspose2d(in_channels, out_channels, + kernel_size, stride=stride, + padding=padding, **kwargs) + + def forward(self, x): + return self.trans_conv(x) + + class FocalConv2d(nn.Module): def __init__( self, |