diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-18 22:37:21 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-18 22:37:21 +0800 |
commit | 1db7d5cefbd14f14c8393e862d08fa9c620f90f6 (patch) | |
tree | cfc83c82a49a4a5553817e1ffdf8d83329dc24d2 /utils/dataset.py | |
parent | 5b42006e37e4c1d285eb627efd9453ffddf02785 (diff) |
Implement CASIA-B dataset
Diffstat (limited to 'utils/dataset.py')
-rw-r--r-- | utils/dataset.py | 194 |
1 files changed, 194 insertions, 0 deletions
diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000..7518127 --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,194 @@ +import os +import random +import re +from typing import Optional, Dict, NewType, Union, List, Set + +import numpy as np +import torch +from torch.utils import data +from torchvision.io import read_image +import torchvision.transforms as transforms + +ClipLabels = NewType('ClipLabels', Set[str]) +ClipConditions = NewType('ClipConditions', Set[str]) +ClipViews = NewType('ClipViews', Set[str]) + +default_frame_transform = transforms.Compose([ + transforms.Resize(size=(64, 32)) +]) + + +class CASIAB(data.Dataset): + """CASIA-B multi-view gait dataset""" + + def __init__( + self, + root_dir: str, + is_train: bool = True, + train_size: int = 74, + num_sampled_frames: int = 30, + selector: Optional[Dict[ + str, Union[ClipLabels, ClipConditions, ClipLabels] + ]] = None, + num_input_channels: int = 3, + frame_height: int = 64, + frame_width: int = 32, + device: torch.device = torch.device('cpu') + ): + """ + :param root_dir: Directory to dataset root. + :param is_train: Train or test, True for train, False for test. + :param train_size: Number of subjects in train, when `is_train` + is False, test size will be inferred. + :param num_sampled_frames: Number of sampled frames for train + :param selector: Restrict data labels, conditions and views + :param num_input_channels Number of input channel, RBG image + has 3 channel, grayscale image has 1 channel + :param frame_height Frame height after transforms + :param frame_width Frame width after transforms + :param device Device be used for transforms + """ + super(CASIAB, self).__init__() + self.root_dir = root_dir + self.is_train = is_train + self.train_size = train_size + self.num_sampled_frames = num_sampled_frames + self.num_input_channels = num_input_channels + self.frame_height = frame_height + self.frame_width = frame_width + self.device = device + + self.frame_transform: transforms.Compose + transform_compose_list = [ + transforms.Resize(size=(self.frame_height, self.frame_width)) + ] + if self.num_input_channels == 1: + transform_compose_list.insert(0, transforms.Grayscale()) + self.frame_transform = transforms.Compose(transform_compose_list) + + # Labels, conditions and views corresponding to each video clip + self.labels: np.ndarray[np.str_] + self.conditions: np.ndarray[np.str_] + self.views: np.ndarray[np.str_] + # Video clip directory names + self._clip_names: List[str] = [] + # Labels, conditions and views in dataset, + # set of three attributes above + self.metadata = Dict[str, Set[str]] + + clip_names = sorted(os.listdir(self.root_dir)) + + if self.is_train: + clip_names = clip_names[:self.train_size * 10 * 11] + else: # is_test + clip_names = clip_names[self.train_size * 10 * 11:] + + # Remove empty clips + for clip_name in clip_names.copy(): + if len(os.listdir(os.path.join(self.root_dir, clip_name))) == 0: + print("Clip '{}' is empty.".format(clip_name)) + clip_names.remove(clip_name) + + # clip name constructed by label, condition and view + # e.g 002-bg-02-090 means clip from Subject #2 + # in Bag #2 condition from 90 degree angle + labels, conditions, views = [], [], [] + if selector: + selected_labels = selector.pop('labels', None) + selected_conditions = selector.pop('conditions', None) + selected_views = selector.pop('views', None) + + label_regex = r'\d{3}' + condition_regex = r'(nm|bg|cl)-0[0-4]' + view_regex = r'\d{3}' + + # Match required data using RegEx + if selected_labels: + label_regex = '|'.join(selected_labels) + if selected_conditions: + condition_regex = '|'.join(selected_conditions) + if selected_views: + view_regex = '|'.join(selected_views) + clip_regex = '(' + ')-('.join([ + label_regex, condition_regex, view_regex + ]) + ')' + + for clip_name in clip_names: + match = re.fullmatch(clip_regex, clip_name) + if match: + labels.append(match.group(1)) + conditions.append(match.group(2)) + views.append(match.group(3)) + self._clip_names.append(match.group(0)) + + self.metadata = { + 'labels': selected_labels, + 'conditions': selected_conditions, + 'views': selected_views + } + else: # Add all + self._clip_names += clip_names + for clip_name in self._clip_names: + split_clip_name = clip_name.split('-') + label = split_clip_name[0] + labels.append(label) + condition = '-'.join(split_clip_name[1:2 + 1]) + conditions.append(condition) + view = split_clip_name[-1] + views.append(view) + + self.labels = np.asarray(labels) + self.conditions = np.asarray(conditions) + self.views = np.asarray(views) + + if not selector: + self.metadata = { + 'labels': set(self.labels.tolist()), + 'conditions': set(self.conditions.tolist()), + 'views': set(self.views.tolist()) + } + + def __len__(self) -> int: + return len(self.labels) + + def __getitem__(self, index: int) -> Dict[str, Union[str, torch.Tensor]]: + label = self.labels[index] + condition = self.conditions[index] + view = self.views[index] + clip_name = self._clip_names[index] + clip = self._read_video(clip_name) + + sample = { + 'label': label, + 'condition': condition, + 'view': view, + 'clip': clip + } + + return sample + + def _read_video(self, clip_name: str) -> torch.Tensor: + frames = [] + clip_path = os.path.join(self.root_dir, clip_name) + sampled_frame_names = self._sample_frames(clip_path) + for frame_name in sampled_frame_names: + frame_path = os.path.join(clip_path, frame_name) + frame = read_image(frame_path) + frame = self.frame_transform(frame.to(self.device)) + frames.append(frame.cpu()) + clip = torch.stack(frames) + + return clip + + def _sample_frames(self, clip_path: str) -> List[str]: + frame_names = os.listdir(clip_path) + if self.is_train: + num_frames = len(frame_names) + if num_frames < self.num_sampled_frames: + frame_names = random.choices(frame_names, + k=self.num_sampled_frames) + else: + frame_names = random.sample(frame_names, + k=self.num_sampled_frames) + + return sorted(frame_names) |