diff options
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 49 |
1 files changed, 25 insertions, 24 deletions
diff --git a/models/model.py b/models/model.py index 7ce189c..e83cc7f 100644 --- a/models/model.py +++ b/models/model.py @@ -399,20 +399,20 @@ class Model: self.rgb_pn = self.rgb_pn.to(self.device) self.rgb_pn.eval() - gallery_samples, probe_samples = [], {} - # Gallery - checkpoint = torch.load(list(checkpoints.values())[0]) - self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) - for sample in tqdm(gallery_dataloader, - desc='Transforming gallery', unit='clips'): - gallery_samples.append(self._get_eval_sample(sample)) - gallery_samples = default_collate(gallery_samples) - # Probe - for (condition, dataloader) in probe_dataloaders.items(): + gallery_samples, probe_samples = {}, {} + for (condition, probe_dataloader) in probe_dataloaders.items(): checkpoint = torch.load(checkpoints[condition]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + # Gallery + gallery_samples_c = [] + for sample in tqdm(gallery_dataloader, + desc=f'Transforming gallery {condition}', + unit='clips'): + gallery_samples_c.append(self._get_eval_sample(sample)) + gallery_samples[condition] = default_collate(gallery_samples_c) + # Probe probe_samples_c = [] - for sample in tqdm(dataloader, + for sample in tqdm(probe_dataloader, desc=f'Transforming probe {condition}', unit='clips'): probe_samples_c.append(self._get_eval_sample(sample)) @@ -437,27 +437,28 @@ class Model: probe_samples: Dict[str, Dict[str, Union[List[str], torch.Tensor]]], num_ranks: int = 5 ) -> Dict[str, torch.Tensor]: - probe_conditions = self._probe_datasets_meta.keys() + conditions = gallery_samples.keys() gallery_views_meta = self._gallery_dataset_meta['views'] probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] accuracy = { condition: torch.empty( len(gallery_views_meta), len(probe_views_meta), num_ranks ) - for condition in self._probe_datasets_meta.keys() + for condition in conditions } - (labels_g, _, views_g, features_g) = gallery_samples.values() - views_g = np.asarray(views_g) - for (v_g_i, view_g) in enumerate(gallery_views_meta): - gallery_view_mask = (views_g == view_g) - f_g = features_g[gallery_view_mask] - y_g = labels_g[gallery_view_mask] - for condition in probe_conditions: - probe_samples_c = probe_samples[condition] - accuracy_c = accuracy[condition] - (labels_p, _, views_p, features_p) = probe_samples_c.values() - views_p = np.asarray(views_p) + for condition in conditions: + gallery_samples_c = gallery_samples[condition] + (labels_g, _, views_g, features_g) = gallery_samples_c.values() + views_g = np.asarray(views_g) + probe_samples_c = probe_samples[condition] + (labels_p, _, views_p, features_p) = probe_samples_c.values() + views_p = np.asarray(views_p) + accuracy_c = accuracy[condition] + for (v_g_i, view_g) in enumerate(gallery_views_meta): + gallery_view_mask = (views_g == view_g) + f_g = features_g[gallery_view_mask] + y_g = labels_g[gallery_view_mask] for (v_p_i, view_p) in enumerate(probe_views_meta): probe_view_mask = (views_p == view_p) f_p = features_p[probe_view_mask] |