summaryrefslogtreecommitdiff
path: root/models/part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-12 20:12:33 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-12 20:12:33 +0800
commite83ae0bcb5c763636fd522c2712a3c8aef558f3c (patch)
treeb80da057e4c4574ea95fa9f3d3b2fe8c999e3440 /models/part_net.py
parentf2f7713efa03a877bc96ced37314b4c4a6dc1963 (diff)
parent2ea916b2a963eae7d47151b41c8c78a578c402e2 (diff)
Merge branch 'master' into data_parallel
# Conflicts: # models/auto_encoder.py # models/model.py # models/rgb_part_net.py
Diffstat (limited to 'models/part_net.py')
-rw-r--r--models/part_net.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/models/part_net.py b/models/part_net.py
index 29cf9cd..f2236bf 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -111,17 +111,21 @@ class PartNet(nn.Module):
def __init__(
self,
in_channels: int = 128,
+ embedding_dims: int = 256,
+ num_parts: int = 16,
squeeze_ratio: int = 4,
- num_part: int = 16
):
super().__init__()
- self.num_part = num_part
- self.tfa = TemporalFeatureAggregator(
- in_channels, squeeze_ratio, self.num_part
- )
+ self.num_part = num_parts
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
+ self.tfa = TemporalFeatureAggregator(
+ in_channels, squeeze_ratio, self.num_part
+ )
+ self.fc_mat = nn.Parameter(
+ torch.empty(num_parts, in_channels, embedding_dims)
+ )
def forward(self, x):
n, t, c, h, w = x.size()
@@ -138,4 +142,8 @@ class PartNet(nn.Module):
# p, n, t, c
x = self.tfa(x)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
return x