summaryrefslogtreecommitdiff
path: root/utils/dataset.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-29 20:38:18 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-29 20:38:18 +0800
commitb8653d54efe2a8c94ae408c0c2da9bdd0b43ecdd (patch)
tree0aca443c5f2b0387fae48aa43611ca92d6015bbe /utils/dataset.py
parent6e94fdb587656074dc2e65a80e51b8446f834b41 (diff)
Encode class names to label and some access improvement
1. Encode class names using LabelEncoder from sklearn 2. Remove unneeded class variables 3. Protect some variables from being accessed in userspace
Diffstat (limited to 'utils/dataset.py')
-rw-r--r--utils/dataset.py121
1 files changed, 58 insertions, 63 deletions
diff --git a/utils/dataset.py b/utils/dataset.py
index 1a9d595..050ac03 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -7,10 +7,11 @@ import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
+from sklearn.preprocessing import LabelEncoder
from torch.utils import data
from tqdm import tqdm
-ClipLabels = NewType('ClipLabels', set[str])
+ClipClasses = NewType('ClipClasses', set[str])
ClipConditions = NewType('ClipConditions', set[str])
ClipViews = NewType('ClipViews', set[str])
@@ -26,7 +27,7 @@ class CASIAB(data.Dataset):
num_sampled_frames: int = 30,
discard_threshold: int = 15,
selector: Optional[dict[
- str, Union[ClipLabels, ClipConditions, ClipViews]
+ str, Union[ClipClasses, ClipConditions, ClipViews]
]] = None,
num_input_channels: int = 3,
frame_size: tuple[int, int] = (64, 32),
@@ -41,7 +42,7 @@ class CASIAB(data.Dataset):
(Training Only)
:param discard_threshold: Discard the sample if its number of
frames is less than this threshold.
- :param selector: Restrict output data labels, conditions and
+ :param selector: Restrict output data classes, conditions and
views.
:param num_input_channels: The number of input channel(s),
RBG image has 3 channels, grayscale image has 1 channel.
@@ -49,122 +50,116 @@ class CASIAB(data.Dataset):
:param cache_on: Preload all clips in memory or not, this will
increase loading speed, but will add a preload process and
cost a lot of RAM. Loading the entire dataset
- (is_train = True,train_size = 124, discard_threshold = 1,
+ (is_train = True, train_size = 124, discard_threshold = 1,
num_input_channels = 3, frame_height = 64, frame_width = 32)
need about 22 GB of RAM.
"""
super(CASIAB, self).__init__()
- self.root_dir = root_dir
- self.is_train = is_train
- self.train_size = train_size
- self.num_sampled_frames = num_sampled_frames
- self.discard_threshold = discard_threshold
- self.num_input_channels = num_input_channels
- self.frame_size = frame_size
- self.cache_on = cache_on
-
- self.frame_transform: transforms.Compose
+ self._root_dir = root_dir
+ self._is_train = is_train
+ self._num_sampled_frames = num_sampled_frames
+ self._cache_on = cache_on
+
+ self._frame_transform: transforms.Compose
transform_compose_list = [
- transforms.Resize(size=self.frame_size),
+ transforms.Resize(size=frame_size),
transforms.ToTensor()
]
- if self.num_input_channels == 1:
+ if num_input_channels == 1:
transform_compose_list.insert(0, transforms.Grayscale())
- self.frame_transform = transforms.Compose(transform_compose_list)
+ self._frame_transform = transforms.Compose(transform_compose_list)
# Labels, conditions and views corresponding to each video clip
- self.labels: np.ndarray[np.str_]
+ self.labels: np.ndarray[np.int64]
self.conditions: np.ndarray[np.str_]
self.views: np.ndarray[np.str_]
- # Video clip directory names
- self._clip_names: list[str] = []
- # Labels, conditions and views in dataset,
+ # Labels, classes, conditions and views in dataset,
# set of three attributes above
- self.metadata = dict[str, set[str]]
+ self.metadata = dict[str, list[str]]
# Dictionaries for indexing frames and frame names by clip name
# and chip path when cache is on
self._cached_clips_frame_names: Optional[dict[str, list[str]]] = None
self._cached_clips: Optional[dict[str, torch.Tensor]] = None
- clip_names = sorted(os.listdir(self.root_dir))
+ # Video clip directory names
+ self._clip_names: list[str] = []
+ clip_names = sorted(os.listdir(self._root_dir))
- if self.is_train:
- clip_names = clip_names[:self.train_size * 10 * 11]
+ if self._is_train:
+ clip_names = clip_names[:train_size * 10 * 11]
else: # is_test
- clip_names = clip_names[self.train_size * 10 * 11:]
+ clip_names = clip_names[train_size * 10 * 11:]
- # Remove empty clips
+ # Remove clips under threshold
discard_clips_names = []
for clip_name in clip_names.copy():
- clip_path = os.path.join(self.root_dir, clip_name)
- if len(os.listdir(clip_path)) < self.discard_threshold:
+ clip_path = os.path.join(self._root_dir, clip_name)
+ if len(os.listdir(clip_path)) < discard_threshold:
discard_clips_names.append(clip_name)
clip_names.remove(clip_name)
if len(discard_clips_names) != 0:
print(', '.join(discard_clips_names[:-1]),
'and', discard_clips_names[-1], 'will be discarded.')
- # clip name constructed by label, condition and view
+ # Clip name constructed by class, condition and view
# e.g 002-bg-02-090 means clip from Subject #2
# in Bag #2 condition from 90 degree angle
- labels, conditions, views = [], [], []
+ classes, conditions, views = [], [], []
if selector:
- selected_labels = selector.pop('labels', None)
+ selected_classes = selector.pop('classes', None)
selected_conditions = selector.pop('conditions', None)
selected_views = selector.pop('views', None)
- label_regex = r'\d{3}'
- condition_regex = r'(nm|bg|cl)-0[0-4]'
+ class_regex = r'\d{3}'
+ condition_regex = r'(nm|bg|cl)-0[0-6]'
view_regex = r'\d{3}'
# Match required data using RegEx
- if selected_labels:
- label_regex = '|'.join(selected_labels)
+ if selected_classes:
+ class_regex = '|'.join(selected_classes)
if selected_conditions:
condition_regex = '|'.join(selected_conditions)
if selected_views:
view_regex = '|'.join(selected_views)
clip_re = re.compile('(' + ')-('.join((
- label_regex, condition_regex, view_regex
+ class_regex, condition_regex, view_regex
)) + ')')
for clip_name in clip_names:
match = clip_re.fullmatch(clip_name)
if match:
- labels.append(match.group(1))
+ classes.append(match.group(1))
conditions.append(match.group(2))
views.append(match.group(3))
self._clip_names.append(match.group(0))
- self.metadata = {
- 'labels': selected_labels,
- 'conditions': selected_conditions,
- 'views': selected_views
- }
else: # Add all
self._clip_names += clip_names
for clip_name in self._clip_names:
split_clip_name = clip_name.split('-')
- label = split_clip_name[0]
- labels.append(label)
+ class_ = split_clip_name[0]
+ classes.append(class_)
condition = '-'.join(split_clip_name[1:2 + 1])
conditions.append(condition)
view = split_clip_name[-1]
views.append(view)
- self.labels = np.asarray(labels)
+ # Encode classes to labels
+ self.label_encoder = LabelEncoder()
+ self.label_encoder.fit(classes)
+ self.labels = self.label_encoder.transform(classes)
self.conditions = np.asarray(conditions)
self.views = np.asarray(views)
- if not selector:
- self.metadata = {
- 'labels': set(self.labels.tolist()),
- 'conditions': set(self.conditions.tolist()),
- 'views': set(self.views.tolist())
- }
+ self.metadata = {
+ 'labels': list(dict.fromkeys(self.labels.tolist())),
+ 'classes': self.label_encoder.classes_.tolist(),
+ 'conditions': list(dict.fromkeys(self.conditions.tolist())),
+ 'views': list(dict.fromkeys(self.views.tolist()))
+ }
- if self.cache_on:
+ if self._cache_on:
self._cached_clips_frame_names = dict()
self._cached_clips = dict()
self._preload_all_video()
@@ -194,10 +189,10 @@ class CASIAB(data.Dataset):
def _read_video(self, clip_name: str,
is_caching: bool = False) -> torch.Tensor:
- clip_path = os.path.join(self.root_dir, clip_name)
+ clip_path = os.path.join(self._root_dir, clip_name)
sampled_frame_names = self._sample_frames(clip_path, is_caching)
- if self.cache_on:
+ if self._cache_on:
if is_caching:
clip = self._read_frames(clip_path, sampled_frame_names)
self._cached_clips[clip_name] = clip
@@ -221,14 +216,14 @@ class CASIAB(data.Dataset):
sampled_frame_names: list[str]
) -> torch.Tensor:
# Mask the original clip when it is long enough
- if len(frame_names) >= self.num_sampled_frames:
+ if len(frame_names) >= self._num_sampled_frames:
sampled_frame_mask = np.isin(frame_names,
sampled_frame_names)
sampled_clip = clip[sampled_frame_mask]
else: # Create a indexing filter from the beginning of clip
sampled_index = frame_names.index(sampled_frame_names[0])
sampled_frame_filter = [sampled_index]
- for i in range(1, self.num_sampled_frames):
+ for i in range(1, self._num_sampled_frames):
if sampled_frame_names[i] != sampled_frame_names[i - 1]:
sampled_index += 1
sampled_frame_filter.append(sampled_index)
@@ -241,7 +236,7 @@ class CASIAB(data.Dataset):
for frame_name in frame_names:
frame_path = os.path.join(clip_path, frame_name)
frame = Image.open(frame_path)
- frame = self.frame_transform(frame)
+ frame = self._frame_transform(frame)
frames.append(frame)
clip = torch.stack(frames)
@@ -249,7 +244,7 @@ class CASIAB(data.Dataset):
def _sample_frames(self, clip_path: str,
is_caching: bool = False) -> list[str]:
- if self.cache_on:
+ if self._cache_on:
if is_caching:
# Sort frame in advance for loading convenience
frame_names = sorted(os.listdir(clip_path))
@@ -261,14 +256,14 @@ class CASIAB(data.Dataset):
else: # Cache off
frame_names = os.listdir(clip_path)
- if self.is_train:
+ if self._is_train:
num_frames = len(frame_names)
# Sample frames without replace if have enough frames
- if num_frames < self.num_sampled_frames:
+ if num_frames < self._num_sampled_frames:
frame_names = random.choices(frame_names,
- k=self.num_sampled_frames)
+ k=self._num_sampled_frames)
else:
frame_names = random.sample(frame_names,
- k=self.num_sampled_frames)
+ k=self._num_sampled_frames)
return sorted(frame_names)