summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/__init__.py3
-rw-r--r--models/auto_encoder.py4
-rw-r--r--models/hpm.py29
-rw-r--r--models/part_net.py28
-rw-r--r--models/rgb_part_net.py98
5 files changed, 135 insertions, 27 deletions
diff --git a/models/__init__.py b/models/__init__.py
index 3b4d86e..51c86af 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -1 +1,4 @@
from .model import Model
+from .auto_encoder import AutoEncoder
+from .hpm import HorizontalPyramidMatching
+from .part_net import PartNet
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index de38572..c84061c 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -129,7 +129,7 @@ class AutoEncoder(nn.Module):
self.xent_loss = nn.CrossEntropyLoss()
def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y):
- # t2 is random time step
+ # t1 is random time step
(f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
(_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
(_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)
@@ -142,5 +142,5 @@ class AutoEncoder(nn.Module):
+ self.mse_loss(f_c_c1_t2, f_c_c2_t2)
+ self.xent_loss(y, y_))
- return (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2,
+ return ((f_c_c1_t2, f_p_c1_t2, f_p_c2_t2),
xrecon_loss_t2, cano_cons_loss_t2)
diff --git a/models/hpm.py b/models/hpm.py
index 4a1f1a4..5553094 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -8,20 +8,25 @@ from models.layers import HorizontalPyramidPooling
class HorizontalPyramidMatching(nn.Module):
def __init__(
self,
+ in_channels: int = 3,
+ out_channels: int = 128,
scales: tuple[int, ...] = (1, 2, 4, 8),
- out_channels: int = 256,
use_avg_pool: bool = True,
use_max_pool: bool = True,
+ use_backbone: bool = False,
**kwargs
):
super().__init__()
- self.scales = scales
+ self.in_channels = in_channels
self.out_channels = out_channels
+ self.scales = scales
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
+ self.use_backbone = use_backbone
- self.backbone = resnet50(pretrained=True)
- self.in_channels = self.backbone.layer4[-1].conv1.in_channels
+ if self.use_backbone:
+ self.backbone = resnet50(pretrained=True)
+ self.in_channels = self.backbone.layer4[-1].conv1.in_channels
self.pyramids = nn.ModuleList([
self._make_pyramid(scale, **kwargs) for scale in self.scales
@@ -40,12 +45,14 @@ class HorizontalPyramidMatching(nn.Module):
def forward(self, x):
# Flatten frames in all batches
- n, t, c, h, w = x.size()
+ t, n, c, h, w = x.size()
x = x.view(-1, c, h, w)
- x = self.backbone(x)
- n, c, h, w = x.size()
+ if self.use_backbone:
+ # FIXME Inconsistent dimensions
+ x = self.backbone(x)
+ t_n, _, h, _ = x.size()
feature = []
for pyramid_index, pyramid in enumerate(self.pyramids):
h_per_hpp = h // self.scales[pyramid_index]
@@ -54,11 +61,11 @@ class HorizontalPyramidMatching(nn.Module):
(hpp_index + 1) * h_per_hpp)
x_slice = x[:, :, h_filter, :]
x_slice = hpp(x_slice)
- x_slice = x_slice.view(n, -1)
+ x_slice = x_slice.view(t_n, -1)
feature.append(x_slice)
- x = torch.cat(feature, dim=1)
+ x = torch.stack(feature)
# Unfold frames to original batch
- _, d = x.size()
- x = x.view(n, t, d)
+ p, _, c = x.size()
+ x = x.view(p, t, n, c)
return x
diff --git a/models/part_net.py b/models/part_net.py
index 66e61fc..ac7c434 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -30,15 +30,11 @@ class FrameLevelPartFeatureExtractor(nn.Module):
def forward(self, x):
# Flatten frames in all batches
- n, t, c, h, w = x.size()
+ t, n, c, h, w = x.size()
x = x.view(-1, c, h, w)
for fconv_block in self.fconv_blocks:
x = fconv_block(x)
-
- # Unfold frames to original batch
- _, c, h, w = x.size()
- x = x.view(n, t, c, h, w)
return x
@@ -79,7 +75,8 @@ class TemporalFeatureAggregator(nn.Module):
for _ in range(self.num_part)])
def forward(self, x):
- x = x.transpose(2, 3)
+ # p, t, n, c
+ x = x.permute(0, 2, 3, 1).contiguous()
p, n, c, t = x.size()
feature = x.split(1, dim=0)
feature = [f.squeeze(0) for f in feature]
@@ -87,7 +84,7 @@ class TemporalFeatureAggregator(nn.Module):
# MTB1: ConvNet1d & Sigmoid
logits3x1 = torch.stack(
- [conv(f) for conv, f in zip(self.conv1d3x1, feature)], dim=0
+ [conv(f) for conv, f in zip(self.conv1d3x1, feature)]
)
scores3x1 = torch.sigmoid(logits3x1)
# MTB1: Template Function
@@ -97,7 +94,7 @@ class TemporalFeatureAggregator(nn.Module):
# MTB2: ConvNet1d & Sigmoid
logits3x3 = torch.stack(
- [conv(f) for conv, f in zip(self.conv1d3x3, feature)], dim=0
+ [conv(f) for conv, f in zip(self.conv1d3x3, feature)]
)
scores3x3 = torch.sigmoid(logits3x3)
# MTB2: Template Function
@@ -128,25 +125,28 @@ class PartNet(nn.Module):
)
num_fconv_blocks = len(self.fpfe.fconv_blocks)
- tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1)
+ self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1)
self.tfa = TemporalFeatureAggregator(
- tfa_in_channels, squeeze_ratio, self.num_part
+ self.tfa_in_channels, squeeze_ratio, self.num_part
)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
+ t, n, _, _, _ = x.size()
+ # t, n, c, h, w
x = self.fpfe(x)
+ # t_n, c, h, w
# Horizontal Pooling
- n, t, c, h, w = x.size()
+ _, c, h, w = x.size()
split_size = h // self.num_part
- x = x.split(split_size, dim=3)
+ x = x.split(split_size, dim=2)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.view(n, t, c, -1) for x_ in x]
+ x = [x_.view(t, n, c) for x_ in x]
x = torch.stack(x)
- # p, n, t, c
+ # p, t, n, c
x = self.tfa(x)
return x
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
new file mode 100644
index 0000000..9768dec
--- /dev/null
+++ b/models/rgb_part_net.py
@@ -0,0 +1,98 @@
+import random
+
+import torch
+import torch.nn as nn
+
+from models import AutoEncoder, HorizontalPyramidMatching, PartNet
+
+
+class RGBPartNet(nn.Module):
+ def __init__(
+ self,
+ num_class: int = 74,
+ ae_in_channels: int = 3,
+ ae_feature_channels: int = 64,
+ f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
+ hpm_scales: tuple[int, ...] = (1, 2, 4, 8),
+ hpm_use_avg_pool: bool = True,
+ hpm_use_max_pool: bool = True,
+ fpfe_feature_channels: int = 32,
+ fpfe_kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)),
+ fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)),
+ fpfe_halving: tuple[int, ...] = (0, 2, 3),
+ tfa_squeeze_ratio: int = 4,
+ tfa_num_part: int = 16,
+ ):
+ super().__init__()
+ self.ae = AutoEncoder(
+ num_class, ae_in_channels, ae_feature_channels, f_a_c_p_dims
+ )
+ self.pn = PartNet(
+ ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes,
+ fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_part
+ )
+ self.hpm = HorizontalPyramidMatching(
+ ae_in_channels, self.pn.tfa_in_channels, hpm_scales,
+ hpm_use_avg_pool, hpm_use_max_pool
+ )
+
+ self.mse_loss = nn.MSELoss()
+
+ # TODO Weight inti here
+
+ def pose_sim_loss(self, f_p_c1: torch.Tensor,
+ f_p_c2: torch.Tensor) -> torch.Tensor:
+ f_p_c1_mean = f_p_c1.mean(dim=0)
+ f_p_c2_mean = f_p_c2.mean(dim=0)
+ return self.mse_loss(f_p_c1_mean, f_p_c2_mean).item()
+
+ def forward(self, x_c1, x_c2, y):
+ # Step 0: Swap batch_size and time dimensions for next step
+ # n, t, c, h, w
+ x_c1, x_c2 = x_c1.transpose(0, 1), x_c2.transpose(0, 1)
+
+ # Step 1: Disentanglement
+ # t, n, c, h, w
+ num_frames = len(x_c1)
+ f_c_c1, f_p_c1, f_p_c2 = [], [], []
+ xrecon_loss, cano_cons_loss = 0, 0
+ for t2 in range(num_frames):
+ t1 = random.randrange(num_frames)
+ output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y)
+ (feature_t2, xrecon_loss_t2, cano_cons_loss_t2) = output
+ (f_c_c1_t2, f_p_c1_t2, f_p_c2_t2) = feature_t2
+ # Features for next step
+ f_c_c1.append(f_c_c1_t2)
+ f_p_c1.append(f_p_c1_t2)
+ # Losses per time step
+ f_p_c2.append(f_p_c2_t2)
+ xrecon_loss += xrecon_loss_t2
+ cano_cons_loss += cano_cons_loss_t2
+ f_c_c1 = torch.stack(f_c_c1)
+ f_p_c1 = torch.stack(f_p_c1)
+
+ # Step 2.a: HPM & Static Gait Feature Aggregation
+ # t, n, c, h, w
+ x_c = self.hpm(f_c_c1)
+ # p, t, n, c
+ x_c = x_c.mean(dim=1)
+ # p, n, c
+
+ # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
+ # t, n, c, h, w
+ x_p = self.pn(f_p_c1)
+ # p, n, c
+
+ # Step 3: Cat feature map together and calculate losses
+ x = torch.cat(x_c, x_p)
+ # Losses
+ xrecon_loss /= num_frames
+ f_p_c2 = torch.stack(f_p_c2)
+ pose_sim_loss = self.pose_sim_loss(f_p_c1, f_p_c2)
+ cano_cons_loss /= num_frames
+ # TODO Implement Batch All triplet loss function
+ batch_all_triplet_loss = 0
+ loss = (xrecon_loss + pose_sim_loss + cano_cons_loss
+ + batch_all_triplet_loss)
+
+ return x, loss