summaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-19 18:14:07 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-19 18:14:07 +0800
commitca194fbc53207b7b1c93044bd5fc766f5b9fad75 (patch)
treee66e61f759d70ecb298ec907f55652dad335a87c /utils
parentdaad9a2ada5fb096886a4c1a4a611ce46a5979d5 (diff)
Add cache switch, allowing load all data into RAM before sampling
Diffstat (limited to 'utils')
-rw-r--r--utils/dataset.py128
1 files changed, 106 insertions, 22 deletions
diff --git a/utils/dataset.py b/utils/dataset.py
index 7518127..fad94c0 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -8,6 +8,7 @@ import torch
from torch.utils import data
from torchvision.io import read_image
import torchvision.transforms as transforms
+from tqdm import tqdm
ClipLabels = NewType('ClipLabels', Set[str])
ClipConditions = NewType('ClipConditions', Set[str])
@@ -27,36 +28,48 @@ class CASIAB(data.Dataset):
is_train: bool = True,
train_size: int = 74,
num_sampled_frames: int = 30,
+ discard_threshold: int = 15,
selector: Optional[Dict[
str, Union[ClipLabels, ClipConditions, ClipLabels]
]] = None,
num_input_channels: int = 3,
frame_height: int = 64,
frame_width: int = 32,
- device: torch.device = torch.device('cpu')
+ device: torch.device = torch.device('cpu'),
+ cache_on: bool = False
):
"""
:param root_dir: Directory to dataset root.
:param is_train: Train or test, True for train, False for test.
- :param train_size: Number of subjects in train, when `is_train`
- is False, test size will be inferred.
- :param num_sampled_frames: Number of sampled frames for train
- :param selector: Restrict data labels, conditions and views
- :param num_input_channels Number of input channel, RBG image
- has 3 channel, grayscale image has 1 channel
- :param frame_height Frame height after transforms
- :param frame_width Frame width after transforms
- :param device Device be used for transforms
+ :param train_size: The number of subjects used for training,
+ when `is_train` is False, test size will be inferred.
+ :param num_sampled_frames: The number of sampled frames.
+ (Training Only)
+ :param discard_threshold: Discard the sample if its number of
+ frames is less than this threshold.
+ :param selector: Restrict output data labels, conditions and
+ views.
+ :param num_input_channels: The number of input channel(s),
+ RBG image has 3 channels, grayscale image has 1 channel.
+ :param frame_height: Frame height after transforming.
+ :param frame_width: Frame width after transforming.
+ :param device: Device used in transforms.
+ :param cache_on: Preload all clips in memory or not, this will
+ increase loading speed, but will add a preload process and
+ cost a lot of RAM. Loading the entire dataset need about
+ 7 GB of RAM.
"""
super(CASIAB, self).__init__()
self.root_dir = root_dir
self.is_train = is_train
self.train_size = train_size
self.num_sampled_frames = num_sampled_frames
+ self.discard_threshold = discard_threshold
self.num_input_channels = num_input_channels
self.frame_height = frame_height
self.frame_width = frame_width
self.device = device
+ self.cache_on = cache_on
self.frame_transform: transforms.Compose
transform_compose_list = [
@@ -84,10 +97,14 @@ class CASIAB(data.Dataset):
clip_names = clip_names[self.train_size * 10 * 11:]
# Remove empty clips
+ discard_clips_names = []
for clip_name in clip_names.copy():
- if len(os.listdir(os.path.join(self.root_dir, clip_name))) == 0:
- print("Clip '{}' is empty.".format(clip_name))
+ clip_path = os.path.join(self.root_dir, clip_name)
+ if len(os.listdir(clip_path)) < self.discard_threshold:
+ discard_clips_names.append(clip_name)
clip_names.remove(clip_name)
+ print(', '.join(discard_clips_names[:-1]),
+ 'and', discard_clips_names[-1], 'will be discarded.')
# clip name constructed by label, condition and view
# e.g 002-bg-02-090 means clip from Subject #2
@@ -109,12 +126,12 @@ class CASIAB(data.Dataset):
condition_regex = '|'.join(selected_conditions)
if selected_views:
view_regex = '|'.join(selected_views)
- clip_regex = '(' + ')-('.join([
+ clip_re = re.compile('(' + ')-('.join((
label_regex, condition_regex, view_regex
- ]) + ')'
+ )) + ')')
for clip_name in clip_names:
- match = re.fullmatch(clip_regex, clip_name)
+ match = clip_re.fullmatch(clip_name)
if match:
labels.append(match.group(1))
conditions.append(match.group(2))
@@ -148,6 +165,13 @@ class CASIAB(data.Dataset):
'views': set(self.views.tolist())
}
+ self._cached_clips_frame_names: Optional[Dict[str, List[str]]] = None
+ self._cached_clips: Optional[Dict[str, torch.Tensor]] = None
+ if self.cache_on:
+ self._cached_clips_frame_names = dict()
+ self._cached_clips = dict()
+ self._preload_all_video()
+
def __len__(self) -> int:
return len(self.labels)
@@ -157,7 +181,6 @@ class CASIAB(data.Dataset):
view = self.views[index]
clip_name = self._clip_names[index]
clip = self._read_video(clip_name)
-
sample = {
'label': label,
'condition': condition,
@@ -167,23 +190,84 @@ class CASIAB(data.Dataset):
return sample
- def _read_video(self, clip_name: str) -> torch.Tensor:
- frames = []
+ def _preload_all_video(self):
+ for clip_name in tqdm(self._clip_names,
+ desc='Preloading dataset', unit='clips'):
+ self._read_video(clip_name, is_caching=True)
+
+ def _read_video(self, clip_name: str,
+ is_caching: bool = False) -> torch.Tensor:
clip_path = os.path.join(self.root_dir, clip_name)
- sampled_frame_names = self._sample_frames(clip_path)
- for frame_name in sampled_frame_names:
+ sampled_frame_names = self._sample_frames(clip_path, is_caching)
+
+ if self.cache_on:
+ if is_caching:
+ clip = self._read_frames(clip_path, sampled_frame_names)
+ self._cached_clips[clip_name] = clip
+ else: # Load cache
+ cached_clip = self._cached_clips[clip_name]
+ cached_clip_frame_names \
+ = self._cached_clips_frame_names[clip_path]
+ # Index the original clip via sampled frame names
+ clip = self._load_cached_video(cached_clip,
+ cached_clip_frame_names,
+ sampled_frame_names)
+ else: # Cache off
+ clip = self._read_frames(clip_path, sampled_frame_names)
+
+ return clip
+
+ def _load_cached_video(
+ self,
+ clip: torch.Tensor,
+ frame_names: List[str],
+ sampled_frame_names: List[str]
+ ) -> torch.Tensor:
+ # Mask the original clip when it is long enough
+ if len(frame_names) >= self.num_sampled_frames:
+ sampled_frame_mask = np.isin(frame_names,
+ sampled_frame_names)
+ sampled_clip = clip[sampled_frame_mask]
+ else: # Create a indexing filter from the beginning of clip
+ sampled_index = frame_names.index(sampled_frame_names[0])
+ sampled_frame_filter = [sampled_index]
+ for i in range(1, self.num_sampled_frames):
+ if sampled_frame_names[i] != sampled_frame_names[i - 1]:
+ sampled_index += 1
+ sampled_frame_filter.append(sampled_index)
+ sampled_clip = clip[sampled_frame_filter]
+
+ return sampled_clip
+
+ def _read_frames(self, clip_path, frame_names):
+ frames = []
+ for frame_name in frame_names:
frame_path = os.path.join(clip_path, frame_name)
frame = read_image(frame_path)
+ # Transforming using CPU is not efficient
frame = self.frame_transform(frame.to(self.device))
frames.append(frame.cpu())
clip = torch.stack(frames)
return clip
- def _sample_frames(self, clip_path: str) -> List[str]:
- frame_names = os.listdir(clip_path)
+ def _sample_frames(self, clip_path: str,
+ is_caching: bool = False) -> List[str]:
+ if self.cache_on:
+ if is_caching:
+ # Sort frame in advance for loading convenience
+ frame_names = sorted(os.listdir(clip_path))
+ self._cached_clips_frame_names[clip_path] = frame_names
+ # Load all without sampling
+ return frame_names
+ else: # Load cache
+ frame_names = self._cached_clips_frame_names[clip_path]
+ else: # Cache off
+ frame_names = os.listdir(clip_path)
+
if self.is_train:
num_frames = len(frame_names)
+ # Sample frames without replace if have enough frames
if num_frames < self.num_sampled_frames:
frame_names = random.choices(frame_names,
k=self.num_sampled_frames)