diff options
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 49 |
1 files changed, 34 insertions, 15 deletions
diff --git a/models/model.py b/models/model.py index b09d600..bf79564 100644 --- a/models/model.py +++ b/models/model.py @@ -177,18 +177,21 @@ class Model: triplet_is_hard, triplet_is_mean, None ) + num_sampled_frames = dataset_config.get('num_sampled_frames', 30) num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 num_pos_pairs = (self.k*(self.k-1)//2) * self.pr # Try to accelerate computation using CUDA or others + self.rgb_pn = nn.DataParallel(self.rgb_pn) self.rgb_pn = self.rgb_pn.to(self.device) + self.triplet_loss_hpm = nn.DataParallel(self.triplet_loss_hpm) self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device) + self.triplet_loss_pn = nn.DataParallel(self.triplet_loss_pn) self.triplet_loss_pn = self.triplet_loss_pn.to(self.device) - self.optimizer = optim.Adam([ - {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, + {'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp}, + {'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp}, + {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp}, ], **optim_hp) start_step = sched_hp.get('start_step', 15_000) @@ -241,20 +244,34 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embed_c, embed_p, images, feature_for_loss = self.rgb_pn(x_c1, x_c2) + x_c1_pred = feature_for_loss[0] + xrecon_loss = torch.stack([ + F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :]) + for i in range(num_sampled_frames) + ]).sum() + f_c_c1_t1, f_c_c1_t2, f_c_c2_t2 = feature_for_loss[1] + cano_cons_loss = torch.stack([ + F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :]) + + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :]) + for i in range(num_sampled_frames) + ]).mean() + f_p_c1_t2, f_p_c2_t2 = feature_for_loss[2] + pose_sim_loss = F.mse_loss( + f_p_c1_t2.mean(1), f_p_c2_t2.mean(1) + ) * 10 y = batch_c1['label'].to(self.device) # Duplicate labels for each part - y = y.repeat(self.rgb_pn.num_parts, 1) + y = y.repeat(self.rgb_pn.module.num_parts, 1) trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( - embed_c, y[:self.rgb_pn.hpm.num_parts] + embed_c.transpose(0, 1), y[:self.rgb_pn.module.hpm.num_parts] ) trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( - embed_p, y[self.rgb_pn.hpm.num_parts:] + embed_p.transpose(0, 1), y[self.rgb_pn.module.hpm.num_parts:] ) losses = torch.stack(( - *ae_losses, - trip_loss_hpm.mean(), - trip_loss_pn.mean() + xrecon_loss, cano_cons_loss, pose_sim_loss, + trip_loss_hpm.mean(), trip_loss_pn.mean() )) loss = losses.sum() loss.backward() @@ -264,10 +281,11 @@ class Model: running_loss += losses.detach() # Write losses to TensorBoard self.writer.add_scalar('Loss/all', loss, self.curr_iter) - self.writer.add_scalars('Loss/disentanglement', dict(zip(( - 'Cross reconstruction loss', 'Canonical consistency loss', - 'Pose similarity loss' - ), ae_losses)), self.curr_iter) + self.writer.add_scalars('Loss/disentanglement', { + 'Cross reconstruction loss': xrecon_loss, + 'Canonical consistency loss': cano_cons_loss, + 'Pose similarity loss': pose_sim_loss + }, self.curr_iter) self.writer.add_scalars('Loss/triplet loss', { 'HPM': losses[3], 'PartNet': losses[4] @@ -411,6 +429,7 @@ class Model: model_hp.pop('triplet_margins', None) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp) # Try to accelerate computation using CUDA or others + self.rgb_pn = nn.DataParallel(self.rgb_pn) self.rgb_pn = self.rgb_pn.to(self.device) self.rgb_pn.eval() |