diff options
-rw-r--r-- | config.py | 6 | ||||
-rw-r--r-- | models/auto_encoder.py | 15 | ||||
-rw-r--r-- | models/model.py | 76 | ||||
-rw-r--r-- | models/rgb_part_net.py | 8 |
4 files changed, 59 insertions, 46 deletions
@@ -5,7 +5,7 @@ config: Configuration = { # Disable accelerator 'disable_acc': False, # GPU(s) used in training or testing if available - 'CUDA_VISIBLE_DEVICES': '0', + 'CUDA_VISIBLE_DEVICES': '0,1', # Directory used in training or testing for temporary storage 'save_dir': 'runs', # Recorde disentangled image or not @@ -30,14 +30,14 @@ config: Configuration = { # Resolution after resize, can be divided 16 'frame_size': (64, 48), # Cache dataset or not - 'cache_on': False, + 'cache_on': True, }, # Dataloader settings 'dataloader': { # Batch size (pr, k) # `pr` denotes number of persons # `k` denotes number of sequences per person - 'batch_size': (4, 6), + 'batch_size': (6, 8), # Number of workers of Dataloader 'num_workers': 4, # Faster data transfer from RAM to GPU if enabled diff --git a/models/auto_encoder.py b/models/auto_encoder.py index dbd1da0..023b462 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -153,27 +153,18 @@ class AutoEncoder(nn.Module): x_c1_t2_pred_ = self.decoder(f_a_c1_t1_, f_c_c1_t1_, f_p_c1_t2_) x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w) - xrecon_loss = torch.stack([ - F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :]) - for i in range(t) - ]).sum() - f_c_c1_t1 = f_c_c1_t1_.view(n, t, -1) f_c_c1_t2 = f_c_c1_t2_.view(n, t, -1) f_c_c2_t2 = f_c_c2_t2_.view(n, t, -1) - cano_cons_loss = torch.stack([ - F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :]) - + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :]) - for i in range(t) - ]).mean() f_p_c1_t2 = f_p_c1_t2_.view(n, t, -1) f_p_c2_t2 = f_p_c2_t2_.view(n, t, -1) - pose_sim_loss = F.mse_loss(f_p_c1_t2.mean(1), f_p_c2_t2.mean(1)) return ( (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_), - torch.stack((xrecon_loss, cano_cons_loss, pose_sim_loss * 10)) + (x_c1_t2_pred, + (f_c_c1_t1, f_c_c1_t2, f_c_c2_t2), + (f_p_c1_t2, f_p_c2_t2)) ) else: # evaluating return f_c_c1_t2_, f_p_c1_t2_ diff --git a/models/model.py b/models/model.py index 7ce189c..2a74c8c 100644 --- a/models/model.py +++ b/models/model.py @@ -164,7 +164,7 @@ class Model: ) else: # Different margins self.triplet_loss = JointBatchTripletLoss( - self.rgb_pn.hpm_num_parts, + self.rgb_pn.module.hpm_num_parts, triplet_is_hard, triplet_is_mean, triplet_margins ) else: # Soft margins @@ -172,17 +172,20 @@ class Model: triplet_is_hard, triplet_is_mean, None ) + num_sampled_frames = dataset_config.get('num_sampled_frames', 30) num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 num_pos_pairs = (self.k*(self.k-1)//2) * self.pr # Try to accelerate computation using CUDA or others + self.rgb_pn = nn.DataParallel(self.rgb_pn) self.rgb_pn = self.rgb_pn.to(self.device) + self.triplet_loss = nn.DataParallel(self.triplet_loss) self.triplet_loss = self.triplet_loss.to(self.device) self.optimizer = optim.Adam([ - {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, - {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.fc_mat, **fc_optim_hp} + {'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp}, + {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp}, + {'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp}, + {'params': self.rgb_pn.module.fc_mat, **fc_optim_hp} ], **optim_hp) sched_final_gamma = sched_hp.get('final_gamma', 0.001) sched_start_step = sched_hp.get('start_step', 15_000) @@ -228,17 +231,31 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embedding, images, feature_for_loss = self.rgb_pn(x_c1, x_c2) + x_c1_pred = feature_for_loss[0] + xrecon_loss = torch.stack([ + F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :]) + for i in range(num_sampled_frames) + ]).sum() + f_c_c1_t1, f_c_c1_t2, f_c_c2_t2 = feature_for_loss[1] + cano_cons_loss = torch.stack([ + F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :]) + + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :]) + for i in range(num_sampled_frames) + ]).mean() + f_p_c1_t2, f_p_c2_t2 = feature_for_loss[2] + pose_sim_loss = F.mse_loss( + f_p_c1_t2.mean(1), f_p_c2_t2.mean(1) + ) * 10 y = batch_c1['label'].to(self.device) # Duplicate labels for each part - y = y.repeat(self.rgb_pn.num_total_parts, 1) - trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) - losses = torch.cat(( - ae_losses, - torch.stack(( - trip_loss[:self.rgb_pn.hpm_num_parts].mean(), - trip_loss[self.rgb_pn.hpm_num_parts:].mean() - )) + y = y.repeat(self.rgb_pn.module.num_total_parts, 1) + embedding = embedding.transpose(0, 1) + triplet_loss, dist, num_non_zero = self.triplet_loss(embedding, y) + hpm_loss = triplet_loss[:self.rgb_pn.module.hpm_num_parts].mean() + pn_loss = triplet_loss[self.rgb_pn.module.hpm_num_parts:].mean() + losses = torch.stack(( + xrecon_loss, cano_cons_loss, pose_sim_loss, hpm_loss, pn_loss )) loss = losses.sum() loss.backward() @@ -248,39 +265,43 @@ class Model: running_loss += losses.detach() # Write losses to TensorBoard self.writer.add_scalar('Loss/all', loss, self.curr_iter) - self.writer.add_scalars('Loss/disentanglement', dict(zip(( - 'Cross reconstruction loss', 'Canonical consistency loss', - 'Pose similarity loss' - ), ae_losses)), self.curr_iter) + self.writer.add_scalars('Loss/disentanglement', { + 'Cross reconstruction loss': xrecon_loss, + 'Canonical consistency loss': cano_cons_loss, + 'Pose similarity loss': pose_sim_loss + }, self.curr_iter) self.writer.add_scalars('Loss/triplet loss', { - 'HPM': losses[3], - 'PartNet': losses[4] + 'HPM': hpm_loss, 'PartNet': pn_loss }, self.curr_iter) # None-zero losses in batch if num_non_zero is not None: self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), - 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() + 'HPM': num_non_zero[ + :self.rgb_pn.module.hpm_num_parts].mean(), + 'PartNet': num_non_zero[ + self.rgb_pn.module.hpm_num_parts:].mean() }, self.curr_iter) # Embedding distance - mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_dist = dist[:self.rgb_pn.module.hpm_num_parts].mean(0) self._add_ranked_scalars( 'Embedding/HPM distance', mean_hpm_dist, num_pos_pairs, num_pairs, self.curr_iter ) - mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_dist = dist[self.rgb_pn.module.hpm_num_parts:].mean(0) self._add_ranked_scalars( 'Embedding/ParNet distance', mean_pa_dist, num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm - mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_embedding = embedding[ + :self.rgb_pn.module.hpm_num_parts].mean(0) mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/HPM norm', mean_hpm_norm, self.k, self.pr * self.k, self.curr_iter ) - mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_embedding = embedding[ + self.rgb_pn.module.hpm_num_parts:].mean(0) mean_pa_norm = mean_pa_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/PartNet norm', mean_pa_norm, @@ -390,12 +411,13 @@ class Model: ) # Init models - model_hp: dict = self.hp.get('model', {}).copy() + model_hp: Dict = self.hp.get('model', {}).copy() model_hp.pop('triplet_is_hard', True) model_hp.pop('triplet_is_mean', True) model_hp.pop('triplet_margins', None) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp) # Try to accelerate computation using CUDA or others + self.rgb_pn = nn.DataParallel(self.rgb_pn) self.rgb_pn = self.rgb_pn.to(self.device) self.rgb_pn.eval() diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 310ef25..2853571 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -52,7 +52,7 @@ class RGBPartNet(nn.Module): def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement # n, t, c, h, w - ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2) + ((x_c, x_p), images, f_loss) = self._disentangle(x_c1, x_c2) # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w @@ -69,7 +69,7 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - return x, ae_losses, images + return x.transpose(0, 1), images, f_loss else: return x.unsqueeze(1).view(-1) @@ -78,7 +78,7 @@ class RGBPartNet(nn.Module): device = x_c1_t2.device if self.training: x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] - ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) + (f_a_, f_c_, f_p_), f_loss = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) # Decode features x_c = self._decode_cano_feature(f_c_, n, t, device) x_p_ = self._decode_pose_feature(f_p_, n, t, device) @@ -95,7 +95,7 @@ class RGBPartNet(nn.Module): i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_)) i_p = i_p_.view(n, t, c, h, w) - return (x_c, x_p), losses, (i_a, i_c, i_p) + return (x_c, x_p), (i_a, i_c, i_p), f_loss else: # evaluating f_c_, f_p_ = self.ae(x_c1_t2) |