From b8653d54efe2a8c94ae408c0c2da9bdd0b43ecdd Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Tue, 29 Dec 2020 20:38:18 +0800 Subject: Encode class names to label and some access improvement 1. Encode class names using LabelEncoder from sklearn 2. Remove unneeded class variables 3. Protect some variables from being accessed in userspace --- utils/dataset.py | 121 ++++++++++++++++++++++++++----------------------------- 1 file changed, 58 insertions(+), 63 deletions(-) (limited to 'utils/dataset.py') diff --git a/utils/dataset.py b/utils/dataset.py index 1a9d595..050ac03 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -7,10 +7,11 @@ import numpy as np import torch import torchvision.transforms as transforms from PIL import Image +from sklearn.preprocessing import LabelEncoder from torch.utils import data from tqdm import tqdm -ClipLabels = NewType('ClipLabels', set[str]) +ClipClasses = NewType('ClipClasses', set[str]) ClipConditions = NewType('ClipConditions', set[str]) ClipViews = NewType('ClipViews', set[str]) @@ -26,7 +27,7 @@ class CASIAB(data.Dataset): num_sampled_frames: int = 30, discard_threshold: int = 15, selector: Optional[dict[ - str, Union[ClipLabels, ClipConditions, ClipViews] + str, Union[ClipClasses, ClipConditions, ClipViews] ]] = None, num_input_channels: int = 3, frame_size: tuple[int, int] = (64, 32), @@ -41,7 +42,7 @@ class CASIAB(data.Dataset): (Training Only) :param discard_threshold: Discard the sample if its number of frames is less than this threshold. - :param selector: Restrict output data labels, conditions and + :param selector: Restrict output data classes, conditions and views. :param num_input_channels: The number of input channel(s), RBG image has 3 channels, grayscale image has 1 channel. @@ -49,122 +50,116 @@ class CASIAB(data.Dataset): :param cache_on: Preload all clips in memory or not, this will increase loading speed, but will add a preload process and cost a lot of RAM. Loading the entire dataset - (is_train = True,train_size = 124, discard_threshold = 1, + (is_train = True, train_size = 124, discard_threshold = 1, num_input_channels = 3, frame_height = 64, frame_width = 32) need about 22 GB of RAM. """ super(CASIAB, self).__init__() - self.root_dir = root_dir - self.is_train = is_train - self.train_size = train_size - self.num_sampled_frames = num_sampled_frames - self.discard_threshold = discard_threshold - self.num_input_channels = num_input_channels - self.frame_size = frame_size - self.cache_on = cache_on - - self.frame_transform: transforms.Compose + self._root_dir = root_dir + self._is_train = is_train + self._num_sampled_frames = num_sampled_frames + self._cache_on = cache_on + + self._frame_transform: transforms.Compose transform_compose_list = [ - transforms.Resize(size=self.frame_size), + transforms.Resize(size=frame_size), transforms.ToTensor() ] - if self.num_input_channels == 1: + if num_input_channels == 1: transform_compose_list.insert(0, transforms.Grayscale()) - self.frame_transform = transforms.Compose(transform_compose_list) + self._frame_transform = transforms.Compose(transform_compose_list) # Labels, conditions and views corresponding to each video clip - self.labels: np.ndarray[np.str_] + self.labels: np.ndarray[np.int64] self.conditions: np.ndarray[np.str_] self.views: np.ndarray[np.str_] - # Video clip directory names - self._clip_names: list[str] = [] - # Labels, conditions and views in dataset, + # Labels, classes, conditions and views in dataset, # set of three attributes above - self.metadata = dict[str, set[str]] + self.metadata = dict[str, list[str]] # Dictionaries for indexing frames and frame names by clip name # and chip path when cache is on self._cached_clips_frame_names: Optional[dict[str, list[str]]] = None self._cached_clips: Optional[dict[str, torch.Tensor]] = None - clip_names = sorted(os.listdir(self.root_dir)) + # Video clip directory names + self._clip_names: list[str] = [] + clip_names = sorted(os.listdir(self._root_dir)) - if self.is_train: - clip_names = clip_names[:self.train_size * 10 * 11] + if self._is_train: + clip_names = clip_names[:train_size * 10 * 11] else: # is_test - clip_names = clip_names[self.train_size * 10 * 11:] + clip_names = clip_names[train_size * 10 * 11:] - # Remove empty clips + # Remove clips under threshold discard_clips_names = [] for clip_name in clip_names.copy(): - clip_path = os.path.join(self.root_dir, clip_name) - if len(os.listdir(clip_path)) < self.discard_threshold: + clip_path = os.path.join(self._root_dir, clip_name) + if len(os.listdir(clip_path)) < discard_threshold: discard_clips_names.append(clip_name) clip_names.remove(clip_name) if len(discard_clips_names) != 0: print(', '.join(discard_clips_names[:-1]), 'and', discard_clips_names[-1], 'will be discarded.') - # clip name constructed by label, condition and view + # Clip name constructed by class, condition and view # e.g 002-bg-02-090 means clip from Subject #2 # in Bag #2 condition from 90 degree angle - labels, conditions, views = [], [], [] + classes, conditions, views = [], [], [] if selector: - selected_labels = selector.pop('labels', None) + selected_classes = selector.pop('classes', None) selected_conditions = selector.pop('conditions', None) selected_views = selector.pop('views', None) - label_regex = r'\d{3}' - condition_regex = r'(nm|bg|cl)-0[0-4]' + class_regex = r'\d{3}' + condition_regex = r'(nm|bg|cl)-0[0-6]' view_regex = r'\d{3}' # Match required data using RegEx - if selected_labels: - label_regex = '|'.join(selected_labels) + if selected_classes: + class_regex = '|'.join(selected_classes) if selected_conditions: condition_regex = '|'.join(selected_conditions) if selected_views: view_regex = '|'.join(selected_views) clip_re = re.compile('(' + ')-('.join(( - label_regex, condition_regex, view_regex + class_regex, condition_regex, view_regex )) + ')') for clip_name in clip_names: match = clip_re.fullmatch(clip_name) if match: - labels.append(match.group(1)) + classes.append(match.group(1)) conditions.append(match.group(2)) views.append(match.group(3)) self._clip_names.append(match.group(0)) - self.metadata = { - 'labels': selected_labels, - 'conditions': selected_conditions, - 'views': selected_views - } else: # Add all self._clip_names += clip_names for clip_name in self._clip_names: split_clip_name = clip_name.split('-') - label = split_clip_name[0] - labels.append(label) + class_ = split_clip_name[0] + classes.append(class_) condition = '-'.join(split_clip_name[1:2 + 1]) conditions.append(condition) view = split_clip_name[-1] views.append(view) - self.labels = np.asarray(labels) + # Encode classes to labels + self.label_encoder = LabelEncoder() + self.label_encoder.fit(classes) + self.labels = self.label_encoder.transform(classes) self.conditions = np.asarray(conditions) self.views = np.asarray(views) - if not selector: - self.metadata = { - 'labels': set(self.labels.tolist()), - 'conditions': set(self.conditions.tolist()), - 'views': set(self.views.tolist()) - } + self.metadata = { + 'labels': list(dict.fromkeys(self.labels.tolist())), + 'classes': self.label_encoder.classes_.tolist(), + 'conditions': list(dict.fromkeys(self.conditions.tolist())), + 'views': list(dict.fromkeys(self.views.tolist())) + } - if self.cache_on: + if self._cache_on: self._cached_clips_frame_names = dict() self._cached_clips = dict() self._preload_all_video() @@ -194,10 +189,10 @@ class CASIAB(data.Dataset): def _read_video(self, clip_name: str, is_caching: bool = False) -> torch.Tensor: - clip_path = os.path.join(self.root_dir, clip_name) + clip_path = os.path.join(self._root_dir, clip_name) sampled_frame_names = self._sample_frames(clip_path, is_caching) - if self.cache_on: + if self._cache_on: if is_caching: clip = self._read_frames(clip_path, sampled_frame_names) self._cached_clips[clip_name] = clip @@ -221,14 +216,14 @@ class CASIAB(data.Dataset): sampled_frame_names: list[str] ) -> torch.Tensor: # Mask the original clip when it is long enough - if len(frame_names) >= self.num_sampled_frames: + if len(frame_names) >= self._num_sampled_frames: sampled_frame_mask = np.isin(frame_names, sampled_frame_names) sampled_clip = clip[sampled_frame_mask] else: # Create a indexing filter from the beginning of clip sampled_index = frame_names.index(sampled_frame_names[0]) sampled_frame_filter = [sampled_index] - for i in range(1, self.num_sampled_frames): + for i in range(1, self._num_sampled_frames): if sampled_frame_names[i] != sampled_frame_names[i - 1]: sampled_index += 1 sampled_frame_filter.append(sampled_index) @@ -241,7 +236,7 @@ class CASIAB(data.Dataset): for frame_name in frame_names: frame_path = os.path.join(clip_path, frame_name) frame = Image.open(frame_path) - frame = self.frame_transform(frame) + frame = self._frame_transform(frame) frames.append(frame) clip = torch.stack(frames) @@ -249,7 +244,7 @@ class CASIAB(data.Dataset): def _sample_frames(self, clip_path: str, is_caching: bool = False) -> list[str]: - if self.cache_on: + if self._cache_on: if is_caching: # Sort frame in advance for loading convenience frame_names = sorted(os.listdir(clip_path)) @@ -261,14 +256,14 @@ class CASIAB(data.Dataset): else: # Cache off frame_names = os.listdir(clip_path) - if self.is_train: + if self._is_train: num_frames = len(frame_names) # Sample frames without replace if have enough frames - if num_frames < self.num_sampled_frames: + if num_frames < self._num_sampled_frames: frame_names = random.choices(frame_names, - k=self.num_sampled_frames) + k=self._num_sampled_frames) else: frame_names = random.sample(frame_names, - k=self.num_sampled_frames) + k=self._num_sampled_frames) return sorted(frame_names) -- cgit v1.2.3