diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-12 20:30:25 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-12 20:34:00 +0800 |
commit | 30b475c0a27e0f848743abf0f909607defc6a3ee (patch) | |
tree | aaab163d3d76a835c32ce5014ce62637550d0b0d /models/part_net.py | |
parent | 3d8fc322623ba61610fd206b9f52b406e85cae61 (diff) | |
parent | e83ae0bcb5c763636fd522c2712a3c8aef558f3c (diff) |
Merge branch 'data_parallel' into data_parallel_py3.8
# Conflicts:
# models/hpm.py
# models/model.py
# models/rgb_part_net.py
# utils/configuration.py
# utils/triplet_loss.py
Diffstat (limited to 'models/part_net.py')
-rw-r--r-- | models/part_net.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/models/part_net.py b/models/part_net.py index d9b954f..de19c8c 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -112,17 +112,21 @@ class PartNet(nn.Module): def __init__( self, in_channels: int = 128, + embedding_dims: int = 256, + num_parts: int = 16, squeeze_ratio: int = 4, - num_part: int = 16 ): super().__init__() - self.num_part = num_part - self.tfa = TemporalFeatureAggregator( - in_channels, squeeze_ratio, self.num_part - ) + self.num_part = num_parts self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) + self.tfa = TemporalFeatureAggregator( + in_channels, squeeze_ratio, self.num_part + ) + self.fc_mat = nn.Parameter( + torch.empty(num_parts, in_channels, embedding_dims) + ) def forward(self, x): n, t, c, h, w = x.size() @@ -139,4 +143,8 @@ class PartNet(nn.Module): # p, n, t, c x = self.tfa(x) + + # p, n, c + x = x @ self.fc_mat + # p, n, d return x |