summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-16 11:26:36 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-16 11:27:06 +0800
commitdb97f6218ddce59d3a4abf38bcd70f163ec9da10 (patch)
treea6d5feecb0ef8c494f0512cd91333b5b1374cf7a /models
parentbd42d8e6f3561031f957a976db10d2eb8b3d9849 (diff)
parentc7bff4f1868b8dcf3f1ef104dde477d8d4244316 (diff)
Merge branch 'python3.8' into data_parallel_py3.8
Diffstat (limited to 'models')
-rw-r--r--models/model.py28
1 files changed, 23 insertions, 5 deletions
diff --git a/models/model.py b/models/model.py
index 7bc1da9..2715b26 100644
--- a/models/model.py
+++ b/models/model.py
@@ -276,6 +276,24 @@ class Model:
],
dataloader_config: DataloaderConfiguration,
) -> Dict[str, torch.Tensor]:
+ # Transform data to features
+ gallery_samples, probe_samples = self.transform(
+ iters, dataset_config, dataset_selectors, dataloader_config
+ )
+ # Evaluate features
+ accuracy = self.evaluate(gallery_samples, probe_samples)
+
+ return accuracy
+
+ def transform(
+ self,
+ iters: Tuple[int],
+ dataset_config: DatasetConfiguration,
+ dataset_selectors: Dict[
+ str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
+ ],
+ dataloader_config: DataloaderConfiguration
+ ):
self.is_train = False
# Split gallery and probe dataset
gallery_dataloader, probe_dataloaders = self._split_gallery_probe(
@@ -285,15 +303,16 @@ class Model:
checkpoints = self._load_pretrained(
iters, dataset_config, dataset_selectors
)
+
# Init models
model_hp = self.hp.get('model', {})
self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **model_hp)
# Try to accelerate computation using CUDA or others
+ self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
-
self.rgb_pn.eval()
- gallery_samples, probe_samples = [], {}
+ gallery_samples, probe_samples = [], {}
# Gallery
checkpoint = torch.load(list(checkpoints.values())[0])
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
@@ -301,7 +320,6 @@ class Model:
desc='Transforming gallery', unit='clips'):
gallery_samples.append(self._get_eval_sample(sample))
gallery_samples = default_collate(gallery_samples)
-
# Probe
for (condition, dataloader) in probe_dataloaders.items():
checkpoint = torch.load(checkpoints[condition])
@@ -313,7 +331,7 @@ class Model:
probe_samples_c.append(self._get_eval_sample(sample))
probe_samples[condition] = default_collate(probe_samples_c)
- return self._evaluate(gallery_samples, probe_samples)
+ return gallery_samples, probe_samples
def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]):
label = sample.pop('label').item()
@@ -325,7 +343,7 @@ class Model:
**{'feature': feature}
}
- def _evaluate(
+ def evaluate(
self,
gallery_samples: Dict[str, Union[List[str], torch.Tensor]],
probe_samples: Dict[str, Dict[str, Union[List[str], torch.Tensor]]],