summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-29 20:38:18 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-29 20:38:18 +0800
commitb8653d54efe2a8c94ae408c0c2da9bdd0b43ecdd (patch)
tree0aca443c5f2b0387fae48aa43611ca92d6015bbe
parent6e94fdb587656074dc2e65a80e51b8446f834b41 (diff)
Encode class names to label and some access improvement
1. Encode class names using LabelEncoder from sklearn 2. Remove unneeded class variables 3. Protect some variables from being accessed in userspace
-rw-r--r--requirements.txt3
-rw-r--r--test/dataset.py33
-rw-r--r--utils/configuration.py4
-rw-r--r--utils/dataset.py121
4 files changed, 94 insertions, 67 deletions
diff --git a/requirements.txt b/requirements.txt
index 765977f..7cf8570 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,4 +2,5 @@ torch~=1.7.1
torchvision~=0.8.0a0+ecf4e9c
numpy~=1.19.4
tqdm~=4.55.0
-Pillow~=8.0.1 \ No newline at end of file
+Pillow~=8.0.1
+scikit-learn~=0.23.2 \ No newline at end of file
diff --git a/test/dataset.py b/test/dataset.py
index bfb8563..e0fc59a 100644
--- a/test/dataset.py
+++ b/test/dataset.py
@@ -1,4 +1,4 @@
-from utils.dataset import CASIAB, ClipConditions, ClipViews
+from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
CASIAB_ROOT_DIR = '../data/CASIA-B-MRCNN/SEG'
@@ -7,6 +7,29 @@ def test_casiab():
casiab = CASIAB(CASIAB_ROOT_DIR, discard_threshold=0)
assert len(casiab) == 74 * 10 * 11
+ labels = []
+ for i in range(74):
+ labels += [i] * 10 * 11
+ assert casiab.labels.tolist() == labels
+
+ assert casiab.metadata['labels'] == [i for i in range(74)]
+
+ assert casiab.label_encoder.inverse_transform([0, 2]).tolist() == ['001',
+ '003']
+
+
+def test_casiab_test():
+ casiab_test = CASIAB(CASIAB_ROOT_DIR, is_train=False, discard_threshold=0)
+ assert len(casiab_test) == (124 - 74) * 10 * 11
+
+ labels = []
+ for i in range(124 - 74):
+ labels += [i] * 10 * 11
+ assert casiab_test.labels.tolist() == labels
+
+ assert casiab_test.label_encoder.inverse_transform([0, 2]).tolist() == [
+ '075', '077']
+
def test_casiab_nm():
nm_selector = {'conditions': ClipConditions({r'nm-0\d'})}
@@ -22,3 +45,11 @@ def test_casiab_nm_bg_90():
selector=nm_bg_90_selector,
discard_threshold=0)
assert len(casiab_nm_bg_90) == 74 * (6 + 2) * 1
+
+
+def test_caisab_class_selector():
+ class_selector = {'classes': ClipClasses({'001', '003'})}
+ casiab_class_001_003 = CASIAB(CASIAB_ROOT_DIR,
+ selector=class_selector,
+ discard_threshold=0)
+ assert len(casiab_class_001_003) == 2 * 10 * 11
diff --git a/utils/configuration.py b/utils/configuration.py
index e6bfaf2..965af94 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -2,7 +2,7 @@ from typing import TypedDict, Optional, Union
import torch
-from utils.dataset import ClipLabels, ClipConditions, ClipViews
+from utils.dataset import ClipClasses, ClipConditions, ClipViews
class SystemConfiguration(TypedDict):
@@ -17,7 +17,7 @@ class DatasetConfiguration(TypedDict):
train_size: int
num_sampled_frames: int
discard_threshold: int
- selector: Optional[dict[str, Union[ClipLabels, ClipConditions, ClipViews]]]
+ selector: Optional[dict[str, Union[ClipClasses, ClipConditions, ClipViews]]]
num_input_channels: int
frame_size: tuple[int, int]
cache_on: bool
diff --git a/utils/dataset.py b/utils/dataset.py
index 1a9d595..050ac03 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -7,10 +7,11 @@ import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
+from sklearn.preprocessing import LabelEncoder
from torch.utils import data
from tqdm import tqdm
-ClipLabels = NewType('ClipLabels', set[str])
+ClipClasses = NewType('ClipClasses', set[str])
ClipConditions = NewType('ClipConditions', set[str])
ClipViews = NewType('ClipViews', set[str])
@@ -26,7 +27,7 @@ class CASIAB(data.Dataset):
num_sampled_frames: int = 30,
discard_threshold: int = 15,
selector: Optional[dict[
- str, Union[ClipLabels, ClipConditions, ClipViews]
+ str, Union[ClipClasses, ClipConditions, ClipViews]
]] = None,
num_input_channels: int = 3,
frame_size: tuple[int, int] = (64, 32),
@@ -41,7 +42,7 @@ class CASIAB(data.Dataset):
(Training Only)
:param discard_threshold: Discard the sample if its number of
frames is less than this threshold.
- :param selector: Restrict output data labels, conditions and
+ :param selector: Restrict output data classes, conditions and
views.
:param num_input_channels: The number of input channel(s),
RBG image has 3 channels, grayscale image has 1 channel.
@@ -49,122 +50,116 @@ class CASIAB(data.Dataset):
:param cache_on: Preload all clips in memory or not, this will
increase loading speed, but will add a preload process and
cost a lot of RAM. Loading the entire dataset
- (is_train = True,train_size = 124, discard_threshold = 1,
+ (is_train = True, train_size = 124, discard_threshold = 1,
num_input_channels = 3, frame_height = 64, frame_width = 32)
need about 22 GB of RAM.
"""
super(CASIAB, self).__init__()
- self.root_dir = root_dir
- self.is_train = is_train
- self.train_size = train_size
- self.num_sampled_frames = num_sampled_frames
- self.discard_threshold = discard_threshold
- self.num_input_channels = num_input_channels
- self.frame_size = frame_size
- self.cache_on = cache_on
-
- self.frame_transform: transforms.Compose
+ self._root_dir = root_dir
+ self._is_train = is_train
+ self._num_sampled_frames = num_sampled_frames
+ self._cache_on = cache_on
+
+ self._frame_transform: transforms.Compose
transform_compose_list = [
- transforms.Resize(size=self.frame_size),
+ transforms.Resize(size=frame_size),
transforms.ToTensor()
]
- if self.num_input_channels == 1:
+ if num_input_channels == 1:
transform_compose_list.insert(0, transforms.Grayscale())
- self.frame_transform = transforms.Compose(transform_compose_list)
+ self._frame_transform = transforms.Compose(transform_compose_list)
# Labels, conditions and views corresponding to each video clip
- self.labels: np.ndarray[np.str_]
+ self.labels: np.ndarray[np.int64]
self.conditions: np.ndarray[np.str_]
self.views: np.ndarray[np.str_]
- # Video clip directory names
- self._clip_names: list[str] = []
- # Labels, conditions and views in dataset,
+ # Labels, classes, conditions and views in dataset,
# set of three attributes above
- self.metadata = dict[str, set[str]]
+ self.metadata = dict[str, list[str]]
# Dictionaries for indexing frames and frame names by clip name
# and chip path when cache is on
self._cached_clips_frame_names: Optional[dict[str, list[str]]] = None
self._cached_clips: Optional[dict[str, torch.Tensor]] = None
- clip_names = sorted(os.listdir(self.root_dir))
+ # Video clip directory names
+ self._clip_names: list[str] = []
+ clip_names = sorted(os.listdir(self._root_dir))
- if self.is_train:
- clip_names = clip_names[:self.train_size * 10 * 11]
+ if self._is_train:
+ clip_names = clip_names[:train_size * 10 * 11]
else: # is_test
- clip_names = clip_names[self.train_size * 10 * 11:]
+ clip_names = clip_names[train_size * 10 * 11:]
- # Remove empty clips
+ # Remove clips under threshold
discard_clips_names = []
for clip_name in clip_names.copy():
- clip_path = os.path.join(self.root_dir, clip_name)
- if len(os.listdir(clip_path)) < self.discard_threshold:
+ clip_path = os.path.join(self._root_dir, clip_name)
+ if len(os.listdir(clip_path)) < discard_threshold:
discard_clips_names.append(clip_name)
clip_names.remove(clip_name)
if len(discard_clips_names) != 0:
print(', '.join(discard_clips_names[:-1]),
'and', discard_clips_names[-1], 'will be discarded.')
- # clip name constructed by label, condition and view
+ # Clip name constructed by class, condition and view
# e.g 002-bg-02-090 means clip from Subject #2
# in Bag #2 condition from 90 degree angle
- labels, conditions, views = [], [], []
+ classes, conditions, views = [], [], []
if selector:
- selected_labels = selector.pop('labels', None)
+ selected_classes = selector.pop('classes', None)
selected_conditions = selector.pop('conditions', None)
selected_views = selector.pop('views', None)
- label_regex = r'\d{3}'
- condition_regex = r'(nm|bg|cl)-0[0-4]'
+ class_regex = r'\d{3}'
+ condition_regex = r'(nm|bg|cl)-0[0-6]'
view_regex = r'\d{3}'
# Match required data using RegEx
- if selected_labels:
- label_regex = '|'.join(selected_labels)
+ if selected_classes:
+ class_regex = '|'.join(selected_classes)
if selected_conditions:
condition_regex = '|'.join(selected_conditions)
if selected_views:
view_regex = '|'.join(selected_views)
clip_re = re.compile('(' + ')-('.join((
- label_regex, condition_regex, view_regex
+ class_regex, condition_regex, view_regex
)) + ')')
for clip_name in clip_names:
match = clip_re.fullmatch(clip_name)
if match:
- labels.append(match.group(1))
+ classes.append(match.group(1))
conditions.append(match.group(2))
views.append(match.group(3))
self._clip_names.append(match.group(0))
- self.metadata = {
- 'labels': selected_labels,
- 'conditions': selected_conditions,
- 'views': selected_views
- }
else: # Add all
self._clip_names += clip_names
for clip_name in self._clip_names:
split_clip_name = clip_name.split('-')
- label = split_clip_name[0]
- labels.append(label)
+ class_ = split_clip_name[0]
+ classes.append(class_)
condition = '-'.join(split_clip_name[1:2 + 1])
conditions.append(condition)
view = split_clip_name[-1]
views.append(view)
- self.labels = np.asarray(labels)
+ # Encode classes to labels
+ self.label_encoder = LabelEncoder()
+ self.label_encoder.fit(classes)
+ self.labels = self.label_encoder.transform(classes)
self.conditions = np.asarray(conditions)
self.views = np.asarray(views)
- if not selector:
- self.metadata = {
- 'labels': set(self.labels.tolist()),
- 'conditions': set(self.conditions.tolist()),
- 'views': set(self.views.tolist())
- }
+ self.metadata = {
+ 'labels': list(dict.fromkeys(self.labels.tolist())),
+ 'classes': self.label_encoder.classes_.tolist(),
+ 'conditions': list(dict.fromkeys(self.conditions.tolist())),
+ 'views': list(dict.fromkeys(self.views.tolist()))
+ }
- if self.cache_on:
+ if self._cache_on:
self._cached_clips_frame_names = dict()
self._cached_clips = dict()
self._preload_all_video()
@@ -194,10 +189,10 @@ class CASIAB(data.Dataset):
def _read_video(self, clip_name: str,
is_caching: bool = False) -> torch.Tensor:
- clip_path = os.path.join(self.root_dir, clip_name)
+ clip_path = os.path.join(self._root_dir, clip_name)
sampled_frame_names = self._sample_frames(clip_path, is_caching)
- if self.cache_on:
+ if self._cache_on:
if is_caching:
clip = self._read_frames(clip_path, sampled_frame_names)
self._cached_clips[clip_name] = clip
@@ -221,14 +216,14 @@ class CASIAB(data.Dataset):
sampled_frame_names: list[str]
) -> torch.Tensor:
# Mask the original clip when it is long enough
- if len(frame_names) >= self.num_sampled_frames:
+ if len(frame_names) >= self._num_sampled_frames:
sampled_frame_mask = np.isin(frame_names,
sampled_frame_names)
sampled_clip = clip[sampled_frame_mask]
else: # Create a indexing filter from the beginning of clip
sampled_index = frame_names.index(sampled_frame_names[0])
sampled_frame_filter = [sampled_index]
- for i in range(1, self.num_sampled_frames):
+ for i in range(1, self._num_sampled_frames):
if sampled_frame_names[i] != sampled_frame_names[i - 1]:
sampled_index += 1
sampled_frame_filter.append(sampled_index)
@@ -241,7 +236,7 @@ class CASIAB(data.Dataset):
for frame_name in frame_names:
frame_path = os.path.join(clip_path, frame_name)
frame = Image.open(frame_path)
- frame = self.frame_transform(frame)
+ frame = self._frame_transform(frame)
frames.append(frame)
clip = torch.stack(frames)
@@ -249,7 +244,7 @@ class CASIAB(data.Dataset):
def _sample_frames(self, clip_path: str,
is_caching: bool = False) -> list[str]:
- if self.cache_on:
+ if self._cache_on:
if is_caching:
# Sort frame in advance for loading convenience
frame_names = sorted(os.listdir(clip_path))
@@ -261,14 +256,14 @@ class CASIAB(data.Dataset):
else: # Cache off
frame_names = os.listdir(clip_path)
- if self.is_train:
+ if self._is_train:
num_frames = len(frame_names)
# Sample frames without replace if have enough frames
- if num_frames < self.num_sampled_frames:
+ if num_frames < self._num_sampled_frames:
frame_names = random.choices(frame_names,
- k=self.num_sampled_frames)
+ k=self._num_sampled_frames)
else:
frame_names = random.sample(frame_names,
- k=self.num_sampled_frames)
+ k=self._num_sampled_frames)
return sorted(frame_names)