summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--requirements.txt3
-rw-r--r--test/dataset.py33
-rw-r--r--utils/configuration.py4
-rw-r--r--utils/dataset.py121
4 files changed, 94 insertions, 67 deletions
diff --git a/requirements.txt b/requirements.txt
index 765977f..7cf8570 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,4 +2,5 @@ torch~=1.7.1
torchvision~=0.8.0a0+ecf4e9c
numpy~=1.19.4
tqdm~=4.55.0
-Pillow~=8.0.1 \ No newline at end of file
+Pillow~=8.0.1
+scikit-learn~=0.23.2 \ No newline at end of file
diff --git a/test/dataset.py b/test/dataset.py
index bfb8563..e0fc59a 100644
--- a/test/dataset.py
+++ b/test/dataset.py
@@ -1,4 +1,4 @@
-from utils.dataset import CASIAB, ClipConditions, ClipViews
+from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
CASIAB_ROOT_DIR = '../data/CASIA-B-MRCNN/SEG'
@@ -7,6 +7,29 @@ def test_casiab():
casiab = CASIAB(CASIAB_ROOT_DIR, discard_threshold=0)
assert len(casiab) == 74 * 10 * 11
+ labels = []
+ for i in range(74):
+ labels += [i] * 10 * 11
+ assert casiab.labels.tolist() == labels
+
+ assert casiab.metadata['labels'] == [i for i in range(74)]
+
+ assert casiab.label_encoder.inverse_transform([0, 2]).tolist() == ['001',
+ '003']
+
+
+def test_casiab_test():
+ casiab_test = CASIAB(CASIAB_ROOT_DIR, is_train=False, discard_threshold=0)
+ assert len(casiab_test) == (124 - 74) * 10 * 11
+
+ labels = []
+ for i in range(124 - 74):
+ labels += [i] * 10 * 11
+ assert casiab_test.labels.tolist() == labels
+
+ assert casiab_test.label_encoder.inverse_transform([0, 2]).tolist() == [
+ '075', '077']
+
def test_casiab_nm():
nm_selector = {'conditions': ClipConditions({r'nm-0\d'})}
@@ -22,3 +45,11 @@ def test_casiab_nm_bg_90():
selector=nm_bg_90_selector,
discard_threshold=0)
assert len(casiab_nm_bg_90) == 74 * (6 + 2) * 1
+
+
+def test_caisab_class_selector():
+ class_selector = {'classes': ClipClasses({'001', '003'})}
+ casiab_class_001_003 = CASIAB(CASIAB_ROOT_DIR,
+ selector=class_selector,
+ discard_threshold=0)
+ assert len(casiab_class_001_003) == 2 * 10 * 11
diff --git a/utils/configuration.py b/utils/configuration.py
index e6bfaf2..965af94 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -2,7 +2,7 @@ from typing import TypedDict, Optional, Union
import torch
-from utils.dataset import ClipLabels, ClipConditions, ClipViews
+from utils.dataset import ClipClasses, ClipConditions, ClipViews
class SystemConfiguration(TypedDict):
@@ -17,7 +17,7 @@ class DatasetConfiguration(TypedDict):
train_size: int
num_sampled_frames: int
discard_threshold: int
- selector: Optional[dict[str, Union[ClipLabels, ClipConditions, ClipViews]]]
+ selector: Optional[dict[str, Union[ClipClasses, ClipConditions, ClipViews]]]
num_input_channels: int
frame_size: tuple[int, int]
cache_on: bool
diff --git a/utils/dataset.py b/utils/dataset.py
index 1a9d595..050ac03 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -7,10 +7,11 @@ import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
+from sklearn.preprocessing import LabelEncoder
from torch.utils import data
from tqdm import tqdm
-ClipLabels = NewType('ClipLabels', set[str])
+ClipClasses = NewType('ClipClasses', set[str])
ClipConditions = NewType('ClipConditions', set[str])
ClipViews = NewType('ClipViews', set[str])
@@ -26,7 +27,7 @@ class CASIAB(data.Dataset):
num_sampled_frames: int = 30,
discard_threshold: int = 15,
selector: Optional[dict[
- str, Union[ClipLabels, ClipConditions, ClipViews]
+ str, Union[ClipClasses, ClipConditions, ClipViews]
]] = None,
num_input_channels: int = 3,
frame_size: tuple[int, int] = (64, 32),
@@ -41,7 +42,7 @@ class CASIAB(data.Dataset):
(Training Only)
:param discard_threshold: Discard the sample if its number of
frames is less than this threshold.
- :param selector: Restrict output data labels, conditions and
+ :param selector: Restrict output data classes, conditions and
views.
:param num_input_channels: The number of input channel(s),
RBG image has 3 channels, grayscale image has 1 channel.
@@ -49,122 +50,116 @@ class CASIAB(data.Dataset):
:param cache_on: Preload all clips in memory or not, this will
increase loading speed, but will add a preload process and
cost a lot of RAM. Loading the entire dataset
- (is_train = True,train_size = 124, discard_threshold = 1,
+ (is_train = True, train_size = 124, discard_threshold = 1,
num_input_channels = 3, frame_height = 64, frame_width = 32)
need about 22 GB of RAM.
"""
super(CASIAB, self).__init__()
- self.root_dir = root_dir
- self.is_train = is_train
- self.train_size = train_size
- self.num_sampled_frames = num_sampled_frames
- self.discard_threshold = discard_threshold
- self.num_input_channels = num_input_channels
- self.frame_size = frame_size
- self.cache_on = cache_on
-
- self.frame_transform: transforms.Compose
+ self._root_dir = root_dir
+ self._is_train = is_train
+ self._num_sampled_frames = num_sampled_frames
+ self._cache_on = cache_on
+
+ self._frame_transform: transforms.Compose
transform_compose_list = [
- transforms.Resize(size=self.frame_size),
+ transforms.Resize(size=frame_size),
transforms.ToTensor()
]
- if self.num_input_channels == 1:
+ if num_input_channels == 1:
transform_compose_list.insert(0, transforms.Grayscale())
- self.frame_transform = transforms.Compose(transform_compose_list)
+ self._frame_transform = transforms.Compose(transform_compose_list)
# Labels, conditions and views corresponding to each video clip
- self.labels: np.ndarray[np.str_]
+ self.labels: np.ndarray[np.int64]
self.conditions: np.ndarray[np.str_]
self.views: np.ndarray[np.str_]
- # Video clip directory names
- self._clip_names: list[str] = []
- # Labels, conditions and views in dataset,
+ # Labels, classes, conditions and views in dataset,
# set of three attributes above
- self.metadata = dict[str, set[str]]
+ self.metadata = dict[str, list[str]]
# Dictionaries for indexing frames and frame names by clip name
# and chip path when cache is on
self._cached_clips_frame_names: Optional[dict[str, list[str]]] = None
self._cached_clips: Optional[dict[str, torch.Tensor]] = None
- clip_names = sorted(os.listdir(self.root_dir))
+ # Video clip directory names
+ self._clip_names: list[str] = []
+ clip_names = sorted(os.listdir(self._root_dir))
- if self.is_train:
- clip_names = clip_names[:self.train_size * 10 * 11]
+ if self._is_train:
+ clip_names = clip_names[:train_size * 10 * 11]
else: # is_test
- clip_names = clip_names[self.train_size * 10 * 11:]
+ clip_names = clip_names[train_size * 10 * 11:]
- # Remove empty clips
+ # Remove clips under threshold
discard_clips_names = []
for clip_name in clip_names.copy():
- clip_path = os.path.join(self.root_dir, clip_name)
- if len(os.listdir(clip_path)) < self.discard_threshold:
+ clip_path = os.path.join(self._root_dir, clip_name)
+ if len(os.listdir(clip_path)) < discard_threshold:
discard_clips_names.append(clip_name)
clip_names.remove(clip_name)
if len(discard_clips_names) != 0:
print(', '.join(discard_clips_names[:-1]),
'and', discard_clips_names[-1], 'will be discarded.')
- # clip name constructed by label, condition and view
+ # Clip name constructed by class, condition and view
# e.g 002-bg-02-090 means clip from Subject #2
# in Bag #2 condition from 90 degree angle
- labels, conditions, views = [], [], []
+ classes, conditions, views = [], [], []
if selector:
- selected_labels = selector.pop('labels', None)
+ selected_classes = selector.pop('classes', None)
selected_conditions = selector.pop('conditions', None)
selected_views = selector.pop('views', None)
- label_regex = r'\d{3}'
- condition_regex = r'(nm|bg|cl)-0[0-4]'
+ class_regex = r'\d{3}'
+ condition_regex = r'(nm|bg|cl)-0[0-6]'
view_regex = r'\d{3}'
# Match required data using RegEx
- if selected_labels:
- label_regex = '|'.join(selected_labels)
+ if selected_classes:
+ class_regex = '|'.join(selected_classes)
if selected_conditions:
condition_regex = '|'.join(selected_conditions)
if selected_views:
view_regex = '|'.join(selected_views)
clip_re = re.compile('(' + ')-('.join((
- label_regex, condition_regex, view_regex
+ class_regex, condition_regex, view_regex
)) + ')')
for clip_name in clip_names:
match = clip_re.fullmatch(clip_name)
if match:
- labels.append(match.group(1))
+ classes.append(match.group(1))
conditions.append(match.group(2))
views.append(match.group(3))
self._clip_names.append(match.group(0))
- self.metadata = {
- 'labels': selected_labels,
- 'conditions': selected_conditions,
- 'views': selected_views
- }
else: # Add all
self._clip_names += clip_names
for clip_name in self._clip_names:
split_clip_name = clip_name.split('-')
- label = split_clip_name[0]
- labels.append(label)
+ class_ = split_clip_name[0]
+ classes.append(class_)
condition = '-'.join(split_clip_name[1:2 + 1])
conditions.append(condition)
view = split_clip_name[-1]
views.append(view)
- self.labels = np.asarray(labels)
+ # Encode classes to labels
+ self.label_encoder = LabelEncoder()
+ self.label_encoder.fit(classes)
+ self.labels = self.label_encoder.transform(classes)
self.conditions = np.asarray(conditions)
self.views = np.asarray(views)
- if not selector:
- self.metadata = {
- 'labels': set(self.labels.tolist()),
- 'conditions': set(self.conditions.tolist()),
- 'views': set(self.views.tolist())
- }
+ self.metadata = {
+ 'labels': list(dict.fromkeys(self.labels.tolist())),
+ 'classes': self.label_encoder.classes_.tolist(),
+ 'conditions': list(dict.fromkeys(self.conditions.tolist())),
+ 'views': list(dict.fromkeys(self.views.tolist()))
+ }
- if self.cache_on:
+ if self._cache_on:
self._cached_clips_frame_names = dict()
self._cached_clips = dict()
self._preload_all_video()
@@ -194,10 +189,10 @@ class CASIAB(data.Dataset):
def _read_video(self, clip_name: str,
is_caching: bool = False) -> torch.Tensor:
- clip_path = os.path.join(self.root_dir, clip_name)
+ clip_path = os.path.join(self._root_dir, clip_name)
sampled_frame_names = self._sample_frames(clip_path, is_caching)
- if self.cache_on:
+ if self._cache_on:
if is_caching:
clip = self._read_frames(clip_path, sampled_frame_names)
self._cached_clips[clip_name] = clip
@@ -221,14 +216,14 @@ class CASIAB(data.Dataset):
sampled_frame_names: list[str]
) -> torch.Tensor:
# Mask the original clip when it is long enough
- if len(frame_names) >= self.num_sampled_frames:
+ if len(frame_names) >= self._num_sampled_frames:
sampled_frame_mask = np.isin(frame_names,
sampled_frame_names)
sampled_clip = clip[sampled_frame_mask]
else: # Create a indexing filter from the beginning of clip
sampled_index = frame_names.index(sampled_frame_names[0])
sampled_frame_filter = [sampled_index]
- for i in range(1, self.num_sampled_frames):
+ for i in range(1, self._num_sampled_frames):
if sampled_frame_names[i] != sampled_frame_names[i - 1]:
sampled_index += 1
sampled_frame_filter.append(sampled_index)
@@ -241,7 +236,7 @@ class CASIAB(data.Dataset):
for frame_name in frame_names:
frame_path = os.path.join(clip_path, frame_name)
frame = Image.open(frame_path)
- frame = self.frame_transform(frame)
+ frame = self._frame_transform(frame)
frames.append(frame)
clip = torch.stack(frames)
@@ -249,7 +244,7 @@ class CASIAB(data.Dataset):
def _sample_frames(self, clip_path: str,
is_caching: bool = False) -> list[str]:
- if self.cache_on:
+ if self._cache_on:
if is_caching:
# Sort frame in advance for loading convenience
frame_names = sorted(os.listdir(clip_path))
@@ -261,14 +256,14 @@ class CASIAB(data.Dataset):
else: # Cache off
frame_names = os.listdir(clip_path)
- if self.is_train:
+ if self._is_train:
num_frames = len(frame_names)
# Sample frames without replace if have enough frames
- if num_frames < self.num_sampled_frames:
+ if num_frames < self._num_sampled_frames:
frame_names = random.choices(frame_names,
- k=self.num_sampled_frames)
+ k=self._num_sampled_frames)
else:
frame_names = random.sample(frame_names,
- k=self.num_sampled_frames)
+ k=self._num_sampled_frames)
return sorted(frame_names)