summaryrefslogtreecommitdiff
path: root/models/part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-21 19:00:30 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-21 19:00:30 +0800
commitc52fdc2748e272a5195303299a9739291be32281 (patch)
tree4c62e48d2d79f18e5e526d63f6b4d1f81c9dcd3d /models/part_net.py
parent820d3dec284f38e6a3089dad5277bc3f6c5123bf (diff)
Remove FConv blocks
Diffstat (limited to 'models/part_net.py')
-rw-r--r--models/part_net.py18
1 files changed, 4 insertions, 14 deletions
diff --git a/models/part_net.py b/models/part_net.py
index 62a2bac..29cf9cd 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -110,32 +110,22 @@ class TemporalFeatureAggregator(nn.Module):
class PartNet(nn.Module):
def __init__(
self,
- in_channels: int = 3,
- feature_channels: int = 32,
- kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)),
- paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)),
- halving: tuple[int, ...] = (0, 2, 3),
+ in_channels: int = 128,
squeeze_ratio: int = 4,
num_part: int = 16
):
super().__init__()
self.num_part = num_part
- self.fpfe = FrameLevelPartFeatureExtractor(
- in_channels, feature_channels, kernel_sizes, paddings, halving
- )
-
- num_fconv_blocks = len(self.fpfe.fconv_blocks)
- self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1)
self.tfa = TemporalFeatureAggregator(
- self.tfa_in_channels, squeeze_ratio, self.num_part
+ in_channels, squeeze_ratio, self.num_part
)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
- n, t, _, _, _ = x.size()
- x = self.fpfe(x)
+ n, t, c, h, w = x.size()
+ x = x.view(n * t, c, h, w)
# n * t x c x h x w
# Horizontal Pooling