summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/layers.py16
-rw-r--r--models/tfa.py16
2 files changed, 23 insertions, 9 deletions
diff --git a/models/layers.py b/models/layers.py
index a0e35f0..2be93ad 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -24,3 +24,19 @@ class FocalConv2d(nn.Module):
z = x.split(split_size, dim=2)
z = torch.cat([self.conv(_) for _ in z], dim=2)
return z
+
+
+class BasicConv1d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int]],
+ **kwargs
+ ):
+ super(BasicConv1d, self).__init__()
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
+ bias=False, **kwargs)
+
+ def forward(self, x):
+ return self.conv(x)
diff --git a/models/tfa.py b/models/tfa.py
index 2e4e656..72bfcfb 100644
--- a/models/tfa.py
+++ b/models/tfa.py
@@ -3,6 +3,8 @@ import copy
import torch
from torch import nn as nn
+from models.layers import BasicConv1d
+
class TemporalFeatureAggregator(nn.Module):
def __init__(
@@ -17,11 +19,9 @@ class TemporalFeatureAggregator(nn.Module):
# MTB1
conv3x1 = nn.Sequential(
- nn.Conv1d(in_channels, hidden_dim,
- kernel_size=3, padding=1, bias=False),
+ BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1),
nn.LeakyReLU(inplace=True),
- nn.Conv1d(hidden_dim, in_channels,
- kernel_size=1, padding=0, bias=False)
+ BasicConv1d(hidden_dim, in_channels, kernel_size=1, padding=0)
)
self.conv1d3x1 = self._parted(conv3x1)
self.avg_pool3x1 = nn.AvgPool1d(kernel_size=3, stride=1, padding=1)
@@ -29,11 +29,9 @@ class TemporalFeatureAggregator(nn.Module):
# MTB2
conv3x3 = nn.Sequential(
- nn.Conv1d(in_channels, hidden_dim,
- kernel_size=3, padding=1, bias=False),
+ BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1),
nn.LeakyReLU(inplace=True),
- nn.Conv1d(hidden_dim, in_channels,
- kernel_size=3, padding=1, bias=False)
+ BasicConv1d(hidden_dim, in_channels, kernel_size=3, padding=1)
)
self.conv1d3x3 = self._parted(conv3x3)
self.avg_pool3x3 = nn.AvgPool1d(kernel_size=5, stride=1, padding=2)
@@ -76,4 +74,4 @@ class TemporalFeatureAggregator(nn.Module):
# Temporal Pooling
ret = (feature3x1 + feature3x3).max(-1)[0]
- return ret \ No newline at end of file
+ return ret