summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-28 22:14:27 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-28 22:32:31 +0800
commitb837336695213e3e660992fcd01c5a52c654ea4f (patch)
treed1c072f9e69f3c8b6952785b9cc9f219a72d5397 /models/model.py
parentc96a6c88fa63d62ec62807abf957c9a8df307b43 (diff)
Log n-ile embedding distance and norm
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py66
1 files changed, 53 insertions, 13 deletions
diff --git a/models/model.py b/models/model.py
index 79952cb..18896ae 100644
--- a/models/model.py
+++ b/models/model.py
@@ -59,6 +59,8 @@ class Model:
self.in_size: tuple[int, int] = (64, 48)
self.pr: Optional[int] = None
self.k: Optional[int] = None
+ self.num_pairs: Optional[int] = None
+ self.num_pos_pairs: Optional[int] = None
self._gallery_dataset_meta: Optional[dict[str, list]] = None
self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None
@@ -216,7 +218,7 @@ class Model:
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
y = y.repeat(self.rgb_pn.num_total_parts, 1)
- trip_loss, dist, non_zero_counts = self.triplet_loss(embedding, y)
+ trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
losses = torch.cat((
ae_losses,
torch.stack((
@@ -240,18 +242,36 @@ class Model:
'HPM': losses[3],
'PartNet': losses[4]
}, self.curr_iter)
- self.writer.add_scalars('Loss/non-zero counts', {
- 'HPM': non_zero_counts[:self.rgb_pn.hpm_num_parts].mean(),
- 'PartNet': non_zero_counts[self.rgb_pn.hpm_num_parts:].mean()
- }, self.curr_iter)
- self.writer.add_scalars('Embedding/distance', {
- 'HPM': dist[:self.rgb_pn.hpm_num_parts].mean(),
- 'PartNet': dist[self.rgb_pn.hpm_num_parts].mean()
- }, self.curr_iter)
- self.writer.add_scalars('Embedding/2-norm', {
- 'HPM': embedding[:self.rgb_pn.hpm_num_parts].norm(),
- 'PartNet': embedding[self.rgb_pn.hpm_num_parts].norm()
- }, self.curr_iter)
+ # None-zero losses in batch
+ if num_non_zero:
+ self.writer.add_scalars('Loss/non-zero counts', {
+ 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(),
+ 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean()
+ }, self.curr_iter)
+ # Embedding distance
+ mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0)
+ self._add_ranked_scalars(
+ 'Embedding/HPM distance', mean_hpm_dist,
+ self.num_pos_pairs, self.num_pairs, self.curr_iter
+ )
+ mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0)
+ self._add_ranked_scalars(
+ 'Embedding/ParNet distance', mean_pa_dist,
+ self.num_pos_pairs, self.num_pairs, self.curr_iter
+ )
+ # Embedding norm
+ mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0)
+ mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
+ self._add_ranked_scalars(
+ 'Embedding/HPM norm', mean_hpm_norm,
+ self.k, self.pr * self.k, self.curr_iter
+ )
+ mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0)
+ mean_pa_norm = mean_pa_embedding.norm(dim=-1)
+ self._add_ranked_scalars(
+ 'Embedding/PartNet norm', mean_pa_norm,
+ self.k, self.pr * self.k, self.curr_iter
+ )
if self.curr_iter % 100 == 0:
lrs = self.scheduler.get_last_lr()
@@ -303,6 +323,24 @@ class Model:
self.writer.close()
break
+ def _add_ranked_scalars(
+ self,
+ main_tag: str,
+ metric: torch.Tensor,
+ num_pos: int,
+ num_all: int,
+ global_step: int
+ ):
+ rank = metric.argsort()
+ pos_ile = 100 - (num_pos - 1) * 100 // num_all
+ self.writer.add_scalars(main_tag, {
+ '0%-ile': metric[rank[-1]],
+ f'{100 - pos_ile}%-ile': metric[rank[-num_pos]],
+ '50%-ile': metric[rank[num_all // 2 - 1]],
+ f'{pos_ile}%-ile': metric[rank[num_pos - 1]],
+ '100%-ile': metric[rank[0]]
+ }, global_step)
+
def predict_all(
self,
iters: tuple[int],
@@ -524,6 +562,8 @@ class Model:
) -> DataLoader:
config: dict = dataloader_config.copy()
(self.pr, self.k) = config.pop('batch_size', (8, 16))
+ self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
+ self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
if self.is_train:
triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
return DataLoader(dataset,