summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-04-10 22:34:25 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-04-10 22:34:25 +0800
commitb294b715ec0de6ba94199f3b068dc828095fd2f1 (patch)
tree6b52d1639a80c1800c1fc03dd48c824f92cb0b40
parentaf7faa0f6d1eb3117359f5cf8e4d27a75f3f961c (diff)
Calculate pose similarity loss and canonical consistency loss of each part after pooling
-rw-r--r--models/auto_encoder.py14
-rw-r--r--models/hpm.py20
-rw-r--r--models/model.py99
-rw-r--r--models/part_net.py20
-rw-r--r--models/rgb_part_net.py37
5 files changed, 108 insertions, 82 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 96dfdb3..dc7843a 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -134,25 +134,13 @@ class AutoEncoder(nn.Module):
x_c1_t2_pred_ = self.decoder(f_a_c1_t1_, f_c_c1_t1_, f_p_c1_t2_)
x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w)
- xrecon_loss = torch.stack([
- F.mse_loss(x_c1_t2[:, i], x_c1_t2_pred[:, i])
- for i in range(t)
- ]).sum()
-
f_c_c1_t1 = f_c_c1_t1_.view(f_size[1])
f_c_c2_t2 = f_c_c2_t2_.view(f_size[1])
- cano_cons_loss = torch.stack([
- F.mse_loss(f_c_c1_t1[:, i], f_c_c1_t2[:, i])
- + F.mse_loss(f_c_c1_t2[:, i], f_c_c2_t2[:, i])
- for i in range(t)
- ]).mean()
-
f_p_c2_t2 = f_p_c2_t2_.view(f_size[2])
- pose_sim_loss = F.mse_loss(f_p_c1_t2.mean(1), f_p_c2_t2.mean(1))
return (
(f_a_c1_t2, f_c_c1_t2, f_p_c1_t2),
- (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
+ (x_c1_t2_pred, (f_c_c1_t1, f_c_c2_t2), f_p_c2_t2)
)
else: # evaluating
return f_a_c1_t2, f_c_c1_t2, f_p_c1_t2
diff --git a/models/hpm.py b/models/hpm.py
index 8186b20..fa0f69e 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -33,8 +33,9 @@ class HorizontalPyramidMatching(nn.Module):
])
return pyramid
- def forward(self, x):
- n, c, h, w = x.size()
+ def _horizontal_pyramid_pool(self, x):
+ n, t, c, h, w = x.size()
+ x = x.view(n * t, c, h, w)
feature = []
for scale, pyramid in zip(self.scales, self.pyramids):
h_per_hpp = h // scale
@@ -43,12 +44,23 @@ class HorizontalPyramidMatching(nn.Module):
(hpp_index + 1) * h_per_hpp)
x_slice = x[:, :, h_filter, :]
x_slice = hpp(x_slice)
- x_slice = x_slice.view(n, -1)
+ x_slice = x_slice.view(n, t, c)
feature.append(x_slice)
x = torch.stack(feature)
+ return x
+ def forward(self, f_c1_t2, f_c1_t1=None, f_c2_t2=None):
+ # n, t, c, h, w
+ f_c1_t2_ = self._horizontal_pyramid_pool(f_c1_t2)
+ # p, n, t, c
+ x = f_c1_t2_.mean(2)
# p, n, c
x = x @ self.fc_mat
# p, n, d
- return x
+ if self.training:
+ f_c1_t1_ = self._horizontal_pyramid_pool(f_c1_t1)
+ f_c2_t2_ = self._horizontal_pyramid_pool(f_c2_t2)
+ return x, (f_c1_t2_, f_c1_t1_, f_c2_t2_)
+ else:
+ return x
diff --git a/models/model.py b/models/model.py
index 45067e6..6118bdf 100644
--- a/models/model.py
+++ b/models/model.py
@@ -267,11 +267,15 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+ embedding, f_loss, images = self.rgb_pn(x_c1, x_c2)
+ ae_losses = self._disentangling_loss(x_c1, f_loss)
y = batch_c1['label'].to(self.device)
- losses, hpm_result, pn_result = self._classification_loss(
- embed_c, embed_p, ae_losses, y
- )
+ results = self._classification_loss(embedding, y)
+ losses = torch.stack((
+ *ae_losses,
+ results[0]['loss'].mean(),
+ results[1]['loss'].mean()
+ ))
loss = losses.sum()
loss.backward()
self.optimizer.step()
@@ -282,9 +286,7 @@ class Model:
'Auto-encoder', 'HPM', 'PartNet'
), self.scheduler.get_last_lr())), self.curr_iter)
# Other stats
- self._write_stat(
- 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses
- )
+ self._write_stat('Train', embedding, results, losses)
# Write disentangled images
if self.image_log_on and self.curr_iter % self.image_log_steps \
@@ -306,8 +308,8 @@ class Model:
if self.curr_iter % 100 == 99:
# Validation
- embed_c = self._flatten_embedding(embed_c)
- embed_p = self._flatten_embedding(embed_p)
+ embed_c = self._flatten_embedding(embedding[0])
+ embed_p = self._flatten_embedding(embedding[1])
self._write_embedding('HPM Train', embed_c, x_c1, y)
self._write_embedding('PartNet Train', embed_p, x_c1, y)
@@ -316,18 +318,19 @@ class Model:
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
with torch.no_grad():
- embed_c, embed_p, ae_losses, _ = self.rgb_pn(x_c1, x_c2)
+ embedding, f_loss, images = self.rgb_pn(x_c1, x_c2)
+ ae_losses = self._disentangling_loss(x_c1, f_loss)
y = batch_c1['label'].to(self.device)
- losses, hpm_result, pn_result = self._classification_loss(
- embed_c, embed_p, ae_losses, y
- )
- loss = losses.sum()
-
- self._write_stat(
- 'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses
- )
- embed_c = self._flatten_embedding(embed_c)
- embed_p = self._flatten_embedding(embed_p)
+ results = self._classification_loss(embedding, y)
+ losses = torch.stack((
+ *ae_losses,
+ results[0]['loss'].mean(),
+ results[1]['loss'].mean()
+ ))
+
+ self._write_stat('Val', embedding, results, losses)
+ embed_c = self._flatten_embedding(embedding[0])
+ embed_p = self._flatten_embedding(embedding[1])
self._write_embedding('HPM Val', embed_c, x_c1, y)
self._write_embedding('PartNet Val', embed_p, x_c1, y)
@@ -342,21 +345,39 @@ class Model:
self.writer.close()
- def _classification_loss(self, embed_c, embed_p, ae_losses, y):
+ @staticmethod
+ def _disentangling_loss(x_c1_t2, f_loss):
+ n, t, c, h, w = x_c1_t2.size()
+ x_c1_t2_pred = f_loss[0]
+ xrecon_loss = torch.stack([
+ F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
+ for i in range(t)
+ ]).sum()
+ cano_cons_loss = torch.stack([
+ torch.stack([
+ F.mse_loss(_f_c_c1_t1[:, i, :], _f_c_c1_t2[:, i, :])
+ + F.mse_loss(_f_c_c1_t2[:, i, :], _f_c_c2_t2[:, i, :])
+ for i in range(t)
+ ]).mean()
+ for _f_c_c1_t2, _f_c_c1_t1, _f_c_c2_t2 in zip(*f_loss[1])
+ ]).sum()
+ pose_sim_loss = torch.stack([
+ F.mse_loss(_f_p_c1_t2.mean(1), _f_p_c2_t2.mean(1))
+ for _f_p_c1_t2, _f_p_c2_t2 in zip(*f_loss[2])
+ ]).sum()
+
+ return xrecon_loss, cano_cons_loss * 10, pose_sim_loss * 100
+
+ def _classification_loss(self, embedding, y):
# Duplicate labels for each part
y_triplet = y.repeat(self.rgb_pn.num_parts, 1)
hpm_result = self.triplet_loss_hpm(
- embed_c, y_triplet[:self.rgb_pn.hpm.num_parts]
+ embedding[0], y_triplet[:self.rgb_pn.hpm.num_parts]
)
pn_result = self.triplet_loss_pn(
- embed_p, y_triplet[self.rgb_pn.hpm.num_parts:]
+ embedding[1], y_triplet[self.rgb_pn.hpm.num_parts:]
)
- losses = torch.stack((
- *ae_losses,
- hpm_result.pop('loss').mean(),
- pn_result.pop('loss').mean()
- ))
- return losses, hpm_result, pn_result
+ return hpm_result, pn_result
def _write_embedding(self, tag, embed, x, y):
frame = x[:, 0, :, :, :].cpu()
@@ -374,11 +395,8 @@ class Model:
def _flatten_embedding(self, embed):
return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1)
- def _write_stat(
- self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses
- ):
+ def _write_stat(self, postfix, embeddings, results, losses):
# Write losses to TensorBoard
- self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter)
self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip((
'Cross reconstruction loss', 'Canonical consistency loss',
'Pose similarity loss'
@@ -388,30 +406,31 @@ class Model:
'PartNet': losses[4]
}, self.curr_iter)
# None-zero losses in batch
- if hpm_result['counts'] is not None and pn_result['counts'] is not None:
+ if results[0]['counts'] is not None \
+ and results[1]['counts'] is not None:
self.writer.add_scalars(f'Loss/non-zero counts {postfix}', {
- 'HPM': hpm_result['counts'].mean(),
- 'PartNet': pn_result['counts'].mean()
+ 'HPM': results[0]['counts'].mean(),
+ 'PartNet': results[1]['counts'].mean()
}, self.curr_iter)
# Embedding distance
- mean_hpm_dist = hpm_result['dist'].mean(0)
+ mean_hpm_dist = results[0]['dist'].mean(0)
self._add_ranked_scalars(
f'Embedding/HPM distance {postfix}', mean_hpm_dist,
self.num_pos_pairs, self.num_pairs, self.curr_iter
)
- mean_pn_dist = pn_result['dist'].mean(0)
+ mean_pn_dist = results[1]['dist'].mean(0)
self._add_ranked_scalars(
f'Embedding/ParNet distance {postfix}', mean_pn_dist,
self.num_pos_pairs, self.num_pairs, self.curr_iter
)
# Embedding norm
- mean_hpm_embedding = embed_c.mean(0)
+ mean_hpm_embedding = embeddings[0].mean(0)
mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
self._add_ranked_scalars(
f'Embedding/HPM norm {postfix}', mean_hpm_norm,
self.k, self.pr * self.k, self.curr_iter
)
- mean_pa_embedding = embed_p.mean(0)
+ mean_pa_embedding = embeddings[1].mean(0)
mean_pa_norm = mean_pa_embedding.norm(dim=-1)
self._add_ranked_scalars(
f'Embedding/PartNet norm {postfix}', mean_pa_norm,
diff --git a/models/part_net.py b/models/part_net.py
index f2236bf..65a2c14 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -127,23 +127,27 @@ class PartNet(nn.Module):
torch.empty(num_parts, in_channels, embedding_dims)
)
- def forward(self, x):
+ def _horizontal_pool(self, x):
n, t, c, h, w = x.size()
x = x.view(n * t, c, h, w)
- # n * t x c x h x w
-
- # Horizontal Pooling
- _, c, h, w = x.size()
split_size = h // self.num_part
x = x.split(split_size, dim=2)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
x = [x_.view(n, t, c) for x_ in x]
x = torch.stack(x)
+ return x
+ def forward(self, f_c1_t2, f_c2_t2=None):
+ # n, t, c, h, w
+ f_c1_t2_ = self._horizontal_pool(f_c1_t2)
# p, n, t, c
- x = self.tfa(x)
-
+ x = self.tfa(f_c1_t2_)
# p, n, c
x = x @ self.fc_mat
# p, n, d
- return x
+
+ if self.training:
+ f_c2_t2_ = self._horizontal_pool(f_c2_t2)
+ return x, (f_c1_t2_, f_c2_t2_)
+ else:
+ return x
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index b0169e3..06cbf28 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -39,26 +39,26 @@ class RGBPartNet(nn.Module):
self.num_parts = self.hpm.num_parts + tfa_num_parts
def forward(self, x_c1, x_c2=None):
- # Step 1: Disentanglement
- # n, t, c, h, w
- (f_a, f_c, f_p), ae_losses = self._disentangle(x_c1, x_c2)
+ if self.training:
+ # Step 1: Disentanglement
+ # n, t, c, h, w
+ (f_a, f_c, f_p), f_loss = self._disentangle(x_c1, x_c2)
- # Step 2.a: Static Gait Feature Aggregation & HPM
- # n, c, h, w
- f_c_mean = f_c.mean(1)
- x_c = self.hpm(f_c_mean)
- # p, n, d
+ # Step 2.a: Static Gait Feature Aggregation & HPM
+ # n, t, c, h, w
+ x_c, f_c_loss = self.hpm(f_c, *f_loss[1])
+ # p, n, d / p, n, t, c
- # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
- # n, t, c, h, w
- x_p = self.pn(f_p)
- # p, n, d
+ # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
+ # n, t, c, h, w
+ x_p, f_p_loss = self.pn(f_p, f_loss[2])
+ # p, n, d / p, n, t, c
- if self.training:
i_a, i_c, i_p = None, None, None
if self.image_log_on:
with torch.no_grad():
f_a_mean = f_a.mean(1)
+ f_c_mean = f_c.mean(1)
i_a = self.ae.decoder(
f_a_mean,
torch.zeros_like(f_c_mean),
@@ -77,15 +77,18 @@ class RGBPartNet(nn.Module):
device=f_c.device),
f_p.view(-1, *f_p_size[2:])
).view(x_c1.size())
- return x_c, x_p, ae_losses, (i_a, i_c, i_p)
- else:
+ return (x_c, x_p), (f_loss[0], f_c_loss, f_p_loss), (i_a, i_c, i_p)
+ else: # Evaluating
+ f_c, f_p = self._disentangle(x_c1, x_c2)
+ x_c = self.hpm(f_c)
+ x_p = self.pn(f_p)
return x_c, x_p
def _disentangle(self, x_c1_t2, x_c2_t2=None):
if self.training:
x_c1_t1 = x_c1_t2[:, torch.randperm(x_c1_t2.size(1)), :, :, :]
- features, losses = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
- return features, losses
+ features, f_loss = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
+ return features, f_loss
else: # evaluating
features = self.ae(x_c1_t2)
return features, None