diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-11 21:15:58 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-11 21:15:58 +0800 |
commit | 7188d71b2b6faf3da527c8d0ade9a32ec4893dc5 (patch) | |
tree | a012ba1a14bcc4b58833f8ee53cc3114bea5ba0f /models/model.py | |
parent | d30cf2cb280e83e4a4abe1e9c2abdbba17d903a3 (diff) |
Implement evaluator
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 83 |
1 files changed, 73 insertions, 10 deletions
diff --git a/models/model.py b/models/model.py index cabfa97..456c2f1 100644 --- a/models/model.py +++ b/models/model.py @@ -4,6 +4,7 @@ from typing import Union, Optional import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate @@ -53,6 +54,9 @@ class Model: self.pr: Optional[int] = None self.k: Optional[int] = None + self._gallery_dataset_meta: Optional[dict[str, list]] = None + self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None + self._model_sig: str = self._make_signature(self.meta, ['restore_iter']) self._hp_sig: str = self._make_signature(self.hp) self._dataset_sig: str = 'undefined' @@ -164,7 +168,7 @@ class Model: str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] ], dataloader_config: DataloaderConfiguration, - ): + ) -> dict[str, torch.Tensor]: self.is_train = False # Split gallery and probe dataset gallery_dataloader, probe_dataloaders = self._split_gallery_probe( @@ -189,7 +193,7 @@ class Model: for sample in tqdm(gallery_dataloader, desc='Transforming gallery', unit='clips'): clip = sample.pop('clip').to(self.device) - feature = self.rgb_pn(clip).detach().cpu() + feature = self.rgb_pn(clip).detach() gallery_samples.append({**sample, **{'feature': feature}}) gallery_samples = default_collate(gallery_samples) @@ -200,12 +204,60 @@ class Model: for sample in tqdm(dataloader, desc=f'Transforming probe {name}', unit='clips'): clip = sample.pop('clip').to(self.device) - feature = self.rgb_pn(clip) + feature = self.rgb_pn(clip).detach() probe_samples[name].append({**sample, **{'feature': feature}}) for (k, v) in probe_samples.items(): probe_samples[k] = default_collate(v) - # TODO Implement evaluation function here + return self._evaluate(gallery_samples, probe_samples) + + def _evaluate( + self, + gallery_samples: dict[str, Union[list[str], torch.Tensor]], + probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]], + num_ranks: int = 5 + ) -> dict[str, torch.Tensor]: + probe_conditions = self._probe_datasets_meta.keys() + gallery_views_meta = self._gallery_dataset_meta['views'] + probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] + accuracy = { + condition: torch.empty( + len(gallery_views_meta), len(probe_views_meta), num_ranks + ) + for condition in self._probe_datasets_meta.keys() + } + + (labels_g, _, views_g, features_g) = gallery_samples.values() + views_g = np.asarray(views_g) + for (v_g_i, view_g) in enumerate(gallery_views_meta): + gallery_view_mask = (views_g == view_g) + f_g = features_g[gallery_view_mask] + y_g = labels_g[gallery_view_mask] + for condition in probe_conditions: + probe_samples_c = probe_samples[condition] + accuracy_c = accuracy[condition] + (labels_p, _, views_p, features_p) = probe_samples_c.values() + views_p = np.asarray(views_p) + for (v_p_i, view_p) in enumerate(probe_views_meta): + probe_view_mask = (views_p == view_g) + f_p = features_p[probe_view_mask] + y_p = labels_p[probe_view_mask] + # Euclidean distance + f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(1) + f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(0) + f_g_times_f_p_sum = f_g @ f_p.T + dist = torch.sqrt(F.relu( + f_g_squared_sum - 2*f_g_times_f_p_sum + f_p_squared_sum + )) + # Ranked accuracy + rank_mask = dist.argsort(1)[:, :num_ranks] + positive_mat = torch.eq(y_p.unsqueeze(1), + y_g[rank_mask]).cumsum(1).gt(0) + positive_counts = positive_mat.sum(0) + total_counts, _ = dist.size() + accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts + + return accuracy def _load_pretrained( self, @@ -235,15 +287,26 @@ class Model: gallery_dataset = self._parse_dataset_config( dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR) ) + self._gallery_dataset_meta = gallery_dataset.metadata gallery_dataloader = self._parse_dataloader_config( gallery_dataset, dataloader_config ) - probe_datasets = {k: self._parse_dataset_config( - dict(**dataset_config, **v) - ) for (k, v) in self.CASIAB_PROBE_SELECTORS.items()} - probe_dataloaders = {k: self._parse_dataloader_config( - v, dataloader_config - ) for (k, v) in probe_datasets.items()} + probe_datasets = { + condition: self._parse_dataset_config( + dict(**dataset_config, **selector) + ) + for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items() + } + self._probe_datasets_meta = { + condition: dataset.metadata + for (condition, dataset) in probe_datasets.items() + } + probe_dataloaders = { + condtion: self._parse_dataloader_config( + dataset, dataloader_config + ) + for (condtion, dataset) in probe_datasets.items() + } elif dataset_name == 'FVG': # TODO gallery_dataloader = None |