summaryrefslogtreecommitdiff
path: root/models/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/layers.py')
-rw-r--r--models/layers.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/models/layers.py b/models/layers.py
index 62a3cc6..c69ae07 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -119,6 +119,33 @@ class FocalConv2d(BasicConv2d):
return F.leaky_relu(z, inplace=True)
+class FocalConv2dBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_sizes: tuple[int, int],
+ paddings: tuple[int, int],
+ halving: int,
+ use_pool: bool = True,
+ **kwargs
+ ):
+ super().__init__()
+ self.use_pool = use_pool
+ self.fconv1 = FocalConv2d(in_channels, out_channels, kernel_sizes[0],
+ halving, padding=paddings[0], **kwargs)
+ self.fconv2 = FocalConv2d(out_channels, out_channels, kernel_sizes[1],
+ halving, padding=paddings[1], **kwargs)
+ self.max_pool = nn.MaxPool2d(2)
+
+ def forward(self, x):
+ x = self.fconv1(x)
+ x = self.fconv2(x)
+ if self.use_pool:
+ x = self.max_pool(x)
+ return x
+
+
class BasicConv1d(nn.Module):
def __init__(
self,