diff options
-rw-r--r-- | config.py | 4 | ||||
-rw-r--r-- | models/layers.py | 4 | ||||
-rw-r--r-- | models/model.py | 74 | ||||
-rw-r--r-- | requirements.txt | 2 | ||||
-rw-r--r-- | utils/triplet_loss.py | 46 |
5 files changed, 97 insertions, 33 deletions
@@ -51,7 +51,7 @@ config = { # Use 1x1 convolution in dimensionality reduction 'hpm_use_1x1conv': False, # HPM pyramid scales, of which sum is number of parts - 'hpm_scales': (1, 2, 4), + 'hpm_scales': (1, 2, 4, 8), # Global pooling method 'hpm_use_avg_pool': True, 'hpm_use_max_pool': True, @@ -63,6 +63,8 @@ config = { 'embedding_dims': 256, # Batch Hard or Batch All 'triplet_is_hard': True, + # Use non-zero mean or sum + 'triplet_is_mean': True, # Triplet loss margins for HPM and PartNet, None for soft margin 'triplet_margins': None, }, diff --git a/models/layers.py b/models/layers.py index ae61583..e30d0c4 100644 --- a/models/layers.py +++ b/models/layers.py @@ -80,7 +80,9 @@ class DCGANConvTranspose2d(BasicConvTranspose2d): if self.is_last_layer: return self.trans_conv(x) else: - return super().forward(x) + x = self.trans_conv(x) + x = self.bn(x) + return F.leaky_relu(x, 0.2, inplace=True) class BasicLinear(nn.Module): diff --git a/models/model.py b/models/model.py index f4654f3..aef5302 100644 --- a/models/model.py +++ b/models/model.py @@ -56,6 +56,8 @@ class Model: self.in_size: Tuple[int, int] = (64, 48) self.pr: Optional[int] = None self.k: Optional[int] = None + self.num_pairs: Optional[int] = None + self.num_pos_pairs: Optional[int] = None self._gallery_dataset_meta: Optional[Dict[str, List]] = None self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None @@ -141,6 +143,7 @@ class Model: # Prepare for model, optimizer and scheduler model_hp: Dict = self.hp.get('model', {}).copy() triplet_is_hard = model_hp.pop('triplet_is_hard', True) + triplet_is_mean = model_hp.pop('triplet_is_mean', True) triplet_margins = model_hp.pop('triplet_margins', None) optim_hp: Dict = self.hp.get('optimizer', {}).copy() start_iter = optim_hp.pop('start_iter', 0) @@ -160,10 +163,13 @@ class Model: ) else: # Different margins self.triplet_loss = JointBatchTripletLoss( - self.rgb_pn.hpm_num_parts, triplet_is_hard, triplet_margins + self.rgb_pn.hpm_num_parts, + triplet_is_hard, triplet_is_mean, triplet_margins ) else: # Soft margins - self.triplet_loss = BatchTripletLoss(triplet_is_hard, None) + self.triplet_loss = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, None + ) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) @@ -213,7 +219,7 @@ class Model: y = batch_c1['label'].to(self.device) # Duplicate labels for each part y = y.repeat(self.rgb_pn.num_total_parts, 1) - trip_loss, dist, non_zero_counts = self.triplet_loss(embedding, y) + trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) losses = torch.cat(( ae_losses, torch.stack(( @@ -237,18 +243,36 @@ class Model: 'HPM': losses[3], 'PartNet': losses[4] }, self.curr_iter) - self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': non_zero_counts[:self.rgb_pn.hpm_num_parts].mean(), - 'PartNet': non_zero_counts[self.rgb_pn.hpm_num_parts:].mean() - }, self.curr_iter) - self.writer.add_scalars('Embedding/distance', { - 'HPM': dist[:self.rgb_pn.hpm_num_parts].mean(), - 'PartNet': dist[self.rgb_pn.hpm_num_parts].mean() - }, self.curr_iter) - self.writer.add_scalars('Embedding/2-norm', { - 'HPM': embedding[:self.rgb_pn.hpm_num_parts].norm(), - 'PartNet': embedding[self.rgb_pn.hpm_num_parts].norm() - }, self.curr_iter) + # None-zero losses in batch + if num_non_zero is not None: + self.writer.add_scalars('Loss/non-zero counts', { + 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), + 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() + }, self.curr_iter) + # Embedding distance + mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) + self._add_ranked_scalars( + 'Embedding/HPM distance', mean_hpm_dist, + self.num_pos_pairs, self.num_pairs, self.curr_iter + ) + mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) + self._add_ranked_scalars( + 'Embedding/ParNet distance', mean_pa_dist, + self.num_pos_pairs, self.num_pairs, self.curr_iter + ) + # Embedding norm + mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) + self._add_ranked_scalars( + 'Embedding/HPM norm', mean_hpm_norm, + self.k, self.pr * self.k, self.curr_iter + ) + mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_norm = mean_pa_embedding.norm(dim=-1) + self._add_ranked_scalars( + 'Embedding/PartNet norm', mean_pa_norm, + self.k, self.pr * self.k, self.curr_iter + ) if self.curr_iter % 100 == 0: lrs = self.scheduler.get_last_lr() @@ -300,6 +324,24 @@ class Model: self.writer.close() break + def _add_ranked_scalars( + self, + main_tag: str, + metric: torch.Tensor, + num_pos: int, + num_all: int, + global_step: int + ): + rank = metric.argsort() + pos_ile = 100 - (num_pos - 1) * 100 // num_all + self.writer.add_scalars(main_tag, { + '0%-ile': metric[rank[-1]], + f'{100 - pos_ile}%-ile': metric[rank[-num_pos]], + '50%-ile': metric[rank[num_all // 2 - 1]], + f'{pos_ile}%-ile': metric[rank[num_pos - 1]], + '100%-ile': metric[rank[0]] + }, global_step) + def predict_all( self, iters: Tuple[int], @@ -521,6 +563,8 @@ class Model: ) -> DataLoader: config: Dict = dataloader_config.copy() (self.pr, self.k) = config.pop('batch_size', (8, 16)) + self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 + self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr if self.is_train: triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, diff --git a/requirements.txt b/requirements.txt index 4d30e17..926a587 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch~=1.7.1 torchvision~=0.8.0a0+ecf4e9c numpy~=1.19.4 -tqdm~=4.57.0 +tqdm~=4.58.0 Pillow~=8.1.0 scikit-learn~=0.24.0
\ No newline at end of file diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 22ac2ab..77c7234 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -9,15 +9,19 @@ class BatchTripletLoss(nn.Module): def __init__( self, is_hard: bool = True, + is_mean: bool = True, margin: Optional[float] = 0.2, ): super().__init__() self.is_hard = is_hard + self.is_mean = is_mean self.margin = margin def forward(self, x, y): p, n, c = x.size() dist = self._batch_distance(x) + flat_dist = dist.tril(-1) + flat_dist = flat_dist[flat_dist != 0].view(p, -1) if self.is_hard: positive_negative_dist = self._hard_distance(dist, y, p, n) @@ -25,12 +29,20 @@ class BatchTripletLoss(nn.Module): positive_negative_dist = self._all_distance(dist, y, p, n) if self.margin: - all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - else: - all_loss = F.softplus(positive_negative_dist).view(p, -1) - non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - - return non_zero_mean, dist.mean((1, 2)), non_zero_counts + losses = F.relu(self.margin + positive_negative_dist).view(p, -1) + non_zero_counts = (losses != 0).sum(1).float() + if self.is_mean: + loss_metric = self._none_zero_mean(losses, non_zero_counts) + else: # is_sum + loss_metric = losses.sum(1) + return loss_metric, flat_dist, non_zero_counts + else: # Soft margin + losses = F.softplus(positive_negative_dist).view(p, -1) + if self.is_mean: + loss_metric = losses.mean(1) + else: # is_sum + loss_metric = losses.sum(1) + return loss_metric, flat_dist, None @staticmethod def _batch_distance(x): @@ -65,13 +77,11 @@ class BatchTripletLoss(nn.Module): return positive_negative_dist @staticmethod - def _none_zero_parted_mean(all_loss): + def _none_zero_mean(losses, non_zero_counts): # Non-zero parted mean - non_zero_counts = (all_loss != 0).sum(1).float() - non_zero_mean = all_loss.sum(1) / non_zero_counts + non_zero_mean = losses.sum(1) / non_zero_counts non_zero_mean[non_zero_counts == 0] = 0 - - return non_zero_mean, non_zero_counts + return non_zero_mean class JointBatchTripletLoss(BatchTripletLoss): @@ -79,9 +89,10 @@ class JointBatchTripletLoss(BatchTripletLoss): self, hpm_num_parts: int, is_hard: bool = True, + is_mean: bool = True, margins: Tuple[float, float] = (0.2, 0.2) ): - super().__init__(is_hard) + super().__init__(is_hard, is_mean) self.hpm_num_parts = hpm_num_parts self.margin_hpm, self.margin_pn = margins @@ -100,7 +111,12 @@ class JointBatchTripletLoss(BatchTripletLoss): pn_part_loss = F.relu( self.margin_pn + positive_negative_dist[self.hpm_num_parts:] ) - all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) + losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) + + non_zero_counts = (losses != 0).sum(1).float() + if self.is_mean: + loss_metric = self._none_zero_mean(losses, non_zero_counts) + else: # is_sum + loss_metric = losses.sum(1) - return non_zero_mean, dist.mean((1, 2)), non_zero_counts + return loss_metric, dist, non_zero_counts |