diff options
| -rw-r--r-- | models/model.py | 30 | 
1 files changed, 16 insertions, 14 deletions
| diff --git a/models/model.py b/models/model.py index 2608236..9b73f32 100644 --- a/models/model.py +++ b/models/model.py @@ -241,15 +241,15 @@ class Model:              # forward + backward + optimize              x_c1 = batch_c1['clip'].to(self.device)              x_c2 = batch_c2['clip'].to(self.device) -            embedding_c, embedding_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) +            embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)              y = batch_c1['label'].to(self.device)              # Duplicate labels for each part              y = y.repeat(self.rgb_pn.num_parts, 1)              trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( -                embedding_c, y[:self.rgb_pn.hpm.num_parts] +                embed_c, y[:self.rgb_pn.hpm.num_parts]              )              trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( -                embedding_p, y[self.rgb_pn.hpm.num_parts:] +                embed_p, y[self.rgb_pn.hpm.num_parts:]              )              losses = torch.stack((                  *ae_losses, @@ -290,13 +290,13 @@ class Model:                  num_pos_pairs, num_pairs, self.curr_iter              )              # Embedding norm -            mean_hpm_embedding = embedding_c.mean(0) +            mean_hpm_embedding = embed_c.mean(0)              mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)              self._add_ranked_scalars(                  'Embedding/HPM norm', mean_hpm_norm,                  self.k, self.pr * self.k, self.curr_iter              ) -            mean_pa_embedding = embedding_p.mean(0) +            mean_pa_embedding = embed_p.mean(0)              mean_pa_norm = mean_pa_embedding.norm(dim=-1)              self._add_ranked_scalars(                  'Embedding/PartNet norm', mean_pa_norm, @@ -425,13 +425,16 @@ class Model:                                 unit='clips'):                  gallery_samples_c.append(self._get_eval_sample(sample))              gallery_samples[condition] = default_collate(gallery_samples_c) +            gallery_samples['meta'] = self._gallery_dataset_meta              # Probe              probe_samples_c = []              for sample in tqdm(probe_dataloader,                                 desc=f'Transforming probe {condition}',                                 unit='clips'):                  probe_samples_c.append(self._get_eval_sample(sample)) -            probe_samples[condition] = default_collate(probe_samples_c) +            probe_samples_c = default_collate(probe_samples_c) +            probe_samples_c['meta'] = self._probe_datasets_meta[condition] +            probe_samples[condition] = probe_samples_c          return gallery_samples, probe_samples @@ -446,15 +449,15 @@ class Model:              **{'feature': feature}          } +    @staticmethod      def evaluate( -            self, -            gallery_samples: Dict[str, Union[List[str], torch.Tensor]], -            probe_samples: Dict[str, Dict[str, Union[List[str], torch.Tensor]]], +            gallery_samples: Dict[str, Dict[str, Union[List, torch.Tensor]]], +            probe_samples: Dict[str, Dict[str, Union[List, torch.Tensor]]],              num_ranks: int = 5      ) -> Dict[str, torch.Tensor]: -        conditions = gallery_samples.keys() -        gallery_views_meta = self._gallery_dataset_meta['views'] -        probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] +        conditions = list(probe_samples.keys()) +        gallery_views_meta = gallery_samples['meta']['views'] +        probe_views_meta = probe_samples[conditions[0]]['meta']['views']          accuracy = {              condition: torch.empty(                  len(gallery_views_meta), len(probe_views_meta), num_ranks @@ -467,7 +470,7 @@ class Model:              (labels_g, _, views_g, features_g) = gallery_samples_c.values()              views_g = np.asarray(views_g)              probe_samples_c = probe_samples[condition] -            (labels_p, _, views_p, features_p) = probe_samples_c.values() +            (labels_p, _, views_p, features_p, _) = probe_samples_c.values()              views_p = np.asarray(views_p)              accuracy_c = accuracy[condition]              for (v_g_i, view_g) in enumerate(gallery_views_meta): @@ -492,7 +495,6 @@ class Model:                      positive_counts = positive_mat.sum(0)                      total_counts, _ = dist.size()                      accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts -          return accuracy      def _load_pretrained( | 
