summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-15 20:28:06 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-15 20:28:06 +0800
commita68562cbb7f602cc75b3f8f0bf0c285d9e4e4c8b (patch)
treec320921751d63fcb91f3482441d16fb2b9122ae0 /models
parent864fca2c9ca65847c0f1f318dfe50a1e6155e418 (diff)
Remove redundant wrapper given by dataloader
Diffstat (limited to 'models')
-rw-r--r--models/model.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/models/model.py b/models/model.py
index c350d11..c1cc703 100644
--- a/models/model.py
+++ b/models/model.py
@@ -439,14 +439,14 @@ class Model:
return gallery_samples, probe_samples
def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]):
- label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
+ label, condition, view, clip = sample.values()
with torch.no_grad():
- feature = self.rgb_pn(clip)
+ feature = self.rgb_pn(clip.to(self.device))
return {
- **{'label': label},
- **sample,
- **{'feature': feature}
+ 'label': label.item(),
+ 'condition': condition[0],
+ 'view': view[0],
+ 'feature': feature
}
@staticmethod