summaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
Diffstat (limited to 'utils')
-rw-r--r--utils/dataset.py194
1 files changed, 194 insertions, 0 deletions
diff --git a/utils/dataset.py b/utils/dataset.py
new file mode 100644
index 0000000..7518127
--- /dev/null
+++ b/utils/dataset.py
@@ -0,0 +1,194 @@
+import os
+import random
+import re
+from typing import Optional, Dict, NewType, Union, List, Set
+
+import numpy as np
+import torch
+from torch.utils import data
+from torchvision.io import read_image
+import torchvision.transforms as transforms
+
+ClipLabels = NewType('ClipLabels', Set[str])
+ClipConditions = NewType('ClipConditions', Set[str])
+ClipViews = NewType('ClipViews', Set[str])
+
+default_frame_transform = transforms.Compose([
+ transforms.Resize(size=(64, 32))
+])
+
+
+class CASIAB(data.Dataset):
+ """CASIA-B multi-view gait dataset"""
+
+ def __init__(
+ self,
+ root_dir: str,
+ is_train: bool = True,
+ train_size: int = 74,
+ num_sampled_frames: int = 30,
+ selector: Optional[Dict[
+ str, Union[ClipLabels, ClipConditions, ClipLabels]
+ ]] = None,
+ num_input_channels: int = 3,
+ frame_height: int = 64,
+ frame_width: int = 32,
+ device: torch.device = torch.device('cpu')
+ ):
+ """
+ :param root_dir: Directory to dataset root.
+ :param is_train: Train or test, True for train, False for test.
+ :param train_size: Number of subjects in train, when `is_train`
+ is False, test size will be inferred.
+ :param num_sampled_frames: Number of sampled frames for train
+ :param selector: Restrict data labels, conditions and views
+ :param num_input_channels Number of input channel, RBG image
+ has 3 channel, grayscale image has 1 channel
+ :param frame_height Frame height after transforms
+ :param frame_width Frame width after transforms
+ :param device Device be used for transforms
+ """
+ super(CASIAB, self).__init__()
+ self.root_dir = root_dir
+ self.is_train = is_train
+ self.train_size = train_size
+ self.num_sampled_frames = num_sampled_frames
+ self.num_input_channels = num_input_channels
+ self.frame_height = frame_height
+ self.frame_width = frame_width
+ self.device = device
+
+ self.frame_transform: transforms.Compose
+ transform_compose_list = [
+ transforms.Resize(size=(self.frame_height, self.frame_width))
+ ]
+ if self.num_input_channels == 1:
+ transform_compose_list.insert(0, transforms.Grayscale())
+ self.frame_transform = transforms.Compose(transform_compose_list)
+
+ # Labels, conditions and views corresponding to each video clip
+ self.labels: np.ndarray[np.str_]
+ self.conditions: np.ndarray[np.str_]
+ self.views: np.ndarray[np.str_]
+ # Video clip directory names
+ self._clip_names: List[str] = []
+ # Labels, conditions and views in dataset,
+ # set of three attributes above
+ self.metadata = Dict[str, Set[str]]
+
+ clip_names = sorted(os.listdir(self.root_dir))
+
+ if self.is_train:
+ clip_names = clip_names[:self.train_size * 10 * 11]
+ else: # is_test
+ clip_names = clip_names[self.train_size * 10 * 11:]
+
+ # Remove empty clips
+ for clip_name in clip_names.copy():
+ if len(os.listdir(os.path.join(self.root_dir, clip_name))) == 0:
+ print("Clip '{}' is empty.".format(clip_name))
+ clip_names.remove(clip_name)
+
+ # clip name constructed by label, condition and view
+ # e.g 002-bg-02-090 means clip from Subject #2
+ # in Bag #2 condition from 90 degree angle
+ labels, conditions, views = [], [], []
+ if selector:
+ selected_labels = selector.pop('labels', None)
+ selected_conditions = selector.pop('conditions', None)
+ selected_views = selector.pop('views', None)
+
+ label_regex = r'\d{3}'
+ condition_regex = r'(nm|bg|cl)-0[0-4]'
+ view_regex = r'\d{3}'
+
+ # Match required data using RegEx
+ if selected_labels:
+ label_regex = '|'.join(selected_labels)
+ if selected_conditions:
+ condition_regex = '|'.join(selected_conditions)
+ if selected_views:
+ view_regex = '|'.join(selected_views)
+ clip_regex = '(' + ')-('.join([
+ label_regex, condition_regex, view_regex
+ ]) + ')'
+
+ for clip_name in clip_names:
+ match = re.fullmatch(clip_regex, clip_name)
+ if match:
+ labels.append(match.group(1))
+ conditions.append(match.group(2))
+ views.append(match.group(3))
+ self._clip_names.append(match.group(0))
+
+ self.metadata = {
+ 'labels': selected_labels,
+ 'conditions': selected_conditions,
+ 'views': selected_views
+ }
+ else: # Add all
+ self._clip_names += clip_names
+ for clip_name in self._clip_names:
+ split_clip_name = clip_name.split('-')
+ label = split_clip_name[0]
+ labels.append(label)
+ condition = '-'.join(split_clip_name[1:2 + 1])
+ conditions.append(condition)
+ view = split_clip_name[-1]
+ views.append(view)
+
+ self.labels = np.asarray(labels)
+ self.conditions = np.asarray(conditions)
+ self.views = np.asarray(views)
+
+ if not selector:
+ self.metadata = {
+ 'labels': set(self.labels.tolist()),
+ 'conditions': set(self.conditions.tolist()),
+ 'views': set(self.views.tolist())
+ }
+
+ def __len__(self) -> int:
+ return len(self.labels)
+
+ def __getitem__(self, index: int) -> Dict[str, Union[str, torch.Tensor]]:
+ label = self.labels[index]
+ condition = self.conditions[index]
+ view = self.views[index]
+ clip_name = self._clip_names[index]
+ clip = self._read_video(clip_name)
+
+ sample = {
+ 'label': label,
+ 'condition': condition,
+ 'view': view,
+ 'clip': clip
+ }
+
+ return sample
+
+ def _read_video(self, clip_name: str) -> torch.Tensor:
+ frames = []
+ clip_path = os.path.join(self.root_dir, clip_name)
+ sampled_frame_names = self._sample_frames(clip_path)
+ for frame_name in sampled_frame_names:
+ frame_path = os.path.join(clip_path, frame_name)
+ frame = read_image(frame_path)
+ frame = self.frame_transform(frame.to(self.device))
+ frames.append(frame.cpu())
+ clip = torch.stack(frames)
+
+ return clip
+
+ def _sample_frames(self, clip_path: str) -> List[str]:
+ frame_names = os.listdir(clip_path)
+ if self.is_train:
+ num_frames = len(frame_names)
+ if num_frames < self.num_sampled_frames:
+ frame_names = random.choices(frame_names,
+ k=self.num_sampled_frames)
+ else:
+ frame_names = random.sample(frame_names,
+ k=self.num_sampled_frames)
+
+ return sorted(frame_names)