from typing import Tuple import torch import torch.nn as nn from models.auto_encoder import AutoEncoder from models.hpm import HorizontalPyramidMatching from models.part_net import PartNet from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): def __init__( self, ae_in_channels: int = 3, ae_in_size: tuple[int, int] = (64, 48), ae_feature_channels: int = 64, f_a_c_p_dims: Tuple[int, int, int] = (128, 128, 64), hpm_use_1x1conv: bool = False, hpm_scales: Tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, fpfe_feature_channels: int = 32, fpfe_kernel_sizes: Tuple[Tuple, ...] = ((5, 3), (3, 3), (3, 3)), fpfe_paddings: Tuple[Tuple, ...] = ((2, 1), (1, 1), (1, 1)), fpfe_halving: Tuple[int, ...] = (0, 2, 3), tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, embedding_dims: int = 256, triplet_margins: Tuple[float, float] = (0.2, 0.2), image_log_on: bool = False ): super().__init__() (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims self.hpm_num_parts = sum(hpm_scales) self.image_log_on = image_log_on self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) self.pn = PartNet( ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_parts ) out_channels = self.pn.tfa_in_channels self.hpm = HorizontalPyramidMatching( ae_feature_channels * 2, out_channels, hpm_use_1x1conv, hpm_scales, hpm_use_avg_pool, hpm_use_max_pool ) self.num_total_parts = self.hpm_num_parts + tfa_num_parts empty_fc = torch.empty(self.num_total_parts, out_channels, embedding_dims) self.fc_mat = nn.Parameter(empty_fc) (hpm_margin, pn_margin) = triplet_margins self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) self.pn_ba_trip = BatchAllTripletLoss(pn_margin) def fc(self, x): return x @ self.fc_mat def forward(self, x_c1, x_c2=None, y=None): # Step 1: Disentanglement # n, t, c, h, w ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2) # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w x_c = self.hpm(x_c) # p, n, c # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) # n, t, c, h, w x_p = self.pn(x_p) # p, n, c # Step 3: Cat feature map together and fc x = torch.cat((x_c, x_p)) x = self.fc(x) if self.training: y = y.T hpm_ba_trip = self.hpm_ba_trip( x[:self.hpm_num_parts], y[:self.hpm_num_parts] ) pn_ba_trip = self.pn_ba_trip( x[self.hpm_num_parts:], y[self.hpm_num_parts:] ) losses = (*losses, hpm_ba_trip, pn_ba_trip) return losses, images else: return x.unsqueeze(1).view(-1) def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() device = x_c1_t2.device x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] if self.training: ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) # Decode features with torch.no_grad(): x_c = self._decode_cano_feature(f_c_, n, t, device) x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device) i_a, i_c, i_p = None, None, None if self.image_log_on: i_a = self._decode_appr_feature(f_a_, n, t, device) # Continue decoding canonical features i_c = self.ae.decoder.trans_conv3(x_c) i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c)) i_p = x_p return (x_c, x_p), losses, (i_a, i_c, i_p) else: # evaluating f_c_, f_p_ = self.ae(x_c1_t2) x_c = self._decode_cano_feature(f_c_, n, t, device) x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device) return (x_c, x_p), None, None def _decode_appr_feature(self, f_a_, n, t, device): # Decode appearance features f_a = f_a_.view(n, t, -1) x_a = self.ae.decoder( f_a.mean(1), torch.zeros((n, self.f_c_dim), device=device), torch.zeros((n, self.f_p_dim), device=device) ) return x_a def _decode_cano_feature(self, f_c_, n, t, device): # Decode average canonical features to higher dimension f_c = f_c_.view(n, t, -1) x_c = self.ae.decoder( torch.zeros((n, self.f_a_dim), device=device), f_c.mean(1), torch.zeros((n, self.f_p_dim), device=device), cano_only=True ) return x_c def _decode_pose_feature(self, f_p_, n, t, c, h, w, device): # Decode pose features to images x_p_ = self.ae.decoder( torch.zeros((n * t, self.f_a_dim), device=device), torch.zeros((n * t, self.f_c_dim), device=device), f_p_ ) x_p = x_p_.view(n, t, c, h, w) return x_p