diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-15 20:28:06 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-15 20:28:06 +0800 |
commit | a68562cbb7f602cc75b3f8f0bf0c285d9e4e4c8b (patch) | |
tree | c320921751d63fcb91f3482441d16fb2b9122ae0 /models/model.py | |
parent | 864fca2c9ca65847c0f1f318dfe50a1e6155e418 (diff) |
Remove redundant wrapper given by dataloader
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/models/model.py b/models/model.py index c350d11..c1cc703 100644 --- a/models/model.py +++ b/models/model.py @@ -439,14 +439,14 @@ class Model: return gallery_samples, probe_samples def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): - label = sample.pop('label').item() - clip = sample.pop('clip').to(self.device) + label, condition, view, clip = sample.values() with torch.no_grad(): - feature = self.rgb_pn(clip) + feature = self.rgb_pn(clip.to(self.device)) return { - **{'label': label}, - **sample, - **{'feature': feature} + 'label': label.item(), + 'condition': condition[0], + 'view': view[0], + 'feature': feature } @staticmethod |