diff options
Diffstat (limited to 'models/layers.py')
-rw-r--r-- | models/layers.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/models/layers.py b/models/layers.py index a0e35f0..2be93ad 100644 --- a/models/layers.py +++ b/models/layers.py @@ -24,3 +24,19 @@ class FocalConv2d(nn.Module): z = x.split(split_size, dim=2) z = torch.cat([self.conv(_) for _ in z], dim=2) return z + + +class BasicConv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]], + **kwargs + ): + super(BasicConv1d, self).__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, + bias=False, **kwargs) + + def forward(self, x): + return self.conv(x) |