summaryrefslogtreecommitdiff
path: root/models/tfa.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-23 18:59:08 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-23 18:59:08 +0800
commit96f345d25237c7e616ea5f524a2fc2d340ed8aff (patch)
tree9791c394d147f39d45ecab2a1ea5f83b396f3568 /models/tfa.py
parent74a3df70a47630b7e95abc09197d23de9b81d4de (diff)
Split modules to different files
Diffstat (limited to 'models/tfa.py')
-rw-r--r--models/tfa.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/models/tfa.py b/models/tfa.py
new file mode 100644
index 0000000..2e4e656
--- /dev/null
+++ b/models/tfa.py
@@ -0,0 +1,79 @@
+import copy
+
+import torch
+from torch import nn as nn
+
+
+class TemporalFeatureAggregator(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ squeeze: int = 4,
+ num_part: int = 16
+ ):
+ super(TemporalFeatureAggregator, self).__init__()
+ hidden_dim = in_channels // squeeze
+ self.num_part = num_part
+
+ # MTB1
+ conv3x1 = nn.Sequential(
+ nn.Conv1d(in_channels, hidden_dim,
+ kernel_size=3, padding=1, bias=False),
+ nn.LeakyReLU(inplace=True),
+ nn.Conv1d(hidden_dim, in_channels,
+ kernel_size=1, padding=0, bias=False)
+ )
+ self.conv1d3x1 = self._parted(conv3x1)
+ self.avg_pool3x1 = nn.AvgPool1d(kernel_size=3, stride=1, padding=1)
+ self.max_pool3x1 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1)
+
+ # MTB2
+ conv3x3 = nn.Sequential(
+ nn.Conv1d(in_channels, hidden_dim,
+ kernel_size=3, padding=1, bias=False),
+ nn.LeakyReLU(inplace=True),
+ nn.Conv1d(hidden_dim, in_channels,
+ kernel_size=3, padding=1, bias=False)
+ )
+ self.conv1d3x3 = self._parted(conv3x3)
+ self.avg_pool3x3 = nn.AvgPool1d(kernel_size=5, stride=1, padding=2)
+ self.max_pool3x3 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2)
+
+ def _parted(self, module: nn.Module):
+ """Duplicate module `part_num` times."""
+ return nn.ModuleList([copy.deepcopy(module)
+ for _ in range(self.num_part)])
+
+ def forward(self, x):
+ """
+ Input: x, [p, n, c, s]
+ """
+ p, n, c, s = x.size()
+ feature = x.split(1, 0)
+ x = x.view(-1, c, s)
+
+ # MTB1: ConvNet1d & Sigmoid
+ logits3x1 = torch.cat(
+ [conv(_.squeeze(0)).unsqueeze(0)
+ for conv, _ in zip(self.conv1d3x1, feature)], dim=0
+ )
+ scores3x1 = torch.sigmoid(logits3x1)
+ # MTB1: Template Function
+ feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x)
+ feature3x1 = feature3x1.view(p, n, c, s)
+ feature3x1 = feature3x1 * scores3x1
+
+ # MTB2: ConvNet1d & Sigmoid
+ logits3x3 = torch.cat(
+ [conv(_.squeeze(0)).unsqueeze(0)
+ for conv, _ in zip(self.conv1d3x3, feature)], dim=0
+ )
+ scores3x3 = torch.sigmoid(logits3x3)
+ # MTB2: Template Function
+ feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x)
+ feature3x3 = feature3x3.view(p, n, c, s)
+ feature3x3 = feature3x3 * scores3x3
+
+ # Temporal Pooling
+ ret = (feature3x1 + feature3x3).max(-1)[0]
+ return ret \ No newline at end of file