diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-04 13:43:38 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-04 13:43:38 +0800 | 
| commit | 9a4a62598928f0befdd21ab7a562ee980eff092d (patch) | |
| tree | a4be9fef8da438ac8993461da158196d36023090 /models/model.py | |
| parent | b010eabd0f2d24e7fd6f6a72aa0eac9e4592b0e4 (diff) | |
| parent | a88e71765c5fe93ccf062ed88aa8e980a054c415 (diff) | |
Merge branch 'python3.8' into python3.7
Diffstat (limited to 'models/model.py')
| -rw-r--r-- | models/model.py | 14 | 
1 files changed, 9 insertions, 5 deletions
| diff --git a/models/model.py b/models/model.py index 0eb9823..e924f41 100644 --- a/models/model.py +++ b/models/model.py @@ -1,4 +1,5 @@  import os +import random  from datetime import datetime  from typing import Union, Optional, Tuple, List, Dict, Set @@ -196,14 +197,17 @@ class Model:          self.writer = SummaryWriter(self._log_name) +        # Set seeds for reproducibility +        random.seed(0) +        torch.manual_seed(0)          self.rgb_pn.train()          # Init weights at first iter          if self.curr_iter == 0:              self.rgb_pn.apply(self.init_weights)          else:  # Load saved state dicts              checkpoint = torch.load(self._checkpoint_name) -            iter_, loss = checkpoint['iter'], checkpoint['loss'] -            print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) +            random.setstate(checkpoint['rand_states'][0]) +            torch.set_rng_state(checkpoint['rand_states'][1])              self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])              self.optimizer.load_state_dict(checkpoint['optim_state_dict'])              self.scheduler.load_state_dict(checkpoint['sched_state_dict']) @@ -317,11 +321,10 @@ class Model:              if self.curr_iter % 1000 == 0:                  torch.save({ -                    'iter': self.curr_iter, +                    'rand_states': (random.getstate(), torch.get_rng_state()),                      'model_state_dict': self.rgb_pn.state_dict(),                      'optim_state_dict': self.optimizer.state_dict(),                      'sched_state_dict': self.scheduler.state_dict(), -                    'loss': loss,                  }, self._checkpoint_name)              if self.curr_iter == self.total_iter: @@ -417,7 +420,8 @@ class Model:      def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]):          label = sample.pop('label').item()          clip = sample.pop('clip').to(self.device) -        feature = self.rgb_pn(clip).detach() +        with torch.no_grad(): +            feature = self.rgb_pn(clip)          return {              **{'label': label},              **sample, | 
