From fe6cd66d19c16153322577fb13779020934cf1e2 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 15 Mar 2021 17:15:33 +0800 Subject: Support transforming on training datasets --- models/model.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/models/model.py b/models/model.py index 89b34aa..5966ae1 100644 --- a/models/model.py +++ b/models/model.py @@ -392,12 +392,12 @@ class Model: dataset_selectors: dict[ str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] ], - dataloader_config: DataloaderConfiguration + dataloader_config: DataloaderConfiguration, + is_train: bool = False ): - self.is_train = False # Split gallery and probe dataset gallery_dataloader, probe_dataloaders = self._split_gallery_probe( - dataset_config, dataloader_config + dataset_config, dataloader_config, is_train ) # Get pretrained models at iter_ checkpoints = self._load_pretrained( @@ -506,10 +506,11 @@ class Model: ] ) -> dict[str, str]: checkpoints = {} - for (iter_, (condition, selector)) in zip( - iters, dataset_selectors.items() + for (iter_, total_iter, (condition, selector)) in zip( + iters, self.total_iters, dataset_selectors.items() ): self.curr_iter = iter_ + self.total_iter = total_iter self._dataset_sig = self._make_signature( dict(**dataset_config, **selector), popped_keys=['root_dir', 'cache_on'] @@ -521,26 +522,29 @@ class Model: self, dataset_config: DatasetConfiguration, dataloader_config: DataloaderConfiguration, + is_train: bool = False ) -> tuple[DataLoader, dict[str, DataLoader]]: dataset_name = dataset_config.get('name', 'CASIA-B') if dataset_name == 'CASIA-B': + self.is_train = is_train gallery_dataset = self._parse_dataset_config( dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR) ) - self._gallery_dataset_meta = gallery_dataset.metadata - gallery_dataloader = self._parse_dataloader_config( - gallery_dataset, dataloader_config - ) probe_datasets = { condition: self._parse_dataset_config( dict(**dataset_config, **selector) ) for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items() } + self._gallery_dataset_meta = gallery_dataset.metadata self._probe_datasets_meta = { condition: dataset.metadata for (condition, dataset) in probe_datasets.items() } + self.is_train = False + gallery_dataloader = self._parse_dataloader_config( + gallery_dataset, dataloader_config + ) probe_dataloaders = { condition: self._parse_dataloader_config( dataset, dataloader_config -- cgit v1.2.3 From 864fca2c9ca65847c0f1f318dfe50a1e6155e418 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 15 Mar 2021 19:41:28 +0800 Subject: Fix redundant gallery_dataset_meta assignment --- models/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/model.py b/models/model.py index 5966ae1..c350d11 100644 --- a/models/model.py +++ b/models/model.py @@ -425,7 +425,6 @@ class Model: unit='clips'): gallery_samples_c.append(self._get_eval_sample(sample)) gallery_samples[condition] = default_collate(gallery_samples_c) - gallery_samples['meta'] = self._gallery_dataset_meta # Probe probe_samples_c = [] for sample in tqdm(probe_dataloader, @@ -435,6 +434,7 @@ class Model: probe_samples_c = default_collate(probe_samples_c) probe_samples_c['meta'] = self._probe_datasets_meta[condition] probe_samples[condition] = probe_samples_c + gallery_samples['meta'] = self._gallery_dataset_meta return gallery_samples, probe_samples -- cgit v1.2.3 From a68562cbb7f602cc75b3f8f0bf0c285d9e4e4c8b Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 15 Mar 2021 20:28:06 +0800 Subject: Remove redundant wrapper given by dataloader --- models/model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/models/model.py b/models/model.py index c350d11..c1cc703 100644 --- a/models/model.py +++ b/models/model.py @@ -439,14 +439,14 @@ class Model: return gallery_samples, probe_samples def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): - label = sample.pop('label').item() - clip = sample.pop('clip').to(self.device) + label, condition, view, clip = sample.values() with torch.no_grad(): - feature = self.rgb_pn(clip) + feature = self.rgb_pn(clip.to(self.device)) return { - **{'label': label}, - **sample, - **{'feature': feature} + 'label': label.item(), + 'condition': condition[0], + 'view': view[0], + 'feature': feature } @staticmethod -- cgit v1.2.3