From a68562cbb7f602cc75b3f8f0bf0c285d9e4e4c8b Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 15 Mar 2021 20:28:06 +0800 Subject: Remove redundant wrapper given by dataloader --- models/model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'models/model.py') diff --git a/models/model.py b/models/model.py index c350d11..c1cc703 100644 --- a/models/model.py +++ b/models/model.py @@ -439,14 +439,14 @@ class Model: return gallery_samples, probe_samples def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): - label = sample.pop('label').item() - clip = sample.pop('clip').to(self.device) + label, condition, view, clip = sample.values() with torch.no_grad(): - feature = self.rgb_pn(clip) + feature = self.rgb_pn(clip.to(self.device)) return { - **{'label': label}, - **sample, - **{'feature': feature} + 'label': label.item(), + 'condition': condition[0], + 'view': view[0], + 'feature': feature } @staticmethod -- cgit v1.2.3