summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/hpm.py54
-rw-r--r--models/layers.py21
2 files changed, 75 insertions, 0 deletions
diff --git a/models/hpm.py b/models/hpm.py
new file mode 100644
index 0000000..f387154
--- /dev/null
+++ b/models/hpm.py
@@ -0,0 +1,54 @@
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+from torchvision.models import resnet50
+
+from models.layers import HorizontalPyramidPooling
+
+
+class HorizontalPyramidMatching(nn.Module):
+ def __init__(
+ self,
+ scales: Tuple[int] = (1, 2, 4, 8),
+ out_channels: int = 256,
+ use_avg_pool: bool = False,
+ **kwargs
+ ):
+ super().__init__()
+ self.scales = scales
+ self.out_channels = out_channels
+ self.use_avg_pool = use_avg_pool
+
+ self.backbone = resnet50(pretrained=True)
+ self.in_channels = self.backbone.layer4[-1].conv1.in_channels
+
+ self.pyramids = [
+ self._make_pyramid(scale, **kwargs) for scale in self.scales
+ ]
+
+ def _make_pyramid(self, scale: int, **kwargs):
+ pyramid = [HorizontalPyramidPooling(self.in_channels,
+ self.out_channels,
+ use_avg_pool=self.use_avg_pool,
+ **kwargs)
+ for _ in range(scale)]
+ return pyramid
+
+ def forward(self, x):
+ x = self.backbone(x)
+ n, c, h, w = x.size()
+
+ feature = []
+ for pyramid_index, pyramid in enumerate(self.pyramids):
+ h_per_hpp = h // self.scales[pyramid_index]
+ for hpp_index, hpp in enumerate(pyramid):
+ h_filter = torch.arange(hpp_index * h_per_hpp,
+ (hpp_index + 1) * h_per_hpp)
+ x_slice = x[:, :, h_filter, :]
+ x_slice = hpp(x_slice)
+ x_slice = x_slice.view(n, -1)
+ feature.append(x_slice)
+
+ feature = torch.cat(feature, dim=1)
+ return feature
diff --git a/models/layers.py b/models/layers.py
index 2dc87a1..9b17205 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -117,3 +117,24 @@ class BasicConv1d(nn.Module):
def forward(self, x):
return self.conv(x)
+
+
+class HorizontalPyramidPooling(BasicConv2d):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]] = 1,
+ use_avg_pool: bool = False,
+ **kwargs
+ ):
+ super().__init__(in_channels, out_channels, kernel_size, **kwargs)
+ if use_avg_pool:
+ self.pool = nn.AdaptiveAvgPool2d(1)
+ else:
+ self.pool = nn.AdaptiveMaxPool2d(1)
+
+ def forward(self, x):
+ x = self.pool(x)
+ x = super().forward(x)
+ return x