summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py36
-rw-r--r--models/hpm.py20
-rw-r--r--models/layers.py10
-rw-r--r--models/model.py58
-rw-r--r--models/rgb_part_net.py162
5 files changed, 173 insertions, 113 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index befd2d3..69dae4e 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -97,15 +97,14 @@ class Decoder(nn.Module):
self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
- def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False):
+ def forward(self, f_appearance, f_canonical, f_pose, cano_only=False):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.fc(x)
x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True)
- # Decode canonical features without transpose convolutions
- if no_trans_conv:
- return x
x = self.trans_conv1(x)
x = self.trans_conv2(x)
+ if cano_only:
+ return x
x = self.trans_conv3(x)
x = torch.sigmoid(self.trans_conv4(x))
@@ -115,7 +114,6 @@ class Decoder(nn.Module):
class AutoEncoder(nn.Module):
def __init__(
self,
- num_class: int = 74,
channels: int = 3,
feature_channels: int = 64,
embedding_dims: Tuple[int, int, int] = (128, 128, 64)
@@ -124,27 +122,10 @@ class AutoEncoder(nn.Module):
self.encoder = Encoder(channels, feature_channels, embedding_dims)
self.decoder = Decoder(embedding_dims, feature_channels, channels)
- f_c_dim = embedding_dims[1]
- self.classifier = nn.Sequential(
- nn.LeakyReLU(0.2, inplace=True),
- BasicLinear(f_c_dim, num_class)
- )
-
- def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None, y=None):
+ def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None):
# x_c1_t2 is the frame for later module
(f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
- with torch.no_grad():
- # Decode canonical features for HPM
- x_c_c1_t2 = self.decoder(
- torch.zeros_like(f_a_c1_t2), f_c_c1_t2, torch.zeros_like(f_p_c1_t2),
- no_trans_conv=True
- )
- # Decode pose features for Part Net
- x_p_c1_t2 = self.decoder(
- torch.zeros_like(f_a_c1_t2), torch.zeros_like(f_c_c1_t2), f_p_c1_t2
- )
-
if self.training:
# t1 is random time step, c2 is another condition
(f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
@@ -152,16 +133,13 @@ class AutoEncoder(nn.Module):
x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)
xrecon_loss_t2 = F.mse_loss(x_c1_t2, x_c1_t2_)
-
- y_ = self.classifier(f_c_c1_t2.contiguous())
cano_cons_loss_t2 = (F.mse_loss(f_c_c1_t1, f_c_c1_t2)
- + F.mse_loss(f_c_c1_t2, f_c_c2_t2)
- + F.cross_entropy(y_, y))
+ + F.mse_loss(f_c_c1_t2, f_c_c2_t2))
return (
- (x_c_c1_t2, x_p_c1_t2),
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2),
(f_p_c1_t2, f_p_c2_t2),
(xrecon_loss_t2, cano_cons_loss_t2)
)
else: # evaluating
- return x_c_c1_t2, x_p_c1_t2
+ return f_c_c1_t2, f_p_c1_t2
diff --git a/models/hpm.py b/models/hpm.py
index 7505ed7..b49be3a 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -11,14 +11,16 @@ class HorizontalPyramidMatching(nn.Module):
self,
in_channels: int,
out_channels: int = 128,
+ use_1x1conv: bool = False,
scales: Tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
- use_max_pool: bool = True,
+ use_max_pool: bool = False,
**kwargs
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
+ self.use_1x1conv = use_1x1conv
self.scales = scales
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
@@ -31,6 +33,7 @@ class HorizontalPyramidMatching(nn.Module):
pyramid = nn.ModuleList([
HorizontalPyramidPooling(self.in_channels,
self.out_channels,
+ use_1x1conv=self.use_1x1conv,
use_avg_pool=self.use_avg_pool,
use_max_pool=self.use_max_pool,
**kwargs)
@@ -39,23 +42,16 @@ class HorizontalPyramidMatching(nn.Module):
return pyramid
def forward(self, x):
- # Flatten canonical features in all batches
- t, n, c, h, w = x.size()
- x = x.view(t * n, c, h, w)
-
+ n, c, h, w = x.size()
feature = []
- for pyramid_index, pyramid in enumerate(self.pyramids):
- h_per_hpp = h // self.scales[pyramid_index]
+ for scale, pyramid in zip(self.scales, self.pyramids):
+ h_per_hpp = h // scale
for hpp_index, hpp in enumerate(pyramid):
h_filter = torch.arange(hpp_index * h_per_hpp,
(hpp_index + 1) * h_per_hpp)
x_slice = x[:, :, h_filter, :]
x_slice = hpp(x_slice)
- x_slice = x_slice.view(t * n, -1)
+ x_slice = x_slice.view(n, -1)
feature.append(x_slice)
x = torch.stack(feature)
-
- # Unfold frames to original batch
- p, _, c = x.size()
- x = x.view(p, t, n, c)
return x
diff --git a/models/layers.py b/models/layers.py
index 7f2ccec..98e4c10 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -167,12 +167,13 @@ class HorizontalPyramidPooling(BasicConv2d):
self,
in_channels: int,
out_channels: int,
- kernel_size: Union[int, Tuple[int, int]] = 1,
+ use_1x1conv: bool = False,
use_avg_pool: bool = True,
- use_max_pool: bool = True,
+ use_max_pool: bool = False,
**kwargs
):
- super().__init__(in_channels, out_channels, kernel_size, **kwargs)
+ super().__init__(in_channels, out_channels, kernel_size=1, **kwargs)
+ self.use_1x1conv = use_1x1conv
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.'
@@ -186,5 +187,6 @@ class HorizontalPyramidPooling(BasicConv2d):
x = self.avg_pool(x)
elif not self.use_avg_pool and self.use_max_pool:
x = self.max_pool(x)
- x = super().forward(x)
+ if self.use_1x1conv:
+ x = super().forward(x)
return x
diff --git a/models/model.py b/models/model.py
index d11617b..912d0b9 100644
--- a/models/model.py
+++ b/models/model.py
@@ -51,7 +51,6 @@ class Model:
self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000))
self.is_train: bool = True
- self.train_size: int = 74
self.in_channels: int = 3
self.pr: Optional[int] = None
self.k: Optional[int] = None
@@ -67,6 +66,7 @@ class Model:
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
+ self.image_log_on = system_config.get('image_log_on', False)
self.CASIAB_GALLERY_SELECTOR = {
'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
@@ -138,19 +138,18 @@ class Model:
# Prepare for model, optimizer and scheduler
model_hp = self.hp.get('model', {})
optim_hp: Dict = self.hp.get('optimizer', {}).copy()
+ start_iter = optim_hp.pop('start_iter', 0)
ae_optim_hp = optim_hp.pop('auto_encoder', {})
pn_optim_hp = optim_hp.pop('part_net', {})
hpm_optim_hp = optim_hp.pop('hpm', {})
fc_optim_hp = optim_hp.pop('fc', {})
sched_hp = self.hp.get('scheduler', {})
- self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **model_hp)
+ self.rgb_pn = RGBPartNet(self.in_channels, **model_hp,
+ image_log_on=self.image_log_on)
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam([
{'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
- {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
- {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
- {'params': self.rgb_pn.fc_mat, **fc_optim_hp},
], **optim_hp)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, **sched_hp)
self.writer = SummaryWriter(self._log_name)
@@ -168,10 +167,20 @@ class Model:
# Training start
start_time = datetime.now()
- running_loss = torch.zeros(4).to(self.device)
+ running_loss = torch.zeros(5, device=self.device)
print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
- f"{'CanoCons':^8} {'BATrip':^8} {'LR':^9}")
+ f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} LR(s)")
for (batch_c1, batch_c2) in dataloader:
+ if self.curr_iter == start_iter:
+ self.optimizer.add_param_group(
+ {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}
+ )
+ self.optimizer.add_param_group(
+ {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}
+ )
+ self.optimizer.add_param_group(
+ {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
+ )
self.curr_iter += 1
# Zero the parameter gradients
self.optimizer.zero_grad()
@@ -179,12 +188,10 @@ class Model:
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
y = batch_c1['label'].to(self.device)
- losses = self.rgb_pn(x_c1, x_c2, y)
+ losses, images = self.rgb_pn(x_c1, x_c2, y)
loss = losses.sum()
loss.backward()
self.optimizer.step()
- # Step scheduler
- self.scheduler.step()
# Statistics and checkpoint
running_loss += losses.detach()
@@ -192,15 +199,39 @@ class Model:
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
self.writer.add_scalars('Loss/details', dict(zip([
'Cross reconstruction loss', 'Pose similarity loss',
- 'Canonical consistency loss', 'Batch All triplet loss'
+ 'Canonical consistency loss', 'Batch All triplet loss (HPM)',
+ 'Batch All triplet loss (PartNet)'
], losses)), self.curr_iter)
+ if self.image_log_on:
+ (appearance_image, canonical_image, pose_image) = images
+ self.writer.add_images(
+ 'Canonical image', canonical_image, self.curr_iter
+ )
+ for i in range(self.pr * self.k):
+ self.writer.add_images(
+ f'Original image/batch {i}', x_c1[i], self.curr_iter
+ )
+ self.writer.add_images(
+ f'Appearance image/batch {i}',
+ appearance_image[:, i, :, :, :],
+ self.curr_iter
+ )
+ self.writer.add_images(
+ f'Pose image/batch {i}',
+ pose_image[:, i, :, :, :],
+ self.curr_iter
+ )
if self.curr_iter % 100 == 0:
+ lrs = self.scheduler.get_last_lr()
print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
- '{:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- f'{self.scheduler.get_last_lr()[0]:.3e}')
+ '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
+ ' '.join(('{:.3e}'.format(lr) for lr in lrs)))
running_loss.zero_()
+ # Step scheduler
+ self.scheduler.step()
+
if self.curr_iter % 1000 == 0:
torch.save({
'iter': self.curr_iter,
@@ -396,7 +427,6 @@ class Model:
self,
dataset_config: Dict
) -> Union[CASIAB]:
- self.train_size = dataset_config.get('train_size', 74)
self.in_channels = dataset_config.get('num_input_channels', 3)
self._dataset_sig = self._make_signature(
dataset_config,
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 326ec81..f6dc131 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -14,10 +14,10 @@ from utils.triplet_loss import BatchAllTripletLoss
class RGBPartNet(nn.Module):
def __init__(
self,
- num_class: int = 74,
ae_in_channels: int = 3,
ae_feature_channels: int = 64,
f_a_c_p_dims: Tuple[int, int, int] = (128, 128, 64),
+ hpm_use_1x1conv: bool = False,
hpm_scales: Tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
@@ -28,11 +28,16 @@ class RGBPartNet(nn.Module):
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
embedding_dims: int = 256,
- triplet_margin: float = 0.2
+ triplet_margins: Tuple[float, float] = (0.2, 0.2),
+ image_log_on: bool = False
):
super().__init__()
+ (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
+ self.hpm_num_parts = sum(hpm_scales)
+ self.image_log_on = image_log_on
+
self.ae = AutoEncoder(
- num_class, ae_in_channels, ae_feature_channels, f_a_c_p_dims
+ ae_in_channels, ae_feature_channels, f_a_c_p_dims
)
self.pn = PartNet(
ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes,
@@ -40,14 +45,16 @@ class RGBPartNet(nn.Module):
)
out_channels = self.pn.tfa_in_channels
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 8, out_channels, hpm_scales,
- hpm_use_avg_pool, hpm_use_max_pool
+ ae_feature_channels * 2, out_channels, hpm_use_1x1conv,
+ hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
)
- total_parts = sum(hpm_scales) + tfa_num_parts
- empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
+ empty_fc = torch.empty(self.hpm_num_parts + tfa_num_parts,
+ out_channels, embedding_dims)
self.fc_mat = nn.Parameter(empty_fc)
- self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin)
+ (hpm_margin, pn_margin) = triplet_margins
+ self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin)
+ self.pn_ba_trip = BatchAllTripletLoss(pn_margin)
def fc(self, x):
return x @ self.fc_mat
@@ -61,13 +68,11 @@ class RGBPartNet(nn.Module):
# Step 1: Disentanglement
# t, n, c, h, w
- ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2, y)
+ ((x_c_c1, x_p_c1), images, losses) = self._disentangle(x_c1, x_c2)
- # Step 2.a: HPM & Static Gait Feature Aggregation
- # t, n, c, h, w
+ # Step 2.a: Static Gait Feature Aggregation & HPM
+ # n, c, h, w
x_c = self.hpm(x_c_c1)
- # p, t, n, c
- x_c = x_c.mean(dim=1)
# p, n, c
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
@@ -80,44 +85,83 @@ class RGBPartNet(nn.Module):
x = self.fc(x)
if self.training:
- batch_all_triplet_loss = self.ba_triplet_loss(x, y)
- losses = torch.stack((*losses, batch_all_triplet_loss))
- return losses
+ hpm_ba_trip = self.hpm_ba_trip(x[:self.hpm_num_parts], y)
+ pn_ba_trip = self.pn_ba_trip(x[self.hpm_num_parts:], y)
+ losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip))
+ return losses, images
else:
return x.unsqueeze(1).view(-1)
- def _disentangle(self, x_c1, x_c2=None, y=None):
- num_frames = len(x_c1)
- # Decoded canonical features and Pose images
- x_c_c1, x_p_c1 = [], []
+ def _disentangle(self, x_c1, x_c2=None):
+ t, n, c, h, w = x_c1.size()
+ device = x_c1.device
if self.training:
+ # Encoded appearance, canonical and pose features
+ f_a_c1, f_c_c1, f_p_c1 = [], [], []
# Features required to calculate losses
- f_p_c1, f_p_c2 = [], []
+ f_p_c2 = []
xrecon_loss, cano_cons_loss = [], []
- for t2 in range(num_frames):
- t1 = random.randrange(num_frames)
- output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y)
- (x_c1_t2, f_p_t2, losses) = output
-
- # Decoded features or image
- (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
- # Canonical Features for HPM
- x_c_c1.append(x_c_c1_t2)
- # Pose image for Part Net
- x_p_c1.append(x_p_c1_t2)
+ for t2 in range(t):
+ t1 = random.randrange(t)
+ output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2])
+ (f_c1_t2, f_p_t2, losses) = output
+
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = f_c1_t2
+ if self.image_log_on:
+ f_a_c1.append(f_a_c1_t2)
+ # Save canonical features and pose features
+ f_c_c1.append(f_c_c1_t2)
+ f_p_c1.append(f_p_c1_t2)
# Losses per time step
# Used in pose similarity loss
- (f_p_c1_t2, f_p_c2_t2) = f_p_t2
- f_p_c1.append(f_p_c1_t2)
+ (_, f_p_c2_t2) = f_p_t2
f_p_c2.append(f_p_c2_t2)
+
# Cross reconstruction loss and canonical loss
(xrecon_loss_t2, cano_cons_loss_t2) = losses
xrecon_loss.append(xrecon_loss_t2)
cano_cons_loss.append(cano_cons_loss_t2)
-
- x_c_c1 = torch.stack(x_c_c1)
- x_p_c1 = torch.stack(x_p_c1)
+ if self.image_log_on:
+ f_a_c1 = torch.stack(f_a_c1)
+ f_c_c1_mean = torch.stack(f_c_c1).mean(0)
+ f_p_c1 = torch.stack(f_p_c1)
+ f_p_c2 = torch.stack(f_p_c2)
+
+ # Decode features
+ appearance_image, canonical_image, pose_image = None, None, None
+ with torch.no_grad():
+ # Decode average canonical features to higher dimension
+ x_c_c1 = self.ae.decoder(
+ torch.zeros((n, self.f_a_dim), device=device),
+ f_c_c1_mean,
+ torch.zeros((n, self.f_p_dim), device=device),
+ cano_only=True
+ )
+ # Decode pose features to images
+ f_p_c1_ = f_p_c1.view(t * n, -1)
+ x_p_c1_ = self.ae.decoder(
+ torch.zeros((t * n, self.f_a_dim), device=device),
+ torch.zeros((t * n, self.f_c_dim), device=device),
+ f_p_c1_
+ )
+ x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+
+ if self.image_log_on:
+ # Decode appearance features
+ f_a_c1_ = f_a_c1.view(t * n, -1)
+ appearance_image_ = self.ae.decoder(
+ f_a_c1_,
+ torch.zeros((t * n, self.f_c_dim), device=device),
+ torch.zeros((t * n, self.f_p_dim), device=device)
+ )
+ appearance_image = appearance_image_.view(t, n, c, h, w)
+ # Continue decoding canonical features
+ canonical_image = self.ae.decoder.trans_conv3(x_c_c1)
+ canonical_image = torch.sigmoid(
+ self.ae.decoder.trans_conv4(canonical_image)
+ )
+ pose_image = x_p_c1
# Losses
xrecon_loss = torch.sum(torch.stack(xrecon_loss))
@@ -125,26 +169,36 @@ class RGBPartNet(nn.Module):
cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
return ((x_c_c1, x_p_c1),
+ (appearance_image, canonical_image, pose_image),
(xrecon_loss, pose_sim_loss, cano_cons_loss))
else: # evaluating
- for t2 in range(num_frames):
- x_c1_t2 = self.ae(x_c1[t2])
- # Decoded features or image
- (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
- # Canonical Features for HPM
- x_c_c1.append(x_c_c1_t2)
- # Pose image for Part Net
- x_p_c1.append(x_p_c1_t2)
-
- x_c_c1 = torch.stack(x_c_c1)
- x_p_c1 = torch.stack(x_p_c1)
-
- return (x_c_c1, x_p_c1), None
+ x_c1_ = x_c1.view(t * n, c, h, w)
+ (f_c_c1_, f_p_c1_) = self.ae(x_c1_)
+
+ # Canonical features
+ f_c_c1 = f_c_c1_.view(t, n, -1)
+ f_c_c1_mean = f_c_c1.mean(0)
+ x_c_c1 = self.ae.decoder(
+ torch.zeros((n, self.f_a_dim)),
+ f_c_c1_mean,
+ torch.zeros((n, self.f_p_dim)),
+ cano_only=True
+ )
+
+ # Pose features
+ x_p_c1_ = self.ae.decoder(
+ torch.zeros((t * n, self.f_a_dim)),
+ torch.zeros((t * n, self.f_c_dim)),
+ f_p_c1_
+ )
+ x_p_c1 = x_p_c1_.view(t, n, c, h, w)
+
+ return (x_c_c1, x_p_c1), None, None
@staticmethod
- def _pose_sim_loss(f_p_c1: List[torch.Tensor],
- f_p_c2: List[torch.Tensor]) -> torch.Tensor:
- f_p_c1_mean = torch.stack(f_p_c1).mean(dim=0)
- f_p_c2_mean = torch.stack(f_p_c2).mean(dim=0)
+ def _pose_sim_loss(f_p_c1: torch.Tensor,
+ f_p_c2: torch.Tensor) -> torch.Tensor:
+ f_p_c1_mean = f_p_c1.mean(dim=0)
+ f_p_c2_mean = f_p_c2.mean(dim=0)
return F.mse_loss(f_p_c1_mean, f_p_c2_mean)