summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py63
1 files changed, 3 insertions, 60 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 2af990e..797e02b 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -4,9 +4,6 @@ import torch
import torch.nn as nn
from models.auto_encoder import AutoEncoder
-from models.hpm import HorizontalPyramidMatching
-from models.part_net import PartNet
-from utils.triplet_loss import BatchAllTripletLoss
class RGBPartNet(nn.Module):
@@ -16,80 +13,26 @@ class RGBPartNet(nn.Module):
ae_in_size: Tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
f_a_c_p_dims: Tuple[int, int, int] = (128, 128, 64),
- hpm_use_1x1conv: bool = False,
- hpm_scales: Tuple[int, ...] = (1, 2, 4),
- hpm_use_avg_pool: bool = True,
- hpm_use_max_pool: bool = True,
- fpfe_feature_channels: int = 32,
- fpfe_kernel_sizes: Tuple[Tuple, ...] = ((5, 3), (3, 3), (3, 3)),
- fpfe_paddings: Tuple[Tuple, ...] = ((2, 1), (1, 1), (1, 1)),
- fpfe_halving: Tuple[int, ...] = (0, 2, 3),
- tfa_squeeze_ratio: int = 4,
- tfa_num_parts: int = 16,
- embedding_dims: int = 256,
- triplet_margins: Tuple[float, float] = (0.2, 0.2),
image_log_on: bool = False
):
super().__init__()
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
- self.hpm_num_parts = sum(hpm_scales)
self.image_log_on = image_log_on
self.ae = AutoEncoder(
ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
- self.pn = PartNet(
- ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes,
- fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_parts
- )
- out_channels = self.pn.tfa_in_channels
- self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 2, out_channels, hpm_use_1x1conv,
- hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
- )
- self.num_total_parts = self.hpm_num_parts + tfa_num_parts
- empty_fc = torch.empty(self.num_total_parts,
- out_channels, embedding_dims)
- self.fc_mat = nn.Parameter(empty_fc)
-
- (hpm_margin, pn_margin) = triplet_margins
- self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin)
- self.pn_ba_trip = BatchAllTripletLoss(pn_margin)
-
- def fc(self, x):
- return x @ self.fc_mat
- def forward(self, x_c1, x_c2=None, y=None):
+ def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
# n, t, c, h, w
((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2)
- # Step 2.a: Static Gait Feature Aggregation & HPM
- # n, c, h, w
- x_c = self.hpm(x_c)
- # p, n, c
-
- # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
- # n, t, c, h, w
- x_p = self.pn(x_p)
- # p, n, c
-
- # Step 3: Cat feature map together and fc
- x = torch.cat((x_c, x_p))
- x = self.fc(x)
-
if self.training:
- y = y.T
- hpm_ba_trip = self.hpm_ba_trip(
- x[:self.hpm_num_parts], y[:self.hpm_num_parts]
- )
- pn_ba_trip = self.pn_ba_trip(
- x[self.hpm_num_parts:], y[self.hpm_num_parts:]
- )
- losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip))
+ losses = torch.stack(losses)
return losses, images
else:
- return x.unsqueeze(1).view(-1)
+ return x_c, x_p
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()