From b8653d54efe2a8c94ae408c0c2da9bdd0b43ecdd Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Tue, 29 Dec 2020 20:38:18 +0800 Subject: Encode class names to label and some access improvement 1. Encode class names using LabelEncoder from sklearn 2. Remove unneeded class variables 3. Protect some variables from being accessed in userspace --- requirements.txt | 3 +- test/dataset.py | 33 +++++++++++++- utils/configuration.py | 4 +- utils/dataset.py | 121 ++++++++++++++++++++++++------------------------- 4 files changed, 94 insertions(+), 67 deletions(-) diff --git a/requirements.txt b/requirements.txt index 765977f..7cf8570 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ torch~=1.7.1 torchvision~=0.8.0a0+ecf4e9c numpy~=1.19.4 tqdm~=4.55.0 -Pillow~=8.0.1 \ No newline at end of file +Pillow~=8.0.1 +scikit-learn~=0.23.2 \ No newline at end of file diff --git a/test/dataset.py b/test/dataset.py index bfb8563..e0fc59a 100644 --- a/test/dataset.py +++ b/test/dataset.py @@ -1,4 +1,4 @@ -from utils.dataset import CASIAB, ClipConditions, ClipViews +from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses CASIAB_ROOT_DIR = '../data/CASIA-B-MRCNN/SEG' @@ -7,6 +7,29 @@ def test_casiab(): casiab = CASIAB(CASIAB_ROOT_DIR, discard_threshold=0) assert len(casiab) == 74 * 10 * 11 + labels = [] + for i in range(74): + labels += [i] * 10 * 11 + assert casiab.labels.tolist() == labels + + assert casiab.metadata['labels'] == [i for i in range(74)] + + assert casiab.label_encoder.inverse_transform([0, 2]).tolist() == ['001', + '003'] + + +def test_casiab_test(): + casiab_test = CASIAB(CASIAB_ROOT_DIR, is_train=False, discard_threshold=0) + assert len(casiab_test) == (124 - 74) * 10 * 11 + + labels = [] + for i in range(124 - 74): + labels += [i] * 10 * 11 + assert casiab_test.labels.tolist() == labels + + assert casiab_test.label_encoder.inverse_transform([0, 2]).tolist() == [ + '075', '077'] + def test_casiab_nm(): nm_selector = {'conditions': ClipConditions({r'nm-0\d'})} @@ -22,3 +45,11 @@ def test_casiab_nm_bg_90(): selector=nm_bg_90_selector, discard_threshold=0) assert len(casiab_nm_bg_90) == 74 * (6 + 2) * 1 + + +def test_caisab_class_selector(): + class_selector = {'classes': ClipClasses({'001', '003'})} + casiab_class_001_003 = CASIAB(CASIAB_ROOT_DIR, + selector=class_selector, + discard_threshold=0) + assert len(casiab_class_001_003) == 2 * 10 * 11 diff --git a/utils/configuration.py b/utils/configuration.py index e6bfaf2..965af94 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -2,7 +2,7 @@ from typing import TypedDict, Optional, Union import torch -from utils.dataset import ClipLabels, ClipConditions, ClipViews +from utils.dataset import ClipClasses, ClipConditions, ClipViews class SystemConfiguration(TypedDict): @@ -17,7 +17,7 @@ class DatasetConfiguration(TypedDict): train_size: int num_sampled_frames: int discard_threshold: int - selector: Optional[dict[str, Union[ClipLabels, ClipConditions, ClipViews]]] + selector: Optional[dict[str, Union[ClipClasses, ClipConditions, ClipViews]]] num_input_channels: int frame_size: tuple[int, int] cache_on: bool diff --git a/utils/dataset.py b/utils/dataset.py index 1a9d595..050ac03 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -7,10 +7,11 @@ import numpy as np import torch import torchvision.transforms as transforms from PIL import Image +from sklearn.preprocessing import LabelEncoder from torch.utils import data from tqdm import tqdm -ClipLabels = NewType('ClipLabels', set[str]) +ClipClasses = NewType('ClipClasses', set[str]) ClipConditions = NewType('ClipConditions', set[str]) ClipViews = NewType('ClipViews', set[str]) @@ -26,7 +27,7 @@ class CASIAB(data.Dataset): num_sampled_frames: int = 30, discard_threshold: int = 15, selector: Optional[dict[ - str, Union[ClipLabels, ClipConditions, ClipViews] + str, Union[ClipClasses, ClipConditions, ClipViews] ]] = None, num_input_channels: int = 3, frame_size: tuple[int, int] = (64, 32), @@ -41,7 +42,7 @@ class CASIAB(data.Dataset): (Training Only) :param discard_threshold: Discard the sample if its number of frames is less than this threshold. - :param selector: Restrict output data labels, conditions and + :param selector: Restrict output data classes, conditions and views. :param num_input_channels: The number of input channel(s), RBG image has 3 channels, grayscale image has 1 channel. @@ -49,122 +50,116 @@ class CASIAB(data.Dataset): :param cache_on: Preload all clips in memory or not, this will increase loading speed, but will add a preload process and cost a lot of RAM. Loading the entire dataset - (is_train = True,train_size = 124, discard_threshold = 1, + (is_train = True, train_size = 124, discard_threshold = 1, num_input_channels = 3, frame_height = 64, frame_width = 32) need about 22 GB of RAM. """ super(CASIAB, self).__init__() - self.root_dir = root_dir - self.is_train = is_train - self.train_size = train_size - self.num_sampled_frames = num_sampled_frames - self.discard_threshold = discard_threshold - self.num_input_channels = num_input_channels - self.frame_size = frame_size - self.cache_on = cache_on - - self.frame_transform: transforms.Compose + self._root_dir = root_dir + self._is_train = is_train + self._num_sampled_frames = num_sampled_frames + self._cache_on = cache_on + + self._frame_transform: transforms.Compose transform_compose_list = [ - transforms.Resize(size=self.frame_size), + transforms.Resize(size=frame_size), transforms.ToTensor() ] - if self.num_input_channels == 1: + if num_input_channels == 1: transform_compose_list.insert(0, transforms.Grayscale()) - self.frame_transform = transforms.Compose(transform_compose_list) + self._frame_transform = transforms.Compose(transform_compose_list) # Labels, conditions and views corresponding to each video clip - self.labels: np.ndarray[np.str_] + self.labels: np.ndarray[np.int64] self.conditions: np.ndarray[np.str_] self.views: np.ndarray[np.str_] - # Video clip directory names - self._clip_names: list[str] = [] - # Labels, conditions and views in dataset, + # Labels, classes, conditions and views in dataset, # set of three attributes above - self.metadata = dict[str, set[str]] + self.metadata = dict[str, list[str]] # Dictionaries for indexing frames and frame names by clip name # and chip path when cache is on self._cached_clips_frame_names: Optional[dict[str, list[str]]] = None self._cached_clips: Optional[dict[str, torch.Tensor]] = None - clip_names = sorted(os.listdir(self.root_dir)) + # Video clip directory names + self._clip_names: list[str] = [] + clip_names = sorted(os.listdir(self._root_dir)) - if self.is_train: - clip_names = clip_names[:self.train_size * 10 * 11] + if self._is_train: + clip_names = clip_names[:train_size * 10 * 11] else: # is_test - clip_names = clip_names[self.train_size * 10 * 11:] + clip_names = clip_names[train_size * 10 * 11:] - # Remove empty clips + # Remove clips under threshold discard_clips_names = [] for clip_name in clip_names.copy(): - clip_path = os.path.join(self.root_dir, clip_name) - if len(os.listdir(clip_path)) < self.discard_threshold: + clip_path = os.path.join(self._root_dir, clip_name) + if len(os.listdir(clip_path)) < discard_threshold: discard_clips_names.append(clip_name) clip_names.remove(clip_name) if len(discard_clips_names) != 0: print(', '.join(discard_clips_names[:-1]), 'and', discard_clips_names[-1], 'will be discarded.') - # clip name constructed by label, condition and view + # Clip name constructed by class, condition and view # e.g 002-bg-02-090 means clip from Subject #2 # in Bag #2 condition from 90 degree angle - labels, conditions, views = [], [], [] + classes, conditions, views = [], [], [] if selector: - selected_labels = selector.pop('labels', None) + selected_classes = selector.pop('classes', None) selected_conditions = selector.pop('conditions', None) selected_views = selector.pop('views', None) - label_regex = r'\d{3}' - condition_regex = r'(nm|bg|cl)-0[0-4]' + class_regex = r'\d{3}' + condition_regex = r'(nm|bg|cl)-0[0-6]' view_regex = r'\d{3}' # Match required data using RegEx - if selected_labels: - label_regex = '|'.join(selected_labels) + if selected_classes: + class_regex = '|'.join(selected_classes) if selected_conditions: condition_regex = '|'.join(selected_conditions) if selected_views: view_regex = '|'.join(selected_views) clip_re = re.compile('(' + ')-('.join(( - label_regex, condition_regex, view_regex + class_regex, condition_regex, view_regex )) + ')') for clip_name in clip_names: match = clip_re.fullmatch(clip_name) if match: - labels.append(match.group(1)) + classes.append(match.group(1)) conditions.append(match.group(2)) views.append(match.group(3)) self._clip_names.append(match.group(0)) - self.metadata = { - 'labels': selected_labels, - 'conditions': selected_conditions, - 'views': selected_views - } else: # Add all self._clip_names += clip_names for clip_name in self._clip_names: split_clip_name = clip_name.split('-') - label = split_clip_name[0] - labels.append(label) + class_ = split_clip_name[0] + classes.append(class_) condition = '-'.join(split_clip_name[1:2 + 1]) conditions.append(condition) view = split_clip_name[-1] views.append(view) - self.labels = np.asarray(labels) + # Encode classes to labels + self.label_encoder = LabelEncoder() + self.label_encoder.fit(classes) + self.labels = self.label_encoder.transform(classes) self.conditions = np.asarray(conditions) self.views = np.asarray(views) - if not selector: - self.metadata = { - 'labels': set(self.labels.tolist()), - 'conditions': set(self.conditions.tolist()), - 'views': set(self.views.tolist()) - } + self.metadata = { + 'labels': list(dict.fromkeys(self.labels.tolist())), + 'classes': self.label_encoder.classes_.tolist(), + 'conditions': list(dict.fromkeys(self.conditions.tolist())), + 'views': list(dict.fromkeys(self.views.tolist())) + } - if self.cache_on: + if self._cache_on: self._cached_clips_frame_names = dict() self._cached_clips = dict() self._preload_all_video() @@ -194,10 +189,10 @@ class CASIAB(data.Dataset): def _read_video(self, clip_name: str, is_caching: bool = False) -> torch.Tensor: - clip_path = os.path.join(self.root_dir, clip_name) + clip_path = os.path.join(self._root_dir, clip_name) sampled_frame_names = self._sample_frames(clip_path, is_caching) - if self.cache_on: + if self._cache_on: if is_caching: clip = self._read_frames(clip_path, sampled_frame_names) self._cached_clips[clip_name] = clip @@ -221,14 +216,14 @@ class CASIAB(data.Dataset): sampled_frame_names: list[str] ) -> torch.Tensor: # Mask the original clip when it is long enough - if len(frame_names) >= self.num_sampled_frames: + if len(frame_names) >= self._num_sampled_frames: sampled_frame_mask = np.isin(frame_names, sampled_frame_names) sampled_clip = clip[sampled_frame_mask] else: # Create a indexing filter from the beginning of clip sampled_index = frame_names.index(sampled_frame_names[0]) sampled_frame_filter = [sampled_index] - for i in range(1, self.num_sampled_frames): + for i in range(1, self._num_sampled_frames): if sampled_frame_names[i] != sampled_frame_names[i - 1]: sampled_index += 1 sampled_frame_filter.append(sampled_index) @@ -241,7 +236,7 @@ class CASIAB(data.Dataset): for frame_name in frame_names: frame_path = os.path.join(clip_path, frame_name) frame = Image.open(frame_path) - frame = self.frame_transform(frame) + frame = self._frame_transform(frame) frames.append(frame) clip = torch.stack(frames) @@ -249,7 +244,7 @@ class CASIAB(data.Dataset): def _sample_frames(self, clip_path: str, is_caching: bool = False) -> list[str]: - if self.cache_on: + if self._cache_on: if is_caching: # Sort frame in advance for loading convenience frame_names = sorted(os.listdir(clip_path)) @@ -261,14 +256,14 @@ class CASIAB(data.Dataset): else: # Cache off frame_names = os.listdir(clip_path) - if self.is_train: + if self._is_train: num_frames = len(frame_names) # Sample frames without replace if have enough frames - if num_frames < self.num_sampled_frames: + if num_frames < self._num_sampled_frames: frame_names = random.choices(frame_names, - k=self.num_sampled_frames) + k=self._num_sampled_frames) else: frame_names = random.sample(frame_names, - k=self.num_sampled_frames) + k=self._num_sampled_frames) return sorted(frame_names) -- cgit v1.2.3