diff options
Diffstat (limited to 'utils/dataset.py')
-rw-r--r-- | utils/dataset.py | 128 |
1 files changed, 106 insertions, 22 deletions
diff --git a/utils/dataset.py b/utils/dataset.py index 7518127..fad94c0 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -8,6 +8,7 @@ import torch from torch.utils import data from torchvision.io import read_image import torchvision.transforms as transforms +from tqdm import tqdm ClipLabels = NewType('ClipLabels', Set[str]) ClipConditions = NewType('ClipConditions', Set[str]) @@ -27,36 +28,48 @@ class CASIAB(data.Dataset): is_train: bool = True, train_size: int = 74, num_sampled_frames: int = 30, + discard_threshold: int = 15, selector: Optional[Dict[ str, Union[ClipLabels, ClipConditions, ClipLabels] ]] = None, num_input_channels: int = 3, frame_height: int = 64, frame_width: int = 32, - device: torch.device = torch.device('cpu') + device: torch.device = torch.device('cpu'), + cache_on: bool = False ): """ :param root_dir: Directory to dataset root. :param is_train: Train or test, True for train, False for test. - :param train_size: Number of subjects in train, when `is_train` - is False, test size will be inferred. - :param num_sampled_frames: Number of sampled frames for train - :param selector: Restrict data labels, conditions and views - :param num_input_channels Number of input channel, RBG image - has 3 channel, grayscale image has 1 channel - :param frame_height Frame height after transforms - :param frame_width Frame width after transforms - :param device Device be used for transforms + :param train_size: The number of subjects used for training, + when `is_train` is False, test size will be inferred. + :param num_sampled_frames: The number of sampled frames. + (Training Only) + :param discard_threshold: Discard the sample if its number of + frames is less than this threshold. + :param selector: Restrict output data labels, conditions and + views. + :param num_input_channels: The number of input channel(s), + RBG image has 3 channels, grayscale image has 1 channel. + :param frame_height: Frame height after transforming. + :param frame_width: Frame width after transforming. + :param device: Device used in transforms. + :param cache_on: Preload all clips in memory or not, this will + increase loading speed, but will add a preload process and + cost a lot of RAM. Loading the entire dataset need about + 7 GB of RAM. """ super(CASIAB, self).__init__() self.root_dir = root_dir self.is_train = is_train self.train_size = train_size self.num_sampled_frames = num_sampled_frames + self.discard_threshold = discard_threshold self.num_input_channels = num_input_channels self.frame_height = frame_height self.frame_width = frame_width self.device = device + self.cache_on = cache_on self.frame_transform: transforms.Compose transform_compose_list = [ @@ -84,10 +97,14 @@ class CASIAB(data.Dataset): clip_names = clip_names[self.train_size * 10 * 11:] # Remove empty clips + discard_clips_names = [] for clip_name in clip_names.copy(): - if len(os.listdir(os.path.join(self.root_dir, clip_name))) == 0: - print("Clip '{}' is empty.".format(clip_name)) + clip_path = os.path.join(self.root_dir, clip_name) + if len(os.listdir(clip_path)) < self.discard_threshold: + discard_clips_names.append(clip_name) clip_names.remove(clip_name) + print(', '.join(discard_clips_names[:-1]), + 'and', discard_clips_names[-1], 'will be discarded.') # clip name constructed by label, condition and view # e.g 002-bg-02-090 means clip from Subject #2 @@ -109,12 +126,12 @@ class CASIAB(data.Dataset): condition_regex = '|'.join(selected_conditions) if selected_views: view_regex = '|'.join(selected_views) - clip_regex = '(' + ')-('.join([ + clip_re = re.compile('(' + ')-('.join(( label_regex, condition_regex, view_regex - ]) + ')' + )) + ')') for clip_name in clip_names: - match = re.fullmatch(clip_regex, clip_name) + match = clip_re.fullmatch(clip_name) if match: labels.append(match.group(1)) conditions.append(match.group(2)) @@ -148,6 +165,13 @@ class CASIAB(data.Dataset): 'views': set(self.views.tolist()) } + self._cached_clips_frame_names: Optional[Dict[str, List[str]]] = None + self._cached_clips: Optional[Dict[str, torch.Tensor]] = None + if self.cache_on: + self._cached_clips_frame_names = dict() + self._cached_clips = dict() + self._preload_all_video() + def __len__(self) -> int: return len(self.labels) @@ -157,7 +181,6 @@ class CASIAB(data.Dataset): view = self.views[index] clip_name = self._clip_names[index] clip = self._read_video(clip_name) - sample = { 'label': label, 'condition': condition, @@ -167,23 +190,84 @@ class CASIAB(data.Dataset): return sample - def _read_video(self, clip_name: str) -> torch.Tensor: - frames = [] + def _preload_all_video(self): + for clip_name in tqdm(self._clip_names, + desc='Preloading dataset', unit='clips'): + self._read_video(clip_name, is_caching=True) + + def _read_video(self, clip_name: str, + is_caching: bool = False) -> torch.Tensor: clip_path = os.path.join(self.root_dir, clip_name) - sampled_frame_names = self._sample_frames(clip_path) - for frame_name in sampled_frame_names: + sampled_frame_names = self._sample_frames(clip_path, is_caching) + + if self.cache_on: + if is_caching: + clip = self._read_frames(clip_path, sampled_frame_names) + self._cached_clips[clip_name] = clip + else: # Load cache + cached_clip = self._cached_clips[clip_name] + cached_clip_frame_names \ + = self._cached_clips_frame_names[clip_path] + # Index the original clip via sampled frame names + clip = self._load_cached_video(cached_clip, + cached_clip_frame_names, + sampled_frame_names) + else: # Cache off + clip = self._read_frames(clip_path, sampled_frame_names) + + return clip + + def _load_cached_video( + self, + clip: torch.Tensor, + frame_names: List[str], + sampled_frame_names: List[str] + ) -> torch.Tensor: + # Mask the original clip when it is long enough + if len(frame_names) >= self.num_sampled_frames: + sampled_frame_mask = np.isin(frame_names, + sampled_frame_names) + sampled_clip = clip[sampled_frame_mask] + else: # Create a indexing filter from the beginning of clip + sampled_index = frame_names.index(sampled_frame_names[0]) + sampled_frame_filter = [sampled_index] + for i in range(1, self.num_sampled_frames): + if sampled_frame_names[i] != sampled_frame_names[i - 1]: + sampled_index += 1 + sampled_frame_filter.append(sampled_index) + sampled_clip = clip[sampled_frame_filter] + + return sampled_clip + + def _read_frames(self, clip_path, frame_names): + frames = [] + for frame_name in frame_names: frame_path = os.path.join(clip_path, frame_name) frame = read_image(frame_path) + # Transforming using CPU is not efficient frame = self.frame_transform(frame.to(self.device)) frames.append(frame.cpu()) clip = torch.stack(frames) return clip - def _sample_frames(self, clip_path: str) -> List[str]: - frame_names = os.listdir(clip_path) + def _sample_frames(self, clip_path: str, + is_caching: bool = False) -> List[str]: + if self.cache_on: + if is_caching: + # Sort frame in advance for loading convenience + frame_names = sorted(os.listdir(clip_path)) + self._cached_clips_frame_names[clip_path] = frame_names + # Load all without sampling + return frame_names + else: # Load cache + frame_names = self._cached_clips_frame_names[clip_path] + else: # Cache off + frame_names = os.listdir(clip_path) + if self.is_train: num_frames = len(frame_names) + # Sample frames without replace if have enough frames if num_frames < self.num_sampled_frames: frame_names = random.choices(frame_names, k=self.num_sampled_frames) |