summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.py12
-rw-r--r--models/auto_encoder.py26
-rw-r--r--models/hpm.py20
-rw-r--r--models/layers.py10
-rw-r--r--models/model.py34
-rw-r--r--models/rgb_part_net.py141
-rw-r--r--utils/configuration.py4
-rw-r--r--utils/triplet_loss.py2
8 files changed, 160 insertions, 89 deletions
diff --git a/config.py b/config.py
index 641e8fb..04a22b9 100644
--- a/config.py
+++ b/config.py
@@ -8,6 +8,8 @@ config: Configuration = {
'CUDA_VISIBLE_DEVICES': '0',
# Directory used in training or testing for temporary storage
'save_dir': 'runs',
+ # Recorde disentangled image or not
+ 'image_log_on': False
},
# Dataset settings
'dataset': {
@@ -46,11 +48,13 @@ config: Configuration = {
'ae_feature_channels': 64,
# Appearance, canonical and pose feature dimensions
'f_a_c_p_dims': (128, 128, 64),
+ # Use 1x1 convolution in dimensionality reduction
+ 'hpm_use_1x1conv': False,
# HPM pyramid scales, of which sum is number of parts
'hpm_scales': (1, 2, 4),
# Global pooling method
'hpm_use_avg_pool': True,
- 'hpm_use_max_pool': True,
+ 'hpm_use_max_pool': False,
# FConv feature channels coefficient
'fpfe_feature_channels': 32,
# FConv blocks kernel sizes
@@ -65,13 +69,13 @@ config: Configuration = {
'tfa_num_parts': 16,
# Embedding dimension for each part
'embedding_dims': 256,
- # Triplet loss margin
- 'triplet_margin': 0.2,
+ # Triplet loss margins for HPM and PartNet
+ 'triplet_margins': (0.2, 0.2),
},
'optimizer': {
# Global parameters
# Iteration start to optimize non-disentangling parts
- # 'start_iter': 10,
+ # 'start_iter': 0,
# Initial learning rate of Adam Optimizer
'lr': 1e-4,
# Coefficients used for computing running averages of
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 35cb629..f04ffdb 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -95,15 +95,14 @@ class Decoder(nn.Module):
self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
- def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False):
+ def forward(self, f_appearance, f_canonical, f_pose, cano_only=False):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.fc(x)
x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True)
- # Decode canonical features without transpose convolutions
- if no_trans_conv:
- return x
x = self.trans_conv1(x)
x = self.trans_conv2(x)
+ if cano_only:
+ return x
x = self.trans_conv3(x)
x = torch.sigmoid(self.trans_conv4(x))
@@ -125,21 +124,6 @@ class AutoEncoder(nn.Module):
# x_c1_t2 is the frame for later module
(f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
- with torch.no_grad():
- # Decode canonical features for HPM
- x_c_c1_t2 = self.decoder(
- torch.zeros_like(f_a_c1_t2),
- f_c_c1_t2,
- torch.zeros_like(f_p_c1_t2),
- no_trans_conv=True
- )
- # Decode pose features for Part Net
- x_p_c1_t2 = self.decoder(
- torch.zeros_like(f_a_c1_t2),
- torch.zeros_like(f_c_c1_t2),
- f_p_c1_t2
- )
-
if self.training:
# t1 is random time step, c2 is another condition
(f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
@@ -151,9 +135,9 @@ class AutoEncoder(nn.Module):
+ F.mse_loss(f_c_c1_t2, f_c_c2_t2))
return (
- (x_c_c1_t2, x_p_c1_t2),
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2),
(f_p_c1_t2, f_p_c2_t2),
(xrecon_loss_t2, cano_cons_loss_t2)
)
else: # evaluating
- return x_c_c1_t2, x_p_c1_t2
+ return f_c_c1_t2, f_p_c1_t2
diff --git a/models/hpm.py b/models/hpm.py
index 66503e3..9879cfb 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -9,14 +9,16 @@ class HorizontalPyramidMatching(nn.Module):
self,
in_channels: int,
out_channels: int = 128,
+ use_1x1conv: bool = False,
scales: tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
- use_max_pool: bool = True,
+ use_max_pool: bool = False,
**kwargs
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
+ self.use_1x1conv = use_1x1conv
self.scales = scales
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
@@ -29,6 +31,7 @@ class HorizontalPyramidMatching(nn.Module):
pyramid = nn.ModuleList([
HorizontalPyramidPooling(self.in_channels,
self.out_channels,
+ use_1x1conv=self.use_1x1conv,
use_avg_pool=self.use_avg_pool,
use_max_pool=self.use_max_pool,
**kwargs)
@@ -37,23 +40,16 @@ class HorizontalPyramidMatching(nn.Module):
return pyramid
def forward(self, x):
- # Flatten canonical features in all batches
- t, n, c, h, w = x.size()
- x = x.view(t * n, c, h, w)
-
+ n, c, h, w = x.size()
feature = []
- for pyramid_index, pyramid in enumerate(self.pyramids):
- h_per_hpp = h // self.scales[pyramid_index]
+ for scale, pyramid in zip(self.scales, self.pyramids):
+ h_per_hpp = h // scale
for hpp_index, hpp in enumerate(pyramid):
h_filter = torch.arange(hpp_index * h_per_hpp,
(hpp_index + 1) * h_per_hpp)
x_slice = x[:, :, h_filter, :]
x_slice = hpp(x_slice)
- x_slice = x_slice.view(t * n, -1)
+ x_slice = x_slice.view(n, -1)
feature.append(x_slice)
x = torch.stack(feature)
-
- # Unfold frames to original batch
- p, _, c = x.size()
- x = x.view(p, t, n, c)
return x
diff --git a/models/layers.py b/models/layers.py
index a9f04b3..7b6ba5c 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -167,12 +167,13 @@ class HorizontalPyramidPooling(BasicConv2d):
self,
in_channels: int,
out_channels: int,
- kernel_size: Union[int, tuple[int, int]] = 1,
+ use_1x1conv: bool = False,
use_avg_pool: bool = True,
- use_max_pool: bool = True,
+ use_max_pool: bool = False,
**kwargs
):
- super().__init__(in_channels, out_channels, kernel_size, **kwargs)
+ super().__init__(in_channels, out_channels, kernel_size=1, **kwargs)
+ self.use_1x1conv = use_1x1conv
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.'
@@ -186,5 +187,6 @@ class HorizontalPyramidPooling(BasicConv2d):
x = self.avg_pool(x)
elif not self.use_avg_pool and self.use_max_pool:
x = self.max_pool(x)
- x = super().forward(x)
+ if self.use_1x1conv:
+ x = super().forward(x)
return x
diff --git a/models/model.py b/models/model.py
index ddb715d..0418070 100644
--- a/models/model.py
+++ b/models/model.py
@@ -69,6 +69,7 @@ class Model:
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
+ self.image_log_on = system_config.get('image_log_on', False)
self.CASIAB_GALLERY_SELECTOR = {
'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
@@ -146,7 +147,8 @@ class Model:
hpm_optim_hp = optim_hp.pop('hpm', {})
fc_optim_hp = optim_hp.pop('fc', {})
sched_hp = self.hp.get('scheduler', {})
- self.rgb_pn = RGBPartNet(self.in_channels, **model_hp)
+ self.rgb_pn = RGBPartNet(self.in_channels, **model_hp,
+ image_log_on=self.image_log_on)
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam([
@@ -168,9 +170,9 @@ class Model:
# Training start
start_time = datetime.now()
- running_loss = torch.zeros(4).to(self.device)
+ running_loss = torch.zeros(5, device=self.device)
print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
- f"{'CanoCons':^8} {'BATrip':^8} LR(s)")
+ f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} LR(s)")
for (batch_c1, batch_c2) in dataloader:
if self.curr_iter == start_iter:
self.optimizer.add_param_group(
@@ -189,7 +191,7 @@ class Model:
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
y = batch_c1['label'].to(self.device)
- losses = self.rgb_pn(x_c1, x_c2, y)
+ losses, images = self.rgb_pn(x_c1, x_c2, y)
loss = losses.sum()
loss.backward()
self.optimizer.step()
@@ -200,13 +202,33 @@ class Model:
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
self.writer.add_scalars('Loss/details', dict(zip([
'Cross reconstruction loss', 'Pose similarity loss',
- 'Canonical consistency loss', 'Batch All triplet loss'
+ 'Canonical consistency loss', 'Batch All triplet loss (HPM)',
+ 'Batch All triplet loss (PartNet)'
], losses)), self.curr_iter)
+ if self.image_log_on:
+ (appearance_image, canonical_image, pose_image) = images
+ self.writer.add_images(
+ 'Canonical image', canonical_image, self.curr_iter
+ )
+ for i in range(self.pr * self.k):
+ self.writer.add_images(
+ f'Original image/batch {i}', x_c1[i], self.curr_iter
+ )
+ self.writer.add_images(
+ f'Appearance image/batch {i}',
+ appearance_image[:, i, :, :, :],
+ self.curr_iter
+ )
+ self.writer.add_images(
+ f'Pose image/batch {i}',
+ pose_image[:, i, :, :, :],
+ self.curr_iter
+ )
if self.curr_iter % 100 == 0:
lrs = self.scheduler.get_last_lr()
print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
- '{:f} {:f} {:f} {:f}'.format(*running_loss / 100),
+ '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
' '.join(('{:.3e}'.format(lr) for lr in lrs)))
running_loss.zero_()
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 755d5dc..0e7d8b3 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -16,6 +16,7 @@ class RGBPartNet(nn.Module):
ae_in_channels: int = 3,
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
+ hpm_use_1x1conv: bool = False,
hpm_scales: tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
@@ -26,9 +27,14 @@ class RGBPartNet(nn.Module):
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
embedding_dims: int = 256,
- triplet_margin: float = 0.2
+ triplet_margins: tuple[float, float] = (0.2, 0.2),
+ image_log_on: bool = False
):
super().__init__()
+ (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
+ self.hpm_num_parts = sum(hpm_scales)
+ self.image_log_on = image_log_on
+
self.ae = AutoEncoder(
ae_in_channels, ae_feature_channels, f_a_c_p_dims
)
@@ -38,14 +44,16 @@ class RGBPartNet(nn.Module):
)
out_channels = self.pn.tfa_in_channels
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 8, out_channels, hpm_scales,
- hpm_use_avg_pool, hpm_use_max_pool
+ ae_feature_channels * 2, out_channels, hpm_use_1x1conv,
+ hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
)
- total_parts = sum(hpm_scales) + tfa_num_parts
- empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
+ empty_fc = torch.empty(self.hpm_num_parts + tfa_num_parts,
+ out_channels, embedding_dims)
self.fc_mat = nn.Parameter(empty_fc)
- self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin)
+ (hpm_margin, pn_margin) = triplet_margins
+ self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin)
+ self.pn_ba_trip = BatchAllTripletLoss(pn_margin)
def fc(self, x):
return x @ self.fc_mat
@@ -59,13 +67,11 @@ class RGBPartNet(nn.Module):
# Step 1: Disentanglement
# t, n, c, h, w
- ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2)
+ ((x_c_c1, x_p_c1), images, losses) = self._disentangle(x_c1, x_c2)
- # Step 2.a: HPM & Static Gait Feature Aggregation
- # t, n, c, h, w
+ # Step 2.a: Static Gait Feature Aggregation & HPM
+ # n, c, h, w
x_c = self.hpm(x_c_c1)
- # p, t, n, c
- x_c = x_c.mean(dim=1)
# p, n, c
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
@@ -78,44 +84,83 @@ class RGBPartNet(nn.Module):
x = self.fc(x)
if self.training:
- batch_all_triplet_loss = self.ba_triplet_loss(x, y)
- losses = torch.stack((*losses, batch_all_triplet_loss))
- return losses
+ hpm_ba_trip = self.hpm_ba_trip(x[:self.hpm_num_parts], y)
+ pn_ba_trip = self.pn_ba_trip(x[self.hpm_num_parts:], y)
+ losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip))
+ return losses, images
else:
return x.unsqueeze(1).view(-1)
def _disentangle(self, x_c1, x_c2=None):
t, n, c, h, w = x_c1.size()
+ device = x_c1.device
if self.training:
- # Decoded canonical features and Pose images
- x_c_c1, x_p_c1 = [], []
+ # Encoded appearance, canonical and pose features
+ f_a_c1, f_c_c1, f_p_c1 = [], [], []
# Features required to calculate losses
- f_p_c1, f_p_c2 = [], []
+ f_p_c2 = []
xrecon_loss, cano_cons_loss = [], []
for t2 in range(t):
t1 = random.randrange(t)
output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2])
- (x_c1_t2, f_p_t2, losses) = output
+ (f_c1_t2, f_p_t2, losses) = output
- # Decoded features or image
- (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
- # Canonical Features for HPM
- x_c_c1.append(x_c_c1_t2)
- # Pose image for Part Net
- x_p_c1.append(x_p_c1_t2)
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = f_c1_t2
+ if self.image_log_on:
+ f_a_c1.append(f_a_c1_t2)
+ # Save canonical features and pose features
+ f_c_c1.append(f_c_c1_t2)
+ f_p_c1.append(f_p_c1_t2)
# Losses per time step
# Used in pose similarity loss
- (f_p_c1_t2, f_p_c2_t2) = f_p_t2
- f_p_c1.append(f_p_c1_t2)
+ (_, f_p_c2_t2) = f_p_t2
f_p_c2.append(f_p_c2_t2)
+
# Cross reconstruction loss and canonical loss
(xrecon_loss_t2, cano_cons_loss_t2) = losses
xrecon_loss.append(xrecon_loss_t2)
cano_cons_loss.append(cano_cons_loss_t2)
-
- x_c_c1 = torch.stack(x_c_c1)
- x_p_c1 = torch.stack(x_p_c1)
+ if self.image_log_on:
+ f_a_c1 = torch.stack(f_a_c1)
+ f_c_c1_mean = torch.stack(f_c_c1).mean(0)
+ f_p_c1 = torch.stack(f_p_c1)
+ f_p_c2 = torch.stack(f_p_c2)
+
+ # Decode features
+ appearance_image, canonical_image, pose_image = None, None, None
+ with torch.no_grad():
+ # Decode average canonical features to higher dimension
+ x_c_c1 = self.ae.decoder(
+ torch.zeros((n, self.f_a_dim), device=device),
+ f_c_c1_mean,
+ torch.zeros((n, self.f_p_dim), device=device),
+ cano_only=True
+ )
+ # Decode pose features to images
+ f_p_c1_ = f_p_c1.view(t * n, -1)
+ x_p_c1_ = self.ae.decoder(
+ torch.zeros((t * n, self.f_a_dim), device=device),
+ torch.zeros((t * n, self.f_c_dim), device=device),
+ f_p_c1_
+ )
+ x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+
+ if self.image_log_on:
+ # Decode appearance features
+ f_a_c1_ = f_a_c1.view(t * n, -1)
+ appearance_image_ = self.ae.decoder(
+ f_a_c1_,
+ torch.zeros((t * n, self.f_c_dim), device=device),
+ torch.zeros((t * n, self.f_p_dim), device=device)
+ )
+ appearance_image = appearance_image_.view(t, n, c, h, w)
+ # Continue decoding canonical features
+ canonical_image = self.ae.decoder.trans_conv3(x_c_c1)
+ canonical_image = torch.sigmoid(
+ self.ae.decoder.trans_conv4(canonical_image)
+ )
+ pose_image = x_p_c1
# Losses
xrecon_loss = torch.sum(torch.stack(xrecon_loss))
@@ -123,20 +168,36 @@ class RGBPartNet(nn.Module):
cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
return ((x_c_c1, x_p_c1),
+ (appearance_image, canonical_image, pose_image),
(xrecon_loss, pose_sim_loss, cano_cons_loss))
else: # evaluating
- x_c1 = x_c1.view(-1, c, h, w)
- x_c_c1, x_p_c1 = self.ae(x_c1)
- _, c_c, h_c, w_c = x_c_c1.size()
- x_c_c1 = x_c_c1.view(t, n, c_c, h_c, w_c)
- x_p_c1 = x_p_c1.view(t, n, c, h, w)
-
- return (x_c_c1, x_p_c1), None
+ x_c1_ = x_c1.view(t * n, c, h, w)
+ (f_c_c1_, f_p_c1_) = self.ae(x_c1_)
+
+ # Canonical features
+ f_c_c1 = f_c_c1_.view(t, n, -1)
+ f_c_c1_mean = f_c_c1.mean(0)
+ x_c_c1 = self.ae.decoder(
+ torch.zeros((n, self.f_a_dim)),
+ f_c_c1_mean,
+ torch.zeros((n, self.f_p_dim)),
+ cano_only=True
+ )
+
+ # Pose features
+ x_p_c1_ = self.ae.decoder(
+ torch.zeros((t * n, self.f_a_dim)),
+ torch.zeros((t * n, self.f_c_dim)),
+ f_p_c1_
+ )
+ x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+
+ return (x_c_c1, x_p_c1), None, None
@staticmethod
- def _pose_sim_loss(f_p_c1: list[torch.Tensor],
- f_p_c2: list[torch.Tensor]) -> torch.Tensor:
- f_p_c1_mean = torch.stack(f_p_c1).mean(dim=0)
- f_p_c2_mean = torch.stack(f_p_c2).mean(dim=0)
+ def _pose_sim_loss(f_p_c1: torch.Tensor,
+ f_p_c2: torch.Tensor) -> torch.Tensor:
+ f_p_c1_mean = f_p_c1.mean(dim=0)
+ f_p_c2_mean = f_p_c2.mean(dim=0)
return F.mse_loss(f_p_c1_mean, f_p_c2_mean)
diff --git a/utils/configuration.py b/utils/configuration.py
index c4c4b4d..4ab1520 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -7,6 +7,7 @@ class SystemConfiguration(TypedDict):
disable_acc: bool
CUDA_VISIBLE_DEVICES: str
save_dir: str
+ image_log_on: bool
class DatasetConfiguration(TypedDict):
@@ -31,6 +32,7 @@ class ModelHPConfiguration(TypedDict):
ae_feature_channels: int
f_a_c_p_dims: tuple[int, int, int]
hpm_scales: tuple[int, ...]
+ hpm_use_1x1conv: bool
hpm_use_avg_pool: bool
hpm_use_max_pool: bool
fpfe_feature_channels: int
@@ -40,7 +42,7 @@ class ModelHPConfiguration(TypedDict):
tfa_squeeze_ratio: int
tfa_num_parts: int
embedding_dims: int
- triplet_margin: float
+ triplet_margins: tuple[float, float]
class SubOptimizerHPConfiguration(TypedDict):
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 8c143d6..d573ef4 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -34,5 +34,5 @@ class BatchAllTripletLoss(nn.Module):
parted_loss_mean = all_loss.sum(1) / non_zero_counts
parted_loss_mean[non_zero_counts == 0] = 0
- loss = parted_loss_mean.sum()
+ loss = parted_loss_mean.mean()
return loss