summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/model.py83
1 files changed, 73 insertions, 10 deletions
diff --git a/models/model.py b/models/model.py
index cabfa97..456c2f1 100644
--- a/models/model.py
+++ b/models/model.py
@@ -4,6 +4,7 @@ from typing import Union, Optional
import numpy as np
import torch
import torch.nn as nn
+import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
@@ -53,6 +54,9 @@ class Model:
self.pr: Optional[int] = None
self.k: Optional[int] = None
+ self._gallery_dataset_meta: Optional[dict[str, list]] = None
+ self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None
+
self._model_sig: str = self._make_signature(self.meta, ['restore_iter'])
self._hp_sig: str = self._make_signature(self.hp)
self._dataset_sig: str = 'undefined'
@@ -164,7 +168,7 @@ class Model:
str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
],
dataloader_config: DataloaderConfiguration,
- ):
+ ) -> dict[str, torch.Tensor]:
self.is_train = False
# Split gallery and probe dataset
gallery_dataloader, probe_dataloaders = self._split_gallery_probe(
@@ -189,7 +193,7 @@ class Model:
for sample in tqdm(gallery_dataloader,
desc='Transforming gallery', unit='clips'):
clip = sample.pop('clip').to(self.device)
- feature = self.rgb_pn(clip).detach().cpu()
+ feature = self.rgb_pn(clip).detach()
gallery_samples.append({**sample, **{'feature': feature}})
gallery_samples = default_collate(gallery_samples)
@@ -200,12 +204,60 @@ class Model:
for sample in tqdm(dataloader,
desc=f'Transforming probe {name}', unit='clips'):
clip = sample.pop('clip').to(self.device)
- feature = self.rgb_pn(clip)
+ feature = self.rgb_pn(clip).detach()
probe_samples[name].append({**sample, **{'feature': feature}})
for (k, v) in probe_samples.items():
probe_samples[k] = default_collate(v)
- # TODO Implement evaluation function here
+ return self._evaluate(gallery_samples, probe_samples)
+
+ def _evaluate(
+ self,
+ gallery_samples: dict[str, Union[list[str], torch.Tensor]],
+ probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]],
+ num_ranks: int = 5
+ ) -> dict[str, torch.Tensor]:
+ probe_conditions = self._probe_datasets_meta.keys()
+ gallery_views_meta = self._gallery_dataset_meta['views']
+ probe_views_meta = list(self._probe_datasets_meta.values())[0]['views']
+ accuracy = {
+ condition: torch.empty(
+ len(gallery_views_meta), len(probe_views_meta), num_ranks
+ )
+ for condition in self._probe_datasets_meta.keys()
+ }
+
+ (labels_g, _, views_g, features_g) = gallery_samples.values()
+ views_g = np.asarray(views_g)
+ for (v_g_i, view_g) in enumerate(gallery_views_meta):
+ gallery_view_mask = (views_g == view_g)
+ f_g = features_g[gallery_view_mask]
+ y_g = labels_g[gallery_view_mask]
+ for condition in probe_conditions:
+ probe_samples_c = probe_samples[condition]
+ accuracy_c = accuracy[condition]
+ (labels_p, _, views_p, features_p) = probe_samples_c.values()
+ views_p = np.asarray(views_p)
+ for (v_p_i, view_p) in enumerate(probe_views_meta):
+ probe_view_mask = (views_p == view_g)
+ f_p = features_p[probe_view_mask]
+ y_p = labels_p[probe_view_mask]
+ # Euclidean distance
+ f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(1)
+ f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(0)
+ f_g_times_f_p_sum = f_g @ f_p.T
+ dist = torch.sqrt(F.relu(
+ f_g_squared_sum - 2*f_g_times_f_p_sum + f_p_squared_sum
+ ))
+ # Ranked accuracy
+ rank_mask = dist.argsort(1)[:, :num_ranks]
+ positive_mat = torch.eq(y_p.unsqueeze(1),
+ y_g[rank_mask]).cumsum(1).gt(0)
+ positive_counts = positive_mat.sum(0)
+ total_counts, _ = dist.size()
+ accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts
+
+ return accuracy
def _load_pretrained(
self,
@@ -235,15 +287,26 @@ class Model:
gallery_dataset = self._parse_dataset_config(
dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR)
)
+ self._gallery_dataset_meta = gallery_dataset.metadata
gallery_dataloader = self._parse_dataloader_config(
gallery_dataset, dataloader_config
)
- probe_datasets = {k: self._parse_dataset_config(
- dict(**dataset_config, **v)
- ) for (k, v) in self.CASIAB_PROBE_SELECTORS.items()}
- probe_dataloaders = {k: self._parse_dataloader_config(
- v, dataloader_config
- ) for (k, v) in probe_datasets.items()}
+ probe_datasets = {
+ condition: self._parse_dataset_config(
+ dict(**dataset_config, **selector)
+ )
+ for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items()
+ }
+ self._probe_datasets_meta = {
+ condition: dataset.metadata
+ for (condition, dataset) in probe_datasets.items()
+ }
+ probe_dataloaders = {
+ condtion: self._parse_dataloader_config(
+ dataset, dataloader_config
+ )
+ for (condtion, dataset) in probe_datasets.items()
+ }
elif dataset_name == 'FVG':
# TODO
gallery_dataloader = None