diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-23 00:16:58 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-23 00:16:58 +0800 | 
| commit | f3321b6782aff20df2c029239cc0f8b9a59ffde7 (patch) | |
| tree | e4b0c952d90e6ffcf255cdc4233496fb5612b886 | |
| parent | 7b5ff0a8dd2406af74776f31b1e2afabf9fadacd (diff) | |
Evaluation bug fixes and code review
1. Return full cached clip in evaluation
2. Add multi-iter checkpoints support in evaluation
3. Remove duplicated code while transforming
| -rw-r--r-- | eval.py | 2 | ||||
| -rw-r--r-- | models/model.py | 53 | ||||
| -rw-r--r-- | utils/dataset.py | 3 | 
3 files changed, 29 insertions, 29 deletions
| @@ -14,7 +14,7 @@ dataset_selectors = {      'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},  } -accuracy = model.predict_all(config['model']['total_iter'], config['dataset'], +accuracy = model.predict_all(config['model']['total_iters'], config['dataset'],                               dataset_selectors, config['dataloader'])  rank = 5  np.set_printoptions(formatter={'float': '{:5.2f}'.format}) diff --git a/models/model.py b/models/model.py index 1a62bae..6b799ad 100644 --- a/models/model.py +++ b/models/model.py @@ -220,7 +220,7 @@ class Model:      def predict_all(              self, -            iter_: int, +            iters: tuple[int],              dataset_config: DatasetConfiguration,              dataset_selectors: dict[                  str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] @@ -234,7 +234,7 @@ class Model:          )          # Get pretrained models at iter_          checkpoints = self._load_pretrained( -            iter_, dataset_config, dataset_selectors +            iters, dataset_config, dataset_selectors          )          # Init models          model_hp = self.hp.get('model', {}) @@ -250,37 +250,32 @@ class Model:          self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])          for sample in tqdm(gallery_dataloader,                             desc='Transforming gallery', unit='clips'): -            label = sample.pop('label').item() -            clip = sample.pop('clip').to(self.device) -            feature = self.rgb_pn(clip).detach() -            gallery_samples.append({ -                **{'label': label}, -                **sample, -                **{'feature': feature} -            }) +            gallery_samples.append(self._get_eval_sample(sample))          gallery_samples = default_collate(gallery_samples)          # Probe          for (condition, dataloader) in probe_dataloaders.items():              checkpoint = torch.load(checkpoints[condition])              self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) -            probe_samples[condition] = [] +            probe_samples_c = []              for sample in tqdm(dataloader,                                 desc=f'Transforming probe {condition}',                                 unit='clips'): -                label = sample.pop('label').item() -                clip = sample.pop('clip').to(self.device) -                feature = self.rgb_pn(clip).detach() -                probe_samples[condition].append({ -                    **{'label': label}, -                    **sample, -                    **{'feature': feature} -                }) -        for (k, v) in probe_samples.items(): -            probe_samples[k] = default_collate(v) +                probe_samples_c.append(self._get_eval_sample(sample)) +            probe_samples[condition] = default_collate(probe_samples_c)          return self._evaluate(gallery_samples, probe_samples) +    def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): +        label = sample.pop('label').item() +        clip = sample.pop('clip').to(self.device) +        feature = self.rgb_pn(clip).detach() +        return { +            **{'label': label}, +            **sample, +            **{'feature': feature} +        } +      def _evaluate(              self,              gallery_samples: dict[str, Union[list[str], torch.Tensor]], @@ -331,20 +326,22 @@ class Model:      def _load_pretrained(              self, -            iter_: int, +            iters: tuple[int],              dataset_config: DatasetConfiguration,              dataset_selectors: dict[                  str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]              ]      ) -> dict[str, str]:          checkpoints = {} -        self.curr_iter = iter_ -        for (k, v) in dataset_selectors.items(): +        for (iter_, (condition, selector)) in zip( +                iters, dataset_selectors.items() +        ): +            self.curr_iter = self.total_iter = iter_              self._dataset_sig = self._make_signature( -                dict(**dataset_config, **v), +                dict(**dataset_config, **selector),                  popped_keys=['root_dir', 'cache_on']              ) -            checkpoints[k] = self._checkpoint_name +            checkpoints[condition] = self._checkpoint_name          return checkpoints      def _split_gallery_probe( @@ -372,10 +369,10 @@ class Model:                  for (condition, dataset) in probe_datasets.items()              }              probe_dataloaders = { -                condtion: self._parse_dataloader_config( +                condition: self._parse_dataloader_config(                      dataset, dataloader_config                  ) -                for (condtion, dataset) in probe_datasets.items() +                for (condition, dataset) in probe_datasets.items()              }          elif dataset_name == 'FVG':              # TODO diff --git a/utils/dataset.py b/utils/dataset.py index cd8b0f1..bbd42c3 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -201,6 +201,9 @@ class CASIAB(data.Dataset):                  self._cached_clips[clip_name] = clip              else:  # Load cache                  cached_clip = self._cached_clips[clip_name] +                # Return full clips while evaluating +                if not self._is_train: +                    return cached_clip                  cached_clip_frame_names \                      = self._cached_clips_frame_names[clip_path]                  # Index the original clip via sampled frame names | 
