summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/models/model.py b/models/model.py
index 83b970a..7aff6c4 100644
--- a/models/model.py
+++ b/models/model.py
@@ -423,7 +423,8 @@ class Model:
def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]):
label = sample.pop('label').item()
clip = sample.pop('clip').to(self.device)
- feature = self.rgb_pn(clip).detach()
+ with torch.no_grad():
+ feature = self.rgb_pn(clip)
return {
**{'label': label},
**sample,