summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/fpfe.py38
-rw-r--r--models/layers.py27
-rw-r--r--models/part_net.py (renamed from models/tfa.py)73
3 files changed, 96 insertions, 42 deletions
diff --git a/models/fpfe.py b/models/fpfe.py
deleted file mode 100644
index 28a0440..0000000
--- a/models/fpfe.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import torch.nn as nn
-
-from models.layers import FocalConv2d
-
-
-class FrameLevelPartFeatureExtractor(nn.Module):
-
- def __init__(self, in_channels: int):
- super(FrameLevelPartFeatureExtractor, self).__init__()
- nf = 32
-
- self.focal_conv1 = FocalConv2d(in_channels, nf, kernel_size=5,
- padding=2, halving=1)
- self.focal_conv2 = FocalConv2d(nf, nf, kernel_size=3,
- padding=1, halving=1)
- self.focal_conv3 = FocalConv2d(nf, nf * 2, kernel_size=3,
- padding=1, halving=4)
- self.focal_conv4 = FocalConv2d(nf * 2, nf * 2, kernel_size=3,
- padding=1, halving=4)
- self.focal_conv5 = FocalConv2d(nf * 2, nf * 4, kernel_size=3,
- padding=1, halving=8)
- self.focal_conv6 = FocalConv2d(nf * 4, nf * 4, kernel_size=3,
- padding=1, halving=8)
- self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
-
- def forward(self, x):
- x = self.focal_conv1(x)
- x = self.focal_conv2(x)
- x = self.max_pool(x)
-
- x = self.focal_conv3(x)
- x = self.focal_conv4(x)
- x = self.max_pool(x)
-
- x = self.focal_conv5(x)
- x = self.focal_conv6(x)
-
- return x
diff --git a/models/layers.py b/models/layers.py
index 62a3cc6..c69ae07 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -119,6 +119,33 @@ class FocalConv2d(BasicConv2d):
return F.leaky_relu(z, inplace=True)
+class FocalConv2dBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_sizes: tuple[int, int],
+ paddings: tuple[int, int],
+ halving: int,
+ use_pool: bool = True,
+ **kwargs
+ ):
+ super().__init__()
+ self.use_pool = use_pool
+ self.fconv1 = FocalConv2d(in_channels, out_channels, kernel_sizes[0],
+ halving, padding=paddings[0], **kwargs)
+ self.fconv2 = FocalConv2d(out_channels, out_channels, kernel_sizes[1],
+ halving, padding=paddings[1], **kwargs)
+ self.max_pool = nn.MaxPool2d(2)
+
+ def forward(self, x):
+ x = self.fconv1(x)
+ x = self.fconv2(x)
+ if self.use_pool:
+ x = self.max_pool(x)
+ return x
+
+
class BasicConv1d(nn.Module):
def __init__(
self,
diff --git a/models/tfa.py b/models/part_net.py
index b80328a..2116600 100644
--- a/models/tfa.py
+++ b/models/part_net.py
@@ -3,18 +3,45 @@ import copy
import torch
import torch.nn as nn
-from models.layers import BasicConv1d
+from models.layers import BasicConv1d, FocalConv2dBlock
+
+
+class FrameLevelPartFeatureExtractor(nn.Module):
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ feature_channels: int = 32,
+ kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)),
+ paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)),
+ halving: tuple[int, ...] = (0, 2, 3)
+ ):
+ super().__init__()
+ num_blocks = len(kernel_sizes)
+ out_channels = [feature_channels * 2 ** i for i in range(num_blocks)]
+ in_channels = [in_channels] + out_channels[:-1]
+ use_pools = [True] * (num_blocks - 1) + [False]
+ params = (in_channels, out_channels, kernel_sizes,
+ paddings, halving, use_pools)
+
+ self.fconv_blocks = [FocalConv2dBlock(*_params)
+ for _params in zip(*params)]
+
+ def forward(self, x):
+ for fconv_block in self.fconv_blocks:
+ x = fconv_block(x)
+ return x
class TemporalFeatureAggregator(nn.Module):
def __init__(
self,
in_channels: int,
- squeeze: int = 4,
+ squeeze_ratio: int = 4,
num_part: int = 16
):
- super(TemporalFeatureAggregator, self).__init__()
- hidden_dim = in_channels // squeeze
+ super().__init__()
+ hidden_dim = in_channels // squeeze_ratio
self.num_part = num_part
# MTB1
@@ -75,3 +102,41 @@ class TemporalFeatureAggregator(nn.Module):
# Temporal Pooling
ret = (feature3x1 + feature3x3).max(-1)[0]
return ret
+
+
+class PartNet(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 3,
+ feature_channels: int = 32,
+ kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)),
+ paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)),
+ halving: tuple[int, ...] = (0, 2, 3),
+ squeeze_ratio: int = 4,
+ num_part: int = 16
+ ):
+ super().__init__()
+ self.num_part = num_part
+ self.fpfe = FrameLevelPartFeatureExtractor(
+ in_channels, feature_channels, kernel_sizes, paddings, halving
+ )
+
+ num_fconv_blocks = len(self.fpfe.fconv_blocks)
+ tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1)
+ self.tfa = TemporalFeatureAggregator(
+ tfa_in_channels, squeeze_ratio, self.num_part
+ )
+
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
+
+ def forward(self, x):
+ x = self.fpfe(x)
+ n, t, c, h, w = x.size()
+ split_size = h // self.num_part
+ x = x.split(split_size, dim=3)
+ x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
+ x = [x_.view(n, t, c, -1) for x_ in x]
+ x = torch.cat(x, dim=3)
+ x = self.tfa(x)
+ return x