diff options
Diffstat (limited to 'utils/dataset.py')
-rw-r--r-- | utils/dataset.py | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/utils/dataset.py b/utils/dataset.py index bbd42c3..c487988 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -25,6 +25,7 @@ class CASIAB(data.Dataset): is_train: bool = True, train_size: int = 74, num_sampled_frames: int = 30, + truncate_threshold: int = 40, discard_threshold: int = 15, selector: Optional[dict[ str, Union[ClipClasses, ClipConditions, ClipViews] @@ -40,6 +41,8 @@ class CASIAB(data.Dataset): when `is_train` is False, test size will be inferred. :param num_sampled_frames: The number of sampled frames. (Training Only) + :param truncate_threshold: Truncate redundant frames larger + than this threshold. (Training Only) :param discard_threshold: Discard the sample if its number of frames is less than this threshold. :param selector: Restrict output data classes, conditions and @@ -58,6 +61,7 @@ class CASIAB(data.Dataset): self._root_dir = root_dir self._is_train = is_train self._num_sampled_frames = num_sampled_frames + self._truncate_threshold = truncate_threshold self._cache_on = cache_on self._frame_transform: transforms.Compose @@ -264,11 +268,22 @@ class CASIAB(data.Dataset): if self._is_train: num_frames = len(frame_names) - # Sample frames without replace if have enough frames + # 1# Sample with replacement if less than num_sampled_frames if num_frames < self._num_sampled_frames: frame_names = random.choices(frame_names, k=self._num_sampled_frames) + # 2# Sample without replacement if less than 40 + elif self._num_sampled_frames \ + <= num_frames < self._truncate_threshold: + frame_names = random.sample(frame_names, + k=self._num_sampled_frames) + # 3# Cut out an interval with truncate_threshold frames else: + frame_name_0_max = num_frames - self._truncate_threshold + frame_name_0 = random.randint(0, frame_name_0_max) + frame_name_1 = frame_name_0 + self._truncate_threshold + frame_names = sorted(frame_names) + frame_names = frame_names[frame_name_0:frame_name_1] frame_names = random.sample(frame_names, k=self._num_sampled_frames) |