From 820d3dec284f38e6a3089dad5277bc3f6c5123bf Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 20 Feb 2021 14:19:30 +0800 Subject: Separate triplet loss from model --- models/auto_encoder.py | 2 +- models/model.py | 21 +++++++++++++++--- models/rgb_part_net.py | 20 +++-------------- utils/triplet_loss.py | 58 +++++++++++++++++++++++++++++++++++++++++++------- 4 files changed, 72 insertions(+), 29 deletions(-) diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 2d715db..1ef7494 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -171,7 +171,7 @@ class AutoEncoder(nn.Module): return ( (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_), - (xrecon_loss, cano_cons_loss, pose_sim_loss * 10) + torch.stack((xrecon_loss, cano_cons_loss, pose_sim_loss * 10)) ) else: # evaluating return f_c_c1_t2_, f_p_c1_t2_ diff --git a/models/model.py b/models/model.py index 82d6461..5899fc0 100644 --- a/models/model.py +++ b/models/model.py @@ -18,6 +18,7 @@ from utils.configuration import DataloaderConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler +from utils.triplet_loss import JointBatchAllTripletLoss class Model: @@ -67,6 +68,7 @@ class Model: self._dataset_sig: str = 'undefined' self.rgb_pn: Optional[RGBPartNet] = None + self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None @@ -140,7 +142,8 @@ class Model: dataset = self._parse_dataset_config(dataset_config) dataloader = self._parse_dataloader_config(dataset, dataloader_config) # Prepare for model, optimizer and scheduler - model_hp = self.hp.get('model', {}) + model_hp: dict = self.hp.get('model', {}).copy() + triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2)) optim_hp: dict = self.hp.get('optimizer', {}).copy() start_iter = optim_hp.pop('start_iter', 0) ae_optim_hp = optim_hp.pop('auto_encoder', {}) @@ -150,8 +153,12 @@ class Model: sched_hp = self.hp.get('scheduler', {}) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) + self.ba_triplet_loss = JointBatchAllTripletLoss( + self.rgb_pn.hpm_num_parts, triplet_margins + ) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) + self.ba_triplet_loss = self.ba_triplet_loss.to(self.device) self.optimizer = optim.Adam([ {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, @@ -193,10 +200,18 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) + feature, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part - y = y.unsqueeze(1).repeat(1, self.rgb_pn.num_total_parts) - losses, images = self.rgb_pn(x_c1, x_c2, y) + y = y.repeat(self.rgb_pn.num_total_parts, 1) + triplet_loss = self.ba_triplet_loss(feature, y) + losses = torch.cat(( + ae_losses, + torch.stack(( + triplet_loss[:self.rgb_pn.hpm_num_parts].mean(), + triplet_loss[self.rgb_pn.hpm_num_parts:].mean() + )) + )) loss = losses.sum() loss.backward() self.optimizer.step() diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 67acac3..408bca0 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -4,7 +4,6 @@ import torch.nn as nn from models.auto_encoder import AutoEncoder from models.hpm import HorizontalPyramidMatching from models.part_net import PartNet -from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): @@ -25,7 +24,6 @@ class RGBPartNet(nn.Module): tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, embedding_dims: int = 256, - triplet_margins: tuple[float, float] = (0.2, 0.2), image_log_on: bool = False ): super().__init__() @@ -50,17 +48,13 @@ class RGBPartNet(nn.Module): out_channels, embedding_dims) self.fc_mat = nn.Parameter(empty_fc) - (hpm_margin, pn_margin) = triplet_margins - self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) - self.pn_ba_trip = BatchAllTripletLoss(pn_margin) - def fc(self, x): return x @ self.fc_mat - def forward(self, x_c1, x_c2=None, y=None): + def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement # n, t, c, h, w - ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2) + ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2) # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w @@ -77,15 +71,7 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - y = y.T - hpm_ba_trip = self.hpm_ba_trip( - x[:self.hpm_num_parts], y[:self.hpm_num_parts] - ) - pn_ba_trip = self.pn_ba_trip( - x[self.hpm_num_parts:], y[self.hpm_num_parts:] - ) - losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) - return losses, images + return x, ae_losses, images else: return x.unsqueeze(1).view(-1) diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 954def2..0df2188 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -11,6 +11,25 @@ class BatchAllTripletLoss(nn.Module): def forward(self, x, y): p, n, c = x.size() + dist = self._batch_distance(x) + positive_negative_dist = self._hard_distance(dist, y, p, n) + all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) + parted_loss_mean = self._none_zero_parted_mean(all_loss) + + return parted_loss_mean + + @staticmethod + def _hard_distance(dist, y, p, n): + hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) + hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) + all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1) + all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1) + positive_negative_dist = all_hard_positive - all_hard_negative + + return positive_negative_dist + + @staticmethod + def _batch_distance(x): # Euclidean distance p x n x n x_squared_sum = torch.sum(x ** 2, dim=2) x1_squared_sum = x_squared_sum.unsqueeze(2) @@ -20,17 +39,40 @@ class BatchAllTripletLoss(nn.Module): F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) ) - hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) - hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) - all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1) - all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1) - positive_negative_dist = all_hard_positive - all_hard_negative - all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) + return dist + @staticmethod + def _none_zero_parted_mean(all_loss): # Non-zero parted mean non_zero_counts = (all_loss != 0).sum(1) parted_loss_mean = all_loss.sum(1) / non_zero_counts parted_loss_mean[non_zero_counts == 0] = 0 - loss = parted_loss_mean.mean() - return loss + return parted_loss_mean + + +class JointBatchAllTripletLoss(BatchAllTripletLoss): + def __init__( + self, + hpm_num_parts: int, + margins: tuple[float, float] = (0.2, 0.2) + ): + super().__init__() + self.hpm_num_parts = hpm_num_parts + self.margin_hpm, self.margin_pn = margins + + def forward(self, x, y): + p, n, c = x.size() + + dist = self._batch_distance(x) + positive_negative_dist = self._hard_distance(dist, y, p, n) + hpm_part_loss = F.relu( + self.margin_hpm + positive_negative_dist[:self.hpm_num_parts] + ).view(self.hpm_num_parts, -1) + pn_part_loss = F.relu( + self.margin_pn + positive_negative_dist[self.hpm_num_parts:] + ).view(p - self.hpm_num_parts, -1) + all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) + parted_loss_mean = self._none_zero_parted_mean(all_loss) + + return parted_loss_mean -- cgit v1.2.3 From c52fdc2748e272a5195303299a9739291be32281 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 21 Feb 2021 19:00:30 +0800 Subject: Remove FConv blocks --- config.py | 12 ++---------- models/auto_encoder.py | 4 ++-- models/part_net.py | 18 ++++-------------- models/rgb_part_net.py | 37 ++++++++++++++++++------------------- 4 files changed, 26 insertions(+), 45 deletions(-) diff --git a/config.py b/config.py index 424bf5b..88ad371 100644 --- a/config.py +++ b/config.py @@ -49,22 +49,14 @@ config: Configuration = { # Auto-encoder feature channels coefficient 'ae_feature_channels': 64, # Appearance, canonical and pose feature dimensions - 'f_a_c_p_dims': (128, 128, 64), + 'f_a_c_p_dims': (192, 192, 96), # Use 1x1 convolution in dimensionality reduction 'hpm_use_1x1conv': False, # HPM pyramid scales, of which sum is number of parts 'hpm_scales': (1, 2, 4), # Global pooling method 'hpm_use_avg_pool': True, - 'hpm_use_max_pool': False, - # FConv feature channels coefficient - 'fpfe_feature_channels': 32, - # FConv blocks kernel sizes - 'fpfe_kernel_sizes': ((5, 3), (3, 3), (3, 3)), - # FConv blocks paddings - 'fpfe_paddings': ((2, 1), (1, 1), (1, 1)), - # FConv blocks halving - 'fpfe_halving': (0, 2, 3), + 'hpm_use_max_pool': True, # Attention squeeze ratio 'tfa_squeeze_ratio': 4, # Number of parts after Part Net diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 1ef7494..e6a3e60 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -106,14 +106,14 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose, cano_only=False): + def forward(self, f_appearance, f_canonical, f_pose, is_feature_map=False): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0) x = F.relu(x, inplace=True) x = self.trans_conv1(x) x = self.trans_conv2(x) - if cano_only: + if is_feature_map: return x x = self.trans_conv3(x) x = torch.sigmoid(self.trans_conv4(x)) diff --git a/models/part_net.py b/models/part_net.py index 62a2bac..29cf9cd 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -110,32 +110,22 @@ class TemporalFeatureAggregator(nn.Module): class PartNet(nn.Module): def __init__( self, - in_channels: int = 3, - feature_channels: int = 32, - kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), - paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), - halving: tuple[int, ...] = (0, 2, 3), + in_channels: int = 128, squeeze_ratio: int = 4, num_part: int = 16 ): super().__init__() self.num_part = num_part - self.fpfe = FrameLevelPartFeatureExtractor( - in_channels, feature_channels, kernel_sizes, paddings, halving - ) - - num_fconv_blocks = len(self.fpfe.fconv_blocks) - self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) self.tfa = TemporalFeatureAggregator( - self.tfa_in_channels, squeeze_ratio, self.num_part + in_channels, squeeze_ratio, self.num_part ) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) def forward(self, x): - n, t, _, _, _ = x.size() - x = self.fpfe(x) + n, t, c, h, w = x.size() + x = x.view(n * t, c, h, w) # n * t x c x h x w # Horizontal Pooling diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 408bca0..936ec46 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -17,16 +17,13 @@ class RGBPartNet(nn.Module): hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, - fpfe_feature_channels: int = 32, - fpfe_kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), - fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), - fpfe_halving: tuple[int, ...] = (0, 2, 3), tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, embedding_dims: int = 256, image_log_on: bool = False ): super().__init__() + self.h, self.w = ae_in_size (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims self.hpm_num_parts = sum(hpm_scales) self.image_log_on = image_log_on @@ -34,18 +31,17 @@ class RGBPartNet(nn.Module): self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) + self.pn_in_channels = ae_feature_channels * 2 self.pn = PartNet( - ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, - fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_parts + self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts ) - out_channels = self.pn.tfa_in_channels self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 2, out_channels, hpm_use_1x1conv, + ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv, hpm_scales, hpm_use_avg_pool, hpm_use_max_pool ) self.num_total_parts = self.hpm_num_parts + tfa_num_parts empty_fc = torch.empty(self.num_total_parts, - out_channels, embedding_dims) + self.pn_in_channels, embedding_dims) self.fc_mat = nn.Parameter(empty_fc) def fc(self, x): @@ -82,17 +78,20 @@ class RGBPartNet(nn.Module): if self.training: ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) # Decode features - with torch.no_grad(): - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + x_c = self._decode_cano_feature(f_c_, n, t, device) + x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) - i_a, i_c, i_p = None, None, None - if self.image_log_on: + i_a, i_c, i_p = None, None, None + if self.image_log_on: + with torch.no_grad(): i_a = self._decode_appr_feature(f_a_, n, t, device) # Continue decoding canonical features i_c = self.ae.decoder.trans_conv3(x_c) i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c)) - i_p = x_p + i_p_ = self.ae.decoder.trans_conv3(x_p_) + i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_)) + i_p = i_p_.view(n, t, c, h, w) return (x_c, x_p), losses, (i_a, i_c, i_p) @@ -119,7 +118,7 @@ class RGBPartNet(nn.Module): torch.zeros((n, self.f_a_dim), device=device), f_c.mean(1), torch.zeros((n, self.f_p_dim), device=device), - cano_only=True + is_feature_map=True ) return x_c @@ -128,7 +127,7 @@ class RGBPartNet(nn.Module): x_p_ = self.ae.decoder( torch.zeros((n * t, self.f_a_dim), device=device), torch.zeros((n * t, self.f_c_dim), device=device), - f_p_ + f_p_, + is_feature_map=True ) - x_p = x_p_.view(n, t, c, h, w) - return x_p + return x_p_ -- cgit v1.2.3 From 390bac976ff52fe0c3cf6bea820c22084613ee94 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 26 Feb 2021 20:09:22 +0800 Subject: Fix predict function --- models/model.py | 3 ++- models/rgb_part_net.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/models/model.py b/models/model.py index 5899fc0..90d48e0 100644 --- a/models/model.py +++ b/models/model.py @@ -314,7 +314,8 @@ class Model: ) # Init models - model_hp = self.hp.get('model', {}) + model_hp: dict = self.hp.get('model', {}).copy() + model_hp.pop('triplet_margins', None) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 936ec46..4367c62 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -74,8 +74,8 @@ class RGBPartNet(nn.Module): def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() device = x_c1_t2.device - x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] if self.training: + x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) # Decode features x_c = self._decode_cano_feature(f_c_, n, t, device) @@ -98,7 +98,8 @@ class RGBPartNet(nn.Module): else: # evaluating f_c_, f_p_ = self.ae(x_c1_t2) x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) return (x_c, x_p), None, None def _decode_appr_feature(self, f_a_, n, t, device): -- cgit v1.2.3 From 9001f7e13d8985b220bd218d8de716bc586dbdcf Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 26 Feb 2021 20:17:03 +0800 Subject: Update default config --- config.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/config.py b/config.py index 88ad371..03f2f0d 100644 --- a/config.py +++ b/config.py @@ -37,7 +37,7 @@ config: Configuration = { # Batch size (pr, k) # `pr` denotes number of persons # `k` denotes number of sequences per person - 'batch_size': (4, 8), + 'batch_size': (4, 6), # Number of workers of Dataloader 'num_workers': 4, # Faster data transfer from RAM to GPU if enabled @@ -64,7 +64,7 @@ config: Configuration = { # Embedding dimension for each part 'embedding_dims': 256, # Triplet loss margins for HPM and PartNet - 'triplet_margins': (0.2, 0.2), + 'triplet_margins': (1.5, 1.5), }, 'optimizer': { # Global parameters @@ -83,15 +83,15 @@ config: Configuration = { # 'amsgrad': False, # Local parameters (override global ones) - 'auto_encoder': { - 'weight_decay': 0.001 - }, + # 'auto_encoder': { + # 'weight_decay': 0.001 + # }, }, 'scheduler': { # Period of learning rate decay 'step_size': 500, # Multiplicative factor of decay - 'gamma': 0.9, + 'gamma': 1, } }, # Model metadata -- cgit v1.2.3 From 46391257ff50848efa1aa251ab3f15dc8b7a2d2c Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 27 Feb 2021 22:14:21 +0800 Subject: Implement Batch Hard triplet loss and soft margin --- config.py | 6 ++-- models/model.py | 57 +++++++++++++++++++++++++--------- models/rgb_part_net.py | 6 ++-- utils/configuration.py | 1 + utils/triplet_loss.py | 84 +++++++++++++++++++++++++++++++++----------------- 5 files changed, 106 insertions(+), 48 deletions(-) diff --git a/config.py b/config.py index 03f2f0d..f76cea5 100644 --- a/config.py +++ b/config.py @@ -63,8 +63,10 @@ config: Configuration = { 'tfa_num_parts': 16, # Embedding dimension for each part 'embedding_dims': 256, - # Triplet loss margins for HPM and PartNet - 'triplet_margins': (1.5, 1.5), + # Batch Hard or Batch All + 'triplet_is_hard': True, + # Triplet loss margins for HPM and PartNet, None for soft margin + 'triplet_margins': None, }, 'optimizer': { # Global parameters diff --git a/models/model.py b/models/model.py index 90d48e0..79952cb 100644 --- a/models/model.py +++ b/models/model.py @@ -18,7 +18,7 @@ from utils.configuration import DataloaderConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler -from utils.triplet_loss import JointBatchAllTripletLoss +from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss class Model: @@ -68,7 +68,7 @@ class Model: self._dataset_sig: str = 'undefined' self.rgb_pn: Optional[RGBPartNet] = None - self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None + self.triplet_loss: Optional[JointBatchTripletLoss] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None @@ -143,7 +143,8 @@ class Model: dataloader = self._parse_dataloader_config(dataset, dataloader_config) # Prepare for model, optimizer and scheduler model_hp: dict = self.hp.get('model', {}).copy() - triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2)) + triplet_is_hard = model_hp.pop('triplet_is_hard', True) + triplet_margins = model_hp.pop('triplet_margins', None) optim_hp: dict = self.hp.get('optimizer', {}).copy() start_iter = optim_hp.pop('start_iter', 0) ae_optim_hp = optim_hp.pop('auto_encoder', {}) @@ -153,12 +154,23 @@ class Model: sched_hp = self.hp.get('scheduler', {}) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) - self.ba_triplet_loss = JointBatchAllTripletLoss( - self.rgb_pn.hpm_num_parts, triplet_margins - ) + # Hard margins + if triplet_margins: + # Same margins + if triplet_margins[0] == triplet_margins[1]: + self.triplet_loss = BatchTripletLoss( + triplet_is_hard, triplet_margins[0] + ) + else: # Different margins + self.triplet_loss = JointBatchTripletLoss( + self.rgb_pn.hpm_num_parts, triplet_is_hard, triplet_margins + ) + else: # Soft margins + self.triplet_loss = BatchTripletLoss(triplet_is_hard, None) + # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) - self.ba_triplet_loss = self.ba_triplet_loss.to(self.device) + self.triplet_loss = self.triplet_loss.to(self.device) self.optimizer = optim.Adam([ {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, @@ -200,16 +212,16 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - feature, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part y = y.repeat(self.rgb_pn.num_total_parts, 1) - triplet_loss = self.ba_triplet_loss(feature, y) + trip_loss, dist, non_zero_counts = self.triplet_loss(embedding, y) losses = torch.cat(( ae_losses, torch.stack(( - triplet_loss[:self.rgb_pn.hpm_num_parts].mean(), - triplet_loss[self.rgb_pn.hpm_num_parts:].mean() + trip_loss[:self.rgb_pn.hpm_num_parts].mean(), + trip_loss[self.rgb_pn.hpm_num_parts:].mean() )) )) loss = losses.sum() @@ -220,11 +232,26 @@ class Model: running_loss += losses.detach() # Write losses to TensorBoard self.writer.add_scalar('Loss/all', loss, self.curr_iter) - self.writer.add_scalars('Loss/details', dict(zip([ + self.writer.add_scalars('Loss/disentanglement', dict(zip(( 'Cross reconstruction loss', 'Canonical consistency loss', - 'Pose similarity loss', 'Batch All triplet loss (HPM)', - 'Batch All triplet loss (PartNet)' - ], losses)), self.curr_iter) + 'Pose similarity loss' + ), ae_losses)), self.curr_iter) + self.writer.add_scalars('Loss/triplet loss', { + 'HPM': losses[3], + 'PartNet': losses[4] + }, self.curr_iter) + self.writer.add_scalars('Loss/non-zero counts', { + 'HPM': non_zero_counts[:self.rgb_pn.hpm_num_parts].mean(), + 'PartNet': non_zero_counts[self.rgb_pn.hpm_num_parts:].mean() + }, self.curr_iter) + self.writer.add_scalars('Embedding/distance', { + 'HPM': dist[:self.rgb_pn.hpm_num_parts].mean(), + 'PartNet': dist[self.rgb_pn.hpm_num_parts].mean() + }, self.curr_iter) + self.writer.add_scalars('Embedding/2-norm', { + 'HPM': embedding[:self.rgb_pn.hpm_num_parts].norm(), + 'PartNet': embedding[self.rgb_pn.hpm_num_parts].norm() + }, self.curr_iter) if self.curr_iter % 100 == 0: lrs = self.scheduler.get_last_lr() diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 4367c62..8a0f3a7 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -79,7 +79,7 @@ class RGBPartNet(nn.Module): ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) # Decode features x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + x_p_ = self._decode_pose_feature(f_p_, n, t, device) x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) i_a, i_c, i_p = None, None, None @@ -98,7 +98,7 @@ class RGBPartNet(nn.Module): else: # evaluating f_c_, f_p_ = self.ae(x_c1_t2) x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + x_p_ = self._decode_pose_feature(f_p_, n, t, device) x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) return (x_c, x_p), None, None @@ -123,7 +123,7 @@ class RGBPartNet(nn.Module): ) return x_c - def _decode_pose_feature(self, f_p_, n, t, c, h, w, device): + def _decode_pose_feature(self, f_p_, n, t, device): # Decode pose features to images x_p_ = self.ae.decoder( torch.zeros((n * t, self.f_a_dim), device=device), diff --git a/utils/configuration.py b/utils/configuration.py index 435d815..20aec76 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -43,6 +43,7 @@ class ModelHPConfiguration(TypedDict): tfa_squeeze_ratio: int tfa_num_parts: int embedding_dims: int + triplet_is_hard: bool triplet_margins: tuple[float, float] diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 0df2188..c3e5802 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -1,32 +1,36 @@ +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -class BatchAllTripletLoss(nn.Module): - def __init__(self, margin: float = 0.2): +class BatchTripletLoss(nn.Module): + def __init__( + self, + is_hard: bool = True, + margin: Optional[float] = 0.2, + ): super().__init__() + self.is_hard = is_hard self.margin = margin def forward(self, x, y): p, n, c = x.size() - dist = self._batch_distance(x) - positive_negative_dist = self._hard_distance(dist, y, p, n) - all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - parted_loss_mean = self._none_zero_parted_mean(all_loss) - return parted_loss_mean + if self.is_hard: + positive_negative_dist = self._hard_distance(dist, y, p, n) + else: # is_all + positive_negative_dist = self._all_distance(dist, y, p, n) - @staticmethod - def _hard_distance(dist, y, p, n): - hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) - hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) - all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1) - all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1) - positive_negative_dist = all_hard_positive - all_hard_negative + if self.margin: + all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) + else: + all_loss = F.softplus(positive_negative_dist).view(p, -1) + non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - return positive_negative_dist + return non_zero_mean, dist.mean((1, 2)), non_zero_counts @staticmethod def _batch_distance(x): @@ -38,41 +42,65 @@ class BatchAllTripletLoss(nn.Module): dist = torch.sqrt( F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) ) - return dist + @staticmethod + def _hard_distance(dist, y, p, n): + positive_mask = y.unsqueeze(1) == y.unsqueeze(2) + negative_mask = y.unsqueeze(1) != y.unsqueeze(2) + hard_positive = dist[positive_mask].view(p, n, -1).max(-1).values + hard_negative = dist[negative_mask].view(p, n, -1).min(-1).values + positive_negative_dist = hard_positive - hard_negative + + return positive_negative_dist + + @staticmethod + def _all_distance(dist, y, p, n): + positive_mask = y.unsqueeze(1) == y.unsqueeze(2) + negative_mask = y.unsqueeze(1) != y.unsqueeze(2) + all_positive = dist[positive_mask].view(p, n, -1, 1) + all_negative = dist[negative_mask].view(p, n, 1, -1) + positive_negative_dist = all_positive - all_negative + + return positive_negative_dist + @staticmethod def _none_zero_parted_mean(all_loss): # Non-zero parted mean - non_zero_counts = (all_loss != 0).sum(1) - parted_loss_mean = all_loss.sum(1) / non_zero_counts - parted_loss_mean[non_zero_counts == 0] = 0 + non_zero_counts = (all_loss != 0).sum(1).float() + non_zero_mean = all_loss.sum(1) / non_zero_counts + non_zero_mean[non_zero_counts == 0] = 0 - return parted_loss_mean + return non_zero_mean, non_zero_counts -class JointBatchAllTripletLoss(BatchAllTripletLoss): +class JointBatchTripletLoss(BatchTripletLoss): def __init__( self, hpm_num_parts: int, + is_hard: bool = True, margins: tuple[float, float] = (0.2, 0.2) ): - super().__init__() + super().__init__(is_hard) self.hpm_num_parts = hpm_num_parts self.margin_hpm, self.margin_pn = margins def forward(self, x, y): p, n, c = x.size() - dist = self._batch_distance(x) - positive_negative_dist = self._hard_distance(dist, y, p, n) + + if self.is_hard: + positive_negative_dist = self._hard_distance(dist, y, p, n) + else: # is_all + positive_negative_dist = self._all_distance(dist, y, p, n) + hpm_part_loss = F.relu( self.margin_hpm + positive_negative_dist[:self.hpm_num_parts] - ).view(self.hpm_num_parts, -1) + ) pn_part_loss = F.relu( self.margin_pn + positive_negative_dist[self.hpm_num_parts:] - ).view(p - self.hpm_num_parts, -1) + ) all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - parted_loss_mean = self._none_zero_parted_mean(all_loss) + non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - return parted_loss_mean + return non_zero_mean, dist.mean((1, 2)), non_zero_counts -- cgit v1.2.3 From bbd89fda7f6bf30c9ce4a3b576c0087858b407b3 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 28 Feb 2021 13:31:09 +0800 Subject: Bump up version for tqdm --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4d30e17..926a587 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch~=1.7.1 torchvision~=0.8.0a0+ecf4e9c numpy~=1.19.4 -tqdm~=4.57.0 +tqdm~=4.58.0 Pillow~=8.1.0 scikit-learn~=0.24.0 \ No newline at end of file -- cgit v1.2.3 From c96a6c88fa63d62ec62807abf957c9a8df307b43 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 28 Feb 2021 22:13:43 +0800 Subject: Modify default parameters 1. Change ReLU to Leaky ReLU in decoder 2. Add 8-scale-pyramid in HPM --- config.py | 2 +- models/layers.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/config.py b/config.py index f76cea5..9072982 100644 --- a/config.py +++ b/config.py @@ -53,7 +53,7 @@ config: Configuration = { # Use 1x1 convolution in dimensionality reduction 'hpm_use_1x1conv': False, # HPM pyramid scales, of which sum is number of parts - 'hpm_scales': (1, 2, 4), + 'hpm_scales': (1, 2, 4, 8), # Global pooling method 'hpm_use_avg_pool': True, 'hpm_use_max_pool': True, diff --git a/models/layers.py b/models/layers.py index ef53a95..f1d72b6 100644 --- a/models/layers.py +++ b/models/layers.py @@ -80,7 +80,9 @@ class DCGANConvTranspose2d(BasicConvTranspose2d): if self.is_last_layer: return self.trans_conv(x) else: - return super().forward(x) + x = self.trans_conv(x) + x = self.bn(x) + return F.leaky_relu(x, 0.2, inplace=True) class BasicLinear(nn.Module): -- cgit v1.2.3 From b837336695213e3e660992fcd01c5a52c654ea4f Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 28 Feb 2021 22:14:27 +0800 Subject: Log n-ile embedding distance and norm --- models/model.py | 66 +++++++++++++++++++++++++++++++++++++++++---------- utils/triplet_loss.py | 13 ++++++---- 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/models/model.py b/models/model.py index 79952cb..18896ae 100644 --- a/models/model.py +++ b/models/model.py @@ -59,6 +59,8 @@ class Model: self.in_size: tuple[int, int] = (64, 48) self.pr: Optional[int] = None self.k: Optional[int] = None + self.num_pairs: Optional[int] = None + self.num_pos_pairs: Optional[int] = None self._gallery_dataset_meta: Optional[dict[str, list]] = None self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None @@ -216,7 +218,7 @@ class Model: y = batch_c1['label'].to(self.device) # Duplicate labels for each part y = y.repeat(self.rgb_pn.num_total_parts, 1) - trip_loss, dist, non_zero_counts = self.triplet_loss(embedding, y) + trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) losses = torch.cat(( ae_losses, torch.stack(( @@ -240,18 +242,36 @@ class Model: 'HPM': losses[3], 'PartNet': losses[4] }, self.curr_iter) - self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': non_zero_counts[:self.rgb_pn.hpm_num_parts].mean(), - 'PartNet': non_zero_counts[self.rgb_pn.hpm_num_parts:].mean() - }, self.curr_iter) - self.writer.add_scalars('Embedding/distance', { - 'HPM': dist[:self.rgb_pn.hpm_num_parts].mean(), - 'PartNet': dist[self.rgb_pn.hpm_num_parts].mean() - }, self.curr_iter) - self.writer.add_scalars('Embedding/2-norm', { - 'HPM': embedding[:self.rgb_pn.hpm_num_parts].norm(), - 'PartNet': embedding[self.rgb_pn.hpm_num_parts].norm() - }, self.curr_iter) + # None-zero losses in batch + if num_non_zero: + self.writer.add_scalars('Loss/non-zero counts', { + 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), + 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() + }, self.curr_iter) + # Embedding distance + mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) + self._add_ranked_scalars( + 'Embedding/HPM distance', mean_hpm_dist, + self.num_pos_pairs, self.num_pairs, self.curr_iter + ) + mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) + self._add_ranked_scalars( + 'Embedding/ParNet distance', mean_pa_dist, + self.num_pos_pairs, self.num_pairs, self.curr_iter + ) + # Embedding norm + mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) + self._add_ranked_scalars( + 'Embedding/HPM norm', mean_hpm_norm, + self.k, self.pr * self.k, self.curr_iter + ) + mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_norm = mean_pa_embedding.norm(dim=-1) + self._add_ranked_scalars( + 'Embedding/PartNet norm', mean_pa_norm, + self.k, self.pr * self.k, self.curr_iter + ) if self.curr_iter % 100 == 0: lrs = self.scheduler.get_last_lr() @@ -303,6 +323,24 @@ class Model: self.writer.close() break + def _add_ranked_scalars( + self, + main_tag: str, + metric: torch.Tensor, + num_pos: int, + num_all: int, + global_step: int + ): + rank = metric.argsort() + pos_ile = 100 - (num_pos - 1) * 100 // num_all + self.writer.add_scalars(main_tag, { + '0%-ile': metric[rank[-1]], + f'{100 - pos_ile}%-ile': metric[rank[-num_pos]], + '50%-ile': metric[rank[num_all // 2 - 1]], + f'{pos_ile}%-ile': metric[rank[num_pos - 1]], + '100%-ile': metric[rank[0]] + }, global_step) + def predict_all( self, iters: tuple[int], @@ -524,6 +562,8 @@ class Model: ) -> DataLoader: config: dict = dataloader_config.copy() (self.pr, self.k) = config.pop('batch_size', (8, 16)) + self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 + self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr if self.is_train: triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index c3e5802..52d676e 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -18,6 +18,8 @@ class BatchTripletLoss(nn.Module): def forward(self, x, y): p, n, c = x.size() dist = self._batch_distance(x) + flat_dist = dist.tril(-1) + flat_dist = flat_dist[flat_dist != 0].view(p, -1) if self.is_hard: positive_negative_dist = self._hard_distance(dist, y, p, n) @@ -26,11 +28,12 @@ class BatchTripletLoss(nn.Module): if self.margin: all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - else: + loss_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) + return loss_mean, flat_dist, non_zero_counts + else: # Soft margin all_loss = F.softplus(positive_negative_dist).view(p, -1) - non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - - return non_zero_mean, dist.mean((1, 2)), non_zero_counts + loss_mean = all_loss.mean(1) + return loss_mean, flat_dist, None @staticmethod def _batch_distance(x): @@ -103,4 +106,4 @@ class JointBatchTripletLoss(BatchTripletLoss): all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - return non_zero_mean, dist.mean((1, 2)), non_zero_counts + return non_zero_mean, dist, non_zero_counts -- cgit v1.2.3 From fed5e6a9b35fda8306147e9ce772dfbf3142a061 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 28 Feb 2021 23:11:05 +0800 Subject: Implement sum of loss default in [1] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [1]A. Hermans, L. Beyer, and B. Leibe, “In defense of the triplet loss for person re-identification,” arXiv preprint arXiv:1703.07737, 2017. --- config.py | 2 ++ models/model.py | 10 +++++++--- utils/configuration.py | 1 + utils/triplet_loss.py | 43 ++++++++++++++++++++++++++++--------------- 4 files changed, 38 insertions(+), 18 deletions(-) diff --git a/config.py b/config.py index 9072982..4c108e2 100644 --- a/config.py +++ b/config.py @@ -65,6 +65,8 @@ config: Configuration = { 'embedding_dims': 256, # Batch Hard or Batch All 'triplet_is_hard': True, + # Use non-zero mean or sum + 'triplet_is_mean': True, # Triplet loss margins for HPM and PartNet, None for soft margin 'triplet_margins': None, }, diff --git a/models/model.py b/models/model.py index 18896ae..34cb816 100644 --- a/models/model.py +++ b/models/model.py @@ -146,6 +146,7 @@ class Model: # Prepare for model, optimizer and scheduler model_hp: dict = self.hp.get('model', {}).copy() triplet_is_hard = model_hp.pop('triplet_is_hard', True) + triplet_is_mean = model_hp.pop('triplet_is_mean', True) triplet_margins = model_hp.pop('triplet_margins', None) optim_hp: dict = self.hp.get('optimizer', {}).copy() start_iter = optim_hp.pop('start_iter', 0) @@ -165,10 +166,13 @@ class Model: ) else: # Different margins self.triplet_loss = JointBatchTripletLoss( - self.rgb_pn.hpm_num_parts, triplet_is_hard, triplet_margins + self.rgb_pn.hpm_num_parts, + triplet_is_hard, triplet_is_mean, triplet_margins ) else: # Soft margins - self.triplet_loss = BatchTripletLoss(triplet_is_hard, None) + self.triplet_loss = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, None + ) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) @@ -243,7 +247,7 @@ class Model: 'PartNet': losses[4] }, self.curr_iter) # None-zero losses in batch - if num_non_zero: + if num_non_zero is not None: self.writer.add_scalars('Loss/non-zero counts', { 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() diff --git a/utils/configuration.py b/utils/configuration.py index 20aec76..31eb243 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -44,6 +44,7 @@ class ModelHPConfiguration(TypedDict): tfa_num_parts: int embedding_dims: int triplet_is_hard: bool + triplet_is_mean: bool triplet_margins: tuple[float, float] diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 52d676e..db0cf0f 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -9,10 +9,12 @@ class BatchTripletLoss(nn.Module): def __init__( self, is_hard: bool = True, + is_mean: bool = True, margin: Optional[float] = 0.2, ): super().__init__() self.is_hard = is_hard + self.is_mean = is_mean self.margin = margin def forward(self, x, y): @@ -27,13 +29,20 @@ class BatchTripletLoss(nn.Module): positive_negative_dist = self._all_distance(dist, y, p, n) if self.margin: - all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - loss_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - return loss_mean, flat_dist, non_zero_counts + losses = F.relu(self.margin + positive_negative_dist).view(p, -1) + non_zero_counts = (losses != 0).sum(1).float() + if self.is_mean: + loss_metric = self._none_zero_mean(losses, non_zero_counts) + else: # is_sum + loss_metric = losses.sum(1) + return loss_metric, flat_dist, non_zero_counts else: # Soft margin - all_loss = F.softplus(positive_negative_dist).view(p, -1) - loss_mean = all_loss.mean(1) - return loss_mean, flat_dist, None + losses = F.softplus(positive_negative_dist).view(p, -1) + if self.is_mean: + loss_metric = losses.mean(1) + else: # is_sum + loss_metric = losses.sum(1) + return loss_metric, flat_dist, None @staticmethod def _batch_distance(x): @@ -68,13 +77,11 @@ class BatchTripletLoss(nn.Module): return positive_negative_dist @staticmethod - def _none_zero_parted_mean(all_loss): + def _none_zero_mean(losses, non_zero_counts): # Non-zero parted mean - non_zero_counts = (all_loss != 0).sum(1).float() - non_zero_mean = all_loss.sum(1) / non_zero_counts + non_zero_mean = losses.sum(1) / non_zero_counts non_zero_mean[non_zero_counts == 0] = 0 - - return non_zero_mean, non_zero_counts + return non_zero_mean class JointBatchTripletLoss(BatchTripletLoss): @@ -82,9 +89,10 @@ class JointBatchTripletLoss(BatchTripletLoss): self, hpm_num_parts: int, is_hard: bool = True, + is_mean: bool = True, margins: tuple[float, float] = (0.2, 0.2) ): - super().__init__(is_hard) + super().__init__(is_hard, is_mean) self.hpm_num_parts = hpm_num_parts self.margin_hpm, self.margin_pn = margins @@ -103,7 +111,12 @@ class JointBatchTripletLoss(BatchTripletLoss): pn_part_loss = F.relu( self.margin_pn + positive_negative_dist[self.hpm_num_parts:] ) - all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) + losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) + + non_zero_counts = (losses != 0).sum(1).float() + if self.is_mean: + loss_metric = self._none_zero_mean(losses, non_zero_counts) + else: # is_sum + loss_metric = losses.sum(1) - return non_zero_mean, dist, non_zero_counts + return loss_metric, dist, non_zero_counts -- cgit v1.2.3 From 4bdc37bbd86a83647bbbda7bd1367c08e6c6f6d4 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 1 Mar 2021 11:20:34 +0800 Subject: Remove identical sample in Batch All case --- utils/triplet_loss.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index db0cf0f..6822cf6 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -68,7 +68,10 @@ class BatchTripletLoss(nn.Module): @staticmethod def _all_distance(dist, y, p, n): - positive_mask = y.unsqueeze(1) == y.unsqueeze(2) + # Unmask identical samples + positive_mask = torch.eye( + n, dtype=torch.bool, device=y.device + ) ^ (y.unsqueeze(1) == y.unsqueeze(2)) negative_mask = y.unsqueeze(1) != y.unsqueeze(2) all_positive = dist[positive_mask].view(p, n, -1, 1) all_negative = dist[negative_mask].view(p, n, 1, -1) -- cgit v1.2.3 From db0564967d8cfc03b2d3fe4f7d10eff0867e1771 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 1 Mar 2021 11:22:16 +0800 Subject: Move pairs variable to local --- models/model.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/models/model.py b/models/model.py index 34cb816..b942eb8 100644 --- a/models/model.py +++ b/models/model.py @@ -59,8 +59,6 @@ class Model: self.in_size: tuple[int, int] = (64, 48) self.pr: Optional[int] = None self.k: Optional[int] = None - self.num_pairs: Optional[int] = None - self.num_pos_pairs: Optional[int] = None self._gallery_dataset_meta: Optional[dict[str, list]] = None self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None @@ -174,6 +172,9 @@ class Model: triplet_is_hard, triplet_is_mean, None ) + num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 + num_pos_pairs = (self.k*(self.k-1)//2) * self.pr + # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.triplet_loss = self.triplet_loss.to(self.device) @@ -256,12 +257,12 @@ class Model: mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) self._add_ranked_scalars( 'Embedding/HPM distance', mean_hpm_dist, - self.num_pos_pairs, self.num_pairs, self.curr_iter + num_pos_pairs, num_pairs, self.curr_iter ) mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) self._add_ranked_scalars( 'Embedding/ParNet distance', mean_pa_dist, - self.num_pos_pairs, self.num_pairs, self.curr_iter + num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) @@ -566,8 +567,6 @@ class Model: ) -> DataLoader: config: dict = dataloader_config.copy() (self.pr, self.k) = config.pop('batch_size', (8, 16)) - self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 - self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr if self.is_train: triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, -- cgit v1.2.3 From 7318a09451852e3f7d5f68180964f03bd0b0f616 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 1 Mar 2021 14:04:02 +0800 Subject: Change flat distance calculation method --- utils/triplet_loss.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 6822cf6..e05b69d 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -20,8 +20,8 @@ class BatchTripletLoss(nn.Module): def forward(self, x, y): p, n, c = x.size() dist = self._batch_distance(x) - flat_dist = dist.tril(-1) - flat_dist = flat_dist[flat_dist != 0].view(p, -1) + flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) + flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] if self.is_hard: positive_negative_dist = self._hard_distance(dist, y, p, n) @@ -102,6 +102,8 @@ class JointBatchTripletLoss(BatchTripletLoss): def forward(self, x, y): p, n, c = x.size() dist = self._batch_distance(x) + flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) + flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] if self.is_hard: positive_negative_dist = self._hard_distance(dist, y, p, n) @@ -122,4 +124,4 @@ class JointBatchTripletLoss(BatchTripletLoss): else: # is_sum loss_metric = losses.sum(1) - return loss_metric, dist, non_zero_counts + return loss_metric, flat_dist, non_zero_counts -- cgit v1.2.3 From 6002b2d2017912f90e8917e6e8b71b78ce58e7c2 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 1 Mar 2021 18:20:38 +0800 Subject: New scheduler and new config --- config.py | 18 ++++++++---------- models/model.py | 27 ++++++++++++++------------- utils/configuration.py | 5 ++--- 3 files changed, 24 insertions(+), 26 deletions(-) diff --git a/config.py b/config.py index 4c108e2..e70c2bd 100644 --- a/config.py +++ b/config.py @@ -72,8 +72,6 @@ config: Configuration = { }, 'optimizer': { # Global parameters - # Iteration start to optimize non-disentangling parts - # 'start_iter': 0, # Initial learning rate of Adam Optimizer 'lr': 1e-4, # Coefficients used for computing running averages of @@ -87,15 +85,15 @@ config: Configuration = { # 'amsgrad': False, # Local parameters (override global ones) - # 'auto_encoder': { - # 'weight_decay': 0.001 - # }, + 'auto_encoder': { + 'weight_decay': 0.001 + }, }, 'scheduler': { - # Period of learning rate decay - 'step_size': 500, - # Multiplicative factor of decay - 'gamma': 1, + # Step start to decay + 'start_step': 15_000, + # Multiplicative factor of decay in the end + 'final_gamma': 0.001, } }, # Model metadata @@ -109,6 +107,6 @@ config: Configuration = { # Restoration iteration (multiple models, e.g. nm, bg and cl) 'restore_iters': (0, 0, 0), # Total iteration for training (multiple models) - 'total_iters': (80_000, 80_000, 80_000), + 'total_iters': (25_000, 25_000, 25_000), }, } diff --git a/models/model.py b/models/model.py index b942eb8..497a0ea 100644 --- a/models/model.py +++ b/models/model.py @@ -147,7 +147,6 @@ class Model: triplet_is_mean = model_hp.pop('triplet_is_mean', True) triplet_margins = model_hp.pop('triplet_margins', None) optim_hp: dict = self.hp.get('optimizer', {}).copy() - start_iter = optim_hp.pop('start_iter', 0) ae_optim_hp = optim_hp.pop('auto_encoder', {}) pn_optim_hp = optim_hp.pop('part_net', {}) hpm_optim_hp = optim_hp.pop('hpm', {}) @@ -184,14 +183,17 @@ class Model: {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, {'params': self.rgb_pn.fc_mat, **fc_optim_hp} ], **optim_hp) - sched_gamma = sched_hp.get('gamma', 0.9) - sched_step_size = sched_hp.get('step_size', 500) + sched_final_gamma = sched_hp.get('final_gamma', 0.001) + sched_start_step = sched_hp.get('start_step', 15_000) + + def lr_lambda(epoch): + passed_step = epoch - sched_start_step + all_step = self.total_iter - sched_start_step + return sched_final_gamma ** (passed_step / all_step) self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ - lambda epoch: sched_gamma ** (epoch // sched_step_size), - lambda epoch: 0 if epoch < start_iter else 1, - lambda epoch: 0 if epoch < start_iter else 1, - lambda epoch: 0 if epoch < start_iter else 1, + lr_lambda, lr_lambda, lr_lambda, lr_lambda ]) + self.writer = SummaryWriter(self._log_name) self.rgb_pn.train() @@ -211,7 +213,7 @@ class Model: running_loss = torch.zeros(5, device=self.device) print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}") + f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}") for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -282,10 +284,7 @@ class Model: lrs = self.scheduler.get_last_lr() # Write learning rates self.writer.add_scalar( - 'Learning rate/Auto-encoder', lrs[0], self.curr_iter - ) - self.writer.add_scalar( - 'Learning rate/Others', lrs[1], self.curr_iter + 'Learning rate', lrs[0], self.curr_iter ) # Write disentangled images if self.image_log_on: @@ -309,7 +308,7 @@ class Model: print(f'{hour:02}:{minute:02}:{second:02}', f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), - '{:.3e} {:.3e}'.format(lrs[0], lrs[1])) + f'{lrs[0]:.3e}') running_loss.zero_() # Step scheduler @@ -385,6 +384,8 @@ class Model: # Init models model_hp: dict = self.hp.get('model', {}).copy() + model_hp.pop('triplet_is_hard', True) + model_hp.pop('triplet_is_mean', True) model_hp.pop('triplet_margins', None) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp) # Try to accelerate computation using CUDA or others diff --git a/utils/configuration.py b/utils/configuration.py index 31eb243..0f8d9ff 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -57,7 +57,6 @@ class SubOptimizerHPConfiguration(TypedDict): class OptimizerHPConfiguration(TypedDict): - start_iter: int lr: int betas: tuple[float, float] eps: float @@ -70,8 +69,8 @@ class OptimizerHPConfiguration(TypedDict): class SchedulerHPConfiguration(TypedDict): - step_size: int - gamma: float + start_step: int + final_gamma: float class HyperparameterConfiguration(TypedDict): -- cgit v1.2.3 From 543a35d814754c86ccafa1243bece387c1a780d6 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Tue, 2 Mar 2021 19:22:25 +0800 Subject: Fix bugs in new scheduler --- models/model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/models/model.py b/models/model.py index 497a0ea..3242141 100644 --- a/models/model.py +++ b/models/model.py @@ -185,11 +185,14 @@ class Model: ], **optim_hp) sched_final_gamma = sched_hp.get('final_gamma', 0.001) sched_start_step = sched_hp.get('start_step', 15_000) + all_step = self.total_iter - sched_start_step def lr_lambda(epoch): - passed_step = epoch - sched_start_step - all_step = self.total_iter - sched_start_step - return sched_final_gamma ** (passed_step / all_step) + if epoch > sched_start_step: + passed_step = epoch - sched_start_step + return sched_final_gamma ** (passed_step / all_step) + else: + return 1 self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ lr_lambda, lr_lambda, lr_lambda, lr_lambda ]) -- cgit v1.2.3 From 02780c31385af7e1103448bd1994012ac95dd2bb Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Tue, 2 Mar 2021 19:26:11 +0800 Subject: Record learning rate every step --- models/model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/models/model.py b/models/model.py index 3242141..acccbff 100644 --- a/models/model.py +++ b/models/model.py @@ -282,13 +282,14 @@ class Model: 'Embedding/PartNet norm', mean_pa_norm, self.k, self.pr * self.k, self.curr_iter ) + # Learning rate + lrs = self.scheduler.get_last_lr() + # Write learning rates + self.writer.add_scalar( + 'Learning rate', lrs[0], self.curr_iter + ) if self.curr_iter % 100 == 0: - lrs = self.scheduler.get_last_lr() - # Write learning rates - self.writer.add_scalar( - 'Learning rate', lrs[0], self.curr_iter - ) # Write disentangled images if self.image_log_on: i_a, i_c, i_p = images -- cgit v1.2.3 From 0527c5b657c7b4fdfd7d57bf9bc5334eac480731 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 3 Mar 2021 10:04:38 +0800 Subject: Add L2 penalty to global --- config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/config.py b/config.py index e70c2bd..d6de788 100644 --- a/config.py +++ b/config.py @@ -80,14 +80,14 @@ config: Configuration = { # Term added to the denominator # 'eps': 1e-8, # Weight decay (L2 penalty) - # 'weight_decay': 0, + 'weight_decay': 0.001, # Use AMSGrad or not # 'amsgrad': False, # Local parameters (override global ones) - 'auto_encoder': { - 'weight_decay': 0.001 - }, + # 'auto_encoder': { + # 'weight_decay': 0.001 + # }, }, 'scheduler': { # Step start to decay -- cgit v1.2.3 From 96c1f9fa2fc747ff3c54c7f06e65706ce27fccfd Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 4 Mar 2021 13:20:52 +0800 Subject: Set seed for reproducibility --- models/model.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/models/model.py b/models/model.py index acccbff..83b970a 100644 --- a/models/model.py +++ b/models/model.py @@ -1,4 +1,5 @@ import os +import random from datetime import datetime from typing import Union, Optional @@ -199,14 +200,17 @@ class Model: self.writer = SummaryWriter(self._log_name) + # Set seeds for reproducibility + random.seed(0) + torch.manual_seed(0) self.rgb_pn.train() # Init weights at first iter if self.curr_iter == 0: self.rgb_pn.apply(self.init_weights) else: # Load saved state dicts checkpoint = torch.load(self._checkpoint_name) - iter_, loss = checkpoint['iter'], checkpoint['loss'] - print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) + random.setstate(checkpoint['rand_states'][0]) + torch.set_rng_state(checkpoint['rand_states'][1]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) self.scheduler.load_state_dict(checkpoint['sched_state_dict']) @@ -320,11 +324,10 @@ class Model: if self.curr_iter % 1000 == 0: torch.save({ - 'iter': self.curr_iter, + 'rand_states': (random.getstate(), torch.get_rng_state()), 'model_state_dict': self.rgb_pn.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), 'sched_state_dict': self.scheduler.state_dict(), - 'loss': loss, }, self._checkpoint_name) if self.curr_iter == self.total_iter: -- cgit v1.2.3 From 8578a141969720ec93b9bc172c8f20d0ef66ed16 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 4 Mar 2021 13:29:07 +0800 Subject: Replace detach with no_grad in evaluation --- models/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/model.py b/models/model.py index 83b970a..7aff6c4 100644 --- a/models/model.py +++ b/models/model.py @@ -423,7 +423,8 @@ class Model: def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): label = sample.pop('label').item() clip = sample.pop('clip').to(self.device) - feature = self.rgb_pn(clip).detach() + with torch.no_grad(): + feature = self.rgb_pn(clip) return { **{'label': label}, **sample, -- cgit v1.2.3 From 0c174f09ce3c2c70834e5a973347c7dc22012bf3 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 5 Mar 2021 10:39:22 +0800 Subject: Bump up torch version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 926a587..de81280 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch~=1.7.1 +torch~=1.8.0 torchvision~=0.8.0a0+ecf4e9c numpy~=1.19.4 tqdm~=4.58.0 -- cgit v1.2.3 From ad78ce7f8d80f3a4ccd72e93ab3244f130b5ec1f Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 7 Mar 2021 11:20:24 +0800 Subject: Bump up package version --- requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index de81280..8b0a41b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch~=1.8.0 -torchvision~=0.8.0a0+ecf4e9c -numpy~=1.19.4 -tqdm~=4.58.0 +torchvision~=0.9.0a0+3f090d0 +numpy~=1.20.1 +tqdm~=4.59.0 Pillow~=8.1.0 scikit-learn~=0.24.0 \ No newline at end of file -- cgit v1.2.3 From 1b8d1614168ce6590c5e029c7f1007ac9b17048c Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 10 Mar 2021 14:10:24 +0800 Subject: Bug fixes 1. Resolve reference problems when parsing dataset selectors 2. Transform gallery using different models --- models/model.py | 49 +++++++++++++++++++++++++------------------------ utils/dataset.py | 6 +++--- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/models/model.py b/models/model.py index 7aff6c4..cb455f3 100644 --- a/models/model.py +++ b/models/model.py @@ -399,20 +399,20 @@ class Model: self.rgb_pn = self.rgb_pn.to(self.device) self.rgb_pn.eval() - gallery_samples, probe_samples = [], {} - # Gallery - checkpoint = torch.load(list(checkpoints.values())[0]) - self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) - for sample in tqdm(gallery_dataloader, - desc='Transforming gallery', unit='clips'): - gallery_samples.append(self._get_eval_sample(sample)) - gallery_samples = default_collate(gallery_samples) - # Probe - for (condition, dataloader) in probe_dataloaders.items(): + gallery_samples, probe_samples = {}, {} + for (condition, probe_dataloader) in probe_dataloaders.items(): checkpoint = torch.load(checkpoints[condition]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + # Gallery + gallery_samples_c = [] + for sample in tqdm(gallery_dataloader, + desc=f'Transforming gallery {condition}', + unit='clips'): + gallery_samples_c.append(self._get_eval_sample(sample)) + gallery_samples[condition] = default_collate(gallery_samples_c) + # Probe probe_samples_c = [] - for sample in tqdm(dataloader, + for sample in tqdm(probe_dataloader, desc=f'Transforming probe {condition}', unit='clips'): probe_samples_c.append(self._get_eval_sample(sample)) @@ -437,27 +437,28 @@ class Model: probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]], num_ranks: int = 5 ) -> dict[str, torch.Tensor]: - probe_conditions = self._probe_datasets_meta.keys() + conditions = gallery_samples.keys() gallery_views_meta = self._gallery_dataset_meta['views'] probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] accuracy = { condition: torch.empty( len(gallery_views_meta), len(probe_views_meta), num_ranks ) - for condition in self._probe_datasets_meta.keys() + for condition in conditions } - (labels_g, _, views_g, features_g) = gallery_samples.values() - views_g = np.asarray(views_g) - for (v_g_i, view_g) in enumerate(gallery_views_meta): - gallery_view_mask = (views_g == view_g) - f_g = features_g[gallery_view_mask] - y_g = labels_g[gallery_view_mask] - for condition in probe_conditions: - probe_samples_c = probe_samples[condition] - accuracy_c = accuracy[condition] - (labels_p, _, views_p, features_p) = probe_samples_c.values() - views_p = np.asarray(views_p) + for condition in conditions: + gallery_samples_c = gallery_samples[condition] + (labels_g, _, views_g, features_g) = gallery_samples_c.values() + views_g = np.asarray(views_g) + probe_samples_c = probe_samples[condition] + (labels_p, _, views_p, features_p) = probe_samples_c.values() + views_p = np.asarray(views_p) + accuracy_c = accuracy[condition] + for (v_g_i, view_g) in enumerate(gallery_views_meta): + gallery_view_mask = (views_g == view_g) + f_g = features_g[gallery_view_mask] + y_g = labels_g[gallery_view_mask] for (v_p_i, view_p) in enumerate(probe_views_meta): probe_view_mask = (views_p == view_p) f_p = features_p[probe_view_mask] diff --git a/utils/dataset.py b/utils/dataset.py index c487988..387c211 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -111,9 +111,9 @@ class CASIAB(data.Dataset): # in Bag #2 condition from 90 degree angle classes, conditions, views = [], [], [] if selector: - selected_classes = selector.pop('classes', None) - selected_conditions = selector.pop('conditions', None) - selected_views = selector.pop('views', None) + selected_classes = selector.get('classes', None) + selected_conditions = selector.get('conditions', None) + selected_views = selector.get('views', None) class_regex = r'\d{3}' condition_regex = r'(nm|bg|cl)-0[0-6]' -- cgit v1.2.3 From c74df416b00f837ba051f3947be92f76e7afbd88 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 12 Mar 2021 13:56:17 +0800 Subject: Code refactoring 1. Separate FCs and triplet losses for HPM and PartNet 2. Remove FC-equivalent 1x1 conv layers in HPM 3. Support adjustable learning rate schedulers --- config.py | 21 +++++---- models/auto_encoder.py | 2 +- models/hpm.py | 25 +++++------ models/layers.py | 9 ---- models/model.py | 119 ++++++++++++++++++++++++++++--------------------- models/part_net.py | 18 +++++--- models/rgb_part_net.py | 32 +++++-------- test/part_net.py | 2 +- utils/configuration.py | 20 +++++---- utils/triplet_loss.py | 40 ----------------- 10 files changed, 127 insertions(+), 161 deletions(-) diff --git a/config.py b/config.py index d6de788..8abeba3 100644 --- a/config.py +++ b/config.py @@ -50,19 +50,17 @@ config: Configuration = { 'ae_feature_channels': 64, # Appearance, canonical and pose feature dimensions 'f_a_c_p_dims': (192, 192, 96), - # Use 1x1 convolution in dimensionality reduction - 'hpm_use_1x1conv': False, # HPM pyramid scales, of which sum is number of parts 'hpm_scales': (1, 2, 4, 8), # Global pooling method 'hpm_use_avg_pool': True, 'hpm_use_max_pool': True, - # Attention squeeze ratio - 'tfa_squeeze_ratio': 4, # Number of parts after Part Net 'tfa_num_parts': 16, - # Embedding dimension for each part - 'embedding_dims': 256, + # Attention squeeze ratio + 'tfa_squeeze_ratio': 4, + # Embedding dimensions for each part + 'embedding_dims': (256, 256), # Batch Hard or Batch All 'triplet_is_hard': True, # Use non-zero mean or sum @@ -91,9 +89,14 @@ config: Configuration = { }, 'scheduler': { # Step start to decay - 'start_step': 15_000, + 'start_step': 500, # Multiplicative factor of decay in the end - 'final_gamma': 0.001, + 'final_gamma': 0.01, + + # Local parameters (override global ones) + 'hpm': { + 'final_gamma': 0.001 + } } }, # Model metadata @@ -107,6 +110,6 @@ config: Configuration = { # Restoration iteration (multiple models, e.g. nm, bg and cl) 'restore_iters': (0, 0, 0), # Total iteration for training (multiple models) - 'total_iters': (25_000, 25_000, 25_000), + 'total_iters': (30_000, 40_000, 60_000), }, } diff --git a/models/auto_encoder.py b/models/auto_encoder.py index e6a3e60..4fece69 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -171,7 +171,7 @@ class AutoEncoder(nn.Module): return ( (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_), - torch.stack((xrecon_loss, cano_cons_loss, pose_sim_loss * 10)) + (xrecon_loss, cano_cons_loss, pose_sim_loss * 10) ) else: # evaluating return f_c_c1_t2_, f_p_c1_t2_ diff --git a/models/hpm.py b/models/hpm.py index 9879cfb..8186b20 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -9,32 +9,26 @@ class HorizontalPyramidMatching(nn.Module): self, in_channels: int, out_channels: int = 128, - use_1x1conv: bool = False, scales: tuple[int, ...] = (1, 2, 4), use_avg_pool: bool = True, use_max_pool: bool = False, - **kwargs ): super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.use_1x1conv = use_1x1conv self.scales = scales + self.num_parts = sum(scales) self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool self.pyramids = nn.ModuleList([ - self._make_pyramid(scale, **kwargs) for scale in self.scales + self._make_pyramid(scale) for scale in scales ]) + self.fc_mat = nn.Parameter( + torch.empty(self.num_parts, in_channels, out_channels) + ) - def _make_pyramid(self, scale: int, **kwargs): + def _make_pyramid(self, scale: int): pyramid = nn.ModuleList([ - HorizontalPyramidPooling(self.in_channels, - self.out_channels, - use_1x1conv=self.use_1x1conv, - use_avg_pool=self.use_avg_pool, - use_max_pool=self.use_max_pool, - **kwargs) + HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool) for _ in range(scale) ]) return pyramid @@ -52,4 +46,9 @@ class HorizontalPyramidMatching(nn.Module): x_slice = x_slice.view(n, -1) feature.append(x_slice) x = torch.stack(feature) + + # p, n, c + x = x @ self.fc_mat + # p, n, d + return x diff --git a/models/layers.py b/models/layers.py index f1d72b6..c609698 100644 --- a/models/layers.py +++ b/models/layers.py @@ -167,17 +167,10 @@ class BasicConv1d(nn.Module): class HorizontalPyramidPooling(nn.Module): def __init__( self, - in_channels: int, - out_channels: int, - use_1x1conv: bool = False, use_avg_pool: bool = True, use_max_pool: bool = False, - **kwargs ): super().__init__() - self.use_1x1conv = use_1x1conv - if use_1x1conv: - self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs) self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.' @@ -191,6 +184,4 @@ class HorizontalPyramidPooling(nn.Module): x = self.avg_pool(x) elif not self.use_avg_pool and self.use_max_pool: x = self.max_pool(x) - if self.use_1x1conv: - x = self.conv(x) return x diff --git a/models/model.py b/models/model.py index cb455f3..adea626 100644 --- a/models/model.py +++ b/models/model.py @@ -13,13 +13,15 @@ from torch.utils.data.dataloader import default_collate from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm +from models.hpm import HorizontalPyramidMatching +from models.part_net import PartNet from models.rgb_part_net import RGBPartNet from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler -from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss +from utils.triplet_loss import BatchTripletLoss class Model: @@ -69,7 +71,8 @@ class Model: self._dataset_sig: str = 'undefined' self.rgb_pn: Optional[RGBPartNet] = None - self.triplet_loss: Optional[JointBatchTripletLoss] = None + self.triplet_loss_hpm: Optional[BatchTripletLoss] = None + self.triplet_loss_pn: Optional[BatchTripletLoss] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None @@ -149,26 +152,28 @@ class Model: triplet_margins = model_hp.pop('triplet_margins', None) optim_hp: dict = self.hp.get('optimizer', {}).copy() ae_optim_hp = optim_hp.pop('auto_encoder', {}) - pn_optim_hp = optim_hp.pop('part_net', {}) hpm_optim_hp = optim_hp.pop('hpm', {}) - fc_optim_hp = optim_hp.pop('fc', {}) + pn_optim_hp = optim_hp.pop('part_net', {}) sched_hp = self.hp.get('scheduler', {}) + ae_sched_hp = sched_hp.get('auto_encoder', {}) + hpm_sched_hp = sched_hp.get('hpm', {}) + pn_sched_hp = sched_hp.get('part_net', {}) + self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) # Hard margins if triplet_margins: - # Same margins - if triplet_margins[0] == triplet_margins[1]: - self.triplet_loss = BatchTripletLoss( - triplet_is_hard, triplet_margins[0] - ) - else: # Different margins - self.triplet_loss = JointBatchTripletLoss( - self.rgb_pn.hpm_num_parts, - triplet_is_hard, triplet_is_mean, triplet_margins - ) + self.triplet_loss_hpm = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, triplet_margins[0] + ) + self.triplet_loss_pn = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, triplet_margins[1] + ) else: # Soft margins - self.triplet_loss = BatchTripletLoss( + self.triplet_loss_hpm = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, None + ) + self.triplet_loss_pn = BatchTripletLoss( triplet_is_hard, triplet_is_mean, None ) @@ -177,25 +182,33 @@ class Model: # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) - self.triplet_loss = self.triplet_loss.to(self.device) + self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device) + self.triplet_loss_pn = self.triplet_loss_pn.to(self.device) + self.optimizer = optim.Adam([ {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.fc_mat, **fc_optim_hp} + {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, ], **optim_hp) - sched_final_gamma = sched_hp.get('final_gamma', 0.001) - sched_start_step = sched_hp.get('start_step', 15_000) - all_step = self.total_iter - sched_start_step - - def lr_lambda(epoch): - if epoch > sched_start_step: - passed_step = epoch - sched_start_step - return sched_final_gamma ** (passed_step / all_step) - else: - return 1 + + start_step = sched_hp.get('start_step', 15_000) + final_gamma = sched_hp.get('final_gamma', 0.001) + ae_start_step = ae_sched_hp.get('start_step', start_step) + ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma) + ae_all_step = self.total_iter - ae_start_step + hpm_start_step = hpm_sched_hp.get('start_step', start_step) + hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma) + hpm_all_step = self.total_iter - hpm_start_step + pn_start_step = pn_sched_hp.get('start_step', start_step) + pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma) + pn_all_step = self.total_iter - pn_start_step self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ - lr_lambda, lr_lambda, lr_lambda, lr_lambda + lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step) + if t > ae_start_step else 1, + lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step) + if t > hpm_start_step else 1, + lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step) + if t > pn_start_step else 1, ]) self.writer = SummaryWriter(self._log_name) @@ -220,7 +233,7 @@ class Model: running_loss = torch.zeros(5, device=self.device) print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}") + f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}") for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -228,17 +241,20 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embedding_c, embedding_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part - y = y.repeat(self.rgb_pn.num_total_parts, 1) - trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) - losses = torch.cat(( - ae_losses, - torch.stack(( - trip_loss[:self.rgb_pn.hpm_num_parts].mean(), - trip_loss[self.rgb_pn.hpm_num_parts:].mean() - )) + y = y.repeat(self.rgb_pn.num_parts, 1) + trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( + embedding_c, y[:self.rgb_pn.hpm.num_parts] + ) + trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( + embedding_p, y[self.rgb_pn.hpm.num_parts:] + ) + losses = torch.stack(( + *ae_losses, + trip_loss_hpm.mean(), + trip_loss_pn.mean() )) loss = losses.sum() loss.backward() @@ -257,30 +273,30 @@ class Model: 'PartNet': losses[4] }, self.curr_iter) # None-zero losses in batch - if num_non_zero is not None: + if hpm_num_non_zero is not None and hpm_num_non_zero is not None: self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), - 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() + 'HPM': hpm_num_non_zero.mean(), + 'PartNet': pn_num_non_zero.mean() }, self.curr_iter) # Embedding distance - mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_dist = hpm_dist.mean(0) self._add_ranked_scalars( 'Embedding/HPM distance', mean_hpm_dist, num_pos_pairs, num_pairs, self.curr_iter ) - mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_dist = pn_dist.mean(0) self._add_ranked_scalars( 'Embedding/ParNet distance', mean_pa_dist, num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm - mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_embedding = embedding_c.mean(0) mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/HPM norm', mean_hpm_norm, self.k, self.pr * self.k, self.curr_iter ) - mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_embedding = embedding_p.mean(0) mean_pa_norm = mean_pa_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/PartNet norm', mean_pa_norm, @@ -288,10 +304,9 @@ class Model: ) # Learning rate lrs = self.scheduler.get_last_lr() - # Write learning rates - self.writer.add_scalar( - 'Learning rate', lrs[0], self.curr_iter - ) + self.writer.add_scalars('Learning rate', dict(zip(( + 'Auto-encoder', 'HPM', 'PartNet' + ), lrs)), self.curr_iter) if self.curr_iter % 100 == 0: # Write disentangled images @@ -316,7 +331,7 @@ class Model: print(f'{hour:02}:{minute:02}:{second:02}', f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), - f'{lrs[0]:.3e}') + '{:.3e} {:.3e} {:.3e}'.format(*lrs)) running_loss.zero_() # Step scheduler @@ -548,7 +563,7 @@ class Model: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) - elif isinstance(m, RGBPartNet): + elif isinstance(m, (HorizontalPyramidMatching, PartNet)): nn.init.xavier_uniform_(m.fc_mat) def _parse_dataset_config( diff --git a/models/part_net.py b/models/part_net.py index 29cf9cd..f2236bf 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -111,17 +111,21 @@ class PartNet(nn.Module): def __init__( self, in_channels: int = 128, + embedding_dims: int = 256, + num_parts: int = 16, squeeze_ratio: int = 4, - num_part: int = 16 ): super().__init__() - self.num_part = num_part - self.tfa = TemporalFeatureAggregator( - in_channels, squeeze_ratio, self.num_part - ) + self.num_part = num_parts self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) + self.tfa = TemporalFeatureAggregator( + in_channels, squeeze_ratio, self.num_part + ) + self.fc_mat = nn.Parameter( + torch.empty(num_parts, in_channels, embedding_dims) + ) def forward(self, x): n, t, c, h, w = x.size() @@ -138,4 +142,8 @@ class PartNet(nn.Module): # p, n, t, c x = self.tfa(x) + + # p, n, c + x = x @ self.fc_mat + # p, n, d return x diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 8a0f3a7..c38a567 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,39 +13,31 @@ class RGBPartNet(nn.Module): ae_in_size: tuple[int, int] = (64, 48), ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), - hpm_use_1x1conv: bool = False, hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, - embedding_dims: int = 256, + embedding_dims: tuple[int] = (256, 256), image_log_on: bool = False ): super().__init__() self.h, self.w = ae_in_size (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims - self.hpm_num_parts = sum(hpm_scales) self.image_log_on = image_log_on self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) self.pn_in_channels = ae_feature_channels * 2 - self.pn = PartNet( - self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts - ) self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv, - hpm_scales, hpm_use_avg_pool, hpm_use_max_pool + self.pn_in_channels, embedding_dims[0], hpm_scales, + hpm_use_avg_pool, hpm_use_max_pool ) - self.num_total_parts = self.hpm_num_parts + tfa_num_parts - empty_fc = torch.empty(self.num_total_parts, - self.pn_in_channels, embedding_dims) - self.fc_mat = nn.Parameter(empty_fc) + self.pn = PartNet(self.pn_in_channels, embedding_dims[1], + tfa_num_parts, tfa_squeeze_ratio) - def fc(self, x): - return x @ self.fc_mat + self.num_parts = self.hpm.num_parts + tfa_num_parts def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement @@ -55,21 +47,17 @@ class RGBPartNet(nn.Module): # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w x_c = self.hpm(x_c) - # p, n, c + # p, n, d # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) # n, t, c, h, w x_p = self.pn(x_p) - # p, n, c - - # Step 3: Cat feature map together and fc - x = torch.cat((x_c, x_p)) - x = self.fc(x) + # p, n, d if self.training: - return x, ae_losses, images + return x_c, x_p, ae_losses, images else: - return x.unsqueeze(1).view(-1) + return torch.cat((x_c, x_p)).unsqueeze(1).view(-1) def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() diff --git a/test/part_net.py b/test/part_net.py index 25e92ae..fada2c4 100644 --- a/test/part_net.py +++ b/test/part_net.py @@ -64,7 +64,7 @@ def test_custom_part_net(): paddings=((2, 1), (1, 1), (1, 1), (1, 1)), halving=(1, 1, 3, 3), squeeze_ratio=8, - num_part=8) + num_parts=8) x = torch.rand(T, N, 1, H, W) x = pa(x) diff --git a/utils/configuration.py b/utils/configuration.py index 0f8d9ff..f6ac182 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -33,16 +33,11 @@ class ModelHPConfiguration(TypedDict): ae_feature_channels: int f_a_c_p_dims: tuple[int, int, int] hpm_scales: tuple[int, ...] - hpm_use_1x1conv: bool hpm_use_avg_pool: bool hpm_use_max_pool: bool - fpfe_feature_channels: int - fpfe_kernel_sizes: tuple[tuple, ...] - fpfe_paddings: tuple[tuple, ...] - fpfe_halving: tuple[int, ...] - tfa_squeeze_ratio: int tfa_num_parts: int - embedding_dims: int + tfa_squeeze_ratio: int + embedding_dims: tuple[int] triplet_is_hard: bool triplet_is_mean: bool triplet_margins: tuple[float, float] @@ -63,14 +58,21 @@ class OptimizerHPConfiguration(TypedDict): weight_decay: float amsgrad: bool auto_encoder: SubOptimizerHPConfiguration - part_net: SubOptimizerHPConfiguration hpm: SubOptimizerHPConfiguration - fc: SubOptimizerHPConfiguration + part_net: SubOptimizerHPConfiguration + + +class SubSchedulerHPConfiguration(TypedDict): + start_step: int + final_gamma: float class SchedulerHPConfiguration(TypedDict): start_step: int final_gamma: float + auto_encoder: SubSchedulerHPConfiguration + hpm: SubSchedulerHPConfiguration + part_net: SubSchedulerHPConfiguration class HyperparameterConfiguration(TypedDict): diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index e05b69d..03fff21 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -85,43 +85,3 @@ class BatchTripletLoss(nn.Module): non_zero_mean = losses.sum(1) / non_zero_counts non_zero_mean[non_zero_counts == 0] = 0 return non_zero_mean - - -class JointBatchTripletLoss(BatchTripletLoss): - def __init__( - self, - hpm_num_parts: int, - is_hard: bool = True, - is_mean: bool = True, - margins: tuple[float, float] = (0.2, 0.2) - ): - super().__init__(is_hard, is_mean) - self.hpm_num_parts = hpm_num_parts - self.margin_hpm, self.margin_pn = margins - - def forward(self, x, y): - p, n, c = x.size() - dist = self._batch_distance(x) - flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) - flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] - - if self.is_hard: - positive_negative_dist = self._hard_distance(dist, y, p, n) - else: # is_all - positive_negative_dist = self._all_distance(dist, y, p, n) - - hpm_part_loss = F.relu( - self.margin_hpm + positive_negative_dist[:self.hpm_num_parts] - ) - pn_part_loss = F.relu( - self.margin_pn + positive_negative_dist[self.hpm_num_parts:] - ) - losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - - non_zero_counts = (losses != 0).sum(1).float() - if self.is_mean: - loss_metric = self._none_zero_mean(losses, non_zero_counts) - else: # is_sum - loss_metric = losses.sum(1) - - return loss_metric, flat_dist, non_zero_counts -- cgit v1.2.3 From 2ea916b2a963eae7d47151b41c8c78a578c402e2 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 12 Mar 2021 15:31:44 +0800 Subject: Make evaluate method static --- models/model.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/models/model.py b/models/model.py index adea626..b09d600 100644 --- a/models/model.py +++ b/models/model.py @@ -241,15 +241,15 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embedding_c, embedding_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part y = y.repeat(self.rgb_pn.num_parts, 1) trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( - embedding_c, y[:self.rgb_pn.hpm.num_parts] + embed_c, y[:self.rgb_pn.hpm.num_parts] ) trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( - embedding_p, y[self.rgb_pn.hpm.num_parts:] + embed_p, y[self.rgb_pn.hpm.num_parts:] ) losses = torch.stack(( *ae_losses, @@ -290,13 +290,13 @@ class Model: num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm - mean_hpm_embedding = embedding_c.mean(0) + mean_hpm_embedding = embed_c.mean(0) mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/HPM norm', mean_hpm_norm, self.k, self.pr * self.k, self.curr_iter ) - mean_pa_embedding = embedding_p.mean(0) + mean_pa_embedding = embed_p.mean(0) mean_pa_norm = mean_pa_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/PartNet norm', mean_pa_norm, @@ -425,13 +425,16 @@ class Model: unit='clips'): gallery_samples_c.append(self._get_eval_sample(sample)) gallery_samples[condition] = default_collate(gallery_samples_c) + gallery_samples['meta'] = self._gallery_dataset_meta # Probe probe_samples_c = [] for sample in tqdm(probe_dataloader, desc=f'Transforming probe {condition}', unit='clips'): probe_samples_c.append(self._get_eval_sample(sample)) - probe_samples[condition] = default_collate(probe_samples_c) + probe_samples_c = default_collate(probe_samples_c) + probe_samples_c['meta'] = self._probe_datasets_meta[condition] + probe_samples[condition] = probe_samples_c return gallery_samples, probe_samples @@ -446,15 +449,15 @@ class Model: **{'feature': feature} } + @staticmethod def evaluate( - self, - gallery_samples: dict[str, Union[list[str], torch.Tensor]], - probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]], + gallery_samples: dict[str, dict[str, Union[list, torch.Tensor]]], + probe_samples: dict[str, dict[str, Union[list, torch.Tensor]]], num_ranks: int = 5 ) -> dict[str, torch.Tensor]: - conditions = gallery_samples.keys() - gallery_views_meta = self._gallery_dataset_meta['views'] - probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] + conditions = list(probe_samples.keys()) + gallery_views_meta = gallery_samples['meta']['views'] + probe_views_meta = probe_samples[conditions[0]]['meta']['views'] accuracy = { condition: torch.empty( len(gallery_views_meta), len(probe_views_meta), num_ranks @@ -467,7 +470,7 @@ class Model: (labels_g, _, views_g, features_g) = gallery_samples_c.values() views_g = np.asarray(views_g) probe_samples_c = probe_samples[condition] - (labels_p, _, views_p, features_p) = probe_samples_c.values() + (labels_p, _, views_p, features_p, _) = probe_samples_c.values() views_p = np.asarray(views_p) accuracy_c = accuracy[condition] for (v_g_i, view_g) in enumerate(gallery_views_meta): @@ -492,7 +495,6 @@ class Model: positive_counts = positive_mat.sum(0) total_counts, _ = dist.size() accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts - return accuracy def _load_pretrained( -- cgit v1.2.3 From 5c4483ea0c7b2e9166f94d6b6e3b8a070eb959f2 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Fri, 12 Mar 2021 20:36:27 +0800 Subject: Fix a typo when record none-zero counts --- models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/model.py b/models/model.py index b09d600..89b34aa 100644 --- a/models/model.py +++ b/models/model.py @@ -273,7 +273,7 @@ class Model: 'PartNet': losses[4] }, self.curr_iter) # None-zero losses in batch - if hpm_num_non_zero is not None and hpm_num_non_zero is not None: + if hpm_num_non_zero is not None and pn_num_non_zero is not None: self.writer.add_scalars('Loss/non-zero counts', { 'HPM': hpm_num_non_zero.mean(), 'PartNet': pn_num_non_zero.mean() -- cgit v1.2.3 From 263b8001ca1b25a43d1c87f187423054e141925d Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 14 Mar 2021 21:07:03 +0800 Subject: Fix unbalanced datasets --- utils/sampler.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/utils/sampler.py b/utils/sampler.py index cdf1984..0c9872c 100644 --- a/utils/sampler.py +++ b/utils/sampler.py @@ -16,7 +16,18 @@ class TripletSampler(data.Sampler): ): super().__init__(data_source) self.metadata_labels = data_source.metadata['labels'] + metadata_conditions = data_source.metadata['conditions'] + self.subsets = {} + for condition in metadata_conditions: + pre, _ = condition.split('-') + if self.subsets.get(pre, None) is None: + self.subsets[pre] = [] + self.subsets[pre].append(condition) + self.num_subsets = len(self.subsets) + self.num_seq = {pre: len(seq) for (pre, seq) in self.subsets.items()} + self.min_num_seq = min(self.num_seq.values()) self.labels = data_source.labels + self.conditions = data_source.conditions self.length = len(self.labels) self.indexes = np.arange(0, self.length) (self.pr, self.k) = batch_size @@ -27,15 +38,31 @@ class TripletSampler(data.Sampler): # Sample pr subjects by sampling labels appeared in dataset sampled_subjects = random.sample(self.metadata_labels, k=self.pr) for label in sampled_subjects: - clips_from_subject = self.indexes[self.labels == label].tolist() + mask = self.labels == label + # Fix unbalanced datasets + if self.num_subsets > 1: + condition_mask = np.zeros(self.conditions.shape, dtype=bool) + for num, conditions_ in zip( + self.num_seq.values(), self.subsets.values() + ): + if num > self.min_num_seq: + conditions = random.sample( + conditions_, self.min_num_seq + ) + else: + conditions = conditions_ + for condition in conditions: + condition_mask |= self.conditions == condition + mask &= condition_mask + clips = self.indexes[mask].tolist() # Sample k clips from the subject without replacement if # have enough clips, k more clips will sampled for # disentanglement k = self.k * 2 - if len(clips_from_subject) >= k: - _sampled_indexes = random.sample(clips_from_subject, k=k) + if len(clips) >= k: + _sampled_indexes = random.sample(clips, k=k) else: - _sampled_indexes = random.choices(clips_from_subject, k=k) + _sampled_indexes = random.choices(clips, k=k) sampled_indexes += _sampled_indexes yield sampled_indexes -- cgit v1.2.3 From da922be042d96338a3f207386e410b6746d046f5 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 14 Mar 2021 21:07:28 +0800 Subject: Bug fix when transforming and new config --- config.py | 6 +++--- models/rgb_part_net.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/config.py b/config.py index 8abeba3..c928067 100644 --- a/config.py +++ b/config.py @@ -94,9 +94,9 @@ config: Configuration = { 'final_gamma': 0.01, # Local parameters (override global ones) - 'hpm': { - 'final_gamma': 0.001 - } + # 'hpm': { + # 'final_gamma': 0.001 + # } } }, # Model metadata diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index c38a567..c136040 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -57,7 +57,7 @@ class RGBPartNet(nn.Module): if self.training: return x_c, x_p, ae_losses, images else: - return torch.cat((x_c, x_p)).unsqueeze(1).view(-1) + return torch.cat((x_c.view(-1), x_p.view(-1))) def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() -- cgit v1.2.3 From fe6cd66d19c16153322577fb13779020934cf1e2 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 15 Mar 2021 17:15:33 +0800 Subject: Support transforming on training datasets --- models/model.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/models/model.py b/models/model.py index 89b34aa..5966ae1 100644 --- a/models/model.py +++ b/models/model.py @@ -392,12 +392,12 @@ class Model: dataset_selectors: dict[ str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] ], - dataloader_config: DataloaderConfiguration + dataloader_config: DataloaderConfiguration, + is_train: bool = False ): - self.is_train = False # Split gallery and probe dataset gallery_dataloader, probe_dataloaders = self._split_gallery_probe( - dataset_config, dataloader_config + dataset_config, dataloader_config, is_train ) # Get pretrained models at iter_ checkpoints = self._load_pretrained( @@ -506,10 +506,11 @@ class Model: ] ) -> dict[str, str]: checkpoints = {} - for (iter_, (condition, selector)) in zip( - iters, dataset_selectors.items() + for (iter_, total_iter, (condition, selector)) in zip( + iters, self.total_iters, dataset_selectors.items() ): self.curr_iter = iter_ + self.total_iter = total_iter self._dataset_sig = self._make_signature( dict(**dataset_config, **selector), popped_keys=['root_dir', 'cache_on'] @@ -521,26 +522,29 @@ class Model: self, dataset_config: DatasetConfiguration, dataloader_config: DataloaderConfiguration, + is_train: bool = False ) -> tuple[DataLoader, dict[str, DataLoader]]: dataset_name = dataset_config.get('name', 'CASIA-B') if dataset_name == 'CASIA-B': + self.is_train = is_train gallery_dataset = self._parse_dataset_config( dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR) ) - self._gallery_dataset_meta = gallery_dataset.metadata - gallery_dataloader = self._parse_dataloader_config( - gallery_dataset, dataloader_config - ) probe_datasets = { condition: self._parse_dataset_config( dict(**dataset_config, **selector) ) for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items() } + self._gallery_dataset_meta = gallery_dataset.metadata self._probe_datasets_meta = { condition: dataset.metadata for (condition, dataset) in probe_datasets.items() } + self.is_train = False + gallery_dataloader = self._parse_dataloader_config( + gallery_dataset, dataloader_config + ) probe_dataloaders = { condition: self._parse_dataloader_config( dataset, dataloader_config -- cgit v1.2.3 From 864fca2c9ca65847c0f1f318dfe50a1e6155e418 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 15 Mar 2021 19:41:28 +0800 Subject: Fix redundant gallery_dataset_meta assignment --- models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/model.py b/models/model.py index 5966ae1..c350d11 100644 --- a/models/model.py +++ b/models/model.py @@ -425,7 +425,6 @@ class Model: unit='clips'): gallery_samples_c.append(self._get_eval_sample(sample)) gallery_samples[condition] = default_collate(gallery_samples_c) - gallery_samples['meta'] = self._gallery_dataset_meta # Probe probe_samples_c = [] for sample in tqdm(probe_dataloader, @@ -435,6 +434,7 @@ class Model: probe_samples_c = default_collate(probe_samples_c) probe_samples_c['meta'] = self._probe_datasets_meta[condition] probe_samples[condition] = probe_samples_c + gallery_samples['meta'] = self._gallery_dataset_meta return gallery_samples, probe_samples -- cgit v1.2.3 From a68562cbb7f602cc75b3f8f0bf0c285d9e4e4c8b Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 15 Mar 2021 20:28:06 +0800 Subject: Remove redundant wrapper given by dataloader --- models/model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/models/model.py b/models/model.py index c350d11..c1cc703 100644 --- a/models/model.py +++ b/models/model.py @@ -439,14 +439,14 @@ class Model: return gallery_samples, probe_samples def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): - label = sample.pop('label').item() - clip = sample.pop('clip').to(self.device) + label, condition, view, clip = sample.values() with torch.no_grad(): - feature = self.rgb_pn(clip) + feature = self.rgb_pn(clip.to(self.device)) return { - **{'label': label}, - **sample, - **{'feature': feature} + 'label': label.item(), + 'condition': condition[0], + 'view': view[0], + 'feature': feature } @staticmethod -- cgit v1.2.3 From 38555617816cfaef6f330e7fc90f3cfa65d692fb Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Tue, 16 Mar 2021 15:42:06 +0800 Subject: Set *_iter as *_iters in default --- models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/model.py b/models/model.py index c1cc703..36a4f7f 100644 --- a/models/model.py +++ b/models/model.py @@ -54,8 +54,8 @@ class Model: self.hp = hyperparameter_config self.curr_iter = self.meta.get('restore_iter', 0) self.total_iter = self.meta.get('total_iter', 80_000) - self.curr_iters = self.meta.get('restore_iters', (0, 0, 0)) - self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000)) + self.curr_iters = self.meta.get('restore_iters', (self.curr_iter,)) + self.total_iters = self.meta.get('total_iters', (self.total_iter,)) self.is_train: bool = True self.in_channels: int = 3 -- cgit v1.2.3 From b6e5972b64cc61fc967cf3d098fc629d781adce4 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 22 Mar 2021 19:32:16 +0800 Subject: Add embedding visualization and validate on testing set --- config.py | 2 + models/model.py | 249 ++++++++++++++++++++++++++++++------------------- models/rgb_part_net.py | 2 +- utils/configuration.py | 1 + utils/triplet_loss.py | 9 +- 5 files changed, 165 insertions(+), 98 deletions(-) diff --git a/config.py b/config.py index c928067..3d98263 100644 --- a/config.py +++ b/config.py @@ -19,6 +19,8 @@ config: Configuration = { 'root_dir': 'data/CASIA-B-MRCNN-V2/SEG', # The number of subjects for training 'train_size': 74, + # The number of subjects for validating (Part of testing set) + 'val_size': 10, # Number of sampled frames per sequence (Training only) 'num_sampled_frames': 30, # Truncate clips longer than `truncate_threshold` diff --git a/models/model.py b/models/model.py index 36a4f7f..766e513 100644 --- a/models/model.py +++ b/models/model.py @@ -1,6 +1,6 @@ +import copy import os import random -from datetime import datetime from typing import Union, Optional import numpy as np @@ -52,9 +52,9 @@ class Model: self.meta = model_config self.hp = hyperparameter_config - self.curr_iter = self.meta.get('restore_iter', 0) + self.restore_iter = self.curr_iter = self.meta.get('restore_iter', 0) self.total_iter = self.meta.get('total_iter', 80_000) - self.curr_iters = self.meta.get('restore_iters', (self.curr_iter,)) + self.restore_iters = self.meta.get('restore_iters', (self.curr_iter,)) self.total_iters = self.meta.get('total_iters', (self.total_iter,)) self.is_train: bool = True @@ -62,6 +62,8 @@ class Model: self.in_size: tuple[int, int] = (64, 48) self.pr: Optional[int] = None self.k: Optional[int] = None + self.num_pairs: Optional[int] = None + self.num_pos_pairs: Optional[int] = None self._gallery_dataset_meta: Optional[dict[str, list]] = None self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None @@ -90,7 +92,7 @@ class Model: @property def _model_sig(self) -> str: return '_'.join( - (self._model_name, str(self.curr_iter), str(self.total_iter)) + (self._model_name, str(self.curr_iter + 1), str(self.total_iter)) ) @property @@ -119,18 +121,18 @@ class Model: ], dataloader_config: DataloaderConfiguration, ): - for (curr_iter, total_iter, (condition, selector)) in zip( - self.curr_iters, self.total_iters, dataset_selectors.items() + for (restore_iter, total_iter, (condition, selector)) in zip( + self.restore_iters, self.total_iters, dataset_selectors.items() ): print(f'Training model {condition} ...') # Skip finished model - if curr_iter == total_iter: + if restore_iter == total_iter: continue # Check invalid restore iter - elif curr_iter > total_iter: + elif restore_iter > total_iter: raise ValueError("Restore iter '{}' should less than total " - "iter '{}'".format(curr_iter, total_iter)) - self.curr_iter = curr_iter + "iter '{}'".format(restore_iter, total_iter)) + self.restore_iter = self.curr_iter = restore_iter self.total_iter = total_iter self.fit( dict(**dataset_config, **{'selector': selector}), @@ -143,8 +145,24 @@ class Model: dataloader_config: DataloaderConfiguration, ): self.is_train = True - dataset = self._parse_dataset_config(dataset_config) - dataloader = self._parse_dataloader_config(dataset, dataloader_config) + # Validation dataset + # (the first `val_size` subjects from evaluation set) + val_size = dataset_config.pop('val_size', 10) + val_dataset_config = copy.deepcopy(dataset_config) + train_size = dataset_config.get('train_size', 74) + val_dataset_config['train_size'] = train_size + val_size + val_dataset_config['selector']['classes'] = ClipClasses({ + str(c).zfill(3) for c in range(train_size, train_size + val_size) + }) + val_dataset = self._parse_dataset_config(val_dataset_config) + val_dataloader = iter(self._parse_dataloader_config( + val_dataset, dataloader_config + )) + # Training dataset + train_dataset = self._parse_dataset_config(dataset_config) + train_dataloader = iter(self._parse_dataloader_config( + train_dataset, dataloader_config + )) # Prepare for model, optimizer and scheduler model_hp: dict = self.hp.get('model', {}).copy() triplet_is_hard = model_hp.pop('triplet_is_hard', True) @@ -177,8 +195,8 @@ class Model: triplet_is_hard, triplet_is_mean, None ) - num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 - num_pos_pairs = (self.k*(self.k-1)//2) * self.pr + self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 + self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) @@ -191,6 +209,7 @@ class Model: {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, ], **optim_hp) + # Scheduler start_step = sched_hp.get('start_step', 15_000) final_gamma = sched_hp.get('final_gamma', 0.001) ae_start_step = ae_sched_hp.get('start_step', start_step) @@ -221,6 +240,8 @@ class Model: if self.curr_iter == 0: self.rgb_pn.apply(self.init_weights) else: # Load saved state dicts + # Offset a iter to load last checkpoint + self.curr_iter -= 1 checkpoint = torch.load(self._checkpoint_name) random.setstate(checkpoint['rand_states'][0]) torch.set_rng_state(checkpoint['rand_states'][1]) @@ -229,13 +250,9 @@ class Model: self.scheduler.load_state_dict(checkpoint['sched_state_dict']) # Training start - start_time = datetime.now() - running_loss = torch.zeros(5, device=self.device) - print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", - f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}") - for (batch_c1, batch_c2) in dataloader: - self.curr_iter += 1 + for self.curr_iter in tqdm(range(self.restore_iter, self.total_iter), + desc='Training'): + batch_c1, batch_c2 = next(train_dataloader) # Zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize @@ -243,72 +260,24 @@ class Model: x_c2 = batch_c2['clip'].to(self.device) embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) - # Duplicate labels for each part - y = y.repeat(self.rgb_pn.num_parts, 1) - trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( - embed_c, y[:self.rgb_pn.hpm.num_parts] + losses, hpm_result, pn_result = self._classification_loss( + embed_c, embed_p, ae_losses, y ) - trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( - embed_p, y[self.rgb_pn.hpm.num_parts:] - ) - losses = torch.stack(( - *ae_losses, - trip_loss_hpm.mean(), - trip_loss_pn.mean() - )) loss = losses.sum() loss.backward() self.optimizer.step() + self.scheduler.step() - # Statistics and checkpoint - running_loss += losses.detach() - # Write losses to TensorBoard - self.writer.add_scalar('Loss/all', loss, self.curr_iter) - self.writer.add_scalars('Loss/disentanglement', dict(zip(( - 'Cross reconstruction loss', 'Canonical consistency loss', - 'Pose similarity loss' - ), ae_losses)), self.curr_iter) - self.writer.add_scalars('Loss/triplet loss', { - 'HPM': losses[3], - 'PartNet': losses[4] - }, self.curr_iter) - # None-zero losses in batch - if hpm_num_non_zero is not None and pn_num_non_zero is not None: - self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': hpm_num_non_zero.mean(), - 'PartNet': pn_num_non_zero.mean() - }, self.curr_iter) - # Embedding distance - mean_hpm_dist = hpm_dist.mean(0) - self._add_ranked_scalars( - 'Embedding/HPM distance', mean_hpm_dist, - num_pos_pairs, num_pairs, self.curr_iter - ) - mean_pa_dist = pn_dist.mean(0) - self._add_ranked_scalars( - 'Embedding/ParNet distance', mean_pa_dist, - num_pos_pairs, num_pairs, self.curr_iter - ) - # Embedding norm - mean_hpm_embedding = embed_c.mean(0) - mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) - self._add_ranked_scalars( - 'Embedding/HPM norm', mean_hpm_norm, - self.k, self.pr * self.k, self.curr_iter - ) - mean_pa_embedding = embed_p.mean(0) - mean_pa_norm = mean_pa_embedding.norm(dim=-1) - self._add_ranked_scalars( - 'Embedding/PartNet norm', mean_pa_norm, - self.k, self.pr * self.k, self.curr_iter - ) # Learning rate - lrs = self.scheduler.get_last_lr() self.writer.add_scalars('Learning rate', dict(zip(( 'Auto-encoder', 'HPM', 'PartNet' - ), lrs)), self.curr_iter) + ), self.scheduler.get_last_lr())), self.curr_iter) + # Other stats + self._write_stat( + 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses + ) - if self.curr_iter % 100 == 0: + if self.curr_iter % 100 == 99: # Write disentangled images if self.image_log_on: i_a, i_c, i_p = images @@ -325,19 +294,35 @@ class Model: self.writer.add_images( f'Pose image/batch {i}', p, self.curr_iter ) - time_used = datetime.now() - start_time - remaining_minute, second = divmod(time_used.seconds, 60) - hour, minute = divmod(remaining_minute, 60) - print(f'{hour:02}:{minute:02}:{second:02}', - f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', - '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), - '{:.3e} {:.3e} {:.3e}'.format(*lrs)) - running_loss.zero_() - - # Step scheduler - self.scheduler.step() - if self.curr_iter % 1000 == 0: + # Validation + embed_c = self._flatten_embedding(embed_c) + embed_p = self._flatten_embedding(embed_p) + self._write_embedding('HPM Train', embed_c, x_c1, y) + self._write_embedding('PartNet Train', embed_p, x_c1, y) + + # Calculate losses on testing batch + batch_c1, batch_c2 = next(val_dataloader) + x_c1 = batch_c1['clip'].to(self.device) + x_c2 = batch_c2['clip'].to(self.device) + with torch.no_grad(): + embed_c, embed_p, ae_losses, _ = self.rgb_pn(x_c1, x_c2) + y = batch_c1['label'].to(self.device) + losses, hpm_result, pn_result = self._classification_loss( + embed_c, embed_p, ae_losses, y + ) + loss = losses.sum() + + self._write_stat( + 'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses + ) + embed_c = self._flatten_embedding(embed_c) + embed_p = self._flatten_embedding(embed_p) + self._write_embedding('HPM Val', embed_c, x_c1, y) + self._write_embedding('PartNet Val', embed_p, x_c1, y) + + # Checkpoint + if self.curr_iter % 1000 == 999: torch.save({ 'rand_states': (random.getstate(), torch.get_rng_state()), 'model_state_dict': self.rgb_pn.state_dict(), @@ -345,9 +330,83 @@ class Model: 'sched_state_dict': self.scheduler.state_dict(), }, self._checkpoint_name) - if self.curr_iter == self.total_iter: - self.writer.close() - break + self.writer.close() + + def _classification_loss(self, embed_c, embed_p, ae_losses, y): + # Duplicate labels for each part + y_triplet = y.repeat(self.rgb_pn.num_parts, 1) + hpm_result = self.triplet_loss_hpm( + embed_c, y_triplet[:self.rgb_pn.hpm.num_parts] + ) + pn_result = self.triplet_loss_pn( + embed_p, y_triplet[self.rgb_pn.hpm.num_parts:] + ) + losses = torch.stack(( + *ae_losses, + hpm_result.pop('loss').mean(), + pn_result.pop('loss').mean() + )) + return losses, hpm_result, pn_result + + def _write_embedding(self, tag, embed, x, y): + frame = x[:, 0, :, :, :].cpu() + n, c, h, w = frame.size() + padding = torch.zeros(n, c, h, (h-w) // 2) + padded_frame = torch.cat((padding, frame, padding), dim=-1) + self.writer.add_embedding( + embed, + metadata=y.cpu().tolist(), + label_img=padded_frame, + global_step=self.curr_iter, + tag=tag + ) + + def _flatten_embedding(self, embed): + return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1) + + def _write_stat( + self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses + ): + # Write losses to TensorBoard + self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter) + self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip(( + 'Cross reconstruction loss', 'Canonical consistency loss', + 'Pose similarity loss' + ), losses[:3])), self.curr_iter) + self.writer.add_scalars(f'Loss/triplet loss {postfix}', { + 'HPM': losses[3], + 'PartNet': losses[4] + }, self.curr_iter) + # None-zero losses in batch + if hpm_result['counts'] is not None and pn_result['counts'] is not None: + self.writer.add_scalars(f'Loss/non-zero counts {postfix}', { + 'HPM': hpm_result['counts'].mean(), + 'PartNet': pn_result['counts'].mean() + }, self.curr_iter) + # Embedding distance + mean_hpm_dist = hpm_result['dist'].mean(0) + self._add_ranked_scalars( + f'Embedding/HPM distance {postfix}', mean_hpm_dist, + self.num_pos_pairs, self.num_pairs, self.curr_iter + ) + mean_pn_dist = pn_result['dist'].mean(0) + self._add_ranked_scalars( + f'Embedding/ParNet distance {postfix}', mean_pn_dist, + self.num_pos_pairs, self.num_pairs, self.curr_iter + ) + # Embedding norm + mean_hpm_embedding = embed_c.mean(0) + mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) + self._add_ranked_scalars( + f'Embedding/HPM norm {postfix}', mean_hpm_norm, + self.k, self.pr * self.k, self.curr_iter + ) + mean_pa_embedding = embed_p.mean(0) + mean_pa_norm = mean_pa_embedding.norm(dim=-1) + self._add_ranked_scalars( + f'Embedding/PartNet norm {postfix}', mean_pa_norm, + self.k, self.pr * self.k, self.curr_iter + ) def _add_ranked_scalars( self, @@ -441,12 +500,12 @@ class Model: def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): label, condition, view, clip = sample.values() with torch.no_grad(): - feature = self.rgb_pn(clip.to(self.device)) + feature_c, feature_p = self.rgb_pn(clip.to(self.device)) return { 'label': label.item(), 'condition': condition[0], 'view': view[0], - 'feature': feature + 'feature': torch.cat((feature_c, feature_p)).view(-1) } @staticmethod diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index c136040..4a82da3 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -57,7 +57,7 @@ class RGBPartNet(nn.Module): if self.training: return x_c, x_p, ae_losses, images else: - return torch.cat((x_c.view(-1), x_p.view(-1))) + return x_c, x_p def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() diff --git a/utils/configuration.py b/utils/configuration.py index f6ac182..157d249 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -14,6 +14,7 @@ class DatasetConfiguration(TypedDict): name: str root_dir: str train_size: int + val_size: int num_sampled_frames: int truncate_threshold: int discard_threshold: int diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 03fff21..5e3a97a 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -28,6 +28,7 @@ class BatchTripletLoss(nn.Module): else: # is_all positive_negative_dist = self._all_distance(dist, y, p, n) + non_zero_counts = None if self.margin: losses = F.relu(self.margin + positive_negative_dist).view(p, -1) non_zero_counts = (losses != 0).sum(1).float() @@ -35,14 +36,18 @@ class BatchTripletLoss(nn.Module): loss_metric = self._none_zero_mean(losses, non_zero_counts) else: # is_sum loss_metric = losses.sum(1) - return loss_metric, flat_dist, non_zero_counts else: # Soft margin losses = F.softplus(positive_negative_dist).view(p, -1) if self.is_mean: loss_metric = losses.mean(1) else: # is_sum loss_metric = losses.sum(1) - return loss_metric, flat_dist, None + + return { + 'loss': loss_metric, + 'dist': flat_dist, + 'counts': non_zero_counts + } @staticmethod def _batch_distance(x): -- cgit v1.2.3 From c967a2ac88e075082473f3ca219660fa12cf67d1 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Tue, 23 Mar 2021 10:26:29 +0800 Subject: Fix indexing bugs in validation dataset selector --- models/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/model.py b/models/model.py index 766e513..17716ad 100644 --- a/models/model.py +++ b/models/model.py @@ -152,7 +152,8 @@ class Model: train_size = dataset_config.get('train_size', 74) val_dataset_config['train_size'] = train_size + val_size val_dataset_config['selector']['classes'] = ClipClasses({ - str(c).zfill(3) for c in range(train_size, train_size + val_size) + str(c).zfill(3) + for c in range(train_size + 1, train_size + val_size + 1) }) val_dataset = self._parse_dataset_config(val_dataset_config) val_dataloader = iter(self._parse_dataloader_config( -- cgit v1.2.3 From 5a063855dbecb8f1a86ad25d9e61a9c8b63312b3 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 25 Mar 2021 12:23:23 +0800 Subject: Bug fixes and refactoring 1. Correct trained model signature 2. Move `val_size` to system config --- config.py | 6 +++--- models/model.py | 8 ++++---- utils/configuration.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/config.py b/config.py index 3d98263..66eab98 100644 --- a/config.py +++ b/config.py @@ -9,7 +9,9 @@ config: Configuration = { # Directory used in training or testing for temporary storage 'save_dir': 'runs', # Recorde disentangled image or not - 'image_log_on': False + 'image_log_on': False, + # The number of subjects for validating (Part of testing set) + 'val_size': 10, }, # Dataset settings 'dataset': { @@ -19,8 +21,6 @@ config: Configuration = { 'root_dir': 'data/CASIA-B-MRCNN-V2/SEG', # The number of subjects for training 'train_size': 74, - # The number of subjects for validating (Part of testing set) - 'val_size': 10, # Number of sampled frames per sequence (Training only) 'num_sampled_frames': 30, # Truncate clips longer than `truncate_threshold` diff --git a/models/model.py b/models/model.py index 17716ad..ceadb92 100644 --- a/models/model.py +++ b/models/model.py @@ -79,6 +79,7 @@ class Model: self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None self.image_log_on = system_config.get('image_log_on', False) + self.val_size = system_config.get('val_size', 10) self.CASIAB_GALLERY_SELECTOR = { 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})} @@ -147,13 +148,12 @@ class Model: self.is_train = True # Validation dataset # (the first `val_size` subjects from evaluation set) - val_size = dataset_config.pop('val_size', 10) val_dataset_config = copy.deepcopy(dataset_config) train_size = dataset_config.get('train_size', 74) - val_dataset_config['train_size'] = train_size + val_size + val_dataset_config['train_size'] = train_size + self.val_size val_dataset_config['selector']['classes'] = ClipClasses({ str(c).zfill(3) - for c in range(train_size + 1, train_size + val_size + 1) + for c in range(train_size + 1, train_size + self.val_size + 1) }) val_dataset = self._parse_dataset_config(val_dataset_config) val_dataloader = iter(self._parse_dataloader_config( @@ -569,7 +569,7 @@ class Model: for (iter_, total_iter, (condition, selector)) in zip( iters, self.total_iters, dataset_selectors.items() ): - self.curr_iter = iter_ + self.curr_iter = iter_ - 1 self.total_iter = total_iter self._dataset_sig = self._make_signature( dict(**dataset_config, **selector), diff --git a/utils/configuration.py b/utils/configuration.py index 157d249..5a5bc0c 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -8,13 +8,13 @@ class SystemConfiguration(TypedDict): CUDA_VISIBLE_DEVICES: str save_dir: str image_log_on: bool + val_size: int class DatasetConfiguration(TypedDict): name: str root_dir: str train_size: int - val_size: int num_sampled_frames: int truncate_threshold: int discard_threshold: int -- cgit v1.2.3 From 99d1b18ac380bcf7d7579d80e09c9ddaecde44b3 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 27 Mar 2021 21:27:12 +0800 Subject: Normalize triplet losses --- models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/model.py b/models/model.py index ceadb92..cc5887e 100644 --- a/models/model.py +++ b/models/model.py @@ -344,8 +344,8 @@ class Model: ) losses = torch.stack(( *ae_losses, - hpm_result.pop('loss').mean(), - pn_result.pop('loss').mean() + torch.log(hpm_result.pop('loss').mean() + 1), + torch.log(pn_result.pop('loss').mean() + 1) )) return losses, hpm_result, pn_result -- cgit v1.2.3 From b9f35fbe7d78b3c478086ea26c2a76f72ce35687 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 3 Apr 2021 19:45:00 +0800 Subject: Revert "Normalize triplet losses" This reverts commit 99d1b18a --- models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/model.py b/models/model.py index cc5887e..ceadb92 100644 --- a/models/model.py +++ b/models/model.py @@ -344,8 +344,8 @@ class Model: ) losses = torch.stack(( *ae_losses, - torch.log(hpm_result.pop('loss').mean() + 1), - torch.log(pn_result.pop('loss').mean() + 1) + hpm_result.pop('loss').mean(), + pn_result.pop('loss').mean() )) return losses, hpm_result, pn_result -- cgit v1.2.3