summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py60
1 files changed, 32 insertions, 28 deletions
diff --git a/models/model.py b/models/model.py
index 5a8c0e8..d11617b 100644
--- a/models/model.py
+++ b/models/model.py
@@ -113,6 +113,13 @@ class Model:
self.curr_iters, self.total_iters, dataset_selectors.items()
):
print(f'Training model {condition} ...')
+ # Skip finished model
+ if curr_iter == total_iter:
+ continue
+ # Check invalid restore iter
+ elif curr_iter > total_iter:
+ raise ValueError("Restore iter '{}' should less than total "
+ "iter '{}'".format(curr_iter, total_iter))
self.curr_iter = curr_iter
self.total_iter = total_iter
self.fit(
@@ -210,7 +217,7 @@ class Model:
def predict_all(
self,
- iter_: int,
+ iters: Tuple[int],
dataset_config: Dict,
dataset_selectors: Dict[
str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
@@ -224,7 +231,7 @@ class Model:
)
# Get pretrained models at iter_
checkpoints = self._load_pretrained(
- iter_, dataset_config, dataset_selectors
+ iters, dataset_config, dataset_selectors
)
# Init models
model_hp = self.hp.get('model', {})
@@ -240,37 +247,32 @@ class Model:
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
for sample in tqdm(gallery_dataloader,
desc='Transforming gallery', unit='clips'):
- label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
- feature = self.rgb_pn(clip).detach()
- gallery_samples.append({
- **{'label': label},
- **sample,
- **{'feature': feature}
- })
+ gallery_samples.append(self._get_eval_sample(sample))
gallery_samples = default_collate(gallery_samples)
# Probe
for (condition, dataloader) in probe_dataloaders.items():
checkpoint = torch.load(checkpoints[condition])
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
- probe_samples[condition] = []
+ probe_samples_c = []
for sample in tqdm(dataloader,
desc=f'Transforming probe {condition}',
unit='clips'):
- label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
- feature = self.rgb_pn(clip).detach()
- probe_samples[condition].append({
- **{'label': label},
- **sample,
- **{'feature': feature}
- })
- for (k, v) in probe_samples.items():
- probe_samples[k] = default_collate(v)
+ probe_samples_c.append(self._get_eval_sample(sample))
+ probe_samples[condition] = default_collate(probe_samples_c)
return self._evaluate(gallery_samples, probe_samples)
+ def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]):
+ label = sample.pop('label').item()
+ clip = sample.pop('clip').to(self.device)
+ feature = self.rgb_pn(clip).detach()
+ return {
+ **{'label': label},
+ **sample,
+ **{'feature': feature}
+ }
+
def _evaluate(
self,
gallery_samples: Dict[str, Union[List[str], torch.Tensor]],
@@ -321,20 +323,22 @@ class Model:
def _load_pretrained(
self,
- iter_: int,
+ iters: Tuple[int],
dataset_config: Dict,
dataset_selectors: Dict[
str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
]
) -> Dict[str, str]:
checkpoints = {}
- self.curr_iter = iter_
- for (k, v) in dataset_selectors.items():
+ for (iter_, (condition, selector)) in zip(
+ iters, dataset_selectors.items()
+ ):
+ self.curr_iter = self.total_iter = iter_
self._dataset_sig = self._make_signature(
- dict(**dataset_config, **v),
+ dict(**dataset_config, **selector),
popped_keys=['root_dir', 'cache_on']
)
- checkpoints[k] = self._checkpoint_name
+ checkpoints[condition] = self._checkpoint_name
return checkpoints
def _split_gallery_probe(
@@ -362,10 +366,10 @@ class Model:
for (condition, dataset) in probe_datasets.items()
}
probe_dataloaders = {
- condtion: self._parse_dataloader_config(
+ condition: self._parse_dataloader_config(
dataset, dataloader_config
)
- for (condtion, dataset) in probe_datasets.items()
+ for (condition, dataset) in probe_datasets.items()
}
elif dataset_name == 'FVG':
# TODO