summaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
Diffstat (limited to 'utils')
-rw-r--r--utils/dataset.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/utils/dataset.py b/utils/dataset.py
index 6b2a991..72cf050 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -25,6 +25,7 @@ class CASIAB(data.Dataset):
is_train: bool = True,
train_size: int = 74,
num_sampled_frames: int = 30,
+ truncate_threshold: int = 40,
discard_threshold: int = 15,
selector: Optional[Dict[
str, Union[ClipClasses, ClipConditions, ClipViews]
@@ -40,6 +41,8 @@ class CASIAB(data.Dataset):
when `is_train` is False, test size will be inferred.
:param num_sampled_frames: The number of sampled frames.
(Training Only)
+ :param truncate_threshold: Truncate redundant frames larger
+ than this threshold. (Training Only)
:param discard_threshold: Discard the sample if its number of
frames is less than this threshold.
:param selector: Restrict output data classes, conditions and
@@ -58,6 +61,7 @@ class CASIAB(data.Dataset):
self._root_dir = root_dir
self._is_train = is_train
self._num_sampled_frames = num_sampled_frames
+ self._truncate_threshold = truncate_threshold
self._cache_on = cache_on
self._frame_transform: transforms.Compose
@@ -264,11 +268,22 @@ class CASIAB(data.Dataset):
if self._is_train:
num_frames = len(frame_names)
- # Sample frames without replace if have enough frames
+ # 1# Sample with replacement if less than num_sampled_frames
if num_frames < self._num_sampled_frames:
frame_names = random.choices(frame_names,
k=self._num_sampled_frames)
+ # 2# Sample without replacement if less than 40
+ elif self._num_sampled_frames \
+ <= num_frames < self._truncate_threshold:
+ frame_names = random.sample(frame_names,
+ k=self._num_sampled_frames)
+ # 3# Cut out an interval with truncate_threshold frames
else:
+ frame_name_0_max = num_frames - self._truncate_threshold
+ frame_name_0 = random.randint(0, frame_name_0_max)
+ frame_name_1 = frame_name_0 + self._truncate_threshold
+ frame_names = sorted(frame_names)
+ frame_names = frame_names[frame_name_0:frame_name_1]
frame_names = random.sample(frame_names,
k=self._num_sampled_frames)