diff options
-rw-r--r-- | models/model.py | 17 | ||||
-rw-r--r-- | models/rgb_part_net.py | 2 |
2 files changed, 13 insertions, 6 deletions
diff --git a/models/model.py b/models/model.py index f79b832..eb12285 100644 --- a/models/model.py +++ b/models/model.py @@ -150,12 +150,13 @@ class Model: self.rgb_pn = RGBPartNet(self.in_channels, **model_hp, image_log_on=self.image_log_on) # Try to accelerate computation using CUDA or others + self.rgb_pn = nn.DataParallel(self.rgb_pn) self.rgb_pn = self.rgb_pn.to(self.device) self.optimizer = optim.Adam([ - {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, - {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.fc_mat, **fc_optim_hp} + {'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp}, + {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp}, + {'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp}, + {'params': self.rgb_pn.module.fc_mat, **fc_optim_hp} ], **optim_hp) sched_gamma = sched_hp.get('gamma', 0.9) sched_step_size = sched_hp.get('step_size', 500) @@ -194,8 +195,14 @@ class Model: x_c2 = batch_c2['clip'].to(self.device) y = batch_c1['label'].to(self.device) # Duplicate labels for each part - y = y.unsqueeze(1).repeat(1, self.rgb_pn.num_total_parts) + y = y.unsqueeze(1).repeat(1, self.rgb_pn.module.num_total_parts) losses, images = self.rgb_pn(x_c1, x_c2, y) + losses = torch.stack(( + # xrecon cano_cons pose_sim + losses[0].sum(), losses[1].mean(), losses[2].mean(), + # hpm_ba_trip pn_ba_trip + losses[3].mean(), losses[4].mean() + )) loss = losses.sum() loss.backward() self.optimizer.step() diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 2aa680c..66609fd 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -83,7 +83,7 @@ class RGBPartNet(nn.Module): pn_ba_trip = self.pn_ba_trip( x[self.hpm_num_parts:], y[self.hpm_num_parts:] ) - losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) + losses = (*losses, hpm_ba_trip, pn_ba_trip) return losses, images else: return x.unsqueeze(1).view(-1) |