summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--eval.py2
-rw-r--r--models/model.py53
-rw-r--r--utils/dataset.py3
3 files changed, 29 insertions, 29 deletions
diff --git a/eval.py b/eval.py
index 7b68220..ec48066 100644
--- a/eval.py
+++ b/eval.py
@@ -14,7 +14,7 @@ dataset_selectors = {
'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
}
-accuracy = model.predict_all(config['model']['total_iter'], config['dataset'],
+accuracy = model.predict_all(config['model']['total_iters'], config['dataset'],
dataset_selectors, config['dataloader'])
rank = 5
np.set_printoptions(formatter={'float': '{:5.2f}'.format})
diff --git a/models/model.py b/models/model.py
index 1a62bae..6b799ad 100644
--- a/models/model.py
+++ b/models/model.py
@@ -220,7 +220,7 @@ class Model:
def predict_all(
self,
- iter_: int,
+ iters: tuple[int],
dataset_config: DatasetConfiguration,
dataset_selectors: dict[
str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
@@ -234,7 +234,7 @@ class Model:
)
# Get pretrained models at iter_
checkpoints = self._load_pretrained(
- iter_, dataset_config, dataset_selectors
+ iters, dataset_config, dataset_selectors
)
# Init models
model_hp = self.hp.get('model', {})
@@ -250,37 +250,32 @@ class Model:
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
for sample in tqdm(gallery_dataloader,
desc='Transforming gallery', unit='clips'):
- label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
- feature = self.rgb_pn(clip).detach()
- gallery_samples.append({
- **{'label': label},
- **sample,
- **{'feature': feature}
- })
+ gallery_samples.append(self._get_eval_sample(sample))
gallery_samples = default_collate(gallery_samples)
# Probe
for (condition, dataloader) in probe_dataloaders.items():
checkpoint = torch.load(checkpoints[condition])
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
- probe_samples[condition] = []
+ probe_samples_c = []
for sample in tqdm(dataloader,
desc=f'Transforming probe {condition}',
unit='clips'):
- label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
- feature = self.rgb_pn(clip).detach()
- probe_samples[condition].append({
- **{'label': label},
- **sample,
- **{'feature': feature}
- })
- for (k, v) in probe_samples.items():
- probe_samples[k] = default_collate(v)
+ probe_samples_c.append(self._get_eval_sample(sample))
+ probe_samples[condition] = default_collate(probe_samples_c)
return self._evaluate(gallery_samples, probe_samples)
+ def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]):
+ label = sample.pop('label').item()
+ clip = sample.pop('clip').to(self.device)
+ feature = self.rgb_pn(clip).detach()
+ return {
+ **{'label': label},
+ **sample,
+ **{'feature': feature}
+ }
+
def _evaluate(
self,
gallery_samples: dict[str, Union[list[str], torch.Tensor]],
@@ -331,20 +326,22 @@ class Model:
def _load_pretrained(
self,
- iter_: int,
+ iters: tuple[int],
dataset_config: DatasetConfiguration,
dataset_selectors: dict[
str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
]
) -> dict[str, str]:
checkpoints = {}
- self.curr_iter = iter_
- for (k, v) in dataset_selectors.items():
+ for (iter_, (condition, selector)) in zip(
+ iters, dataset_selectors.items()
+ ):
+ self.curr_iter = self.total_iter = iter_
self._dataset_sig = self._make_signature(
- dict(**dataset_config, **v),
+ dict(**dataset_config, **selector),
popped_keys=['root_dir', 'cache_on']
)
- checkpoints[k] = self._checkpoint_name
+ checkpoints[condition] = self._checkpoint_name
return checkpoints
def _split_gallery_probe(
@@ -372,10 +369,10 @@ class Model:
for (condition, dataset) in probe_datasets.items()
}
probe_dataloaders = {
- condtion: self._parse_dataloader_config(
+ condition: self._parse_dataloader_config(
dataset, dataloader_config
)
- for (condtion, dataset) in probe_datasets.items()
+ for (condition, dataset) in probe_datasets.items()
}
elif dataset_name == 'FVG':
# TODO
diff --git a/utils/dataset.py b/utils/dataset.py
index cd8b0f1..bbd42c3 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -201,6 +201,9 @@ class CASIAB(data.Dataset):
self._cached_clips[clip_name] = clip
else: # Load cache
cached_clip = self._cached_clips[clip_name]
+ # Return full clips while evaluating
+ if not self._is_train:
+ return cached_clip
cached_clip_frame_names \
= self._cached_clips_frame_names[clip_path]
# Index the original clip via sampled frame names