summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/model.py11
-rw-r--r--utils/triplet_loss.py5
2 files changed, 9 insertions, 7 deletions
diff --git a/models/model.py b/models/model.py
index 48dcfaf..f0d4f08 100644
--- a/models/model.py
+++ b/models/model.py
@@ -59,8 +59,6 @@ class Model:
self.in_size: Tuple[int, int] = (64, 48)
self.pr: Optional[int] = None
self.k: Optional[int] = None
- self.num_pairs: Optional[int] = None
- self.num_pos_pairs: Optional[int] = None
self._gallery_dataset_meta: Optional[Dict[str, List]] = None
self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None
@@ -174,6 +172,9 @@ class Model:
triplet_is_hard, triplet_is_mean, None
)
+ num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
+ num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
+
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
self.triplet_loss = self.triplet_loss.to(self.device)
@@ -256,12 +257,12 @@ class Model:
mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0)
self._add_ranked_scalars(
'Embedding/HPM distance', mean_hpm_dist,
- self.num_pos_pairs, self.num_pairs, self.curr_iter
+ num_pos_pairs, num_pairs, self.curr_iter
)
mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0)
self._add_ranked_scalars(
'Embedding/ParNet distance', mean_pa_dist,
- self.num_pos_pairs, self.num_pairs, self.curr_iter
+ num_pos_pairs, num_pairs, self.curr_iter
)
# Embedding norm
mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0)
@@ -566,8 +567,6 @@ class Model:
) -> DataLoader:
config: Dict = dataloader_config.copy()
(self.pr, self.k) = config.pop('batch_size', (8, 16))
- self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
- self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
if self.is_train:
triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
return DataLoader(dataset,
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 77c7234..60faa0c 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -68,7 +68,10 @@ class BatchTripletLoss(nn.Module):
@staticmethod
def _all_distance(dist, y, p, n):
- positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
+ # Unmask identical samples
+ positive_mask = torch.eye(
+ n, dtype=torch.bool, device=y.device
+ ) ^ (y.unsqueeze(1) == y.unsqueeze(2))
negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
all_positive = dist[positive_mask].view(p, n, -1, 1)
all_negative = dist[negative_mask].view(p, n, 1, -1)