diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-23 20:15:50 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-23 20:15:50 +0800 |
commit | b7db891e0756fb490466246cf802358b1265a0c9 (patch) | |
tree | 9e7418ecfb2b1206db7fd6af0fa53da7c1f3712d | |
parent | 7460ce2c8904a009f2f1139b11ec18faf208d6d2 (diff) |
Wrap Conv1d no bias layer
-rw-r--r-- | models/layers.py | 16 | ||||
-rw-r--r-- | models/tfa.py | 16 |
2 files changed, 23 insertions, 9 deletions
diff --git a/models/layers.py b/models/layers.py index a0e35f0..2be93ad 100644 --- a/models/layers.py +++ b/models/layers.py @@ -24,3 +24,19 @@ class FocalConv2d(nn.Module): z = x.split(split_size, dim=2) z = torch.cat([self.conv(_) for _ in z], dim=2) return z + + +class BasicConv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]], + **kwargs + ): + super(BasicConv1d, self).__init__() + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, + bias=False, **kwargs) + + def forward(self, x): + return self.conv(x) diff --git a/models/tfa.py b/models/tfa.py index 2e4e656..72bfcfb 100644 --- a/models/tfa.py +++ b/models/tfa.py @@ -3,6 +3,8 @@ import copy import torch from torch import nn as nn +from models.layers import BasicConv1d + class TemporalFeatureAggregator(nn.Module): def __init__( @@ -17,11 +19,9 @@ class TemporalFeatureAggregator(nn.Module): # MTB1 conv3x1 = nn.Sequential( - nn.Conv1d(in_channels, hidden_dim, - kernel_size=3, padding=1, bias=False), + BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1), nn.LeakyReLU(inplace=True), - nn.Conv1d(hidden_dim, in_channels, - kernel_size=1, padding=0, bias=False) + BasicConv1d(hidden_dim, in_channels, kernel_size=1, padding=0) ) self.conv1d3x1 = self._parted(conv3x1) self.avg_pool3x1 = nn.AvgPool1d(kernel_size=3, stride=1, padding=1) @@ -29,11 +29,9 @@ class TemporalFeatureAggregator(nn.Module): # MTB2 conv3x3 = nn.Sequential( - nn.Conv1d(in_channels, hidden_dim, - kernel_size=3, padding=1, bias=False), + BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1), nn.LeakyReLU(inplace=True), - nn.Conv1d(hidden_dim, in_channels, - kernel_size=3, padding=1, bias=False) + BasicConv1d(hidden_dim, in_channels, kernel_size=3, padding=1) ) self.conv1d3x3 = self._parted(conv3x3) self.avg_pool3x3 = nn.AvgPool1d(kernel_size=5, stride=1, padding=2) @@ -76,4 +74,4 @@ class TemporalFeatureAggregator(nn.Module): # Temporal Pooling ret = (feature3x1 + feature3x3).max(-1)[0] - return ret
\ No newline at end of file + return ret |