diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-10 13:45:38 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-10 13:45:38 +0800 |
commit | ba705ac60f463b034782ba03796dbfa2776a5631 (patch) | |
tree | 984a50254777583fd384bd4d72acf12843ad8de0 | |
parent | 58ef39d75098bce92654492e09edf1e83033d0c8 (diff) | |
parent | 2f7c09fb4fb985db1cbf6e2bdc6622a2c51ebfc3 (diff) |
Merge branch 'master' into python3.8
-rw-r--r-- | config.py | 2 | ||||
-rw-r--r-- | models/model.py | 2 | ||||
-rw-r--r-- | utils/configuration.py | 1 | ||||
-rw-r--r-- | utils/dataset.py | 17 |
4 files changed, 21 insertions, 1 deletions
@@ -21,6 +21,8 @@ config: Configuration = { 'train_size': 74, # Number of sampled frames per sequence (Training only) 'num_sampled_frames': 30, + # Truncate clips longer than `truncate_threshold` + 'truncate_threshold': 40, # Discard clips shorter than `discard_threshold` 'discard_threshold': 15, # Number of input channels of model diff --git a/models/model.py b/models/model.py index b5daa54..a8843c5 100644 --- a/models/model.py +++ b/models/model.py @@ -177,6 +177,7 @@ class Model: print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) + self.scheduler.load_state_dict(checkpoint['sched_state_dict']) # Training start start_time = datetime.now() @@ -249,6 +250,7 @@ class Model: 'iter': self.curr_iter, 'model_state_dict': self.rgb_pn.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), + 'sched_state_dict': self.scheduler.state_dict(), 'loss': loss, }, self._checkpoint_name) diff --git a/utils/configuration.py b/utils/configuration.py index f44bcf0..b9e6d92 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -15,6 +15,7 @@ class DatasetConfiguration(TypedDict): root_dir: str train_size: int num_sampled_frames: int + truncate_threshold: int discard_threshold: int selector: Optional[Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]] num_input_channels: int diff --git a/utils/dataset.py b/utils/dataset.py index 6b2a991..72cf050 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -25,6 +25,7 @@ class CASIAB(data.Dataset): is_train: bool = True, train_size: int = 74, num_sampled_frames: int = 30, + truncate_threshold: int = 40, discard_threshold: int = 15, selector: Optional[Dict[ str, Union[ClipClasses, ClipConditions, ClipViews] @@ -40,6 +41,8 @@ class CASIAB(data.Dataset): when `is_train` is False, test size will be inferred. :param num_sampled_frames: The number of sampled frames. (Training Only) + :param truncate_threshold: Truncate redundant frames larger + than this threshold. (Training Only) :param discard_threshold: Discard the sample if its number of frames is less than this threshold. :param selector: Restrict output data classes, conditions and @@ -58,6 +61,7 @@ class CASIAB(data.Dataset): self._root_dir = root_dir self._is_train = is_train self._num_sampled_frames = num_sampled_frames + self._truncate_threshold = truncate_threshold self._cache_on = cache_on self._frame_transform: transforms.Compose @@ -264,11 +268,22 @@ class CASIAB(data.Dataset): if self._is_train: num_frames = len(frame_names) - # Sample frames without replace if have enough frames + # 1# Sample with replacement if less than num_sampled_frames if num_frames < self._num_sampled_frames: frame_names = random.choices(frame_names, k=self._num_sampled_frames) + # 2# Sample without replacement if less than 40 + elif self._num_sampled_frames \ + <= num_frames < self._truncate_threshold: + frame_names = random.sample(frame_names, + k=self._num_sampled_frames) + # 3# Cut out an interval with truncate_threshold frames else: + frame_name_0_max = num_frames - self._truncate_threshold + frame_name_0 = random.randint(0, frame_name_0_max) + frame_name_1 = frame_name_0 + self._truncate_threshold + frame_names = sorted(frame_names) + frame_names = frame_names[frame_name_0:frame_name_1] frame_names = random.sample(frame_names, k=self._num_sampled_frames) |