diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-01 11:22:16 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-01 11:22:16 +0800 |
commit | db0564967d8cfc03b2d3fe4f7d10eff0867e1771 (patch) | |
tree | 478fab37bb052ec40586f0c8418fffade828bfe2 | |
parent | 4bdc37bbd86a83647bbbda7bd1367c08e6c6f6d4 (diff) |
Move pairs variable to local
-rw-r--r-- | models/model.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/models/model.py b/models/model.py index 34cb816..b942eb8 100644 --- a/models/model.py +++ b/models/model.py @@ -59,8 +59,6 @@ class Model: self.in_size: tuple[int, int] = (64, 48) self.pr: Optional[int] = None self.k: Optional[int] = None - self.num_pairs: Optional[int] = None - self.num_pos_pairs: Optional[int] = None self._gallery_dataset_meta: Optional[dict[str, list]] = None self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None @@ -174,6 +172,9 @@ class Model: triplet_is_hard, triplet_is_mean, None ) + num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 + num_pos_pairs = (self.k*(self.k-1)//2) * self.pr + # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.triplet_loss = self.triplet_loss.to(self.device) @@ -256,12 +257,12 @@ class Model: mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) self._add_ranked_scalars( 'Embedding/HPM distance', mean_hpm_dist, - self.num_pos_pairs, self.num_pairs, self.curr_iter + num_pos_pairs, num_pairs, self.curr_iter ) mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) self._add_ranked_scalars( 'Embedding/ParNet distance', mean_pa_dist, - self.num_pos_pairs, self.num_pairs, self.curr_iter + num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) @@ -566,8 +567,6 @@ class Model: ) -> DataLoader: config: dict = dataloader_config.copy() (self.pr, self.k) = config.pop('batch_size', (8, 16)) - self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 - self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr if self.is_train: triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, |