diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-10 14:11:15 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-10 14:11:15 +0800 |
commit | f2f7713efa03a877bc96ced37314b4c4a6dc1963 (patch) | |
tree | 12da1b5337d1fde51be523cb3b19a9df44d84982 | |
parent | 786d66d47c77250dd8d9cb11b6be710ce18a50fb (diff) | |
parent | 1b8d1614168ce6590c5e029c7f1007ac9b17048c (diff) |
Merge branch 'master' into data_parallel
-rw-r--r-- | models/model.py | 49 | ||||
-rw-r--r-- | utils/dataset.py | 6 |
2 files changed, 28 insertions, 27 deletions
diff --git a/models/model.py b/models/model.py index c2d70db..f515e05 100644 --- a/models/model.py +++ b/models/model.py @@ -421,20 +421,20 @@ class Model: self.rgb_pn = self.rgb_pn.to(self.device) self.rgb_pn.eval() - gallery_samples, probe_samples = [], {} - # Gallery - checkpoint = torch.load(list(checkpoints.values())[0]) - self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) - for sample in tqdm(gallery_dataloader, - desc='Transforming gallery', unit='clips'): - gallery_samples.append(self._get_eval_sample(sample)) - gallery_samples = default_collate(gallery_samples) - # Probe - for (condition, dataloader) in probe_dataloaders.items(): + gallery_samples, probe_samples = {}, {} + for (condition, probe_dataloader) in probe_dataloaders.items(): checkpoint = torch.load(checkpoints[condition]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + # Gallery + gallery_samples_c = [] + for sample in tqdm(gallery_dataloader, + desc=f'Transforming gallery {condition}', + unit='clips'): + gallery_samples_c.append(self._get_eval_sample(sample)) + gallery_samples[condition] = default_collate(gallery_samples_c) + # Probe probe_samples_c = [] - for sample in tqdm(dataloader, + for sample in tqdm(probe_dataloader, desc=f'Transforming probe {condition}', unit='clips'): probe_samples_c.append(self._get_eval_sample(sample)) @@ -459,27 +459,28 @@ class Model: probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]], num_ranks: int = 5 ) -> dict[str, torch.Tensor]: - probe_conditions = self._probe_datasets_meta.keys() + conditions = gallery_samples.keys() gallery_views_meta = self._gallery_dataset_meta['views'] probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] accuracy = { condition: torch.empty( len(gallery_views_meta), len(probe_views_meta), num_ranks ) - for condition in self._probe_datasets_meta.keys() + for condition in conditions } - (labels_g, _, views_g, features_g) = gallery_samples.values() - views_g = np.asarray(views_g) - for (v_g_i, view_g) in enumerate(gallery_views_meta): - gallery_view_mask = (views_g == view_g) - f_g = features_g[gallery_view_mask] - y_g = labels_g[gallery_view_mask] - for condition in probe_conditions: - probe_samples_c = probe_samples[condition] - accuracy_c = accuracy[condition] - (labels_p, _, views_p, features_p) = probe_samples_c.values() - views_p = np.asarray(views_p) + for condition in conditions: + gallery_samples_c = gallery_samples[condition] + (labels_g, _, views_g, features_g) = gallery_samples_c.values() + views_g = np.asarray(views_g) + probe_samples_c = probe_samples[condition] + (labels_p, _, views_p, features_p) = probe_samples_c.values() + views_p = np.asarray(views_p) + accuracy_c = accuracy[condition] + for (v_g_i, view_g) in enumerate(gallery_views_meta): + gallery_view_mask = (views_g == view_g) + f_g = features_g[gallery_view_mask] + y_g = labels_g[gallery_view_mask] for (v_p_i, view_p) in enumerate(probe_views_meta): probe_view_mask = (views_p == view_p) f_p = features_p[probe_view_mask] diff --git a/utils/dataset.py b/utils/dataset.py index c487988..387c211 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -111,9 +111,9 @@ class CASIAB(data.Dataset): # in Bag #2 condition from 90 degree angle classes, conditions, views = [], [], [] if selector: - selected_classes = selector.pop('classes', None) - selected_conditions = selector.pop('conditions', None) - selected_views = selector.pop('views', None) + selected_classes = selector.get('classes', None) + selected_conditions = selector.get('conditions', None) + selected_views = selector.get('views', None) class_regex = r'\d{3}' condition_regex = r'(nm|bg|cl)-0[0-6]' |