summaryrefslogtreecommitdiff
path: root/models/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/layers.py')
-rw-r--r--models/layers.py49
1 files changed, 29 insertions, 20 deletions
diff --git a/models/layers.py b/models/layers.py
index f3ccbeb..2dc87a1 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -1,6 +1,7 @@
from typing import Union, Tuple
import torch
+import torch.nn.functional as F
import torch.nn as nn
@@ -12,15 +13,18 @@ class BasicConv2d(nn.Module):
kernel_size: Union[int, Tuple[int, int]],
**kwargs
):
- super(BasicConv2d, self).__init__()
+ super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
bias=False, **kwargs)
+ self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
- return self.conv(x)
+ x = self.conv(x)
+ x = self.bn(x)
+ return F.relu(x, inplace=True)
-class VGGConv2d(nn.Module):
+class VGGConv2d(BasicConv2d):
def __init__(
self,
in_channels: int,
@@ -29,12 +33,13 @@ class VGGConv2d(nn.Module):
padding: int = 1,
**kwargs
):
- super(VGGConv2d, self).__init__()
- self.conv = BasicConv2d(in_channels, out_channels, kernel_size,
- padding=padding, **kwargs)
+ super().__init__(in_channels, out_channels, kernel_size,
+ padding=padding, **kwargs)
def forward(self, x):
- return self.conv(x)
+ x = self.conv(x)
+ x = self.bn(x)
+ return F.leaky_relu(x, 0.2, inplace=True)
class BasicConvTranspose2d(nn.Module):
@@ -45,15 +50,18 @@ class BasicConvTranspose2d(nn.Module):
kernel_size: Union[int, Tuple[int, int]],
**kwargs
):
- super(BasicConvTranspose2d, self).__init__()
+ super().__init__()
self.trans_conv = nn.ConvTranspose2d(in_channels, out_channels,
kernel_size, bias=False, **kwargs)
+ self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
- return self.trans_conv(x)
+ x = self.trans_conv(x)
+ x = self.bn(x)
+ return F.relu(x, inplace=True)
-class DCGANConvTranspose2d(nn.Module):
+class DCGANConvTranspose2d(BasicConvTranspose2d):
def __init__(
self,
in_channels: int,
@@ -61,18 +69,21 @@ class DCGANConvTranspose2d(nn.Module):
kernel_size: Union[int, Tuple[int, int]] = 4,
stride: int = 2,
padding: int = 1,
+ is_last_layer: bool = False,
**kwargs
):
- super(DCGANConvTranspose2d).__init__()
- self.trans_conv = BasicConvTranspose2d(in_channels, out_channels,
- kernel_size, stride=stride,
- padding=padding, **kwargs)
+ super().__init__(in_channels, out_channels, kernel_size,
+ stride=stride, padding=padding, **kwargs)
+ self.is_last_layer = is_last_layer
def forward(self, x):
- return self.trans_conv(x)
+ if self.is_last_layer:
+ return self.trans_conv(x)
+ else:
+ return super().forward(x)
-class FocalConv2d(nn.Module):
+class FocalConv2d(BasicConv2d):
def __init__(
self,
in_channels: int,
@@ -81,17 +92,15 @@ class FocalConv2d(nn.Module):
halving: int,
**kwargs
):
- super(FocalConv2d, self).__init__()
+ super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.halving = halving
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
- bias=False, **kwargs)
def forward(self, x):
h = x.size(2)
split_size = h // 2 ** self.halving
z = x.split(split_size, dim=2)
z = torch.cat([self.conv(_) for _ in z], dim=2)
- return z
+ return F.leaky_relu(z, inplace=True)
class BasicConv1d(nn.Module):