diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-01 18:23:33 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-01 18:26:47 +0800 |
commit | d88e40217f56d96e568335ccee1f14ff3ea5a696 (patch) | |
tree | f24b204794fbb8fc501d2124ae67a73faf82db1f | |
parent | 5f75d7ef65f6dcd0e72df320c58b6bd141937b5f (diff) | |
parent | 6002b2d2017912f90e8917e6e8b71b78ce58e7c2 (diff) |
Merge branch 'master' into data_parallel
# Conflicts:
# models/model.py
-rw-r--r-- | config.py | 34 | ||||
-rw-r--r-- | models/layers.py | 4 | ||||
-rw-r--r-- | models/model.py | 129 | ||||
-rw-r--r-- | models/rgb_part_net.py | 6 | ||||
-rw-r--r-- | requirements.txt | 2 | ||||
-rw-r--r-- | utils/configuration.py | 7 | ||||
-rw-r--r-- | utils/triplet_loss.py | 117 |
7 files changed, 212 insertions, 87 deletions
@@ -5,7 +5,7 @@ config: Configuration = { # Disable accelerator 'disable_acc': False, # GPU(s) used in training or testing if available - 'CUDA_VISIBLE_DEVICES': '0', + 'CUDA_VISIBLE_DEVICES': '0,1', # Directory used in training or testing for temporary storage 'save_dir': 'runs', # Recorde disentangled image or not @@ -30,14 +30,14 @@ config: Configuration = { # Resolution after resize, can be divided 16 'frame_size': (64, 48), # Cache dataset or not - 'cache_on': False, + 'cache_on': True, }, # Dataloader settings 'dataloader': { # Batch size (pr, k) # `pr` denotes number of persons # `k` denotes number of sequences per person - 'batch_size': (4, 6), + 'batch_size': (6, 8), # Number of workers of Dataloader 'num_workers': 4, # Faster data transfer from RAM to GPU if enabled @@ -53,7 +53,7 @@ config: Configuration = { # Use 1x1 convolution in dimensionality reduction 'hpm_use_1x1conv': False, # HPM pyramid scales, of which sum is number of parts - 'hpm_scales': (1, 2, 4), + 'hpm_scales': (1, 2, 4, 8), # Global pooling method 'hpm_use_avg_pool': True, 'hpm_use_max_pool': True, @@ -63,13 +63,15 @@ config: Configuration = { 'tfa_num_parts': 16, # Embedding dimension for each part 'embedding_dims': 256, - # Triplet loss margins for HPM and PartNet - 'triplet_margins': (1.5, 1.5), + # Batch Hard or Batch All + 'triplet_is_hard': True, + # Use non-zero mean or sum + 'triplet_is_mean': True, + # Triplet loss margins for HPM and PartNet, None for soft margin + 'triplet_margins': None, }, 'optimizer': { # Global parameters - # Iteration start to optimize non-disentangling parts - # 'start_iter': 0, # Initial learning rate of Adam Optimizer 'lr': 1e-4, # Coefficients used for computing running averages of @@ -83,15 +85,15 @@ config: Configuration = { # 'amsgrad': False, # Local parameters (override global ones) - # 'auto_encoder': { - # 'weight_decay': 0.001 - # }, + 'auto_encoder': { + 'weight_decay': 0.001 + }, }, 'scheduler': { - # Period of learning rate decay - 'step_size': 500, - # Multiplicative factor of decay - 'gamma': 1, + # Step start to decay + 'start_step': 15_000, + # Multiplicative factor of decay in the end + 'final_gamma': 0.001, } }, # Model metadata @@ -105,6 +107,6 @@ config: Configuration = { # Restoration iteration (multiple models, e.g. nm, bg and cl) 'restore_iters': (0, 0, 0), # Total iteration for training (multiple models) - 'total_iters': (80_000, 80_000, 80_000), + 'total_iters': (25_000, 25_000, 25_000), }, } diff --git a/models/layers.py b/models/layers.py index ef53a95..f1d72b6 100644 --- a/models/layers.py +++ b/models/layers.py @@ -80,7 +80,9 @@ class DCGANConvTranspose2d(BasicConvTranspose2d): if self.is_last_layer: return self.trans_conv(x) else: - return super().forward(x) + x = self.trans_conv(x) + x = self.bn(x) + return F.leaky_relu(x, 0.2, inplace=True) class BasicLinear(nn.Module): diff --git a/models/model.py b/models/model.py index a086e7b..2eeaf5e 100644 --- a/models/model.py +++ b/models/model.py @@ -18,7 +18,7 @@ from utils.configuration import DataloaderConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler -from utils.triplet_loss import JointBatchAllTripletLoss +from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss class Model: @@ -68,7 +68,7 @@ class Model: self._dataset_sig: str = 'undefined' self.rgb_pn: Optional[RGBPartNet] = None - self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None + self.triplet_loss: Optional[JointBatchTripletLoss] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None @@ -143,9 +143,10 @@ class Model: dataloader = self._parse_dataloader_config(dataset, dataloader_config) # Prepare for model, optimizer and scheduler model_hp: dict = self.hp.get('model', {}).copy() - triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2)) + triplet_is_hard = model_hp.pop('triplet_is_hard', True) + triplet_is_mean = model_hp.pop('triplet_is_mean', True) + triplet_margins = model_hp.pop('triplet_margins', None) optim_hp: dict = self.hp.get('optimizer', {}).copy() - start_iter = optim_hp.pop('start_iter', 0) ae_optim_hp = optim_hp.pop('auto_encoder', {}) pn_optim_hp = optim_hp.pop('part_net', {}) hpm_optim_hp = optim_hp.pop('hpm', {}) @@ -153,28 +154,48 @@ class Model: sched_hp = self.hp.get('scheduler', {}) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) - self.ba_triplet_loss = JointBatchAllTripletLoss( - self.rgb_pn.hpm_num_parts, triplet_margins - ) + # Hard margins + if triplet_margins: + # Same margins + if triplet_margins[0] == triplet_margins[1]: + self.triplet_loss = BatchTripletLoss( + triplet_is_hard, triplet_margins[0] + ) + else: # Different margins + self.triplet_loss = JointBatchTripletLoss( + self.rgb_pn.hpm_num_parts, + triplet_is_hard, triplet_is_mean, triplet_margins + ) + else: # Soft margins + self.triplet_loss = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, None + ) + + num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 + num_pos_pairs = (self.k*(self.k-1)//2) * self.pr + # Try to accelerate computation using CUDA or others self.rgb_pn = nn.DataParallel(self.rgb_pn) self.rgb_pn = self.rgb_pn.to(self.device) - self.ba_triplet_loss = nn.DataParallel(self.ba_triplet_loss) - self.ba_triplet_loss = self.ba_triplet_loss.to(self.device) + self.triplet_loss = nn.DataParallel(self.triplet_loss) + self.triplet_loss = self.triplet_loss.to(self.device) self.optimizer = optim.Adam([ {'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp}, {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp}, {'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp}, {'params': self.rgb_pn.module.fc_mat, **fc_optim_hp} ], **optim_hp) - sched_gamma = sched_hp.get('gamma', 0.9) - sched_step_size = sched_hp.get('step_size', 500) + sched_final_gamma = sched_hp.get('final_gamma', 0.001) + sched_start_step = sched_hp.get('start_step', 15_000) + + def lr_lambda(epoch): + passed_step = epoch - sched_start_step + all_step = self.total_iter - sched_start_step + return sched_final_gamma ** (passed_step / all_step) self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ - lambda epoch: sched_gamma ** (epoch // sched_step_size), - lambda epoch: 0 if epoch < start_iter else 1, - lambda epoch: 0 if epoch < start_iter else 1, - lambda epoch: 0 if epoch < start_iter else 1, + lr_lambda, lr_lambda, lr_lambda, lr_lambda ]) + self.writer = SummaryWriter(self._log_name) self.rgb_pn.train() @@ -194,7 +215,7 @@ class Model: running_loss = torch.zeros(5, device=self.device) print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}") + f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}") for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -202,16 +223,16 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - feature, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part y = y.repeat(self.rgb_pn.num_total_parts, 1) - triplet_loss = self.ba_triplet_loss(feature, y) + trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) losses = torch.cat(( ae_losses.mean(0), torch.stack(( - triplet_loss[:self.rgb_pn.hpm_num_parts].mean(), - triplet_loss[self.rgb_pn.hpm_num_parts:].mean() + trip_loss[:self.rgb_pn.hpm_num_parts].mean(), + trip_loss[self.rgb_pn.hpm_num_parts:].mean() )) )) loss = losses.sum() @@ -222,20 +243,50 @@ class Model: running_loss += losses.detach() # Write losses to TensorBoard self.writer.add_scalar('Loss/all', loss, self.curr_iter) - self.writer.add_scalars('Loss/details', dict(zip([ + self.writer.add_scalars('Loss/disentanglement', dict(zip(( 'Cross reconstruction loss', 'Canonical consistency loss', - 'Pose similarity loss', 'Batch All triplet loss (HPM)', - 'Batch All triplet loss (PartNet)' - ], losses)), self.curr_iter) + 'Pose similarity loss' + ), ae_losses)), self.curr_iter) + self.writer.add_scalars('Loss/triplet loss', { + 'HPM': losses[3], + 'PartNet': losses[4] + }, self.curr_iter) + # None-zero losses in batch + if num_non_zero is not None: + self.writer.add_scalars('Loss/non-zero counts', { + 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), + 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() + }, self.curr_iter) + # Embedding distance + mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) + self._add_ranked_scalars( + 'Embedding/HPM distance', mean_hpm_dist, + num_pos_pairs, num_pairs, self.curr_iter + ) + mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) + self._add_ranked_scalars( + 'Embedding/ParNet distance', mean_pa_dist, + num_pos_pairs, num_pairs, self.curr_iter + ) + # Embedding norm + mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) + self._add_ranked_scalars( + 'Embedding/HPM norm', mean_hpm_norm, + self.k, self.pr * self.k, self.curr_iter + ) + mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_norm = mean_pa_embedding.norm(dim=-1) + self._add_ranked_scalars( + 'Embedding/PartNet norm', mean_pa_norm, + self.k, self.pr * self.k, self.curr_iter + ) if self.curr_iter % 100 == 0: lrs = self.scheduler.get_last_lr() # Write learning rates self.writer.add_scalar( - 'Learning rate/Auto-encoder', lrs[0], self.curr_iter - ) - self.writer.add_scalar( - 'Learning rate/Others', lrs[1], self.curr_iter + 'Learning rate', lrs[0], self.curr_iter ) # Write disentangled images if self.image_log_on: @@ -259,7 +310,7 @@ class Model: print(f'{hour:02}:{minute:02}:{second:02}', f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), - '{:.3e} {:.3e}'.format(lrs[0], lrs[1])) + f'{lrs[0]:.3e}') running_loss.zero_() # Step scheduler @@ -278,6 +329,24 @@ class Model: self.writer.close() break + def _add_ranked_scalars( + self, + main_tag: str, + metric: torch.Tensor, + num_pos: int, + num_all: int, + global_step: int + ): + rank = metric.argsort() + pos_ile = 100 - (num_pos - 1) * 100 // num_all + self.writer.add_scalars(main_tag, { + '0%-ile': metric[rank[-1]], + f'{100 - pos_ile}%-ile': metric[rank[-num_pos]], + '50%-ile': metric[rank[num_all // 2 - 1]], + f'{pos_ile}%-ile': metric[rank[num_pos - 1]], + '100%-ile': metric[rank[0]] + }, global_step) + def predict_all( self, iters: tuple[int], @@ -317,6 +386,8 @@ class Model: # Init models model_hp: dict = self.hp.get('model', {}).copy() + model_hp.pop('triplet_is_hard', True) + model_hp.pop('triplet_is_mean', True) model_hp.pop('triplet_margins', None) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp) # Try to accelerate computation using CUDA or others diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 4367c62..8a0f3a7 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -79,7 +79,7 @@ class RGBPartNet(nn.Module): ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) # Decode features x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + x_p_ = self._decode_pose_feature(f_p_, n, t, device) x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) i_a, i_c, i_p = None, None, None @@ -98,7 +98,7 @@ class RGBPartNet(nn.Module): else: # evaluating f_c_, f_p_ = self.ae(x_c1_t2) x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + x_p_ = self._decode_pose_feature(f_p_, n, t, device) x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) return (x_c, x_p), None, None @@ -123,7 +123,7 @@ class RGBPartNet(nn.Module): ) return x_c - def _decode_pose_feature(self, f_p_, n, t, c, h, w, device): + def _decode_pose_feature(self, f_p_, n, t, device): # Decode pose features to images x_p_ = self.ae.decoder( torch.zeros((n * t, self.f_a_dim), device=device), diff --git a/requirements.txt b/requirements.txt index 4d30e17..926a587 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch~=1.7.1 torchvision~=0.8.0a0+ecf4e9c numpy~=1.19.4 -tqdm~=4.57.0 +tqdm~=4.58.0 Pillow~=8.1.0 scikit-learn~=0.24.0
\ No newline at end of file diff --git a/utils/configuration.py b/utils/configuration.py index 435d815..0f8d9ff 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -43,6 +43,8 @@ class ModelHPConfiguration(TypedDict): tfa_squeeze_ratio: int tfa_num_parts: int embedding_dims: int + triplet_is_hard: bool + triplet_is_mean: bool triplet_margins: tuple[float, float] @@ -55,7 +57,6 @@ class SubOptimizerHPConfiguration(TypedDict): class OptimizerHPConfiguration(TypedDict): - start_iter: int lr: int betas: tuple[float, float] eps: float @@ -68,8 +69,8 @@ class OptimizerHPConfiguration(TypedDict): class SchedulerHPConfiguration(TypedDict): - step_size: int - gamma: float + start_step: int + final_gamma: float class HyperparameterConfiguration(TypedDict): diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 0df2188..e05b69d 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -1,32 +1,48 @@ +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -class BatchAllTripletLoss(nn.Module): - def __init__(self, margin: float = 0.2): +class BatchTripletLoss(nn.Module): + def __init__( + self, + is_hard: bool = True, + is_mean: bool = True, + margin: Optional[float] = 0.2, + ): super().__init__() + self.is_hard = is_hard + self.is_mean = is_mean self.margin = margin def forward(self, x, y): p, n, c = x.size() - dist = self._batch_distance(x) - positive_negative_dist = self._hard_distance(dist, y, p, n) - all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - parted_loss_mean = self._none_zero_parted_mean(all_loss) - - return parted_loss_mean - - @staticmethod - def _hard_distance(dist, y, p, n): - hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) - hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) - all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1) - all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1) - positive_negative_dist = all_hard_positive - all_hard_negative - - return positive_negative_dist + flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) + flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] + + if self.is_hard: + positive_negative_dist = self._hard_distance(dist, y, p, n) + else: # is_all + positive_negative_dist = self._all_distance(dist, y, p, n) + + if self.margin: + losses = F.relu(self.margin + positive_negative_dist).view(p, -1) + non_zero_counts = (losses != 0).sum(1).float() + if self.is_mean: + loss_metric = self._none_zero_mean(losses, non_zero_counts) + else: # is_sum + loss_metric = losses.sum(1) + return loss_metric, flat_dist, non_zero_counts + else: # Soft margin + losses = F.softplus(positive_negative_dist).view(p, -1) + if self.is_mean: + loss_metric = losses.mean(1) + else: # is_sum + loss_metric = losses.sum(1) + return loss_metric, flat_dist, None @staticmethod def _batch_distance(x): @@ -38,41 +54,74 @@ class BatchAllTripletLoss(nn.Module): dist = torch.sqrt( F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) ) - return dist @staticmethod - def _none_zero_parted_mean(all_loss): - # Non-zero parted mean - non_zero_counts = (all_loss != 0).sum(1) - parted_loss_mean = all_loss.sum(1) / non_zero_counts - parted_loss_mean[non_zero_counts == 0] = 0 + def _hard_distance(dist, y, p, n): + positive_mask = y.unsqueeze(1) == y.unsqueeze(2) + negative_mask = y.unsqueeze(1) != y.unsqueeze(2) + hard_positive = dist[positive_mask].view(p, n, -1).max(-1).values + hard_negative = dist[negative_mask].view(p, n, -1).min(-1).values + positive_negative_dist = hard_positive - hard_negative - return parted_loss_mean + return positive_negative_dist + + @staticmethod + def _all_distance(dist, y, p, n): + # Unmask identical samples + positive_mask = torch.eye( + n, dtype=torch.bool, device=y.device + ) ^ (y.unsqueeze(1) == y.unsqueeze(2)) + negative_mask = y.unsqueeze(1) != y.unsqueeze(2) + all_positive = dist[positive_mask].view(p, n, -1, 1) + all_negative = dist[negative_mask].view(p, n, 1, -1) + positive_negative_dist = all_positive - all_negative + + return positive_negative_dist + + @staticmethod + def _none_zero_mean(losses, non_zero_counts): + # Non-zero parted mean + non_zero_mean = losses.sum(1) / non_zero_counts + non_zero_mean[non_zero_counts == 0] = 0 + return non_zero_mean -class JointBatchAllTripletLoss(BatchAllTripletLoss): +class JointBatchTripletLoss(BatchTripletLoss): def __init__( self, hpm_num_parts: int, + is_hard: bool = True, + is_mean: bool = True, margins: tuple[float, float] = (0.2, 0.2) ): - super().__init__() + super().__init__(is_hard, is_mean) self.hpm_num_parts = hpm_num_parts self.margin_hpm, self.margin_pn = margins def forward(self, x, y): p, n, c = x.size() - dist = self._batch_distance(x) - positive_negative_dist = self._hard_distance(dist, y, p, n) + flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) + flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] + + if self.is_hard: + positive_negative_dist = self._hard_distance(dist, y, p, n) + else: # is_all + positive_negative_dist = self._all_distance(dist, y, p, n) + hpm_part_loss = F.relu( self.margin_hpm + positive_negative_dist[:self.hpm_num_parts] - ).view(self.hpm_num_parts, -1) + ) pn_part_loss = F.relu( self.margin_pn + positive_negative_dist[self.hpm_num_parts:] - ).view(p - self.hpm_num_parts, -1) - all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - parted_loss_mean = self._none_zero_parted_mean(all_loss) + ) + losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) + + non_zero_counts = (losses != 0).sum(1).float() + if self.is_mean: + loss_metric = self._none_zero_mean(losses, non_zero_counts) + else: # is_sum + loss_metric = losses.sum(1) - return parted_loss_mean + return loss_metric, flat_dist, non_zero_counts |