summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-04 13:43:17 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-04 13:43:17 +0800
commita88e71765c5fe93ccf062ed88aa8e980a054c415 (patch)
tree674a9264f6e4bd81fe30a073d4cf6e0b1ba3c1f8 /models
parent6a79a0eee0401318554d6859a733254b770e8e87 (diff)
parent8578a141969720ec93b9bc172c8f20d0ef66ed16 (diff)
Merge branch 'master' into python3.8
Diffstat (limited to 'models')
-rw-r--r--models/model.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/models/model.py b/models/model.py
index 3f6a49d..7ce189c 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,4 +1,5 @@
import os
+import random
from datetime import datetime
from typing import Union, Optional, Tuple, List, Dict, Set
@@ -199,14 +200,17 @@ class Model:
self.writer = SummaryWriter(self._log_name)
+ # Set seeds for reproducibility
+ random.seed(0)
+ torch.manual_seed(0)
self.rgb_pn.train()
# Init weights at first iter
if self.curr_iter == 0:
self.rgb_pn.apply(self.init_weights)
else: # Load saved state dicts
checkpoint = torch.load(self._checkpoint_name)
- iter_, loss = checkpoint['iter'], checkpoint['loss']
- print('{0:5d} loss: {1:.3f}'.format(iter_, loss))
+ random.setstate(checkpoint['rand_states'][0])
+ torch.set_rng_state(checkpoint['rand_states'][1])
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
self.scheduler.load_state_dict(checkpoint['sched_state_dict'])
@@ -320,11 +324,10 @@ class Model:
if self.curr_iter % 1000 == 0:
torch.save({
- 'iter': self.curr_iter,
+ 'rand_states': (random.getstate(), torch.get_rng_state()),
'model_state_dict': self.rgb_pn.state_dict(),
'optim_state_dict': self.optimizer.state_dict(),
'sched_state_dict': self.scheduler.state_dict(),
- 'loss': loss,
}, self._checkpoint_name)
if self.curr_iter == self.total_iter:
@@ -420,7 +423,8 @@ class Model:
def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]):
label = sample.pop('label').item()
clip = sample.pop('clip').to(self.device)
- feature = self.rgb_pn(clip).detach()
+ with torch.no_grad():
+ feature = self.rgb_pn(clip)
return {
**{'label': label},
**sample,