summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py50
1 files changed, 41 insertions, 9 deletions
diff --git a/models/model.py b/models/model.py
index ceadb92..9cac5e5 100644
--- a/models/model.py
+++ b/models/model.py
@@ -196,18 +196,21 @@ class Model:
triplet_is_hard, triplet_is_mean, None
)
+ num_sampled_frames = dataset_config.get('num_sampled_frames', 30)
self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
# Try to accelerate computation using CUDA or others
+ self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
+ self.triplet_loss_hpm = nn.DataParallel(self.triplet_loss_hpm)
self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device)
+ self.triplet_loss_pn = nn.DataParallel(self.triplet_loss_pn)
self.triplet_loss_pn = self.triplet_loss_pn.to(self.device)
-
self.optimizer = optim.Adam([
- {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
- {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
- {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
+ {'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp},
+ {'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp},
+ {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
], **optim_hp)
# Scheduler
@@ -259,7 +262,11 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+ embed_c, embed_p, images, f_loss = self.rgb_pn(x_c1, x_c2)
+ ae_losses = self._disentangling_loss(
+ x_c1, f_loss, num_sampled_frames
+ )
+ embed_c, embed_p = embed_c.transpose(0, 1), embed_p.transpose(0, 1)
y = batch_c1['label'].to(self.device)
losses, hpm_result, pn_result = self._classification_loss(
embed_c, embed_p, ae_losses, y
@@ -307,7 +314,12 @@ class Model:
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
with torch.no_grad():
- embed_c, embed_p, ae_losses, _ = self.rgb_pn(x_c1, x_c2)
+ embed_c, embed_p, _, f_loss = self.rgb_pn(x_c1, x_c2)
+ ae_losses = self._disentangling_loss(
+ x_c1, f_loss, num_sampled_frames
+ )
+ embed_c = embed_c.transpose(0, 1)
+ embed_p = embed_p.transpose(0, 1)
y = batch_c1['label'].to(self.device)
losses, hpm_result, pn_result = self._classification_loss(
embed_c, embed_p, ae_losses, y
@@ -333,14 +345,33 @@ class Model:
self.writer.close()
+ @staticmethod
+ def _disentangling_loss(x_c1, feature_for_loss, num_sampled_frames):
+ x_c1_pred = feature_for_loss[0]
+ xrecon_loss = torch.stack([
+ F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :])
+ for i in range(num_sampled_frames)
+ ]).sum()
+ f_c_c1_t1, f_c_c1_t2, f_c_c2_t2 = feature_for_loss[1]
+ cano_cons_loss = torch.stack([
+ F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
+ + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :])
+ for i in range(num_sampled_frames)
+ ]).mean()
+ f_p_c1_t2, f_p_c2_t2 = feature_for_loss[2]
+ pose_sim_loss = F.mse_loss(
+ f_p_c1_t2.mean(1), f_p_c2_t2.mean(1)
+ ) * 10
+ return xrecon_loss, cano_cons_loss, pose_sim_loss
+
def _classification_loss(self, embed_c, embed_p, ae_losses, y):
# Duplicate labels for each part
- y_triplet = y.repeat(self.rgb_pn.num_parts, 1)
+ y_triplet = y.repeat(self.rgb_pn.module.num_parts, 1)
hpm_result = self.triplet_loss_hpm(
- embed_c, y_triplet[:self.rgb_pn.hpm.num_parts]
+ embed_c, y_triplet[:self.rgb_pn.module.hpm.num_parts]
)
pn_result = self.triplet_loss_pn(
- embed_p, y_triplet[self.rgb_pn.hpm.num_parts:]
+ embed_p, y_triplet[self.rgb_pn.module.hpm.num_parts:]
)
losses = torch.stack((
*ae_losses,
@@ -471,6 +502,7 @@ class Model:
model_hp.pop('triplet_margins', None)
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
# Try to accelerate computation using CUDA or others
+ self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
self.rgb_pn.eval()