diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-04 13:43:17 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-04 13:43:17 +0800 |
commit | a88e71765c5fe93ccf062ed88aa8e980a054c415 (patch) | |
tree | 674a9264f6e4bd81fe30a073d4cf6e0b1ba3c1f8 /models/model.py | |
parent | 6a79a0eee0401318554d6859a733254b770e8e87 (diff) | |
parent | 8578a141969720ec93b9bc172c8f20d0ef66ed16 (diff) |
Merge branch 'master' into python3.8
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/models/model.py b/models/model.py index 3f6a49d..7ce189c 100644 --- a/models/model.py +++ b/models/model.py @@ -1,4 +1,5 @@ import os +import random from datetime import datetime from typing import Union, Optional, Tuple, List, Dict, Set @@ -199,14 +200,17 @@ class Model: self.writer = SummaryWriter(self._log_name) + # Set seeds for reproducibility + random.seed(0) + torch.manual_seed(0) self.rgb_pn.train() # Init weights at first iter if self.curr_iter == 0: self.rgb_pn.apply(self.init_weights) else: # Load saved state dicts checkpoint = torch.load(self._checkpoint_name) - iter_, loss = checkpoint['iter'], checkpoint['loss'] - print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) + random.setstate(checkpoint['rand_states'][0]) + torch.set_rng_state(checkpoint['rand_states'][1]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) self.scheduler.load_state_dict(checkpoint['sched_state_dict']) @@ -320,11 +324,10 @@ class Model: if self.curr_iter % 1000 == 0: torch.save({ - 'iter': self.curr_iter, + 'rand_states': (random.getstate(), torch.get_rng_state()), 'model_state_dict': self.rgb_pn.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), 'sched_state_dict': self.scheduler.state_dict(), - 'loss': loss, }, self._checkpoint_name) if self.curr_iter == self.total_iter: @@ -420,7 +423,8 @@ class Model: def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]): label = sample.pop('label').item() clip = sample.pop('clip').to(self.device) - feature = self.rgb_pn(clip).detach() + with torch.no_grad(): + feature = self.rgb_pn(clip) return { **{'label': label}, **sample, |