summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-08 18:11:25 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-08 18:25:42 +0800
commit99ddd7c142a4ec97cb8bd14b204651790b3cf4ee (patch)
treea4ccbd08a7155e90df63aba60eb93ab2b7969c9b /models
parent507e1d163aaa6ea4be23e7f08ff6ce0ef58c830b (diff)
Code refactoring, modifications and new features
1. Decode features outside of auto-encoder 2. Turn off HPM 1x1 conv by default 3. Change canonical feature map size from `feature_channels * 8 x 4 x 2` to `feature_channels * 2 x 16 x 8` 4. Use mean of canonical embeddings instead of mean of static features 5. Calculate static and dynamic loss separately 6. Calculate mean of parts in triplet loss instead of sum of parts 7. Add switch to log disentangled images 8. Change default configuration
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py26
-rw-r--r--models/hpm.py20
-rw-r--r--models/layers.py10
-rw-r--r--models/model.py34
-rw-r--r--models/rgb_part_net.py141
5 files changed, 148 insertions, 83 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 35cb629..f04ffdb 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -95,15 +95,14 @@ class Decoder(nn.Module):
self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
- def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False):
+ def forward(self, f_appearance, f_canonical, f_pose, cano_only=False):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.fc(x)
x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True)
- # Decode canonical features without transpose convolutions
- if no_trans_conv:
- return x
x = self.trans_conv1(x)
x = self.trans_conv2(x)
+ if cano_only:
+ return x
x = self.trans_conv3(x)
x = torch.sigmoid(self.trans_conv4(x))
@@ -125,21 +124,6 @@ class AutoEncoder(nn.Module):
# x_c1_t2 is the frame for later module
(f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
- with torch.no_grad():
- # Decode canonical features for HPM
- x_c_c1_t2 = self.decoder(
- torch.zeros_like(f_a_c1_t2),
- f_c_c1_t2,
- torch.zeros_like(f_p_c1_t2),
- no_trans_conv=True
- )
- # Decode pose features for Part Net
- x_p_c1_t2 = self.decoder(
- torch.zeros_like(f_a_c1_t2),
- torch.zeros_like(f_c_c1_t2),
- f_p_c1_t2
- )
-
if self.training:
# t1 is random time step, c2 is another condition
(f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
@@ -151,9 +135,9 @@ class AutoEncoder(nn.Module):
+ F.mse_loss(f_c_c1_t2, f_c_c2_t2))
return (
- (x_c_c1_t2, x_p_c1_t2),
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2),
(f_p_c1_t2, f_p_c2_t2),
(xrecon_loss_t2, cano_cons_loss_t2)
)
else: # evaluating
- return x_c_c1_t2, x_p_c1_t2
+ return f_c_c1_t2, f_p_c1_t2
diff --git a/models/hpm.py b/models/hpm.py
index 66503e3..9879cfb 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -9,14 +9,16 @@ class HorizontalPyramidMatching(nn.Module):
self,
in_channels: int,
out_channels: int = 128,
+ use_1x1conv: bool = False,
scales: tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
- use_max_pool: bool = True,
+ use_max_pool: bool = False,
**kwargs
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
+ self.use_1x1conv = use_1x1conv
self.scales = scales
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
@@ -29,6 +31,7 @@ class HorizontalPyramidMatching(nn.Module):
pyramid = nn.ModuleList([
HorizontalPyramidPooling(self.in_channels,
self.out_channels,
+ use_1x1conv=self.use_1x1conv,
use_avg_pool=self.use_avg_pool,
use_max_pool=self.use_max_pool,
**kwargs)
@@ -37,23 +40,16 @@ class HorizontalPyramidMatching(nn.Module):
return pyramid
def forward(self, x):
- # Flatten canonical features in all batches
- t, n, c, h, w = x.size()
- x = x.view(t * n, c, h, w)
-
+ n, c, h, w = x.size()
feature = []
- for pyramid_index, pyramid in enumerate(self.pyramids):
- h_per_hpp = h // self.scales[pyramid_index]
+ for scale, pyramid in zip(self.scales, self.pyramids):
+ h_per_hpp = h // scale
for hpp_index, hpp in enumerate(pyramid):
h_filter = torch.arange(hpp_index * h_per_hpp,
(hpp_index + 1) * h_per_hpp)
x_slice = x[:, :, h_filter, :]
x_slice = hpp(x_slice)
- x_slice = x_slice.view(t * n, -1)
+ x_slice = x_slice.view(n, -1)
feature.append(x_slice)
x = torch.stack(feature)
-
- # Unfold frames to original batch
- p, _, c = x.size()
- x = x.view(p, t, n, c)
return x
diff --git a/models/layers.py b/models/layers.py
index a9f04b3..7b6ba5c 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -167,12 +167,13 @@ class HorizontalPyramidPooling(BasicConv2d):
self,
in_channels: int,
out_channels: int,
- kernel_size: Union[int, tuple[int, int]] = 1,
+ use_1x1conv: bool = False,
use_avg_pool: bool = True,
- use_max_pool: bool = True,
+ use_max_pool: bool = False,
**kwargs
):
- super().__init__(in_channels, out_channels, kernel_size, **kwargs)
+ super().__init__(in_channels, out_channels, kernel_size=1, **kwargs)
+ self.use_1x1conv = use_1x1conv
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.'
@@ -186,5 +187,6 @@ class HorizontalPyramidPooling(BasicConv2d):
x = self.avg_pool(x)
elif not self.use_avg_pool and self.use_max_pool:
x = self.max_pool(x)
- x = super().forward(x)
+ if self.use_1x1conv:
+ x = super().forward(x)
return x
diff --git a/models/model.py b/models/model.py
index ddb715d..0418070 100644
--- a/models/model.py
+++ b/models/model.py
@@ -69,6 +69,7 @@ class Model:
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
+ self.image_log_on = system_config.get('image_log_on', False)
self.CASIAB_GALLERY_SELECTOR = {
'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
@@ -146,7 +147,8 @@ class Model:
hpm_optim_hp = optim_hp.pop('hpm', {})
fc_optim_hp = optim_hp.pop('fc', {})
sched_hp = self.hp.get('scheduler', {})
- self.rgb_pn = RGBPartNet(self.in_channels, **model_hp)
+ self.rgb_pn = RGBPartNet(self.in_channels, **model_hp,
+ image_log_on=self.image_log_on)
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam([
@@ -168,9 +170,9 @@ class Model:
# Training start
start_time = datetime.now()
- running_loss = torch.zeros(4).to(self.device)
+ running_loss = torch.zeros(5, device=self.device)
print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
- f"{'CanoCons':^8} {'BATrip':^8} LR(s)")
+ f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} LR(s)")
for (batch_c1, batch_c2) in dataloader:
if self.curr_iter == start_iter:
self.optimizer.add_param_group(
@@ -189,7 +191,7 @@ class Model:
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
y = batch_c1['label'].to(self.device)
- losses = self.rgb_pn(x_c1, x_c2, y)
+ losses, images = self.rgb_pn(x_c1, x_c2, y)
loss = losses.sum()
loss.backward()
self.optimizer.step()
@@ -200,13 +202,33 @@ class Model:
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
self.writer.add_scalars('Loss/details', dict(zip([
'Cross reconstruction loss', 'Pose similarity loss',
- 'Canonical consistency loss', 'Batch All triplet loss'
+ 'Canonical consistency loss', 'Batch All triplet loss (HPM)',
+ 'Batch All triplet loss (PartNet)'
], losses)), self.curr_iter)
+ if self.image_log_on:
+ (appearance_image, canonical_image, pose_image) = images
+ self.writer.add_images(
+ 'Canonical image', canonical_image, self.curr_iter
+ )
+ for i in range(self.pr * self.k):
+ self.writer.add_images(
+ f'Original image/batch {i}', x_c1[i], self.curr_iter
+ )
+ self.writer.add_images(
+ f'Appearance image/batch {i}',
+ appearance_image[:, i, :, :, :],
+ self.curr_iter
+ )
+ self.writer.add_images(
+ f'Pose image/batch {i}',
+ pose_image[:, i, :, :, :],
+ self.curr_iter
+ )
if self.curr_iter % 100 == 0:
lrs = self.scheduler.get_last_lr()
print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
- '{:f} {:f} {:f} {:f}'.format(*running_loss / 100),
+ '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
' '.join(('{:.3e}'.format(lr) for lr in lrs)))
running_loss.zero_()
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 755d5dc..0e7d8b3 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -16,6 +16,7 @@ class RGBPartNet(nn.Module):
ae_in_channels: int = 3,
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
+ hpm_use_1x1conv: bool = False,
hpm_scales: tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
@@ -26,9 +27,14 @@ class RGBPartNet(nn.Module):
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
embedding_dims: int = 256,
- triplet_margin: float = 0.2
+ triplet_margins: tuple[float, float] = (0.2, 0.2),
+ image_log_on: bool = False
):
super().__init__()
+ (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
+ self.hpm_num_parts = sum(hpm_scales)
+ self.image_log_on = image_log_on
+
self.ae = AutoEncoder(
ae_in_channels, ae_feature_channels, f_a_c_p_dims
)
@@ -38,14 +44,16 @@ class RGBPartNet(nn.Module):
)
out_channels = self.pn.tfa_in_channels
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 8, out_channels, hpm_scales,
- hpm_use_avg_pool, hpm_use_max_pool
+ ae_feature_channels * 2, out_channels, hpm_use_1x1conv,
+ hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
)
- total_parts = sum(hpm_scales) + tfa_num_parts
- empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
+ empty_fc = torch.empty(self.hpm_num_parts + tfa_num_parts,
+ out_channels, embedding_dims)
self.fc_mat = nn.Parameter(empty_fc)
- self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin)
+ (hpm_margin, pn_margin) = triplet_margins
+ self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin)
+ self.pn_ba_trip = BatchAllTripletLoss(pn_margin)
def fc(self, x):
return x @ self.fc_mat
@@ -59,13 +67,11 @@ class RGBPartNet(nn.Module):
# Step 1: Disentanglement
# t, n, c, h, w
- ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2)
+ ((x_c_c1, x_p_c1), images, losses) = self._disentangle(x_c1, x_c2)
- # Step 2.a: HPM & Static Gait Feature Aggregation
- # t, n, c, h, w
+ # Step 2.a: Static Gait Feature Aggregation & HPM
+ # n, c, h, w
x_c = self.hpm(x_c_c1)
- # p, t, n, c
- x_c = x_c.mean(dim=1)
# p, n, c
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
@@ -78,44 +84,83 @@ class RGBPartNet(nn.Module):
x = self.fc(x)
if self.training:
- batch_all_triplet_loss = self.ba_triplet_loss(x, y)
- losses = torch.stack((*losses, batch_all_triplet_loss))
- return losses
+ hpm_ba_trip = self.hpm_ba_trip(x[:self.hpm_num_parts], y)
+ pn_ba_trip = self.pn_ba_trip(x[self.hpm_num_parts:], y)
+ losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip))
+ return losses, images
else:
return x.unsqueeze(1).view(-1)
def _disentangle(self, x_c1, x_c2=None):
t, n, c, h, w = x_c1.size()
+ device = x_c1.device
if self.training:
- # Decoded canonical features and Pose images
- x_c_c1, x_p_c1 = [], []
+ # Encoded appearance, canonical and pose features
+ f_a_c1, f_c_c1, f_p_c1 = [], [], []
# Features required to calculate losses
- f_p_c1, f_p_c2 = [], []
+ f_p_c2 = []
xrecon_loss, cano_cons_loss = [], []
for t2 in range(t):
t1 = random.randrange(t)
output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2])
- (x_c1_t2, f_p_t2, losses) = output
+ (f_c1_t2, f_p_t2, losses) = output
- # Decoded features or image
- (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
- # Canonical Features for HPM
- x_c_c1.append(x_c_c1_t2)
- # Pose image for Part Net
- x_p_c1.append(x_p_c1_t2)
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = f_c1_t2
+ if self.image_log_on:
+ f_a_c1.append(f_a_c1_t2)
+ # Save canonical features and pose features
+ f_c_c1.append(f_c_c1_t2)
+ f_p_c1.append(f_p_c1_t2)
# Losses per time step
# Used in pose similarity loss
- (f_p_c1_t2, f_p_c2_t2) = f_p_t2
- f_p_c1.append(f_p_c1_t2)
+ (_, f_p_c2_t2) = f_p_t2
f_p_c2.append(f_p_c2_t2)
+
# Cross reconstruction loss and canonical loss
(xrecon_loss_t2, cano_cons_loss_t2) = losses
xrecon_loss.append(xrecon_loss_t2)
cano_cons_loss.append(cano_cons_loss_t2)
-
- x_c_c1 = torch.stack(x_c_c1)
- x_p_c1 = torch.stack(x_p_c1)
+ if self.image_log_on:
+ f_a_c1 = torch.stack(f_a_c1)
+ f_c_c1_mean = torch.stack(f_c_c1).mean(0)
+ f_p_c1 = torch.stack(f_p_c1)
+ f_p_c2 = torch.stack(f_p_c2)
+
+ # Decode features
+ appearance_image, canonical_image, pose_image = None, None, None
+ with torch.no_grad():
+ # Decode average canonical features to higher dimension
+ x_c_c1 = self.ae.decoder(
+ torch.zeros((n, self.f_a_dim), device=device),
+ f_c_c1_mean,
+ torch.zeros((n, self.f_p_dim), device=device),
+ cano_only=True
+ )
+ # Decode pose features to images
+ f_p_c1_ = f_p_c1.view(t * n, -1)
+ x_p_c1_ = self.ae.decoder(
+ torch.zeros((t * n, self.f_a_dim), device=device),
+ torch.zeros((t * n, self.f_c_dim), device=device),
+ f_p_c1_
+ )
+ x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+
+ if self.image_log_on:
+ # Decode appearance features
+ f_a_c1_ = f_a_c1.view(t * n, -1)
+ appearance_image_ = self.ae.decoder(
+ f_a_c1_,
+ torch.zeros((t * n, self.f_c_dim), device=device),
+ torch.zeros((t * n, self.f_p_dim), device=device)
+ )
+ appearance_image = appearance_image_.view(t, n, c, h, w)
+ # Continue decoding canonical features
+ canonical_image = self.ae.decoder.trans_conv3(x_c_c1)
+ canonical_image = torch.sigmoid(
+ self.ae.decoder.trans_conv4(canonical_image)
+ )
+ pose_image = x_p_c1
# Losses
xrecon_loss = torch.sum(torch.stack(xrecon_loss))
@@ -123,20 +168,36 @@ class RGBPartNet(nn.Module):
cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
return ((x_c_c1, x_p_c1),
+ (appearance_image, canonical_image, pose_image),
(xrecon_loss, pose_sim_loss, cano_cons_loss))
else: # evaluating
- x_c1 = x_c1.view(-1, c, h, w)
- x_c_c1, x_p_c1 = self.ae(x_c1)
- _, c_c, h_c, w_c = x_c_c1.size()
- x_c_c1 = x_c_c1.view(t, n, c_c, h_c, w_c)
- x_p_c1 = x_p_c1.view(t, n, c, h, w)
-
- return (x_c_c1, x_p_c1), None
+ x_c1_ = x_c1.view(t * n, c, h, w)
+ (f_c_c1_, f_p_c1_) = self.ae(x_c1_)
+
+ # Canonical features
+ f_c_c1 = f_c_c1_.view(t, n, -1)
+ f_c_c1_mean = f_c_c1.mean(0)
+ x_c_c1 = self.ae.decoder(
+ torch.zeros((n, self.f_a_dim)),
+ f_c_c1_mean,
+ torch.zeros((n, self.f_p_dim)),
+ cano_only=True
+ )
+
+ # Pose features
+ x_p_c1_ = self.ae.decoder(
+ torch.zeros((t * n, self.f_a_dim)),
+ torch.zeros((t * n, self.f_c_dim)),
+ f_p_c1_
+ )
+ x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+
+ return (x_c_c1, x_p_c1), None, None
@staticmethod
- def _pose_sim_loss(f_p_c1: list[torch.Tensor],
- f_p_c2: list[torch.Tensor]) -> torch.Tensor:
- f_p_c1_mean = torch.stack(f_p_c1).mean(dim=0)
- f_p_c2_mean = torch.stack(f_p_c2).mean(dim=0)
+ def _pose_sim_loss(f_p_c1: torch.Tensor,
+ f_p_c2: torch.Tensor) -> torch.Tensor:
+ f_p_c1_mean = f_p_c1.mean(dim=0)
+ f_p_c2_mean = f_p_c2.mean(dim=0)
return F.mse_loss(f_p_c1_mean, f_p_c2_mean)