import os import random import re from typing import Optional, Dict, NewType, Union, List, Set import numpy as np import torch from torch.utils import data from torchvision.io import read_image import torchvision.transforms as transforms from tqdm import tqdm ClipLabels = NewType('ClipLabels', Set[str]) ClipConditions = NewType('ClipConditions', Set[str]) ClipViews = NewType('ClipViews', Set[str]) default_frame_transform = transforms.Compose([ transforms.Resize(size=(64, 32)) ]) class CASIAB(data.Dataset): """CASIA-B multi-view gait dataset""" def __init__( self, root_dir: str, is_train: bool = True, train_size: int = 74, num_sampled_frames: int = 30, discard_threshold: int = 15, selector: Optional[Dict[ str, Union[ClipLabels, ClipConditions, ClipLabels] ]] = None, num_input_channels: int = 3, frame_height: int = 64, frame_width: int = 32, device: torch.device = torch.device('cpu'), cache_on: bool = False ): """ :param root_dir: Directory to dataset root. :param is_train: Train or test, True for train, False for test. :param train_size: The number of subjects used for training, when `is_train` is False, test size will be inferred. :param num_sampled_frames: The number of sampled frames. (Training Only) :param discard_threshold: Discard the sample if its number of frames is less than this threshold. :param selector: Restrict output data labels, conditions and views. :param num_input_channels: The number of input channel(s), RBG image has 3 channels, grayscale image has 1 channel. :param frame_height: Frame height after transforming. :param frame_width: Frame width after transforming. :param device: Device used in transforms. :param cache_on: Preload all clips in memory or not, this will increase loading speed, but will add a preload process and cost a lot of RAM. Loading the entire dataset need about 7 GB of RAM. """ super(CASIAB, self).__init__() self.root_dir = root_dir self.is_train = is_train self.train_size = train_size self.num_sampled_frames = num_sampled_frames self.discard_threshold = discard_threshold self.num_input_channels = num_input_channels self.frame_height = frame_height self.frame_width = frame_width self.device = device self.cache_on = cache_on self.frame_transform: transforms.Compose transform_compose_list = [ transforms.Resize(size=(self.frame_height, self.frame_width)) ] if self.num_input_channels == 1: transform_compose_list.insert(0, transforms.Grayscale()) self.frame_transform = transforms.Compose(transform_compose_list) # Labels, conditions and views corresponding to each video clip self.labels: np.ndarray[np.str_] self.conditions: np.ndarray[np.str_] self.views: np.ndarray[np.str_] # Video clip directory names self._clip_names: List[str] = [] # Labels, conditions and views in dataset, # set of three attributes above self.metadata = Dict[str, Set[str]] clip_names = sorted(os.listdir(self.root_dir)) if self.is_train: clip_names = clip_names[:self.train_size * 10 * 11] else: # is_test clip_names = clip_names[self.train_size * 10 * 11:] # Remove empty clips discard_clips_names = [] for clip_name in clip_names.copy(): clip_path = os.path.join(self.root_dir, clip_name) if len(os.listdir(clip_path)) < self.discard_threshold: discard_clips_names.append(clip_name) clip_names.remove(clip_name) if len(discard_clips_names) != 0: print(', '.join(discard_clips_names[:-1]), 'and', discard_clips_names[-1], 'will be discarded.') # clip name constructed by label, condition and view # e.g 002-bg-02-090 means clip from Subject #2 # in Bag #2 condition from 90 degree angle labels, conditions, views = [], [], [] if selector: selected_labels = selector.pop('labels', None) selected_conditions = selector.pop('conditions', None) selected_views = selector.pop('views', None) label_regex = r'\d{3}' condition_regex = r'(nm|bg|cl)-0[0-4]' view_regex = r'\d{3}' # Match required data using RegEx if selected_labels: label_regex = '|'.join(selected_labels) if selected_conditions: condition_regex = '|'.join(selected_conditions) if selected_views: view_regex = '|'.join(selected_views) clip_re = re.compile('(' + ')-('.join(( label_regex, condition_regex, view_regex )) + ')') for clip_name in clip_names: match = clip_re.fullmatch(clip_name) if match: labels.append(match.group(1)) conditions.append(match.group(2)) views.append(match.group(3)) self._clip_names.append(match.group(0)) self.metadata = { 'labels': selected_labels, 'conditions': selected_conditions, 'views': selected_views } else: # Add all self._clip_names += clip_names for clip_name in self._clip_names: split_clip_name = clip_name.split('-') label = split_clip_name[0] labels.append(label) condition = '-'.join(split_clip_name[1:2 + 1]) conditions.append(condition) view = split_clip_name[-1] views.append(view) self.labels = np.asarray(labels) self.conditions = np.asarray(conditions) self.views = np.asarray(views) if not selector: self.metadata = { 'labels': set(self.labels.tolist()), 'conditions': set(self.conditions.tolist()), 'views': set(self.views.tolist()) } self._cached_clips_frame_names: Optional[Dict[str, List[str]]] = None self._cached_clips: Optional[Dict[str, torch.Tensor]] = None if self.cache_on: self._cached_clips_frame_names = dict() self._cached_clips = dict() self._preload_all_video() def __len__(self) -> int: return len(self.labels) def __getitem__(self, index: int) -> Dict[str, Union[str, torch.Tensor]]: label = self.labels[index] condition = self.conditions[index] view = self.views[index] clip_name = self._clip_names[index] clip = self._read_video(clip_name) sample = { 'label': label, 'condition': condition, 'view': view, 'clip': clip } return sample def _preload_all_video(self): for clip_name in tqdm(self._clip_names, desc='Preloading dataset', unit='clips'): self._read_video(clip_name, is_caching=True) def _read_video(self, clip_name: str, is_caching: bool = False) -> torch.Tensor: clip_path = os.path.join(self.root_dir, clip_name) sampled_frame_names = self._sample_frames(clip_path, is_caching) if self.cache_on: if is_caching: clip = self._read_frames(clip_path, sampled_frame_names) self._cached_clips[clip_name] = clip else: # Load cache cached_clip = self._cached_clips[clip_name] cached_clip_frame_names \ = self._cached_clips_frame_names[clip_path] # Index the original clip via sampled frame names clip = self._load_cached_video(cached_clip, cached_clip_frame_names, sampled_frame_names) else: # Cache off clip = self._read_frames(clip_path, sampled_frame_names) return clip def _load_cached_video( self, clip: torch.Tensor, frame_names: List[str], sampled_frame_names: List[str] ) -> torch.Tensor: # Mask the original clip when it is long enough if len(frame_names) >= self.num_sampled_frames: sampled_frame_mask = np.isin(frame_names, sampled_frame_names) sampled_clip = clip[sampled_frame_mask] else: # Create a indexing filter from the beginning of clip sampled_index = frame_names.index(sampled_frame_names[0]) sampled_frame_filter = [sampled_index] for i in range(1, self.num_sampled_frames): if sampled_frame_names[i] != sampled_frame_names[i - 1]: sampled_index += 1 sampled_frame_filter.append(sampled_index) sampled_clip = clip[sampled_frame_filter] return sampled_clip def _read_frames(self, clip_path, frame_names): frames = [] for frame_name in frame_names: frame_path = os.path.join(clip_path, frame_name) frame = read_image(frame_path) # Transforming using CPU is not efficient frame = self.frame_transform(frame.to(self.device)) frames.append(frame.cpu()) clip = torch.stack(frames) return clip def _sample_frames(self, clip_path: str, is_caching: bool = False) -> List[str]: if self.cache_on: if is_caching: # Sort frame in advance for loading convenience frame_names = sorted(os.listdir(clip_path)) self._cached_clips_frame_names[clip_path] = frame_names # Load all without sampling return frame_names else: # Load cache frame_names = self._cached_clips_frame_names[clip_path] else: # Cache off frame_names = os.listdir(clip_path) if self.is_train: num_frames = len(frame_names) # Sample frames without replace if have enough frames if num_frames < self.num_sampled_frames: frame_names = random.choices(frame_names, k=self.num_sampled_frames) else: frame_names = random.sample(frame_names, k=self.num_sampled_frames) return sorted(frame_names)