diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/model.py | 30 | ||||
-rw-r--r-- | models/rgb_part_net.py | 2 |
2 files changed, 17 insertions, 15 deletions
diff --git a/models/model.py b/models/model.py index 25a78d6..acc78d3 100644 --- a/models/model.py +++ b/models/model.py @@ -238,15 +238,15 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embedding_c, embedding_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part y = y.repeat(self.rgb_pn.num_parts, 1) trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( - embedding_c, y[:self.rgb_pn.hpm.num_parts] + embed_c, y[:self.rgb_pn.hpm.num_parts] ) trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( - embedding_p, y[self.rgb_pn.hpm.num_parts:] + embed_p, y[self.rgb_pn.hpm.num_parts:] ) losses = torch.stack(( *ae_losses, @@ -287,13 +287,13 @@ class Model: num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm - mean_hpm_embedding = embedding_c.mean(0) + mean_hpm_embedding = embed_c.mean(0) mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/HPM norm', mean_hpm_norm, self.k, self.pr * self.k, self.curr_iter ) - mean_pa_embedding = embedding_p.mean(0) + mean_pa_embedding = embed_p.mean(0) mean_pa_norm = mean_pa_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/PartNet norm', mean_pa_norm, @@ -422,13 +422,16 @@ class Model: unit='clips'): gallery_samples_c.append(self._get_eval_sample(sample)) gallery_samples[condition] = default_collate(gallery_samples_c) + gallery_samples['meta'] = self._gallery_dataset_meta # Probe probe_samples_c = [] for sample in tqdm(probe_dataloader, desc=f'Transforming probe {condition}', unit='clips'): probe_samples_c.append(self._get_eval_sample(sample)) - probe_samples[condition] = default_collate(probe_samples_c) + probe_samples_c = default_collate(probe_samples_c) + probe_samples_c['meta'] = self._probe_datasets_meta[condition] + probe_samples[condition] = probe_samples_c return gallery_samples, probe_samples @@ -443,15 +446,15 @@ class Model: **{'feature': feature} } + @staticmethod def evaluate( - self, - gallery_samples: Dict[str, Union[List[str], torch.Tensor]], - probe_samples: Dict[str, Dict[str, Union[List[str], torch.Tensor]]], + gallery_samples: Dict[str, Dict[str, Union[List, torch.Tensor]]], + probe_samples: Dict[str, Dict[str, Union[List, torch.Tensor]]], num_ranks: int = 5 ) -> Dict[str, torch.Tensor]: - conditions = gallery_samples.keys() - gallery_views_meta = self._gallery_dataset_meta['views'] - probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] + conditions = list(probe_samples.keys()) + gallery_views_meta = gallery_samples['meta']['views'] + probe_views_meta = probe_samples[conditions[0]]['meta']['views'] accuracy = { condition: torch.empty( len(gallery_views_meta), len(probe_views_meta), num_ranks @@ -464,7 +467,7 @@ class Model: (labels_g, _, views_g, features_g) = gallery_samples_c.values() views_g = np.asarray(views_g) probe_samples_c = probe_samples[condition] - (labels_p, _, views_p, features_p) = probe_samples_c.values() + (labels_p, _, views_p, features_p, _) = probe_samples_c.values() views_p = np.asarray(views_p) accuracy_c = accuracy[condition] for (v_g_i, view_g) in enumerate(gallery_views_meta): @@ -489,7 +492,6 @@ class Model: positive_counts = positive_mat.sum(0) total_counts, _ = dist.size() accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts - return accuracy def _load_pretrained( diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 3a5777e..6e72e38 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -20,7 +20,7 @@ class RGBPartNet(nn.Module): hpm_use_max_pool: bool = True, tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, - embedding_dims: tuple[int] = (256, 256), + embedding_dims: Tuple[int] = (256, 256), image_log_on: bool = False ): super().__init__() |