diff options
Diffstat (limited to 'utils/dataset.py')
-rw-r--r-- | utils/dataset.py | 121 |
1 files changed, 58 insertions, 63 deletions
diff --git a/utils/dataset.py b/utils/dataset.py index 1a9d595..050ac03 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -7,10 +7,11 @@ import numpy as np import torch import torchvision.transforms as transforms from PIL import Image +from sklearn.preprocessing import LabelEncoder from torch.utils import data from tqdm import tqdm -ClipLabels = NewType('ClipLabels', set[str]) +ClipClasses = NewType('ClipClasses', set[str]) ClipConditions = NewType('ClipConditions', set[str]) ClipViews = NewType('ClipViews', set[str]) @@ -26,7 +27,7 @@ class CASIAB(data.Dataset): num_sampled_frames: int = 30, discard_threshold: int = 15, selector: Optional[dict[ - str, Union[ClipLabels, ClipConditions, ClipViews] + str, Union[ClipClasses, ClipConditions, ClipViews] ]] = None, num_input_channels: int = 3, frame_size: tuple[int, int] = (64, 32), @@ -41,7 +42,7 @@ class CASIAB(data.Dataset): (Training Only) :param discard_threshold: Discard the sample if its number of frames is less than this threshold. - :param selector: Restrict output data labels, conditions and + :param selector: Restrict output data classes, conditions and views. :param num_input_channels: The number of input channel(s), RBG image has 3 channels, grayscale image has 1 channel. @@ -49,122 +50,116 @@ class CASIAB(data.Dataset): :param cache_on: Preload all clips in memory or not, this will increase loading speed, but will add a preload process and cost a lot of RAM. Loading the entire dataset - (is_train = True,train_size = 124, discard_threshold = 1, + (is_train = True, train_size = 124, discard_threshold = 1, num_input_channels = 3, frame_height = 64, frame_width = 32) need about 22 GB of RAM. """ super(CASIAB, self).__init__() - self.root_dir = root_dir - self.is_train = is_train - self.train_size = train_size - self.num_sampled_frames = num_sampled_frames - self.discard_threshold = discard_threshold - self.num_input_channels = num_input_channels - self.frame_size = frame_size - self.cache_on = cache_on - - self.frame_transform: transforms.Compose + self._root_dir = root_dir + self._is_train = is_train + self._num_sampled_frames = num_sampled_frames + self._cache_on = cache_on + + self._frame_transform: transforms.Compose transform_compose_list = [ - transforms.Resize(size=self.frame_size), + transforms.Resize(size=frame_size), transforms.ToTensor() ] - if self.num_input_channels == 1: + if num_input_channels == 1: transform_compose_list.insert(0, transforms.Grayscale()) - self.frame_transform = transforms.Compose(transform_compose_list) + self._frame_transform = transforms.Compose(transform_compose_list) # Labels, conditions and views corresponding to each video clip - self.labels: np.ndarray[np.str_] + self.labels: np.ndarray[np.int64] self.conditions: np.ndarray[np.str_] self.views: np.ndarray[np.str_] - # Video clip directory names - self._clip_names: list[str] = [] - # Labels, conditions and views in dataset, + # Labels, classes, conditions and views in dataset, # set of three attributes above - self.metadata = dict[str, set[str]] + self.metadata = dict[str, list[str]] # Dictionaries for indexing frames and frame names by clip name # and chip path when cache is on self._cached_clips_frame_names: Optional[dict[str, list[str]]] = None self._cached_clips: Optional[dict[str, torch.Tensor]] = None - clip_names = sorted(os.listdir(self.root_dir)) + # Video clip directory names + self._clip_names: list[str] = [] + clip_names = sorted(os.listdir(self._root_dir)) - if self.is_train: - clip_names = clip_names[:self.train_size * 10 * 11] + if self._is_train: + clip_names = clip_names[:train_size * 10 * 11] else: # is_test - clip_names = clip_names[self.train_size * 10 * 11:] + clip_names = clip_names[train_size * 10 * 11:] - # Remove empty clips + # Remove clips under threshold discard_clips_names = [] for clip_name in clip_names.copy(): - clip_path = os.path.join(self.root_dir, clip_name) - if len(os.listdir(clip_path)) < self.discard_threshold: + clip_path = os.path.join(self._root_dir, clip_name) + if len(os.listdir(clip_path)) < discard_threshold: discard_clips_names.append(clip_name) clip_names.remove(clip_name) if len(discard_clips_names) != 0: print(', '.join(discard_clips_names[:-1]), 'and', discard_clips_names[-1], 'will be discarded.') - # clip name constructed by label, condition and view + # Clip name constructed by class, condition and view # e.g 002-bg-02-090 means clip from Subject #2 # in Bag #2 condition from 90 degree angle - labels, conditions, views = [], [], [] + classes, conditions, views = [], [], [] if selector: - selected_labels = selector.pop('labels', None) + selected_classes = selector.pop('classes', None) selected_conditions = selector.pop('conditions', None) selected_views = selector.pop('views', None) - label_regex = r'\d{3}' - condition_regex = r'(nm|bg|cl)-0[0-4]' + class_regex = r'\d{3}' + condition_regex = r'(nm|bg|cl)-0[0-6]' view_regex = r'\d{3}' # Match required data using RegEx - if selected_labels: - label_regex = '|'.join(selected_labels) + if selected_classes: + class_regex = '|'.join(selected_classes) if selected_conditions: condition_regex = '|'.join(selected_conditions) if selected_views: view_regex = '|'.join(selected_views) clip_re = re.compile('(' + ')-('.join(( - label_regex, condition_regex, view_regex + class_regex, condition_regex, view_regex )) + ')') for clip_name in clip_names: match = clip_re.fullmatch(clip_name) if match: - labels.append(match.group(1)) + classes.append(match.group(1)) conditions.append(match.group(2)) views.append(match.group(3)) self._clip_names.append(match.group(0)) - self.metadata = { - 'labels': selected_labels, - 'conditions': selected_conditions, - 'views': selected_views - } else: # Add all self._clip_names += clip_names for clip_name in self._clip_names: split_clip_name = clip_name.split('-') - label = split_clip_name[0] - labels.append(label) + class_ = split_clip_name[0] + classes.append(class_) condition = '-'.join(split_clip_name[1:2 + 1]) conditions.append(condition) view = split_clip_name[-1] views.append(view) - self.labels = np.asarray(labels) + # Encode classes to labels + self.label_encoder = LabelEncoder() + self.label_encoder.fit(classes) + self.labels = self.label_encoder.transform(classes) self.conditions = np.asarray(conditions) self.views = np.asarray(views) - if not selector: - self.metadata = { - 'labels': set(self.labels.tolist()), - 'conditions': set(self.conditions.tolist()), - 'views': set(self.views.tolist()) - } + self.metadata = { + 'labels': list(dict.fromkeys(self.labels.tolist())), + 'classes': self.label_encoder.classes_.tolist(), + 'conditions': list(dict.fromkeys(self.conditions.tolist())), + 'views': list(dict.fromkeys(self.views.tolist())) + } - if self.cache_on: + if self._cache_on: self._cached_clips_frame_names = dict() self._cached_clips = dict() self._preload_all_video() @@ -194,10 +189,10 @@ class CASIAB(data.Dataset): def _read_video(self, clip_name: str, is_caching: bool = False) -> torch.Tensor: - clip_path = os.path.join(self.root_dir, clip_name) + clip_path = os.path.join(self._root_dir, clip_name) sampled_frame_names = self._sample_frames(clip_path, is_caching) - if self.cache_on: + if self._cache_on: if is_caching: clip = self._read_frames(clip_path, sampled_frame_names) self._cached_clips[clip_name] = clip @@ -221,14 +216,14 @@ class CASIAB(data.Dataset): sampled_frame_names: list[str] ) -> torch.Tensor: # Mask the original clip when it is long enough - if len(frame_names) >= self.num_sampled_frames: + if len(frame_names) >= self._num_sampled_frames: sampled_frame_mask = np.isin(frame_names, sampled_frame_names) sampled_clip = clip[sampled_frame_mask] else: # Create a indexing filter from the beginning of clip sampled_index = frame_names.index(sampled_frame_names[0]) sampled_frame_filter = [sampled_index] - for i in range(1, self.num_sampled_frames): + for i in range(1, self._num_sampled_frames): if sampled_frame_names[i] != sampled_frame_names[i - 1]: sampled_index += 1 sampled_frame_filter.append(sampled_index) @@ -241,7 +236,7 @@ class CASIAB(data.Dataset): for frame_name in frame_names: frame_path = os.path.join(clip_path, frame_name) frame = Image.open(frame_path) - frame = self.frame_transform(frame) + frame = self._frame_transform(frame) frames.append(frame) clip = torch.stack(frames) @@ -249,7 +244,7 @@ class CASIAB(data.Dataset): def _sample_frames(self, clip_path: str, is_caching: bool = False) -> list[str]: - if self.cache_on: + if self._cache_on: if is_caching: # Sort frame in advance for loading convenience frame_names = sorted(os.listdir(clip_path)) @@ -261,14 +256,14 @@ class CASIAB(data.Dataset): else: # Cache off frame_names = os.listdir(clip_path) - if self.is_train: + if self._is_train: num_frames = len(frame_names) # Sample frames without replace if have enough frames - if num_frames < self.num_sampled_frames: + if num_frames < self._num_sampled_frames: frame_names = random.choices(frame_names, - k=self.num_sampled_frames) + k=self._num_sampled_frames) else: frame_names = random.sample(frame_names, - k=self.num_sampled_frames) + k=self._num_sampled_frames) return sorted(frame_names) |