summaryrefslogtreecommitdiff
path: root/test/dataset.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-29 20:38:18 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-29 20:38:18 +0800
commitb8653d54efe2a8c94ae408c0c2da9bdd0b43ecdd (patch)
tree0aca443c5f2b0387fae48aa43611ca92d6015bbe /test/dataset.py
parent6e94fdb587656074dc2e65a80e51b8446f834b41 (diff)
Encode class names to label and some access improvement
1. Encode class names using LabelEncoder from sklearn 2. Remove unneeded class variables 3. Protect some variables from being accessed in userspace
Diffstat (limited to 'test/dataset.py')
-rw-r--r--test/dataset.py33
1 files changed, 32 insertions, 1 deletions
diff --git a/test/dataset.py b/test/dataset.py
index bfb8563..e0fc59a 100644
--- a/test/dataset.py
+++ b/test/dataset.py
@@ -1,4 +1,4 @@
-from utils.dataset import CASIAB, ClipConditions, ClipViews
+from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
CASIAB_ROOT_DIR = '../data/CASIA-B-MRCNN/SEG'
@@ -7,6 +7,29 @@ def test_casiab():
casiab = CASIAB(CASIAB_ROOT_DIR, discard_threshold=0)
assert len(casiab) == 74 * 10 * 11
+ labels = []
+ for i in range(74):
+ labels += [i] * 10 * 11
+ assert casiab.labels.tolist() == labels
+
+ assert casiab.metadata['labels'] == [i for i in range(74)]
+
+ assert casiab.label_encoder.inverse_transform([0, 2]).tolist() == ['001',
+ '003']
+
+
+def test_casiab_test():
+ casiab_test = CASIAB(CASIAB_ROOT_DIR, is_train=False, discard_threshold=0)
+ assert len(casiab_test) == (124 - 74) * 10 * 11
+
+ labels = []
+ for i in range(124 - 74):
+ labels += [i] * 10 * 11
+ assert casiab_test.labels.tolist() == labels
+
+ assert casiab_test.label_encoder.inverse_transform([0, 2]).tolist() == [
+ '075', '077']
+
def test_casiab_nm():
nm_selector = {'conditions': ClipConditions({r'nm-0\d'})}
@@ -22,3 +45,11 @@ def test_casiab_nm_bg_90():
selector=nm_bg_90_selector,
discard_threshold=0)
assert len(casiab_nm_bg_90) == 74 * (6 + 2) * 1
+
+
+def test_caisab_class_selector():
+ class_selector = {'classes': ClipClasses({'001', '003'})}
+ casiab_class_001_003 = CASIAB(CASIAB_ROOT_DIR,
+ selector=class_selector,
+ discard_threshold=0)
+ assert len(casiab_class_001_003) == 2 * 10 * 11