From c74df416b00f837ba051f3947be92f76e7afbd88 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 12 Mar 2021 13:56:17 +0800 Subject: Code refactoring 1. Separate FCs and triplet losses for HPM and PartNet 2. Remove FC-equivalent 1x1 conv layers in HPM 3. Support adjustable learning rate schedulers --- config.py | 21 +++++---- models/auto_encoder.py | 2 +- models/hpm.py | 25 +++++------ models/layers.py | 9 ---- models/model.py | 119 ++++++++++++++++++++++++++++--------------------- models/part_net.py | 18 +++++--- models/rgb_part_net.py | 32 +++++-------- test/part_net.py | 2 +- utils/configuration.py | 20 +++++---- utils/triplet_loss.py | 40 ----------------- 10 files changed, 127 insertions(+), 161 deletions(-) diff --git a/config.py b/config.py index d6de788..8abeba3 100644 --- a/config.py +++ b/config.py @@ -50,19 +50,17 @@ config: Configuration = { 'ae_feature_channels': 64, # Appearance, canonical and pose feature dimensions 'f_a_c_p_dims': (192, 192, 96), - # Use 1x1 convolution in dimensionality reduction - 'hpm_use_1x1conv': False, # HPM pyramid scales, of which sum is number of parts 'hpm_scales': (1, 2, 4, 8), # Global pooling method 'hpm_use_avg_pool': True, 'hpm_use_max_pool': True, - # Attention squeeze ratio - 'tfa_squeeze_ratio': 4, # Number of parts after Part Net 'tfa_num_parts': 16, - # Embedding dimension for each part - 'embedding_dims': 256, + # Attention squeeze ratio + 'tfa_squeeze_ratio': 4, + # Embedding dimensions for each part + 'embedding_dims': (256, 256), # Batch Hard or Batch All 'triplet_is_hard': True, # Use non-zero mean or sum @@ -91,9 +89,14 @@ config: Configuration = { }, 'scheduler': { # Step start to decay - 'start_step': 15_000, + 'start_step': 500, # Multiplicative factor of decay in the end - 'final_gamma': 0.001, + 'final_gamma': 0.01, + + # Local parameters (override global ones) + 'hpm': { + 'final_gamma': 0.001 + } } }, # Model metadata @@ -107,6 +110,6 @@ config: Configuration = { # Restoration iteration (multiple models, e.g. nm, bg and cl) 'restore_iters': (0, 0, 0), # Total iteration for training (multiple models) - 'total_iters': (25_000, 25_000, 25_000), + 'total_iters': (30_000, 40_000, 60_000), }, } diff --git a/models/auto_encoder.py b/models/auto_encoder.py index e6a3e60..4fece69 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -171,7 +171,7 @@ class AutoEncoder(nn.Module): return ( (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_), - torch.stack((xrecon_loss, cano_cons_loss, pose_sim_loss * 10)) + (xrecon_loss, cano_cons_loss, pose_sim_loss * 10) ) else: # evaluating return f_c_c1_t2_, f_p_c1_t2_ diff --git a/models/hpm.py b/models/hpm.py index 9879cfb..8186b20 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -9,32 +9,26 @@ class HorizontalPyramidMatching(nn.Module): self, in_channels: int, out_channels: int = 128, - use_1x1conv: bool = False, scales: tuple[int, ...] = (1, 2, 4), use_avg_pool: bool = True, use_max_pool: bool = False, - **kwargs ): super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.use_1x1conv = use_1x1conv self.scales = scales + self.num_parts = sum(scales) self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool self.pyramids = nn.ModuleList([ - self._make_pyramid(scale, **kwargs) for scale in self.scales + self._make_pyramid(scale) for scale in scales ]) + self.fc_mat = nn.Parameter( + torch.empty(self.num_parts, in_channels, out_channels) + ) - def _make_pyramid(self, scale: int, **kwargs): + def _make_pyramid(self, scale: int): pyramid = nn.ModuleList([ - HorizontalPyramidPooling(self.in_channels, - self.out_channels, - use_1x1conv=self.use_1x1conv, - use_avg_pool=self.use_avg_pool, - use_max_pool=self.use_max_pool, - **kwargs) + HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool) for _ in range(scale) ]) return pyramid @@ -52,4 +46,9 @@ class HorizontalPyramidMatching(nn.Module): x_slice = x_slice.view(n, -1) feature.append(x_slice) x = torch.stack(feature) + + # p, n, c + x = x @ self.fc_mat + # p, n, d + return x diff --git a/models/layers.py b/models/layers.py index f1d72b6..c609698 100644 --- a/models/layers.py +++ b/models/layers.py @@ -167,17 +167,10 @@ class BasicConv1d(nn.Module): class HorizontalPyramidPooling(nn.Module): def __init__( self, - in_channels: int, - out_channels: int, - use_1x1conv: bool = False, use_avg_pool: bool = True, use_max_pool: bool = False, - **kwargs ): super().__init__() - self.use_1x1conv = use_1x1conv - if use_1x1conv: - self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs) self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.' @@ -191,6 +184,4 @@ class HorizontalPyramidPooling(nn.Module): x = self.avg_pool(x) elif not self.use_avg_pool and self.use_max_pool: x = self.max_pool(x) - if self.use_1x1conv: - x = self.conv(x) return x diff --git a/models/model.py b/models/model.py index cb455f3..adea626 100644 --- a/models/model.py +++ b/models/model.py @@ -13,13 +13,15 @@ from torch.utils.data.dataloader import default_collate from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm +from models.hpm import HorizontalPyramidMatching +from models.part_net import PartNet from models.rgb_part_net import RGBPartNet from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler -from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss +from utils.triplet_loss import BatchTripletLoss class Model: @@ -69,7 +71,8 @@ class Model: self._dataset_sig: str = 'undefined' self.rgb_pn: Optional[RGBPartNet] = None - self.triplet_loss: Optional[JointBatchTripletLoss] = None + self.triplet_loss_hpm: Optional[BatchTripletLoss] = None + self.triplet_loss_pn: Optional[BatchTripletLoss] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None @@ -149,26 +152,28 @@ class Model: triplet_margins = model_hp.pop('triplet_margins', None) optim_hp: dict = self.hp.get('optimizer', {}).copy() ae_optim_hp = optim_hp.pop('auto_encoder', {}) - pn_optim_hp = optim_hp.pop('part_net', {}) hpm_optim_hp = optim_hp.pop('hpm', {}) - fc_optim_hp = optim_hp.pop('fc', {}) + pn_optim_hp = optim_hp.pop('part_net', {}) sched_hp = self.hp.get('scheduler', {}) + ae_sched_hp = sched_hp.get('auto_encoder', {}) + hpm_sched_hp = sched_hp.get('hpm', {}) + pn_sched_hp = sched_hp.get('part_net', {}) + self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) # Hard margins if triplet_margins: - # Same margins - if triplet_margins[0] == triplet_margins[1]: - self.triplet_loss = BatchTripletLoss( - triplet_is_hard, triplet_margins[0] - ) - else: # Different margins - self.triplet_loss = JointBatchTripletLoss( - self.rgb_pn.hpm_num_parts, - triplet_is_hard, triplet_is_mean, triplet_margins - ) + self.triplet_loss_hpm = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, triplet_margins[0] + ) + self.triplet_loss_pn = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, triplet_margins[1] + ) else: # Soft margins - self.triplet_loss = BatchTripletLoss( + self.triplet_loss_hpm = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, None + ) + self.triplet_loss_pn = BatchTripletLoss( triplet_is_hard, triplet_is_mean, None ) @@ -177,25 +182,33 @@ class Model: # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) - self.triplet_loss = self.triplet_loss.to(self.device) + self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device) + self.triplet_loss_pn = self.triplet_loss_pn.to(self.device) + self.optimizer = optim.Adam([ {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.fc_mat, **fc_optim_hp} + {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, ], **optim_hp) - sched_final_gamma = sched_hp.get('final_gamma', 0.001) - sched_start_step = sched_hp.get('start_step', 15_000) - all_step = self.total_iter - sched_start_step - - def lr_lambda(epoch): - if epoch > sched_start_step: - passed_step = epoch - sched_start_step - return sched_final_gamma ** (passed_step / all_step) - else: - return 1 + + start_step = sched_hp.get('start_step', 15_000) + final_gamma = sched_hp.get('final_gamma', 0.001) + ae_start_step = ae_sched_hp.get('start_step', start_step) + ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma) + ae_all_step = self.total_iter - ae_start_step + hpm_start_step = hpm_sched_hp.get('start_step', start_step) + hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma) + hpm_all_step = self.total_iter - hpm_start_step + pn_start_step = pn_sched_hp.get('start_step', start_step) + pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma) + pn_all_step = self.total_iter - pn_start_step self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ - lr_lambda, lr_lambda, lr_lambda, lr_lambda + lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step) + if t > ae_start_step else 1, + lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step) + if t > hpm_start_step else 1, + lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step) + if t > pn_start_step else 1, ]) self.writer = SummaryWriter(self._log_name) @@ -220,7 +233,7 @@ class Model: running_loss = torch.zeros(5, device=self.device) print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}") + f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}") for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -228,17 +241,20 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embedding_c, embedding_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part - y = y.repeat(self.rgb_pn.num_total_parts, 1) - trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) - losses = torch.cat(( - ae_losses, - torch.stack(( - trip_loss[:self.rgb_pn.hpm_num_parts].mean(), - trip_loss[self.rgb_pn.hpm_num_parts:].mean() - )) + y = y.repeat(self.rgb_pn.num_parts, 1) + trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( + embedding_c, y[:self.rgb_pn.hpm.num_parts] + ) + trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( + embedding_p, y[self.rgb_pn.hpm.num_parts:] + ) + losses = torch.stack(( + *ae_losses, + trip_loss_hpm.mean(), + trip_loss_pn.mean() )) loss = losses.sum() loss.backward() @@ -257,30 +273,30 @@ class Model: 'PartNet': losses[4] }, self.curr_iter) # None-zero losses in batch - if num_non_zero is not None: + if hpm_num_non_zero is not None and hpm_num_non_zero is not None: self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), - 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() + 'HPM': hpm_num_non_zero.mean(), + 'PartNet': pn_num_non_zero.mean() }, self.curr_iter) # Embedding distance - mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_dist = hpm_dist.mean(0) self._add_ranked_scalars( 'Embedding/HPM distance', mean_hpm_dist, num_pos_pairs, num_pairs, self.curr_iter ) - mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_dist = pn_dist.mean(0) self._add_ranked_scalars( 'Embedding/ParNet distance', mean_pa_dist, num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm - mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_embedding = embedding_c.mean(0) mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/HPM norm', mean_hpm_norm, self.k, self.pr * self.k, self.curr_iter ) - mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_embedding = embedding_p.mean(0) mean_pa_norm = mean_pa_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/PartNet norm', mean_pa_norm, @@ -288,10 +304,9 @@ class Model: ) # Learning rate lrs = self.scheduler.get_last_lr() - # Write learning rates - self.writer.add_scalar( - 'Learning rate', lrs[0], self.curr_iter - ) + self.writer.add_scalars('Learning rate', dict(zip(( + 'Auto-encoder', 'HPM', 'PartNet' + ), lrs)), self.curr_iter) if self.curr_iter % 100 == 0: # Write disentangled images @@ -316,7 +331,7 @@ class Model: print(f'{hour:02}:{minute:02}:{second:02}', f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), - f'{lrs[0]:.3e}') + '{:.3e} {:.3e} {:.3e}'.format(*lrs)) running_loss.zero_() # Step scheduler @@ -548,7 +563,7 @@ class Model: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) - elif isinstance(m, RGBPartNet): + elif isinstance(m, (HorizontalPyramidMatching, PartNet)): nn.init.xavier_uniform_(m.fc_mat) def _parse_dataset_config( diff --git a/models/part_net.py b/models/part_net.py index 29cf9cd..f2236bf 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -111,17 +111,21 @@ class PartNet(nn.Module): def __init__( self, in_channels: int = 128, + embedding_dims: int = 256, + num_parts: int = 16, squeeze_ratio: int = 4, - num_part: int = 16 ): super().__init__() - self.num_part = num_part - self.tfa = TemporalFeatureAggregator( - in_channels, squeeze_ratio, self.num_part - ) + self.num_part = num_parts self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) + self.tfa = TemporalFeatureAggregator( + in_channels, squeeze_ratio, self.num_part + ) + self.fc_mat = nn.Parameter( + torch.empty(num_parts, in_channels, embedding_dims) + ) def forward(self, x): n, t, c, h, w = x.size() @@ -138,4 +142,8 @@ class PartNet(nn.Module): # p, n, t, c x = self.tfa(x) + + # p, n, c + x = x @ self.fc_mat + # p, n, d return x diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 8a0f3a7..c38a567 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,39 +13,31 @@ class RGBPartNet(nn.Module): ae_in_size: tuple[int, int] = (64, 48), ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), - hpm_use_1x1conv: bool = False, hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, - embedding_dims: int = 256, + embedding_dims: tuple[int] = (256, 256), image_log_on: bool = False ): super().__init__() self.h, self.w = ae_in_size (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims - self.hpm_num_parts = sum(hpm_scales) self.image_log_on = image_log_on self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) self.pn_in_channels = ae_feature_channels * 2 - self.pn = PartNet( - self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts - ) self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv, - hpm_scales, hpm_use_avg_pool, hpm_use_max_pool + self.pn_in_channels, embedding_dims[0], hpm_scales, + hpm_use_avg_pool, hpm_use_max_pool ) - self.num_total_parts = self.hpm_num_parts + tfa_num_parts - empty_fc = torch.empty(self.num_total_parts, - self.pn_in_channels, embedding_dims) - self.fc_mat = nn.Parameter(empty_fc) + self.pn = PartNet(self.pn_in_channels, embedding_dims[1], + tfa_num_parts, tfa_squeeze_ratio) - def fc(self, x): - return x @ self.fc_mat + self.num_parts = self.hpm.num_parts + tfa_num_parts def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement @@ -55,21 +47,17 @@ class RGBPartNet(nn.Module): # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w x_c = self.hpm(x_c) - # p, n, c + # p, n, d # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) # n, t, c, h, w x_p = self.pn(x_p) - # p, n, c - - # Step 3: Cat feature map together and fc - x = torch.cat((x_c, x_p)) - x = self.fc(x) + # p, n, d if self.training: - return x, ae_losses, images + return x_c, x_p, ae_losses, images else: - return x.unsqueeze(1).view(-1) + return torch.cat((x_c, x_p)).unsqueeze(1).view(-1) def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() diff --git a/test/part_net.py b/test/part_net.py index 25e92ae..fada2c4 100644 --- a/test/part_net.py +++ b/test/part_net.py @@ -64,7 +64,7 @@ def test_custom_part_net(): paddings=((2, 1), (1, 1), (1, 1), (1, 1)), halving=(1, 1, 3, 3), squeeze_ratio=8, - num_part=8) + num_parts=8) x = torch.rand(T, N, 1, H, W) x = pa(x) diff --git a/utils/configuration.py b/utils/configuration.py index 0f8d9ff..f6ac182 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -33,16 +33,11 @@ class ModelHPConfiguration(TypedDict): ae_feature_channels: int f_a_c_p_dims: tuple[int, int, int] hpm_scales: tuple[int, ...] - hpm_use_1x1conv: bool hpm_use_avg_pool: bool hpm_use_max_pool: bool - fpfe_feature_channels: int - fpfe_kernel_sizes: tuple[tuple, ...] - fpfe_paddings: tuple[tuple, ...] - fpfe_halving: tuple[int, ...] - tfa_squeeze_ratio: int tfa_num_parts: int - embedding_dims: int + tfa_squeeze_ratio: int + embedding_dims: tuple[int] triplet_is_hard: bool triplet_is_mean: bool triplet_margins: tuple[float, float] @@ -63,14 +58,21 @@ class OptimizerHPConfiguration(TypedDict): weight_decay: float amsgrad: bool auto_encoder: SubOptimizerHPConfiguration - part_net: SubOptimizerHPConfiguration hpm: SubOptimizerHPConfiguration - fc: SubOptimizerHPConfiguration + part_net: SubOptimizerHPConfiguration + + +class SubSchedulerHPConfiguration(TypedDict): + start_step: int + final_gamma: float class SchedulerHPConfiguration(TypedDict): start_step: int final_gamma: float + auto_encoder: SubSchedulerHPConfiguration + hpm: SubSchedulerHPConfiguration + part_net: SubSchedulerHPConfiguration class HyperparameterConfiguration(TypedDict): diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index e05b69d..03fff21 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -85,43 +85,3 @@ class BatchTripletLoss(nn.Module): non_zero_mean = losses.sum(1) / non_zero_counts non_zero_mean[non_zero_counts == 0] = 0 return non_zero_mean - - -class JointBatchTripletLoss(BatchTripletLoss): - def __init__( - self, - hpm_num_parts: int, - is_hard: bool = True, - is_mean: bool = True, - margins: tuple[float, float] = (0.2, 0.2) - ): - super().__init__(is_hard, is_mean) - self.hpm_num_parts = hpm_num_parts - self.margin_hpm, self.margin_pn = margins - - def forward(self, x, y): - p, n, c = x.size() - dist = self._batch_distance(x) - flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) - flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] - - if self.is_hard: - positive_negative_dist = self._hard_distance(dist, y, p, n) - else: # is_all - positive_negative_dist = self._all_distance(dist, y, p, n) - - hpm_part_loss = F.relu( - self.margin_hpm + positive_negative_dist[:self.hpm_num_parts] - ) - pn_part_loss = F.relu( - self.margin_pn + positive_negative_dist[self.hpm_num_parts:] - ) - losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - - non_zero_counts = (losses != 0).sum(1).float() - if self.is_mean: - loss_metric = self._none_zero_mean(losses, non_zero_counts) - else: # is_sum - loss_metric = losses.sum(1) - - return loss_metric, flat_dist, non_zero_counts -- cgit v1.2.3