diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-15 20:29:42 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-15 20:29:42 +0800 |
commit | 6a24467e5c1417a9e41db583ea11cc72d561669c (patch) | |
tree | 09c4e67ea3447cfc66513d4400ec8eddb76bcdd5 /models/model.py | |
parent | f953000c5b9490c040b68cbc233518c756ac00ab (diff) | |
parent | 03d38fe32d668c28d7cba01f0b2e227f32b954c1 (diff) |
Merge branch 'python3.8' into python3.7
# Conflicts:
# models/model.py
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 36 |
1 files changed, 20 insertions, 16 deletions
diff --git a/models/model.py b/models/model.py index b11d92a..8cb1459 100644 --- a/models/model.py +++ b/models/model.py @@ -389,12 +389,12 @@ class Model: dataset_selectors: Dict[ str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]] ], - dataloader_config: Dict + dataloader_config: Dict, + is_train: bool = False ): - self.is_train = False # Split gallery and probe dataset gallery_dataloader, probe_dataloaders = self._split_gallery_probe( - dataset_config, dataloader_config + dataset_config, dataloader_config, is_train ) # Get pretrained models at iter_ checkpoints = self._load_pretrained( @@ -422,7 +422,6 @@ class Model: unit='clips'): gallery_samples_c.append(self._get_eval_sample(sample)) gallery_samples[condition] = default_collate(gallery_samples_c) - gallery_samples['meta'] = self._gallery_dataset_meta # Probe probe_samples_c = [] for sample in tqdm(probe_dataloader, @@ -432,18 +431,19 @@ class Model: probe_samples_c = default_collate(probe_samples_c) probe_samples_c['meta'] = self._probe_datasets_meta[condition] probe_samples[condition] = probe_samples_c + gallery_samples['meta'] = self._gallery_dataset_meta return gallery_samples, probe_samples def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]): - label = sample.pop('label').item() - clip = sample.pop('clip').to(self.device) + label, condition, view, clip = sample.values() with torch.no_grad(): - feature = self.rgb_pn(clip) + feature = self.rgb_pn(clip.to(self.device)) return { - **{'label': label}, - **sample, - **{'feature': feature} + 'label': label.item(), + 'condition': condition[0], + 'view': view[0], + 'feature': feature } @staticmethod @@ -503,10 +503,11 @@ class Model: ] ) -> Dict[str, str]: checkpoints = {} - for (iter_, (condition, selector)) in zip( - iters, dataset_selectors.items() + for (iter_, total_iter, (condition, selector)) in zip( + iters, self.total_iters, dataset_selectors.items() ): self.curr_iter = iter_ + self.total_iter = total_iter self._dataset_sig = self._make_signature( dict(**dataset_config, **selector), popped_keys=['root_dir', 'cache_on'] @@ -518,26 +519,29 @@ class Model: self, dataset_config: Dict, dataloader_config: Dict, + is_train: bool = False ) -> Tuple[DataLoader, Dict[str, DataLoader]]: dataset_name = dataset_config.get('name', 'CASIA-B') if dataset_name == 'CASIA-B': + self.is_train = is_train gallery_dataset = self._parse_dataset_config( dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR) ) - self._gallery_dataset_meta = gallery_dataset.metadata - gallery_dataloader = self._parse_dataloader_config( - gallery_dataset, dataloader_config - ) probe_datasets = { condition: self._parse_dataset_config( dict(**dataset_config, **selector) ) for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items() } + self._gallery_dataset_meta = gallery_dataset.metadata self._probe_datasets_meta = { condition: dataset.metadata for (condition, dataset) in probe_datasets.items() } + self.is_train = False + gallery_dataloader = self._parse_dataloader_config( + gallery_dataset, dataloader_config + ) probe_dataloaders = { condition: self._parse_dataloader_config( dataset, dataloader_config |