summaryrefslogtreecommitdiff
path: root/models/hpm.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/hpm.py')
-rw-r--r--models/hpm.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/models/hpm.py b/models/hpm.py
new file mode 100644
index 0000000..f387154
--- /dev/null
+++ b/models/hpm.py
@@ -0,0 +1,54 @@
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+from torchvision.models import resnet50
+
+from models.layers import HorizontalPyramidPooling
+
+
+class HorizontalPyramidMatching(nn.Module):
+ def __init__(
+ self,
+ scales: Tuple[int] = (1, 2, 4, 8),
+ out_channels: int = 256,
+ use_avg_pool: bool = False,
+ **kwargs
+ ):
+ super().__init__()
+ self.scales = scales
+ self.out_channels = out_channels
+ self.use_avg_pool = use_avg_pool
+
+ self.backbone = resnet50(pretrained=True)
+ self.in_channels = self.backbone.layer4[-1].conv1.in_channels
+
+ self.pyramids = [
+ self._make_pyramid(scale, **kwargs) for scale in self.scales
+ ]
+
+ def _make_pyramid(self, scale: int, **kwargs):
+ pyramid = [HorizontalPyramidPooling(self.in_channels,
+ self.out_channels,
+ use_avg_pool=self.use_avg_pool,
+ **kwargs)
+ for _ in range(scale)]
+ return pyramid
+
+ def forward(self, x):
+ x = self.backbone(x)
+ n, c, h, w = x.size()
+
+ feature = []
+ for pyramid_index, pyramid in enumerate(self.pyramids):
+ h_per_hpp = h // self.scales[pyramid_index]
+ for hpp_index, hpp in enumerate(pyramid):
+ h_filter = torch.arange(hpp_index * h_per_hpp,
+ (hpp_index + 1) * h_per_hpp)
+ x_slice = x[:, :, h_filter, :]
+ x_slice = hpp(x_slice)
+ x_slice = x_slice.view(n, -1)
+ feature.append(x_slice)
+
+ feature = torch.cat(feature, dim=1)
+ return feature