summaryrefslogtreecommitdiff
path: root/models/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/layers.py')
-rw-r--r--models/layers.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/models/layers.py b/models/layers.py
index 2be93ad..f3ccbeb 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -4,6 +4,74 @@ import torch
import torch.nn as nn
+class BasicConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ **kwargs
+ ):
+ super(BasicConv2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
+ bias=False, **kwargs)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class VGGConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]] = 3,
+ padding: int = 1,
+ **kwargs
+ ):
+ super(VGGConv2d, self).__init__()
+ self.conv = BasicConv2d(in_channels, out_channels, kernel_size,
+ padding=padding, **kwargs)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class BasicConvTranspose2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ **kwargs
+ ):
+ super(BasicConvTranspose2d, self).__init__()
+ self.trans_conv = nn.ConvTranspose2d(in_channels, out_channels,
+ kernel_size, bias=False, **kwargs)
+
+ def forward(self, x):
+ return self.trans_conv(x)
+
+
+class DCGANConvTranspose2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]] = 4,
+ stride: int = 2,
+ padding: int = 1,
+ **kwargs
+ ):
+ super(DCGANConvTranspose2d).__init__()
+ self.trans_conv = BasicConvTranspose2d(in_channels, out_channels,
+ kernel_size, stride=stride,
+ padding=padding, **kwargs)
+
+ def forward(self, x):
+ return self.trans_conv(x)
+
+
class FocalConv2d(nn.Module):
def __init__(
self,