diff options
Diffstat (limited to 'models/layers.py')
-rw-r--r-- | models/layers.py | 49 |
1 files changed, 29 insertions, 20 deletions
diff --git a/models/layers.py b/models/layers.py index f3ccbeb..2dc87a1 100644 --- a/models/layers.py +++ b/models/layers.py @@ -1,6 +1,7 @@ from typing import Union, Tuple import torch +import torch.nn.functional as F import torch.nn as nn @@ -12,15 +13,18 @@ class BasicConv2d(nn.Module): kernel_size: Union[int, Tuple[int, int]], **kwargs ): - super(BasicConv2d, self).__init__() + super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): - return self.conv(x) + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) -class VGGConv2d(nn.Module): +class VGGConv2d(BasicConv2d): def __init__( self, in_channels: int, @@ -29,12 +33,13 @@ class VGGConv2d(nn.Module): padding: int = 1, **kwargs ): - super(VGGConv2d, self).__init__() - self.conv = BasicConv2d(in_channels, out_channels, kernel_size, - padding=padding, **kwargs) + super().__init__(in_channels, out_channels, kernel_size, + padding=padding, **kwargs) def forward(self, x): - return self.conv(x) + x = self.conv(x) + x = self.bn(x) + return F.leaky_relu(x, 0.2, inplace=True) class BasicConvTranspose2d(nn.Module): @@ -45,15 +50,18 @@ class BasicConvTranspose2d(nn.Module): kernel_size: Union[int, Tuple[int, int]], **kwargs ): - super(BasicConvTranspose2d, self).__init__() + super().__init__() self.trans_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): - return self.trans_conv(x) + x = self.trans_conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) -class DCGANConvTranspose2d(nn.Module): +class DCGANConvTranspose2d(BasicConvTranspose2d): def __init__( self, in_channels: int, @@ -61,18 +69,21 @@ class DCGANConvTranspose2d(nn.Module): kernel_size: Union[int, Tuple[int, int]] = 4, stride: int = 2, padding: int = 1, + is_last_layer: bool = False, **kwargs ): - super(DCGANConvTranspose2d).__init__() - self.trans_conv = BasicConvTranspose2d(in_channels, out_channels, - kernel_size, stride=stride, - padding=padding, **kwargs) + super().__init__(in_channels, out_channels, kernel_size, + stride=stride, padding=padding, **kwargs) + self.is_last_layer = is_last_layer def forward(self, x): - return self.trans_conv(x) + if self.is_last_layer: + return self.trans_conv(x) + else: + return super().forward(x) -class FocalConv2d(nn.Module): +class FocalConv2d(BasicConv2d): def __init__( self, in_channels: int, @@ -81,17 +92,15 @@ class FocalConv2d(nn.Module): halving: int, **kwargs ): - super(FocalConv2d, self).__init__() + super().__init__(in_channels, out_channels, kernel_size, **kwargs) self.halving = halving - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, - bias=False, **kwargs) def forward(self, x): h = x.size(2) split_size = h // 2 ** self.halving z = x.split(split_size, dim=2) z = torch.cat([self.conv(_) for _ in z], dim=2) - return z + return F.leaky_relu(z, inplace=True) class BasicConv1d(nn.Module): |