summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-15 20:28:42 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-15 20:28:42 +0800
commit03d38fe32d668c28d7cba01f0b2e227f32b954c1 (patch)
tree812af1b0b2a1962c958b640f3c4dbb85a91cc83e /models/model.py
parent17ae5e529475a7e47fdde0ce69982b5dfe34f2eb (diff)
parenta68562cbb7f602cc75b3f8f0bf0c285d9e4e4c8b (diff)
Merge branch 'master' into python3.8
# Conflicts: # models/model.py
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py36
1 files changed, 20 insertions, 16 deletions
diff --git a/models/model.py b/models/model.py
index 4015524..fa56057 100644
--- a/models/model.py
+++ b/models/model.py
@@ -392,12 +392,12 @@ class Model:
dataset_selectors: Dict[
str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
],
- dataloader_config: DataloaderConfiguration
+ dataloader_config: DataloaderConfiguration,
+ is_train: bool = False
):
- self.is_train = False
# Split gallery and probe dataset
gallery_dataloader, probe_dataloaders = self._split_gallery_probe(
- dataset_config, dataloader_config
+ dataset_config, dataloader_config, is_train
)
# Get pretrained models at iter_
checkpoints = self._load_pretrained(
@@ -425,7 +425,6 @@ class Model:
unit='clips'):
gallery_samples_c.append(self._get_eval_sample(sample))
gallery_samples[condition] = default_collate(gallery_samples_c)
- gallery_samples['meta'] = self._gallery_dataset_meta
# Probe
probe_samples_c = []
for sample in tqdm(probe_dataloader,
@@ -435,18 +434,19 @@ class Model:
probe_samples_c = default_collate(probe_samples_c)
probe_samples_c['meta'] = self._probe_datasets_meta[condition]
probe_samples[condition] = probe_samples_c
+ gallery_samples['meta'] = self._gallery_dataset_meta
return gallery_samples, probe_samples
def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]):
- label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
+ label, condition, view, clip = sample.values()
with torch.no_grad():
- feature = self.rgb_pn(clip)
+ feature = self.rgb_pn(clip.to(self.device))
return {
- **{'label': label},
- **sample,
- **{'feature': feature}
+ 'label': label.item(),
+ 'condition': condition[0],
+ 'view': view[0],
+ 'feature': feature
}
@staticmethod
@@ -506,10 +506,11 @@ class Model:
]
) -> Dict[str, str]:
checkpoints = {}
- for (iter_, (condition, selector)) in zip(
- iters, dataset_selectors.items()
+ for (iter_, total_iter, (condition, selector)) in zip(
+ iters, self.total_iters, dataset_selectors.items()
):
self.curr_iter = iter_
+ self.total_iter = total_iter
self._dataset_sig = self._make_signature(
dict(**dataset_config, **selector),
popped_keys=['root_dir', 'cache_on']
@@ -521,26 +522,29 @@ class Model:
self,
dataset_config: DatasetConfiguration,
dataloader_config: DataloaderConfiguration,
+ is_train: bool = False
) -> Tuple[DataLoader, Dict[str, DataLoader]]:
dataset_name = dataset_config.get('name', 'CASIA-B')
if dataset_name == 'CASIA-B':
+ self.is_train = is_train
gallery_dataset = self._parse_dataset_config(
dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR)
)
- self._gallery_dataset_meta = gallery_dataset.metadata
- gallery_dataloader = self._parse_dataloader_config(
- gallery_dataset, dataloader_config
- )
probe_datasets = {
condition: self._parse_dataset_config(
dict(**dataset_config, **selector)
)
for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items()
}
+ self._gallery_dataset_meta = gallery_dataset.metadata
self._probe_datasets_meta = {
condition: dataset.metadata
for (condition, dataset) in probe_datasets.items()
}
+ self.is_train = False
+ gallery_dataloader = self._parse_dataloader_config(
+ gallery_dataset, dataloader_config
+ )
probe_dataloaders = {
condition: self._parse_dataloader_config(
dataset, dataloader_config