summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/models/model.py b/models/model.py
index 89b34aa..5966ae1 100644
--- a/models/model.py
+++ b/models/model.py
@@ -392,12 +392,12 @@ class Model:
dataset_selectors: dict[
str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
],
- dataloader_config: DataloaderConfiguration
+ dataloader_config: DataloaderConfiguration,
+ is_train: bool = False
):
- self.is_train = False
# Split gallery and probe dataset
gallery_dataloader, probe_dataloaders = self._split_gallery_probe(
- dataset_config, dataloader_config
+ dataset_config, dataloader_config, is_train
)
# Get pretrained models at iter_
checkpoints = self._load_pretrained(
@@ -506,10 +506,11 @@ class Model:
]
) -> dict[str, str]:
checkpoints = {}
- for (iter_, (condition, selector)) in zip(
- iters, dataset_selectors.items()
+ for (iter_, total_iter, (condition, selector)) in zip(
+ iters, self.total_iters, dataset_selectors.items()
):
self.curr_iter = iter_
+ self.total_iter = total_iter
self._dataset_sig = self._make_signature(
dict(**dataset_config, **selector),
popped_keys=['root_dir', 'cache_on']
@@ -521,26 +522,29 @@ class Model:
self,
dataset_config: DatasetConfiguration,
dataloader_config: DataloaderConfiguration,
+ is_train: bool = False
) -> tuple[DataLoader, dict[str, DataLoader]]:
dataset_name = dataset_config.get('name', 'CASIA-B')
if dataset_name == 'CASIA-B':
+ self.is_train = is_train
gallery_dataset = self._parse_dataset_config(
dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR)
)
- self._gallery_dataset_meta = gallery_dataset.metadata
- gallery_dataloader = self._parse_dataloader_config(
- gallery_dataset, dataloader_config
- )
probe_datasets = {
condition: self._parse_dataset_config(
dict(**dataset_config, **selector)
)
for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items()
}
+ self._gallery_dataset_meta = gallery_dataset.metadata
self._probe_datasets_meta = {
condition: dataset.metadata
for (condition, dataset) in probe_datasets.items()
}
+ self.is_train = False
+ gallery_dataloader = self._parse_dataloader_config(
+ gallery_dataset, dataloader_config
+ )
probe_dataloaders = {
condition: self._parse_dataloader_config(
dataset, dataloader_config