diff options
-rw-r--r-- | models/model.py | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/models/model.py b/models/model.py index 7bc1da9..2715b26 100644 --- a/models/model.py +++ b/models/model.py @@ -276,6 +276,24 @@ class Model: ], dataloader_config: DataloaderConfiguration, ) -> Dict[str, torch.Tensor]: + # Transform data to features + gallery_samples, probe_samples = self.transform( + iters, dataset_config, dataset_selectors, dataloader_config + ) + # Evaluate features + accuracy = self.evaluate(gallery_samples, probe_samples) + + return accuracy + + def transform( + self, + iters: Tuple[int], + dataset_config: DatasetConfiguration, + dataset_selectors: Dict[ + str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]] + ], + dataloader_config: DataloaderConfiguration + ): self.is_train = False # Split gallery and probe dataset gallery_dataloader, probe_dataloaders = self._split_gallery_probe( @@ -285,15 +303,16 @@ class Model: checkpoints = self._load_pretrained( iters, dataset_config, dataset_selectors ) + # Init models model_hp = self.hp.get('model', {}) self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **model_hp) # Try to accelerate computation using CUDA or others + self.rgb_pn = nn.DataParallel(self.rgb_pn) self.rgb_pn = self.rgb_pn.to(self.device) - self.rgb_pn.eval() - gallery_samples, probe_samples = [], {} + gallery_samples, probe_samples = [], {} # Gallery checkpoint = torch.load(list(checkpoints.values())[0]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) @@ -301,7 +320,6 @@ class Model: desc='Transforming gallery', unit='clips'): gallery_samples.append(self._get_eval_sample(sample)) gallery_samples = default_collate(gallery_samples) - # Probe for (condition, dataloader) in probe_dataloaders.items(): checkpoint = torch.load(checkpoints[condition]) @@ -313,7 +331,7 @@ class Model: probe_samples_c.append(self._get_eval_sample(sample)) probe_samples[condition] = default_collate(probe_samples_c) - return self._evaluate(gallery_samples, probe_samples) + return gallery_samples, probe_samples def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]): label = sample.pop('label').item() @@ -325,7 +343,7 @@ class Model: **{'feature': feature} } - def _evaluate( + def evaluate( self, gallery_samples: Dict[str, Union[List[str], torch.Tensor]], probe_samples: Dict[str, Dict[str, Union[List[str], torch.Tensor]]], |