diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-01 11:26:23 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-01 11:26:23 +0800 |
commit | 2f3a7fbef70efd2cf91b7d77b3b71ffb4de907e2 (patch) | |
tree | b7099202831f014ec27e2d840dc2c659d19df347 /models | |
parent | e04f54d0bfc8fc711e53561065d772dae1926b64 (diff) | |
parent | db0564967d8cfc03b2d3fe4f7d10eff0867e1771 (diff) |
Merge branch 'master' into python3.8
Diffstat (limited to 'models')
-rw-r--r-- | models/model.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/models/model.py b/models/model.py index 48dcfaf..f0d4f08 100644 --- a/models/model.py +++ b/models/model.py @@ -59,8 +59,6 @@ class Model: self.in_size: Tuple[int, int] = (64, 48) self.pr: Optional[int] = None self.k: Optional[int] = None - self.num_pairs: Optional[int] = None - self.num_pos_pairs: Optional[int] = None self._gallery_dataset_meta: Optional[Dict[str, List]] = None self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None @@ -174,6 +172,9 @@ class Model: triplet_is_hard, triplet_is_mean, None ) + num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 + num_pos_pairs = (self.k*(self.k-1)//2) * self.pr + # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.triplet_loss = self.triplet_loss.to(self.device) @@ -256,12 +257,12 @@ class Model: mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) self._add_ranked_scalars( 'Embedding/HPM distance', mean_hpm_dist, - self.num_pos_pairs, self.num_pairs, self.curr_iter + num_pos_pairs, num_pairs, self.curr_iter ) mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) self._add_ranked_scalars( 'Embedding/ParNet distance', mean_pa_dist, - self.num_pos_pairs, self.num_pairs, self.curr_iter + num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) @@ -566,8 +567,6 @@ class Model: ) -> DataLoader: config: Dict = dataloader_config.copy() (self.pr, self.k) = config.pop('batch_size', (8, 16)) - self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 - self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr if self.is_train: triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, |