From 02aaefaba26b6842d2feb403edfd71aaa75904da Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 2 Jan 2021 19:10:08 +0800 Subject: Correct feature dims after disentanglement and HPM backbone removal 1. Features used in HPM is decoded canonical embedding without transpose convolution 2. Decode pose embedding to image for Part Net 3. Backbone seems to be redundant, we can use feature map given by auto-decoder --- models/hpm.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) (limited to 'models/hpm.py') diff --git a/models/hpm.py b/models/hpm.py index 5553094..66503e3 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -from torchvision.models import resnet50 from models.layers import HorizontalPyramidPooling @@ -8,12 +7,11 @@ from models.layers import HorizontalPyramidPooling class HorizontalPyramidMatching(nn.Module): def __init__( self, - in_channels: int = 3, + in_channels: int, out_channels: int = 128, - scales: tuple[int, ...] = (1, 2, 4, 8), + scales: tuple[int, ...] = (1, 2, 4), use_avg_pool: bool = True, use_max_pool: bool = True, - use_backbone: bool = False, **kwargs ): super().__init__() @@ -22,11 +20,6 @@ class HorizontalPyramidMatching(nn.Module): self.scales = scales self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool - self.use_backbone = use_backbone - - if self.use_backbone: - self.backbone = resnet50(pretrained=True) - self.in_channels = self.backbone.layer4[-1].conv1.in_channels self.pyramids = nn.ModuleList([ self._make_pyramid(scale, **kwargs) for scale in self.scales @@ -44,15 +37,10 @@ class HorizontalPyramidMatching(nn.Module): return pyramid def forward(self, x): - # Flatten frames in all batches + # Flatten canonical features in all batches t, n, c, h, w = x.size() - x = x.view(-1, c, h, w) - - if self.use_backbone: - # FIXME Inconsistent dimensions - x = self.backbone(x) + x = x.view(t * n, c, h, w) - t_n, _, h, _ = x.size() feature = [] for pyramid_index, pyramid in enumerate(self.pyramids): h_per_hpp = h // self.scales[pyramid_index] @@ -61,7 +49,7 @@ class HorizontalPyramidMatching(nn.Module): (hpp_index + 1) * h_per_hpp) x_slice = x[:, :, h_filter, :] x_slice = hpp(x_slice) - x_slice = x_slice.view(t_n, -1) + x_slice = x_slice.view(t * n, -1) feature.append(x_slice) x = torch.stack(feature) -- cgit v1.2.3