summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py15
-rw-r--r--models/model.py50
-rw-r--r--models/rgb_part_net.py8
3 files changed, 48 insertions, 25 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 4fece69..0694ff1 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -151,27 +151,18 @@ class AutoEncoder(nn.Module):
x_c1_t2_pred_ = self.decoder(f_a_c1_t1_, f_c_c1_t1_, f_p_c1_t2_)
x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w)
- xrecon_loss = torch.stack([
- F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
- for i in range(t)
- ]).sum()
-
f_c_c1_t1 = f_c_c1_t1_.view(n, t, -1)
f_c_c1_t2 = f_c_c1_t2_.view(n, t, -1)
f_c_c2_t2 = f_c_c2_t2_.view(n, t, -1)
- cano_cons_loss = torch.stack([
- F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
- + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :])
- for i in range(t)
- ]).mean()
f_p_c1_t2 = f_p_c1_t2_.view(n, t, -1)
f_p_c2_t2 = f_p_c2_t2_.view(n, t, -1)
- pose_sim_loss = F.mse_loss(f_p_c1_t2.mean(1), f_p_c2_t2.mean(1))
return (
(f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_),
- (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
+ (x_c1_t2_pred,
+ (f_c_c1_t1, f_c_c1_t2, f_c_c2_t2),
+ (f_p_c1_t2, f_p_c2_t2))
)
else: # evaluating
return f_c_c1_t2_, f_p_c1_t2_
diff --git a/models/model.py b/models/model.py
index ceadb92..9cac5e5 100644
--- a/models/model.py
+++ b/models/model.py
@@ -196,18 +196,21 @@ class Model:
triplet_is_hard, triplet_is_mean, None
)
+ num_sampled_frames = dataset_config.get('num_sampled_frames', 30)
self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
# Try to accelerate computation using CUDA or others
+ self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
+ self.triplet_loss_hpm = nn.DataParallel(self.triplet_loss_hpm)
self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device)
+ self.triplet_loss_pn = nn.DataParallel(self.triplet_loss_pn)
self.triplet_loss_pn = self.triplet_loss_pn.to(self.device)
-
self.optimizer = optim.Adam([
- {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
- {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
- {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
+ {'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp},
+ {'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp},
+ {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
], **optim_hp)
# Scheduler
@@ -259,7 +262,11 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+ embed_c, embed_p, images, f_loss = self.rgb_pn(x_c1, x_c2)
+ ae_losses = self._disentangling_loss(
+ x_c1, f_loss, num_sampled_frames
+ )
+ embed_c, embed_p = embed_c.transpose(0, 1), embed_p.transpose(0, 1)
y = batch_c1['label'].to(self.device)
losses, hpm_result, pn_result = self._classification_loss(
embed_c, embed_p, ae_losses, y
@@ -307,7 +314,12 @@ class Model:
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
with torch.no_grad():
- embed_c, embed_p, ae_losses, _ = self.rgb_pn(x_c1, x_c2)
+ embed_c, embed_p, _, f_loss = self.rgb_pn(x_c1, x_c2)
+ ae_losses = self._disentangling_loss(
+ x_c1, f_loss, num_sampled_frames
+ )
+ embed_c = embed_c.transpose(0, 1)
+ embed_p = embed_p.transpose(0, 1)
y = batch_c1['label'].to(self.device)
losses, hpm_result, pn_result = self._classification_loss(
embed_c, embed_p, ae_losses, y
@@ -333,14 +345,33 @@ class Model:
self.writer.close()
+ @staticmethod
+ def _disentangling_loss(x_c1, feature_for_loss, num_sampled_frames):
+ x_c1_pred = feature_for_loss[0]
+ xrecon_loss = torch.stack([
+ F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :])
+ for i in range(num_sampled_frames)
+ ]).sum()
+ f_c_c1_t1, f_c_c1_t2, f_c_c2_t2 = feature_for_loss[1]
+ cano_cons_loss = torch.stack([
+ F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
+ + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :])
+ for i in range(num_sampled_frames)
+ ]).mean()
+ f_p_c1_t2, f_p_c2_t2 = feature_for_loss[2]
+ pose_sim_loss = F.mse_loss(
+ f_p_c1_t2.mean(1), f_p_c2_t2.mean(1)
+ ) * 10
+ return xrecon_loss, cano_cons_loss, pose_sim_loss
+
def _classification_loss(self, embed_c, embed_p, ae_losses, y):
# Duplicate labels for each part
- y_triplet = y.repeat(self.rgb_pn.num_parts, 1)
+ y_triplet = y.repeat(self.rgb_pn.module.num_parts, 1)
hpm_result = self.triplet_loss_hpm(
- embed_c, y_triplet[:self.rgb_pn.hpm.num_parts]
+ embed_c, y_triplet[:self.rgb_pn.module.hpm.num_parts]
)
pn_result = self.triplet_loss_pn(
- embed_p, y_triplet[self.rgb_pn.hpm.num_parts:]
+ embed_p, y_triplet[self.rgb_pn.module.hpm.num_parts:]
)
losses = torch.stack((
*ae_losses,
@@ -471,6 +502,7 @@ class Model:
model_hp.pop('triplet_margins', None)
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
# Try to accelerate computation using CUDA or others
+ self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
self.rgb_pn.eval()
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 4a82da3..5d2c142 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -42,7 +42,7 @@ class RGBPartNet(nn.Module):
def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
# n, t, c, h, w
- ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2)
+ ((x_c, x_p), images, f_loss) = self._disentangle(x_c1, x_c2)
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
@@ -55,7 +55,7 @@ class RGBPartNet(nn.Module):
# p, n, d
if self.training:
- return x_c, x_p, ae_losses, images
+ return x_c.transpose(0, 1), x_p.transpose(0, 1), images, f_loss
else:
return x_c, x_p
@@ -64,7 +64,7 @@ class RGBPartNet(nn.Module):
device = x_c1_t2.device
if self.training:
x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
- ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
+ (f_a_, f_c_, f_p_), f_loss = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
# Decode features
x_c = self._decode_cano_feature(f_c_, n, t, device)
x_p_ = self._decode_pose_feature(f_p_, n, t, device)
@@ -81,7 +81,7 @@ class RGBPartNet(nn.Module):
i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_))
i_p = i_p_.view(n, t, c, h, w)
- return (x_c, x_p), losses, (i_a, i_c, i_p)
+ return (x_c, x_p), (i_a, i_c, i_p), f_loss
else: # evaluating
f_c_, f_p_ = self.ae(x_c1_t2)