diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-10 13:46:24 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-10 13:46:24 +0800 |
commit | 156fb6d957efda8b897c172d70eccc0d2016b2bf (patch) | |
tree | 5775eeb8e682d6012bb5294267283e325c21fdb9 | |
parent | 045fdb1d8f381ef1dafdec33e87fc2b6736615e4 (diff) | |
parent | ba705ac60f463b034782ba03796dbfa2776a5631 (diff) |
Merge branch 'python3.8' into python3.7
# Conflicts:
# utils/configuration.py
-rw-r--r-- | config.py | 2 | ||||
-rw-r--r-- | models/model.py | 2 | ||||
-rw-r--r-- | utils/dataset.py | 17 |
3 files changed, 20 insertions, 1 deletions
@@ -19,6 +19,8 @@ config = { 'train_size': 74, # Number of sampled frames per sequence (Training only) 'num_sampled_frames': 30, + # Truncate clips longer than `truncate_threshold` + 'truncate_threshold': 40, # Discard clips shorter than `discard_threshold` 'discard_threshold': 15, # Number of input channels of model diff --git a/models/model.py b/models/model.py index 5f079b8..eae15e3 100644 --- a/models/model.py +++ b/models/model.py @@ -174,6 +174,7 @@ class Model: print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) + self.scheduler.load_state_dict(checkpoint['sched_state_dict']) # Training start start_time = datetime.now() @@ -246,6 +247,7 @@ class Model: 'iter': self.curr_iter, 'model_state_dict': self.rgb_pn.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), + 'sched_state_dict': self.scheduler.state_dict(), 'loss': loss, }, self._checkpoint_name) diff --git a/utils/dataset.py b/utils/dataset.py index 6b2a991..72cf050 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -25,6 +25,7 @@ class CASIAB(data.Dataset): is_train: bool = True, train_size: int = 74, num_sampled_frames: int = 30, + truncate_threshold: int = 40, discard_threshold: int = 15, selector: Optional[Dict[ str, Union[ClipClasses, ClipConditions, ClipViews] @@ -40,6 +41,8 @@ class CASIAB(data.Dataset): when `is_train` is False, test size will be inferred. :param num_sampled_frames: The number of sampled frames. (Training Only) + :param truncate_threshold: Truncate redundant frames larger + than this threshold. (Training Only) :param discard_threshold: Discard the sample if its number of frames is less than this threshold. :param selector: Restrict output data classes, conditions and @@ -58,6 +61,7 @@ class CASIAB(data.Dataset): self._root_dir = root_dir self._is_train = is_train self._num_sampled_frames = num_sampled_frames + self._truncate_threshold = truncate_threshold self._cache_on = cache_on self._frame_transform: transforms.Compose @@ -264,11 +268,22 @@ class CASIAB(data.Dataset): if self._is_train: num_frames = len(frame_names) - # Sample frames without replace if have enough frames + # 1# Sample with replacement if less than num_sampled_frames if num_frames < self._num_sampled_frames: frame_names = random.choices(frame_names, k=self._num_sampled_frames) + # 2# Sample without replacement if less than 40 + elif self._num_sampled_frames \ + <= num_frames < self._truncate_threshold: + frame_names = random.sample(frame_names, + k=self._num_sampled_frames) + # 3# Cut out an interval with truncate_threshold frames else: + frame_name_0_max = num_frames - self._truncate_threshold + frame_name_0 = random.randint(0, frame_name_0_max) + frame_name_1 = frame_name_0 + self._truncate_threshold + frame_names = sorted(frame_names) + frame_names = frame_names[frame_name_0:frame_name_1] frame_names = random.sample(frame_names, k=self._num_sampled_frames) |