diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-12 13:59:05 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-12 13:59:05 +0800 | 
| commit | d63b267dd15388dd323d9b8672cdb9461b96c885 (patch) | |
| tree | 5095fc80fb93b946e4cfdee88258ab4fd49a8275 /models/part_net.py | |
| parent | 08911dcb80ecb769972c2d2659c8ad152bbeb447 (diff) | |
| parent | c74df416b00f837ba051f3947be92f76e7afbd88 (diff) | |
Merge branch 'master' into python3.8
# Conflicts:
#	models/hpm.py
#	models/rgb_part_net.py
#	utils/configuration.py
#	utils/triplet_loss.py
Diffstat (limited to 'models/part_net.py')
| -rw-r--r-- | models/part_net.py | 18 | 
1 files changed, 13 insertions, 5 deletions
| diff --git a/models/part_net.py b/models/part_net.py index d9b954f..de19c8c 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -112,17 +112,21 @@ class PartNet(nn.Module):      def __init__(              self,              in_channels: int = 128, +            embedding_dims: int = 256, +            num_parts: int = 16,              squeeze_ratio: int = 4, -            num_part: int = 16      ):          super().__init__() -        self.num_part = num_part -        self.tfa = TemporalFeatureAggregator( -            in_channels, squeeze_ratio, self.num_part -        ) +        self.num_part = num_parts          self.avg_pool = nn.AdaptiveAvgPool2d(1)          self.max_pool = nn.AdaptiveMaxPool2d(1) +        self.tfa = TemporalFeatureAggregator( +            in_channels, squeeze_ratio, self.num_part +        ) +        self.fc_mat = nn.Parameter( +            torch.empty(num_parts, in_channels, embedding_dims) +        )      def forward(self, x):          n, t, c, h, w = x.size() @@ -139,4 +143,8 @@ class PartNet(nn.Module):          # p, n, t, c          x = self.tfa(x) + +        # p, n, c +        x = x @ self.fc_mat +        # p, n, d          return x | 
