From 2f7c09fb4fb985db1cbf6e2bdc6622a2c51ebfc3 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 10 Feb 2021 13:43:15 +0800 Subject: Implement new sampling technique mentioned in GaitPart[1] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [1]C. Fan et al., “GaitPart: Temporal Part-Based Model for Gait Recognition,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2020, pp. 14225–14233. --- utils/dataset.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) (limited to 'utils/dataset.py') diff --git a/utils/dataset.py b/utils/dataset.py index bbd42c3..c487988 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -25,6 +25,7 @@ class CASIAB(data.Dataset): is_train: bool = True, train_size: int = 74, num_sampled_frames: int = 30, + truncate_threshold: int = 40, discard_threshold: int = 15, selector: Optional[dict[ str, Union[ClipClasses, ClipConditions, ClipViews] @@ -40,6 +41,8 @@ class CASIAB(data.Dataset): when `is_train` is False, test size will be inferred. :param num_sampled_frames: The number of sampled frames. (Training Only) + :param truncate_threshold: Truncate redundant frames larger + than this threshold. (Training Only) :param discard_threshold: Discard the sample if its number of frames is less than this threshold. :param selector: Restrict output data classes, conditions and @@ -58,6 +61,7 @@ class CASIAB(data.Dataset): self._root_dir = root_dir self._is_train = is_train self._num_sampled_frames = num_sampled_frames + self._truncate_threshold = truncate_threshold self._cache_on = cache_on self._frame_transform: transforms.Compose @@ -264,11 +268,22 @@ class CASIAB(data.Dataset): if self._is_train: num_frames = len(frame_names) - # Sample frames without replace if have enough frames + # 1# Sample with replacement if less than num_sampled_frames if num_frames < self._num_sampled_frames: frame_names = random.choices(frame_names, k=self._num_sampled_frames) + # 2# Sample without replacement if less than 40 + elif self._num_sampled_frames \ + <= num_frames < self._truncate_threshold: + frame_names = random.sample(frame_names, + k=self._num_sampled_frames) + # 3# Cut out an interval with truncate_threshold frames else: + frame_name_0_max = num_frames - self._truncate_threshold + frame_name_0 = random.randint(0, frame_name_0_max) + frame_name_1 = frame_name_0 + self._truncate_threshold + frame_names = sorted(frame_names) + frame_names = frame_names[frame_name_0:frame_name_1] frame_names = random.sample(frame_names, k=self._num_sampled_frames) -- cgit v1.2.3