summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/model.py30
1 files changed, 16 insertions, 14 deletions
diff --git a/models/model.py b/models/model.py
index adea626..b09d600 100644
--- a/models/model.py
+++ b/models/model.py
@@ -241,15 +241,15 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embedding_c, embedding_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+ embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
y = y.repeat(self.rgb_pn.num_parts, 1)
trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm(
- embedding_c, y[:self.rgb_pn.hpm.num_parts]
+ embed_c, y[:self.rgb_pn.hpm.num_parts]
)
trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn(
- embedding_p, y[self.rgb_pn.hpm.num_parts:]
+ embed_p, y[self.rgb_pn.hpm.num_parts:]
)
losses = torch.stack((
*ae_losses,
@@ -290,13 +290,13 @@ class Model:
num_pos_pairs, num_pairs, self.curr_iter
)
# Embedding norm
- mean_hpm_embedding = embedding_c.mean(0)
+ mean_hpm_embedding = embed_c.mean(0)
mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/HPM norm', mean_hpm_norm,
self.k, self.pr * self.k, self.curr_iter
)
- mean_pa_embedding = embedding_p.mean(0)
+ mean_pa_embedding = embed_p.mean(0)
mean_pa_norm = mean_pa_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/PartNet norm', mean_pa_norm,
@@ -425,13 +425,16 @@ class Model:
unit='clips'):
gallery_samples_c.append(self._get_eval_sample(sample))
gallery_samples[condition] = default_collate(gallery_samples_c)
+ gallery_samples['meta'] = self._gallery_dataset_meta
# Probe
probe_samples_c = []
for sample in tqdm(probe_dataloader,
desc=f'Transforming probe {condition}',
unit='clips'):
probe_samples_c.append(self._get_eval_sample(sample))
- probe_samples[condition] = default_collate(probe_samples_c)
+ probe_samples_c = default_collate(probe_samples_c)
+ probe_samples_c['meta'] = self._probe_datasets_meta[condition]
+ probe_samples[condition] = probe_samples_c
return gallery_samples, probe_samples
@@ -446,15 +449,15 @@ class Model:
**{'feature': feature}
}
+ @staticmethod
def evaluate(
- self,
- gallery_samples: dict[str, Union[list[str], torch.Tensor]],
- probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]],
+ gallery_samples: dict[str, dict[str, Union[list, torch.Tensor]]],
+ probe_samples: dict[str, dict[str, Union[list, torch.Tensor]]],
num_ranks: int = 5
) -> dict[str, torch.Tensor]:
- conditions = gallery_samples.keys()
- gallery_views_meta = self._gallery_dataset_meta['views']
- probe_views_meta = list(self._probe_datasets_meta.values())[0]['views']
+ conditions = list(probe_samples.keys())
+ gallery_views_meta = gallery_samples['meta']['views']
+ probe_views_meta = probe_samples[conditions[0]]['meta']['views']
accuracy = {
condition: torch.empty(
len(gallery_views_meta), len(probe_views_meta), num_ranks
@@ -467,7 +470,7 @@ class Model:
(labels_g, _, views_g, features_g) = gallery_samples_c.values()
views_g = np.asarray(views_g)
probe_samples_c = probe_samples[condition]
- (labels_p, _, views_p, features_p) = probe_samples_c.values()
+ (labels_p, _, views_p, features_p, _) = probe_samples_c.values()
views_p = np.asarray(views_p)
accuracy_c = accuracy[condition]
for (v_g_i, view_g) in enumerate(gallery_views_meta):
@@ -492,7 +495,6 @@ class Model:
positive_counts = positive_mat.sum(0)
total_counts, _ = dist.size()
accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts
-
return accuracy
def _load_pretrained(