summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/models/model.py b/models/model.py
index c350d11..c1cc703 100644
--- a/models/model.py
+++ b/models/model.py
@@ -439,14 +439,14 @@ class Model:
return gallery_samples, probe_samples
def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]):
- label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
+ label, condition, view, clip = sample.values()
with torch.no_grad():
- feature = self.rgb_pn(clip)
+ feature = self.rgb_pn(clip.to(self.device))
return {
- **{'label': label},
- **sample,
- **{'feature': feature}
+ 'label': label.item(),
+ 'condition': condition[0],
+ 'view': view[0],
+ 'feature': feature
}
@staticmethod