summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py76
1 files changed, 49 insertions, 27 deletions
diff --git a/models/model.py b/models/model.py
index e83cc7f..4335bc9 100644
--- a/models/model.py
+++ b/models/model.py
@@ -164,7 +164,7 @@ class Model:
)
else: # Different margins
self.triplet_loss = JointBatchTripletLoss(
- self.rgb_pn.hpm_num_parts,
+ self.rgb_pn.module.hpm_num_parts,
triplet_is_hard, triplet_is_mean, triplet_margins
)
else: # Soft margins
@@ -172,17 +172,20 @@ class Model:
triplet_is_hard, triplet_is_mean, None
)
+ num_sampled_frames = dataset_config.get('num_sampled_frames', 30)
num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
# Try to accelerate computation using CUDA or others
+ self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
+ self.triplet_loss = nn.DataParallel(self.triplet_loss)
self.triplet_loss = self.triplet_loss.to(self.device)
self.optimizer = optim.Adam([
- {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
- {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
- {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
- {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
+ {'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp},
+ {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
+ {'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp},
+ {'params': self.rgb_pn.module.fc_mat, **fc_optim_hp}
], **optim_hp)
sched_final_gamma = sched_hp.get('final_gamma', 0.001)
sched_start_step = sched_hp.get('start_step', 15_000)
@@ -228,17 +231,31 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+ embedding, images, feature_for_loss = self.rgb_pn(x_c1, x_c2)
+ x_c1_pred = feature_for_loss[0]
+ xrecon_loss = torch.stack([
+ F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :])
+ for i in range(num_sampled_frames)
+ ]).sum()
+ f_c_c1_t1, f_c_c1_t2, f_c_c2_t2 = feature_for_loss[1]
+ cano_cons_loss = torch.stack([
+ F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
+ + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :])
+ for i in range(num_sampled_frames)
+ ]).mean()
+ f_p_c1_t2, f_p_c2_t2 = feature_for_loss[2]
+ pose_sim_loss = F.mse_loss(
+ f_p_c1_t2.mean(1), f_p_c2_t2.mean(1)
+ ) * 10
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
- y = y.repeat(self.rgb_pn.num_total_parts, 1)
- trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
- losses = torch.cat((
- ae_losses,
- torch.stack((
- trip_loss[:self.rgb_pn.hpm_num_parts].mean(),
- trip_loss[self.rgb_pn.hpm_num_parts:].mean()
- ))
+ y = y.repeat(self.rgb_pn.module.num_total_parts, 1)
+ embedding = embedding.transpose(0, 1)
+ triplet_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
+ hpm_loss = triplet_loss[:self.rgb_pn.module.hpm_num_parts].mean()
+ pn_loss = triplet_loss[self.rgb_pn.module.hpm_num_parts:].mean()
+ losses = torch.stack((
+ xrecon_loss, cano_cons_loss, pose_sim_loss, hpm_loss, pn_loss
))
loss = losses.sum()
loss.backward()
@@ -248,39 +265,43 @@ class Model:
running_loss += losses.detach()
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
- self.writer.add_scalars('Loss/disentanglement', dict(zip((
- 'Cross reconstruction loss', 'Canonical consistency loss',
- 'Pose similarity loss'
- ), ae_losses)), self.curr_iter)
+ self.writer.add_scalars('Loss/disentanglement', {
+ 'Cross reconstruction loss': xrecon_loss,
+ 'Canonical consistency loss': cano_cons_loss,
+ 'Pose similarity loss': pose_sim_loss
+ }, self.curr_iter)
self.writer.add_scalars('Loss/triplet loss', {
- 'HPM': losses[3],
- 'PartNet': losses[4]
+ 'HPM': hpm_loss, 'PartNet': pn_loss
}, self.curr_iter)
# None-zero losses in batch
if num_non_zero is not None:
self.writer.add_scalars('Loss/non-zero counts', {
- 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(),
- 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean()
+ 'HPM': num_non_zero[
+ :self.rgb_pn.module.hpm_num_parts].mean(),
+ 'PartNet': num_non_zero[
+ self.rgb_pn.module.hpm_num_parts:].mean()
}, self.curr_iter)
# Embedding distance
- mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0)
+ mean_hpm_dist = dist[:self.rgb_pn.module.hpm_num_parts].mean(0)
self._add_ranked_scalars(
'Embedding/HPM distance', mean_hpm_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
- mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0)
+ mean_pa_dist = dist[self.rgb_pn.module.hpm_num_parts:].mean(0)
self._add_ranked_scalars(
'Embedding/ParNet distance', mean_pa_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
# Embedding norm
- mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0)
+ mean_hpm_embedding = embedding[
+ :self.rgb_pn.module.hpm_num_parts].mean(0)
mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/HPM norm', mean_hpm_norm,
self.k, self.pr * self.k, self.curr_iter
)
- mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0)
+ mean_pa_embedding = embedding[
+ self.rgb_pn.module.hpm_num_parts:].mean(0)
mean_pa_norm = mean_pa_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/PartNet norm', mean_pa_norm,
@@ -390,12 +411,13 @@ class Model:
)
# Init models
- model_hp: dict = self.hp.get('model', {}).copy()
+ model_hp: Dict = self.hp.get('model', {}).copy()
model_hp.pop('triplet_is_hard', True)
model_hp.pop('triplet_is_mean', True)
model_hp.pop('triplet_margins', None)
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
# Try to accelerate computation using CUDA or others
+ self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
self.rgb_pn.eval()