diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-28 22:14:27 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-28 22:32:31 +0800 |
commit | b837336695213e3e660992fcd01c5a52c654ea4f (patch) | |
tree | d1c072f9e69f3c8b6952785b9cc9f219a72d5397 | |
parent | c96a6c88fa63d62ec62807abf957c9a8df307b43 (diff) |
Log n-ile embedding distance and norm
-rw-r--r-- | models/model.py | 66 | ||||
-rw-r--r-- | utils/triplet_loss.py | 13 |
2 files changed, 61 insertions, 18 deletions
diff --git a/models/model.py b/models/model.py index 79952cb..18896ae 100644 --- a/models/model.py +++ b/models/model.py @@ -59,6 +59,8 @@ class Model: self.in_size: tuple[int, int] = (64, 48) self.pr: Optional[int] = None self.k: Optional[int] = None + self.num_pairs: Optional[int] = None + self.num_pos_pairs: Optional[int] = None self._gallery_dataset_meta: Optional[dict[str, list]] = None self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None @@ -216,7 +218,7 @@ class Model: y = batch_c1['label'].to(self.device) # Duplicate labels for each part y = y.repeat(self.rgb_pn.num_total_parts, 1) - trip_loss, dist, non_zero_counts = self.triplet_loss(embedding, y) + trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) losses = torch.cat(( ae_losses, torch.stack(( @@ -240,18 +242,36 @@ class Model: 'HPM': losses[3], 'PartNet': losses[4] }, self.curr_iter) - self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': non_zero_counts[:self.rgb_pn.hpm_num_parts].mean(), - 'PartNet': non_zero_counts[self.rgb_pn.hpm_num_parts:].mean() - }, self.curr_iter) - self.writer.add_scalars('Embedding/distance', { - 'HPM': dist[:self.rgb_pn.hpm_num_parts].mean(), - 'PartNet': dist[self.rgb_pn.hpm_num_parts].mean() - }, self.curr_iter) - self.writer.add_scalars('Embedding/2-norm', { - 'HPM': embedding[:self.rgb_pn.hpm_num_parts].norm(), - 'PartNet': embedding[self.rgb_pn.hpm_num_parts].norm() - }, self.curr_iter) + # None-zero losses in batch + if num_non_zero: + self.writer.add_scalars('Loss/non-zero counts', { + 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), + 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() + }, self.curr_iter) + # Embedding distance + mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) + self._add_ranked_scalars( + 'Embedding/HPM distance', mean_hpm_dist, + self.num_pos_pairs, self.num_pairs, self.curr_iter + ) + mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) + self._add_ranked_scalars( + 'Embedding/ParNet distance', mean_pa_dist, + self.num_pos_pairs, self.num_pairs, self.curr_iter + ) + # Embedding norm + mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) + self._add_ranked_scalars( + 'Embedding/HPM norm', mean_hpm_norm, + self.k, self.pr * self.k, self.curr_iter + ) + mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_norm = mean_pa_embedding.norm(dim=-1) + self._add_ranked_scalars( + 'Embedding/PartNet norm', mean_pa_norm, + self.k, self.pr * self.k, self.curr_iter + ) if self.curr_iter % 100 == 0: lrs = self.scheduler.get_last_lr() @@ -303,6 +323,24 @@ class Model: self.writer.close() break + def _add_ranked_scalars( + self, + main_tag: str, + metric: torch.Tensor, + num_pos: int, + num_all: int, + global_step: int + ): + rank = metric.argsort() + pos_ile = 100 - (num_pos - 1) * 100 // num_all + self.writer.add_scalars(main_tag, { + '0%-ile': metric[rank[-1]], + f'{100 - pos_ile}%-ile': metric[rank[-num_pos]], + '50%-ile': metric[rank[num_all // 2 - 1]], + f'{pos_ile}%-ile': metric[rank[num_pos - 1]], + '100%-ile': metric[rank[0]] + }, global_step) + def predict_all( self, iters: tuple[int], @@ -524,6 +562,8 @@ class Model: ) -> DataLoader: config: dict = dataloader_config.copy() (self.pr, self.k) = config.pop('batch_size', (8, 16)) + self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 + self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr if self.is_train: triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index c3e5802..52d676e 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -18,6 +18,8 @@ class BatchTripletLoss(nn.Module): def forward(self, x, y): p, n, c = x.size() dist = self._batch_distance(x) + flat_dist = dist.tril(-1) + flat_dist = flat_dist[flat_dist != 0].view(p, -1) if self.is_hard: positive_negative_dist = self._hard_distance(dist, y, p, n) @@ -26,11 +28,12 @@ class BatchTripletLoss(nn.Module): if self.margin: all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - else: + loss_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) + return loss_mean, flat_dist, non_zero_counts + else: # Soft margin all_loss = F.softplus(positive_negative_dist).view(p, -1) - non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - - return non_zero_mean, dist.mean((1, 2)), non_zero_counts + loss_mean = all_loss.mean(1) + return loss_mean, flat_dist, None @staticmethod def _batch_distance(x): @@ -103,4 +106,4 @@ class JointBatchTripletLoss(BatchTripletLoss): all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - return non_zero_mean, dist.mean((1, 2)), non_zero_counts + return non_zero_mean, dist, non_zero_counts |