summaryrefslogtreecommitdiff
path: root/models/hpm.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-31 21:00:01 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-31 21:00:01 +0800
commit86421e899c87976d8559795979415e3fae2bd7ed (patch)
treec4bfa828da64cc71b43ed12d3fd78be8a8930181 /models/hpm.py
parent57275a210b93f9ffd30a53e22c3c28f49f228d14 (diff)
Implement some parts of RGB-GaitPart wrapper
1. Triplet loss function and weight init function haven't been implement yet 2. Tuplize features returned by auto-encoder for later unpack 3. Correct comment error in auto-encoder 4. Swap batch_size dim and time dim in HPM and PartNet in case of redundant transpose 5. Find backbone problems in HPM and disable it temporarily 6. Make feature structure by HPM consistent to that by PartNet 7. Fix average pooling dimension issue and incorrect view change in HP
Diffstat (limited to 'models/hpm.py')
-rw-r--r--models/hpm.py29
1 files changed, 18 insertions, 11 deletions
diff --git a/models/hpm.py b/models/hpm.py
index 4a1f1a4..5553094 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -8,20 +8,25 @@ from models.layers import HorizontalPyramidPooling
class HorizontalPyramidMatching(nn.Module):
def __init__(
self,
+ in_channels: int = 3,
+ out_channels: int = 128,
scales: tuple[int, ...] = (1, 2, 4, 8),
- out_channels: int = 256,
use_avg_pool: bool = True,
use_max_pool: bool = True,
+ use_backbone: bool = False,
**kwargs
):
super().__init__()
- self.scales = scales
+ self.in_channels = in_channels
self.out_channels = out_channels
+ self.scales = scales
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
+ self.use_backbone = use_backbone
- self.backbone = resnet50(pretrained=True)
- self.in_channels = self.backbone.layer4[-1].conv1.in_channels
+ if self.use_backbone:
+ self.backbone = resnet50(pretrained=True)
+ self.in_channels = self.backbone.layer4[-1].conv1.in_channels
self.pyramids = nn.ModuleList([
self._make_pyramid(scale, **kwargs) for scale in self.scales
@@ -40,12 +45,14 @@ class HorizontalPyramidMatching(nn.Module):
def forward(self, x):
# Flatten frames in all batches
- n, t, c, h, w = x.size()
+ t, n, c, h, w = x.size()
x = x.view(-1, c, h, w)
- x = self.backbone(x)
- n, c, h, w = x.size()
+ if self.use_backbone:
+ # FIXME Inconsistent dimensions
+ x = self.backbone(x)
+ t_n, _, h, _ = x.size()
feature = []
for pyramid_index, pyramid in enumerate(self.pyramids):
h_per_hpp = h // self.scales[pyramid_index]
@@ -54,11 +61,11 @@ class HorizontalPyramidMatching(nn.Module):
(hpp_index + 1) * h_per_hpp)
x_slice = x[:, :, h_filter, :]
x_slice = hpp(x_slice)
- x_slice = x_slice.view(n, -1)
+ x_slice = x_slice.view(t_n, -1)
feature.append(x_slice)
- x = torch.cat(feature, dim=1)
+ x = torch.stack(feature)
# Unfold frames to original batch
- _, d = x.size()
- x = x.view(n, t, d)
+ p, _, c = x.size()
+ x = x.view(p, t, n, c)
return x