summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-09 21:28:38 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-09 21:28:38 +0800
commit58ef39d75098bce92654492e09edf1e83033d0c8 (patch)
tree8af7fe4fb5adfe1b189353dcff4efc38f62cd0c4 /models
parentd380e04df37593e414bd5641db100613fb2ad882 (diff)
parent916cf90d04e57fee23092c966740fbe94fd92cff (diff)
Merge branch 'master' into python3.8
# Conflicts: # models/rgb_part_net.py
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py43
-rw-r--r--models/model.py83
-rw-r--r--models/part_net.py17
-rw-r--r--models/rgb_part_net.py169
4 files changed, 138 insertions, 174 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 69dae4e..7b9b29f 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -123,23 +123,42 @@ class AutoEncoder(nn.Module):
self.decoder = Decoder(embedding_dims, feature_channels, channels)
def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None):
+ n, t, c, h, w = x_c1_t2.size()
# x_c1_t2 is the frame for later module
- (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
+ x_c1_t2_ = x_c1_t2.view(n * t, c, h, w)
+ (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_) = self.encoder(x_c1_t2_)
if self.training:
# t1 is random time step, c2 is another condition
- (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
- (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)
-
- x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)
- xrecon_loss_t2 = F.mse_loss(x_c1_t2, x_c1_t2_)
- cano_cons_loss_t2 = (F.mse_loss(f_c_c1_t1, f_c_c1_t2)
- + F.mse_loss(f_c_c1_t2, f_c_c2_t2))
+ x_c1_t1 = x_c1_t1.view(n * t, c, h, w)
+ (f_a_c1_t1_, f_c_c1_t1_, _) = self.encoder(x_c1_t1)
+ x_c2_t2 = x_c2_t2.view(n * t, c, h, w)
+ (_, f_c_c2_t2_, f_p_c2_t2_) = self.encoder(x_c2_t2)
+
+ x_c1_t2_pred_ = self.decoder(f_a_c1_t1_, f_c_c1_t1_, f_p_c1_t2_)
+ x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w)
+
+ xrecon_loss = torch.stack([
+ F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
+ for i in range(t)
+ ]).sum()
+
+ f_c_c1_t1 = f_c_c1_t1_.view(n, t, -1)
+ f_c_c1_t2 = f_c_c1_t2_.view(n, t, -1)
+ f_c_c2_t2 = f_c_c2_t2_.view(n, t, -1)
+ cano_cons_loss = torch.stack([
+ F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
+ + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :])
+ for i in range(t)
+ ]).mean()
+
+ f_p_c1_t2 = f_p_c1_t2_.view(n, t, -1)
+ f_p_c2_t2 = f_p_c2_t2_.view(n, t, -1)
+ pose_sim_loss = F.mse_loss(f_p_c1_t2.mean(1), f_p_c2_t2.mean(1))
return (
- (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2),
- (f_p_c1_t2, f_p_c2_t2),
- (xrecon_loss_t2, cano_cons_loss_t2)
+ (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_),
+ (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
)
else: # evaluating
- return f_c_c1_t2, f_p_c1_t2
+ return f_c_c1_t2_, f_p_c1_t2_
diff --git a/models/model.py b/models/model.py
index 5af2d76..b5daa54 100644
--- a/models/model.py
+++ b/models/model.py
@@ -153,8 +153,18 @@ class Model:
self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam([
{'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
+ {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
+ {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
+ {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
], **optim_hp)
- self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, **sched_hp)
+ sched_gamma = sched_hp.get('gamma', 0.9)
+ sched_step_size = sched_hp.get('step_size', 500)
+ self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
+ lambda epoch: sched_gamma ** (epoch // sched_step_size),
+ lambda epoch: 0 if epoch < start_iter else 1,
+ lambda epoch: 0 if epoch < start_iter else 1,
+ lambda epoch: 0 if epoch < start_iter else 1,
+ ])
self.writer = SummaryWriter(self._log_name)
self.rgb_pn.train()
@@ -171,19 +181,10 @@ class Model:
# Training start
start_time = datetime.now()
running_loss = torch.zeros(5, device=self.device)
- print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
- f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} LR(s)")
+ print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
+ f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
+ f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
for (batch_c1, batch_c2) in dataloader:
- if self.curr_iter == start_iter:
- self.optimizer.add_param_group(
- {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}
- )
- self.optimizer.add_param_group(
- {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}
- )
- self.optimizer.add_param_group(
- {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
- )
self.curr_iter += 1
# Zero the parameter gradients
self.optimizer.zero_grad()
@@ -201,35 +202,43 @@ class Model:
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
self.writer.add_scalars('Loss/details', dict(zip([
- 'Cross reconstruction loss', 'Pose similarity loss',
- 'Canonical consistency loss', 'Batch All triplet loss (HPM)',
+ 'Cross reconstruction loss', 'Canonical consistency loss',
+ 'Pose similarity loss', 'Batch All triplet loss (HPM)',
'Batch All triplet loss (PartNet)'
], losses)), self.curr_iter)
- if self.image_log_on:
- (appearance_image, canonical_image, pose_image) = images
- self.writer.add_images(
- 'Canonical image', canonical_image, self.curr_iter
- )
- for i in range(self.pr * self.k):
- self.writer.add_images(
- f'Original image/batch {i}', x_c1[i], self.curr_iter
- )
- self.writer.add_images(
- f'Appearance image/batch {i}',
- appearance_image[:, i, :, :, :],
- self.curr_iter
- )
- self.writer.add_images(
- f'Pose image/batch {i}',
- pose_image[:, i, :, :, :],
- self.curr_iter
- )
if self.curr_iter % 100 == 0:
lrs = self.scheduler.get_last_lr()
- print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
+ # Write learning rates
+ self.writer.add_scalar(
+ 'Learning rate/Auto-encoder', lrs[0], self.curr_iter
+ )
+ self.writer.add_scalar(
+ 'Learning rate/Others', lrs[1], self.curr_iter
+ )
+ # Write disentangled images
+ if self.image_log_on:
+ i_a, i_c, i_p = images
+ self.writer.add_images(
+ 'Canonical image', i_c, self.curr_iter
+ )
+ for (i, (o, a, p)) in enumerate(zip(x_c1, i_a, i_p)):
+ self.writer.add_images(
+ f'Original image/batch {i}', o, self.curr_iter
+ )
+ self.writer.add_images(
+ f'Appearance image/batch {i}', a, self.curr_iter
+ )
+ self.writer.add_images(
+ f'Pose image/batch {i}', p, self.curr_iter
+ )
+ time_used = datetime.now() - start_time
+ remaining_minute, second = divmod(time_used.seconds, 60)
+ hour, minute = divmod(remaining_minute, 60)
+ print(f'{hour:02}:{minute:02}:{second:02}',
+ f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- ' '.join(('{:.3e}'.format(lr) for lr in lrs)))
+ '{:.3e} {:.3e}'.format(lrs[0], lrs[1]))
running_loss.zero_()
# Step scheduler
@@ -242,8 +251,6 @@ class Model:
'optim_state_dict': self.optimizer.state_dict(),
'loss': loss,
}, self._checkpoint_name)
- print(datetime.now() - start_time, 'used')
- start_time = datetime.now()
if self.curr_iter == self.total_iter:
self.writer.close()
diff --git a/models/part_net.py b/models/part_net.py
index 6d8d4e1..f34f993 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -31,8 +31,8 @@ class FrameLevelPartFeatureExtractor(nn.Module):
def forward(self, x):
# Flatten frames in all batches
- t, n, c, h, w = x.size()
- x = x.view(-1, c, h, w)
+ n, t, c, h, w = x.size()
+ x = x.view(n * t, c, h, w)
for fconv_block in self.fconv_blocks:
x = fconv_block(x)
@@ -76,8 +76,8 @@ class TemporalFeatureAggregator(nn.Module):
for _ in range(self.num_part)])
def forward(self, x):
- # p, t, n, c
- x = x.permute(0, 2, 3, 1).contiguous()
+ # p, n, t, c
+ x = x.transpose(2, 3)
p, n, c, t = x.size()
feature = x.split(1, dim=0)
feature = [f.squeeze(0) for f in feature]
@@ -135,19 +135,18 @@ class PartNet(nn.Module):
self.max_pool = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
- t, n, _, _, _ = x.size()
- # t, n, c, h, w
+ n, t, _, _, _ = x.size()
x = self.fpfe(x)
- # t_n, c, h, w
+ # n * t x c x h x w
# Horizontal Pooling
_, c, h, w = x.size()
split_size = h // self.num_part
x = x.split(split_size, dim=2)
x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.view(t, n, c) for x_ in x]
+ x = [x_.view(n, t, c) for x_ in x]
x = torch.stack(x)
- # p, t, n, c
+ # p, n, t, c
x = self.tfa(x)
return x
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index f6dc131..841de96 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -1,9 +1,7 @@
-import random
-from typing import Tuple, List
+from typing import Tuple
import torch
import torch.nn as nn
-import torch.nn.functional as F
from models.auto_encoder import AutoEncoder
from models.hpm import HorizontalPyramidMatching
@@ -60,24 +58,18 @@ class RGBPartNet(nn.Module):
return x @ self.fc_mat
def forward(self, x_c1, x_c2=None, y=None):
- # Step 0: Swap batch_size and time dimensions for next step
- # n, t, c, h, w
- x_c1 = x_c1.transpose(0, 1)
- if self.training:
- x_c2 = x_c2.transpose(0, 1)
-
# Step 1: Disentanglement
- # t, n, c, h, w
- ((x_c_c1, x_p_c1), images, losses) = self._disentangle(x_c1, x_c2)
+ # n, t, c, h, w
+ ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2)
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
- x_c = self.hpm(x_c_c1)
+ x_c = self.hpm(x_c)
# p, n, c
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
- # t, n, c, h, w
- x_p = self.pn(x_p_c1)
+ # n, t, c, h, w
+ x_p = self.pn(x_p)
# p, n, c
# Step 3: Cat feature map together and fc
@@ -92,113 +84,60 @@ class RGBPartNet(nn.Module):
else:
return x.unsqueeze(1).view(-1)
- def _disentangle(self, x_c1, x_c2=None):
- t, n, c, h, w = x_c1.size()
- device = x_c1.device
+ def _disentangle(self, x_c1_t2, x_c2_t2=None):
+ n, t, c, h, w = x_c1_t2.size()
+ device = x_c1_t2.device
+ x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
if self.training:
- # Encoded appearance, canonical and pose features
- f_a_c1, f_c_c1, f_p_c1 = [], [], []
- # Features required to calculate losses
- f_p_c2 = []
- xrecon_loss, cano_cons_loss = [], []
- for t2 in range(t):
- t1 = random.randrange(t)
- output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2])
- (f_c1_t2, f_p_t2, losses) = output
-
- (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = f_c1_t2
- if self.image_log_on:
- f_a_c1.append(f_a_c1_t2)
- # Save canonical features and pose features
- f_c_c1.append(f_c_c1_t2)
- f_p_c1.append(f_p_c1_t2)
-
- # Losses per time step
- # Used in pose similarity loss
- (_, f_p_c2_t2) = f_p_t2
- f_p_c2.append(f_p_c2_t2)
-
- # Cross reconstruction loss and canonical loss
- (xrecon_loss_t2, cano_cons_loss_t2) = losses
- xrecon_loss.append(xrecon_loss_t2)
- cano_cons_loss.append(cano_cons_loss_t2)
- if self.image_log_on:
- f_a_c1 = torch.stack(f_a_c1)
- f_c_c1_mean = torch.stack(f_c_c1).mean(0)
- f_p_c1 = torch.stack(f_p_c1)
- f_p_c2 = torch.stack(f_p_c2)
-
+ ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
# Decode features
- appearance_image, canonical_image, pose_image = None, None, None
with torch.no_grad():
- # Decode average canonical features to higher dimension
- x_c_c1 = self.ae.decoder(
- torch.zeros((n, self.f_a_dim), device=device),
- f_c_c1_mean,
- torch.zeros((n, self.f_p_dim), device=device),
- cano_only=True
- )
- # Decode pose features to images
- f_p_c1_ = f_p_c1.view(t * n, -1)
- x_p_c1_ = self.ae.decoder(
- torch.zeros((t * n, self.f_a_dim), device=device),
- torch.zeros((t * n, self.f_c_dim), device=device),
- f_p_c1_
- )
- x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+ x_c = self._decode_cano_feature(f_c_, n, t, device)
+ x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ i_a, i_c, i_p = None, None, None
if self.image_log_on:
- # Decode appearance features
- f_a_c1_ = f_a_c1.view(t * n, -1)
- appearance_image_ = self.ae.decoder(
- f_a_c1_,
- torch.zeros((t * n, self.f_c_dim), device=device),
- torch.zeros((t * n, self.f_p_dim), device=device)
- )
- appearance_image = appearance_image_.view(t, n, c, h, w)
+ i_a = self._decode_appr_feature(f_a_, n, t, c, h, w, device)
# Continue decoding canonical features
- canonical_image = self.ae.decoder.trans_conv3(x_c_c1)
- canonical_image = torch.sigmoid(
- self.ae.decoder.trans_conv4(canonical_image)
- )
- pose_image = x_p_c1
-
- # Losses
- xrecon_loss = torch.sum(torch.stack(xrecon_loss))
- pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) * 10
- cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
+ i_c = self.ae.decoder.trans_conv3(x_c)
+ i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
+ i_p = x_p
- return ((x_c_c1, x_p_c1),
- (appearance_image, canonical_image, pose_image),
- (xrecon_loss, pose_sim_loss, cano_cons_loss))
+ return (x_c, x_p), losses, (i_a, i_c, i_p)
else: # evaluating
- x_c1_ = x_c1.view(t * n, c, h, w)
- (f_c_c1_, f_p_c1_) = self.ae(x_c1_)
-
- # Canonical features
- f_c_c1 = f_c_c1_.view(t, n, -1)
- f_c_c1_mean = f_c_c1.mean(0)
- x_c_c1 = self.ae.decoder(
- torch.zeros((n, self.f_a_dim)),
- f_c_c1_mean,
- torch.zeros((n, self.f_p_dim)),
- cano_only=True
- )
-
- # Pose features
- x_p_c1_ = self.ae.decoder(
- torch.zeros((t * n, self.f_a_dim)),
- torch.zeros((t * n, self.f_c_dim)),
- f_p_c1_
- )
- x_p_c1 = x_p_c1_.view(t, n, c, h, w)
-
- return (x_c_c1, x_p_c1), None, None
-
- @staticmethod
- def _pose_sim_loss(f_p_c1: torch.Tensor,
- f_p_c2: torch.Tensor) -> torch.Tensor:
- f_p_c1_mean = f_p_c1.mean(dim=0)
- f_p_c2_mean = f_p_c2.mean(dim=0)
- return F.mse_loss(f_p_c1_mean, f_p_c2_mean)
+ f_c_, f_p_ = self.ae(x_c1_t2)
+ x_c = self._decode_cano_feature(f_c_, n, t, device)
+ x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ return (x_c, x_p), None, None
+
+ def _decode_appr_feature(self, f_a_, n, t, c, h, w, device):
+ # Decode appearance features
+ x_a_ = self.ae.decoder(
+ f_a_,
+ torch.zeros((n * t, self.f_c_dim), device=device),
+ torch.zeros((n * t, self.f_p_dim), device=device)
+ )
+ x_a = x_a_.view(n, t, c, h, w)
+ return x_a
+
+ def _decode_cano_feature(self, f_c_, n, t, device):
+ # Decode average canonical features to higher dimension
+ f_c = f_c_.view(n, t, -1)
+ x_c = self.ae.decoder(
+ torch.zeros((n, self.f_a_dim), device=device),
+ f_c.mean(1),
+ torch.zeros((n, self.f_p_dim), device=device),
+ cano_only=True
+ )
+ return x_c
+
+ def _decode_pose_feature(self, f_p_, n, t, c, h, w, device):
+ # Decode pose features to images
+ x_p_ = self.ae.decoder(
+ torch.zeros((n * t, self.f_a_dim), device=device),
+ torch.zeros((n * t, self.f_c_dim), device=device),
+ f_p_
+ )
+ x_p = x_p_.view(n, t, c, h, w)
+ return x_p