summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py99
1 files changed, 59 insertions, 40 deletions
diff --git a/models/model.py b/models/model.py
index f4a53bd..eb59862 100644
--- a/models/model.py
+++ b/models/model.py
@@ -267,11 +267,15 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+ embedding, f_loss, images = self.rgb_pn(x_c1, x_c2)
+ ae_losses = self._disentangling_loss(x_c1, f_loss)
y = batch_c1['label'].to(self.device)
- losses, hpm_result, pn_result = self._classification_loss(
- embed_c, embed_p, ae_losses, y
- )
+ results = self._classification_loss(embedding, y)
+ losses = torch.stack((
+ *ae_losses,
+ results[0]['loss'].mean(),
+ results[1]['loss'].mean()
+ ))
loss = losses.sum()
loss.backward()
self.optimizer.step()
@@ -282,9 +286,7 @@ class Model:
'Auto-encoder', 'HPM', 'PartNet'
), self.scheduler.get_last_lr())), self.curr_iter)
# Other stats
- self._write_stat(
- 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses
- )
+ self._write_stat('Train', embedding, results, losses)
# Write disentangled images
if self.image_log_on and self.curr_iter % self.image_log_steps \
@@ -306,8 +308,8 @@ class Model:
if self.curr_iter % 100 == 99:
# Validation
- embed_c = self._flatten_embedding(embed_c)
- embed_p = self._flatten_embedding(embed_p)
+ embed_c = self._flatten_embedding(embedding[0])
+ embed_p = self._flatten_embedding(embedding[1])
self._write_embedding('HPM Train', embed_c, x_c1, y)
self._write_embedding('PartNet Train', embed_p, x_c1, y)
@@ -316,18 +318,19 @@ class Model:
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
with torch.no_grad():
- embed_c, embed_p, ae_losses, _ = self.rgb_pn(x_c1, x_c2)
+ embedding, f_loss, images = self.rgb_pn(x_c1, x_c2)
+ ae_losses = self._disentangling_loss(x_c1, f_loss)
y = batch_c1['label'].to(self.device)
- losses, hpm_result, pn_result = self._classification_loss(
- embed_c, embed_p, ae_losses, y
- )
- loss = losses.sum()
-
- self._write_stat(
- 'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses
- )
- embed_c = self._flatten_embedding(embed_c)
- embed_p = self._flatten_embedding(embed_p)
+ results = self._classification_loss(embedding, y)
+ losses = torch.stack((
+ *ae_losses,
+ results[0]['loss'].mean(),
+ results[1]['loss'].mean()
+ ))
+
+ self._write_stat('Val', embedding, results, losses)
+ embed_c = self._flatten_embedding(embedding[0])
+ embed_p = self._flatten_embedding(embedding[1])
self._write_embedding('HPM Val', embed_c, x_c1, y)
self._write_embedding('PartNet Val', embed_p, x_c1, y)
@@ -342,21 +345,39 @@ class Model:
self.writer.close()
- def _classification_loss(self, embed_c, embed_p, ae_losses, y):
+ @staticmethod
+ def _disentangling_loss(x_c1_t2, f_loss):
+ n, t, c, h, w = x_c1_t2.size()
+ x_c1_t2_pred = f_loss[0]
+ xrecon_loss = torch.stack([
+ F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
+ for i in range(t)
+ ]).sum()
+ cano_cons_loss = torch.stack([
+ torch.stack([
+ F.mse_loss(_f_c_c1_t1[:, i, :], _f_c_c1_t2[:, i, :])
+ + F.mse_loss(_f_c_c1_t2[:, i, :], _f_c_c2_t2[:, i, :])
+ for i in range(t)
+ ]).mean()
+ for _f_c_c1_t2, _f_c_c1_t1, _f_c_c2_t2 in zip(*f_loss[1])
+ ]).sum()
+ pose_sim_loss = torch.stack([
+ F.mse_loss(_f_p_c1_t2.mean(1), _f_p_c2_t2.mean(1))
+ for _f_p_c1_t2, _f_p_c2_t2 in zip(*f_loss[2])
+ ]).sum()
+
+ return xrecon_loss, cano_cons_loss * 10, pose_sim_loss * 100
+
+ def _classification_loss(self, embedding, y):
# Duplicate labels for each part
y_triplet = y.repeat(self.rgb_pn.num_parts, 1)
hpm_result = self.triplet_loss_hpm(
- embed_c, y_triplet[:self.rgb_pn.hpm.num_parts]
+ embedding[0], y_triplet[:self.rgb_pn.hpm.num_parts]
)
pn_result = self.triplet_loss_pn(
- embed_p, y_triplet[self.rgb_pn.hpm.num_parts:]
+ embedding[1], y_triplet[self.rgb_pn.hpm.num_parts:]
)
- losses = torch.stack((
- *ae_losses,
- hpm_result.pop('loss').mean(),
- pn_result.pop('loss').mean()
- ))
- return losses, hpm_result, pn_result
+ return hpm_result, pn_result
def _write_embedding(self, tag, embed, x, y):
frame = x[:, 0, :, :, :].cpu()
@@ -374,11 +395,8 @@ class Model:
def _flatten_embedding(self, embed):
return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1)
- def _write_stat(
- self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses
- ):
+ def _write_stat(self, postfix, embeddings, results, losses):
# Write losses to TensorBoard
- self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter)
self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip((
'Cross reconstruction loss', 'Canonical consistency loss',
'Pose similarity loss'
@@ -388,30 +406,31 @@ class Model:
'PartNet': losses[4]
}, self.curr_iter)
# None-zero losses in batch
- if hpm_result['counts'] is not None and pn_result['counts'] is not None:
+ if results[0]['counts'] is not None \
+ and results[1]['counts'] is not None:
self.writer.add_scalars(f'Loss/non-zero counts {postfix}', {
- 'HPM': hpm_result['counts'].mean(),
- 'PartNet': pn_result['counts'].mean()
+ 'HPM': results[0]['counts'].mean(),
+ 'PartNet': results[1]['counts'].mean()
}, self.curr_iter)
# Embedding distance
- mean_hpm_dist = hpm_result['dist'].mean(0)
+ mean_hpm_dist = results[0]['dist'].mean(0)
self._add_ranked_scalars(
f'Embedding/HPM distance {postfix}', mean_hpm_dist,
self.num_pos_pairs, self.num_pairs, self.curr_iter
)
- mean_pn_dist = pn_result['dist'].mean(0)
+ mean_pn_dist = results[1]['dist'].mean(0)
self._add_ranked_scalars(
f'Embedding/ParNet distance {postfix}', mean_pn_dist,
self.num_pos_pairs, self.num_pairs, self.curr_iter
)
# Embedding norm
- mean_hpm_embedding = embed_c.mean(0)
+ mean_hpm_embedding = embeddings[0].mean(0)
mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
self._add_ranked_scalars(
f'Embedding/HPM norm {postfix}', mean_hpm_norm,
self.k, self.pr * self.k, self.curr_iter
)
- mean_pa_embedding = embed_p.mean(0)
+ mean_pa_embedding = embeddings[1].mean(0)
mean_pa_norm = mean_pa_embedding.norm(dim=-1)
self._add_ranked_scalars(
f'Embedding/PartNet norm {postfix}', mean_pa_norm,