diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 26 | ||||
-rw-r--r-- | models/hpm.py | 20 | ||||
-rw-r--r-- | models/layers.py | 10 | ||||
-rw-r--r-- | models/model.py | 34 | ||||
-rw-r--r-- | models/rgb_part_net.py | 141 |
5 files changed, 148 insertions, 83 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 35cb629..f04ffdb 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -95,15 +95,14 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False): + def forward(self, f_appearance, f_canonical, f_pose, cano_only=False): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) - # Decode canonical features without transpose convolutions - if no_trans_conv: - return x x = self.trans_conv1(x) x = self.trans_conv2(x) + if cano_only: + return x x = self.trans_conv3(x) x = torch.sigmoid(self.trans_conv4(x)) @@ -125,21 +124,6 @@ class AutoEncoder(nn.Module): # x_c1_t2 is the frame for later module (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) - with torch.no_grad(): - # Decode canonical features for HPM - x_c_c1_t2 = self.decoder( - torch.zeros_like(f_a_c1_t2), - f_c_c1_t2, - torch.zeros_like(f_p_c1_t2), - no_trans_conv=True - ) - # Decode pose features for Part Net - x_p_c1_t2 = self.decoder( - torch.zeros_like(f_a_c1_t2), - torch.zeros_like(f_c_c1_t2), - f_p_c1_t2 - ) - if self.training: # t1 is random time step, c2 is another condition (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) @@ -151,9 +135,9 @@ class AutoEncoder(nn.Module): + F.mse_loss(f_c_c1_t2, f_c_c2_t2)) return ( - (x_c_c1_t2, x_p_c1_t2), + (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2), (f_p_c1_t2, f_p_c2_t2), (xrecon_loss_t2, cano_cons_loss_t2) ) else: # evaluating - return x_c_c1_t2, x_p_c1_t2 + return f_c_c1_t2, f_p_c1_t2 diff --git a/models/hpm.py b/models/hpm.py index 66503e3..9879cfb 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -9,14 +9,16 @@ class HorizontalPyramidMatching(nn.Module): self, in_channels: int, out_channels: int = 128, + use_1x1conv: bool = False, scales: tuple[int, ...] = (1, 2, 4), use_avg_pool: bool = True, - use_max_pool: bool = True, + use_max_pool: bool = False, **kwargs ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels + self.use_1x1conv = use_1x1conv self.scales = scales self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool @@ -29,6 +31,7 @@ class HorizontalPyramidMatching(nn.Module): pyramid = nn.ModuleList([ HorizontalPyramidPooling(self.in_channels, self.out_channels, + use_1x1conv=self.use_1x1conv, use_avg_pool=self.use_avg_pool, use_max_pool=self.use_max_pool, **kwargs) @@ -37,23 +40,16 @@ class HorizontalPyramidMatching(nn.Module): return pyramid def forward(self, x): - # Flatten canonical features in all batches - t, n, c, h, w = x.size() - x = x.view(t * n, c, h, w) - + n, c, h, w = x.size() feature = [] - for pyramid_index, pyramid in enumerate(self.pyramids): - h_per_hpp = h // self.scales[pyramid_index] + for scale, pyramid in zip(self.scales, self.pyramids): + h_per_hpp = h // scale for hpp_index, hpp in enumerate(pyramid): h_filter = torch.arange(hpp_index * h_per_hpp, (hpp_index + 1) * h_per_hpp) x_slice = x[:, :, h_filter, :] x_slice = hpp(x_slice) - x_slice = x_slice.view(t * n, -1) + x_slice = x_slice.view(n, -1) feature.append(x_slice) x = torch.stack(feature) - - # Unfold frames to original batch - p, _, c = x.size() - x = x.view(p, t, n, c) return x diff --git a/models/layers.py b/models/layers.py index a9f04b3..7b6ba5c 100644 --- a/models/layers.py +++ b/models/layers.py @@ -167,12 +167,13 @@ class HorizontalPyramidPooling(BasicConv2d): self, in_channels: int, out_channels: int, - kernel_size: Union[int, tuple[int, int]] = 1, + use_1x1conv: bool = False, use_avg_pool: bool = True, - use_max_pool: bool = True, + use_max_pool: bool = False, **kwargs ): - super().__init__(in_channels, out_channels, kernel_size, **kwargs) + super().__init__(in_channels, out_channels, kernel_size=1, **kwargs) + self.use_1x1conv = use_1x1conv self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.' @@ -186,5 +187,6 @@ class HorizontalPyramidPooling(BasicConv2d): x = self.avg_pool(x) elif not self.use_avg_pool and self.use_max_pool: x = self.max_pool(x) - x = super().forward(x) + if self.use_1x1conv: + x = super().forward(x) return x diff --git a/models/model.py b/models/model.py index ddb715d..0418070 100644 --- a/models/model.py +++ b/models/model.py @@ -69,6 +69,7 @@ class Model: self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None + self.image_log_on = system_config.get('image_log_on', False) self.CASIAB_GALLERY_SELECTOR = { 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})} @@ -146,7 +147,8 @@ class Model: hpm_optim_hp = optim_hp.pop('hpm', {}) fc_optim_hp = optim_hp.pop('fc', {}) sched_hp = self.hp.get('scheduler', {}) - self.rgb_pn = RGBPartNet(self.in_channels, **model_hp) + self.rgb_pn = RGBPartNet(self.in_channels, **model_hp, + image_log_on=self.image_log_on) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.optimizer = optim.Adam([ @@ -168,9 +170,9 @@ class Model: # Training start start_time = datetime.now() - running_loss = torch.zeros(4).to(self.device) + running_loss = torch.zeros(5, device=self.device) print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}", - f"{'CanoCons':^8} {'BATrip':^8} LR(s)") + f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} LR(s)") for (batch_c1, batch_c2) in dataloader: if self.curr_iter == start_iter: self.optimizer.add_param_group( @@ -189,7 +191,7 @@ class Model: x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) y = batch_c1['label'].to(self.device) - losses = self.rgb_pn(x_c1, x_c2, y) + losses, images = self.rgb_pn(x_c1, x_c2, y) loss = losses.sum() loss.backward() self.optimizer.step() @@ -200,13 +202,33 @@ class Model: self.writer.add_scalar('Loss/all', loss, self.curr_iter) self.writer.add_scalars('Loss/details', dict(zip([ 'Cross reconstruction loss', 'Pose similarity loss', - 'Canonical consistency loss', 'Batch All triplet loss' + 'Canonical consistency loss', 'Batch All triplet loss (HPM)', + 'Batch All triplet loss (PartNet)' ], losses)), self.curr_iter) + if self.image_log_on: + (appearance_image, canonical_image, pose_image) = images + self.writer.add_images( + 'Canonical image', canonical_image, self.curr_iter + ) + for i in range(self.pr * self.k): + self.writer.add_images( + f'Original image/batch {i}', x_c1[i], self.curr_iter + ) + self.writer.add_images( + f'Appearance image/batch {i}', + appearance_image[:, i, :, :, :], + self.curr_iter + ) + self.writer.add_images( + f'Pose image/batch {i}', + pose_image[:, i, :, :, :], + self.curr_iter + ) if self.curr_iter % 100 == 0: lrs = self.scheduler.get_last_lr() print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', - '{:f} {:f} {:f} {:f}'.format(*running_loss / 100), + '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), ' '.join(('{:.3e}'.format(lr) for lr in lrs))) running_loss.zero_() diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 755d5dc..0e7d8b3 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -16,6 +16,7 @@ class RGBPartNet(nn.Module): ae_in_channels: int = 3, ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), + hpm_use_1x1conv: bool = False, hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, @@ -26,9 +27,14 @@ class RGBPartNet(nn.Module): tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, embedding_dims: int = 256, - triplet_margin: float = 0.2 + triplet_margins: tuple[float, float] = (0.2, 0.2), + image_log_on: bool = False ): super().__init__() + (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims + self.hpm_num_parts = sum(hpm_scales) + self.image_log_on = image_log_on + self.ae = AutoEncoder( ae_in_channels, ae_feature_channels, f_a_c_p_dims ) @@ -38,14 +44,16 @@ class RGBPartNet(nn.Module): ) out_channels = self.pn.tfa_in_channels self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 8, out_channels, hpm_scales, - hpm_use_avg_pool, hpm_use_max_pool + ae_feature_channels * 2, out_channels, hpm_use_1x1conv, + hpm_scales, hpm_use_avg_pool, hpm_use_max_pool ) - total_parts = sum(hpm_scales) + tfa_num_parts - empty_fc = torch.empty(total_parts, out_channels, embedding_dims) + empty_fc = torch.empty(self.hpm_num_parts + tfa_num_parts, + out_channels, embedding_dims) self.fc_mat = nn.Parameter(empty_fc) - self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin) + (hpm_margin, pn_margin) = triplet_margins + self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) + self.pn_ba_trip = BatchAllTripletLoss(pn_margin) def fc(self, x): return x @ self.fc_mat @@ -59,13 +67,11 @@ class RGBPartNet(nn.Module): # Step 1: Disentanglement # t, n, c, h, w - ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2) + ((x_c_c1, x_p_c1), images, losses) = self._disentangle(x_c1, x_c2) - # Step 2.a: HPM & Static Gait Feature Aggregation - # t, n, c, h, w + # Step 2.a: Static Gait Feature Aggregation & HPM + # n, c, h, w x_c = self.hpm(x_c_c1) - # p, t, n, c - x_c = x_c.mean(dim=1) # p, n, c # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) @@ -78,44 +84,83 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - batch_all_triplet_loss = self.ba_triplet_loss(x, y) - losses = torch.stack((*losses, batch_all_triplet_loss)) - return losses + hpm_ba_trip = self.hpm_ba_trip(x[:self.hpm_num_parts], y) + pn_ba_trip = self.pn_ba_trip(x[self.hpm_num_parts:], y) + losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) + return losses, images else: return x.unsqueeze(1).view(-1) def _disentangle(self, x_c1, x_c2=None): t, n, c, h, w = x_c1.size() + device = x_c1.device if self.training: - # Decoded canonical features and Pose images - x_c_c1, x_p_c1 = [], [] + # Encoded appearance, canonical and pose features + f_a_c1, f_c_c1, f_p_c1 = [], [], [] # Features required to calculate losses - f_p_c1, f_p_c2 = [], [] + f_p_c2 = [] xrecon_loss, cano_cons_loss = [], [] for t2 in range(t): t1 = random.randrange(t) output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2]) - (x_c1_t2, f_p_t2, losses) = output + (f_c1_t2, f_p_t2, losses) = output - # Decoded features or image - (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 - # Canonical Features for HPM - x_c_c1.append(x_c_c1_t2) - # Pose image for Part Net - x_p_c1.append(x_p_c1_t2) + (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = f_c1_t2 + if self.image_log_on: + f_a_c1.append(f_a_c1_t2) + # Save canonical features and pose features + f_c_c1.append(f_c_c1_t2) + f_p_c1.append(f_p_c1_t2) # Losses per time step # Used in pose similarity loss - (f_p_c1_t2, f_p_c2_t2) = f_p_t2 - f_p_c1.append(f_p_c1_t2) + (_, f_p_c2_t2) = f_p_t2 f_p_c2.append(f_p_c2_t2) + # Cross reconstruction loss and canonical loss (xrecon_loss_t2, cano_cons_loss_t2) = losses xrecon_loss.append(xrecon_loss_t2) cano_cons_loss.append(cano_cons_loss_t2) - - x_c_c1 = torch.stack(x_c_c1) - x_p_c1 = torch.stack(x_p_c1) + if self.image_log_on: + f_a_c1 = torch.stack(f_a_c1) + f_c_c1_mean = torch.stack(f_c_c1).mean(0) + f_p_c1 = torch.stack(f_p_c1) + f_p_c2 = torch.stack(f_p_c2) + + # Decode features + appearance_image, canonical_image, pose_image = None, None, None + with torch.no_grad(): + # Decode average canonical features to higher dimension + x_c_c1 = self.ae.decoder( + torch.zeros((n, self.f_a_dim), device=device), + f_c_c1_mean, + torch.zeros((n, self.f_p_dim), device=device), + cano_only=True + ) + # Decode pose features to images + f_p_c1_ = f_p_c1.view(t * n, -1) + x_p_c1_ = self.ae.decoder( + torch.zeros((t * n, self.f_a_dim), device=device), + torch.zeros((t * n, self.f_c_dim), device=device), + f_p_c1_ + ) + x_p_c1 = x_p_c1_.view(t, n, c, h, w) + + if self.image_log_on: + # Decode appearance features + f_a_c1_ = f_a_c1.view(t * n, -1) + appearance_image_ = self.ae.decoder( + f_a_c1_, + torch.zeros((t * n, self.f_c_dim), device=device), + torch.zeros((t * n, self.f_p_dim), device=device) + ) + appearance_image = appearance_image_.view(t, n, c, h, w) + # Continue decoding canonical features + canonical_image = self.ae.decoder.trans_conv3(x_c_c1) + canonical_image = torch.sigmoid( + self.ae.decoder.trans_conv4(canonical_image) + ) + pose_image = x_p_c1 # Losses xrecon_loss = torch.sum(torch.stack(xrecon_loss)) @@ -123,20 +168,36 @@ class RGBPartNet(nn.Module): cano_cons_loss = torch.mean(torch.stack(cano_cons_loss)) return ((x_c_c1, x_p_c1), + (appearance_image, canonical_image, pose_image), (xrecon_loss, pose_sim_loss, cano_cons_loss)) else: # evaluating - x_c1 = x_c1.view(-1, c, h, w) - x_c_c1, x_p_c1 = self.ae(x_c1) - _, c_c, h_c, w_c = x_c_c1.size() - x_c_c1 = x_c_c1.view(t, n, c_c, h_c, w_c) - x_p_c1 = x_p_c1.view(t, n, c, h, w) - - return (x_c_c1, x_p_c1), None + x_c1_ = x_c1.view(t * n, c, h, w) + (f_c_c1_, f_p_c1_) = self.ae(x_c1_) + + # Canonical features + f_c_c1 = f_c_c1_.view(t, n, -1) + f_c_c1_mean = f_c_c1.mean(0) + x_c_c1 = self.ae.decoder( + torch.zeros((n, self.f_a_dim)), + f_c_c1_mean, + torch.zeros((n, self.f_p_dim)), + cano_only=True + ) + + # Pose features + x_p_c1_ = self.ae.decoder( + torch.zeros((t * n, self.f_a_dim)), + torch.zeros((t * n, self.f_c_dim)), + f_p_c1_ + ) + x_p_c1 = x_p_c1_.view(t, n, c, h, w) + + return (x_c_c1, x_p_c1), None, None @staticmethod - def _pose_sim_loss(f_p_c1: list[torch.Tensor], - f_p_c2: list[torch.Tensor]) -> torch.Tensor: - f_p_c1_mean = torch.stack(f_p_c1).mean(dim=0) - f_p_c2_mean = torch.stack(f_p_c2).mean(dim=0) + def _pose_sim_loss(f_p_c1: torch.Tensor, + f_p_c2: torch.Tensor) -> torch.Tensor: + f_p_c1_mean = f_p_c1.mean(dim=0) + f_p_c2_mean = f_p_c2.mean(dim=0) return F.mse_loss(f_p_c1_mean, f_p_c2_mean) |