summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/dataset.py33
1 files changed, 32 insertions, 1 deletions
diff --git a/test/dataset.py b/test/dataset.py
index bfb8563..e0fc59a 100644
--- a/test/dataset.py
+++ b/test/dataset.py
@@ -1,4 +1,4 @@
-from utils.dataset import CASIAB, ClipConditions, ClipViews
+from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
CASIAB_ROOT_DIR = '../data/CASIA-B-MRCNN/SEG'
@@ -7,6 +7,29 @@ def test_casiab():
casiab = CASIAB(CASIAB_ROOT_DIR, discard_threshold=0)
assert len(casiab) == 74 * 10 * 11
+ labels = []
+ for i in range(74):
+ labels += [i] * 10 * 11
+ assert casiab.labels.tolist() == labels
+
+ assert casiab.metadata['labels'] == [i for i in range(74)]
+
+ assert casiab.label_encoder.inverse_transform([0, 2]).tolist() == ['001',
+ '003']
+
+
+def test_casiab_test():
+ casiab_test = CASIAB(CASIAB_ROOT_DIR, is_train=False, discard_threshold=0)
+ assert len(casiab_test) == (124 - 74) * 10 * 11
+
+ labels = []
+ for i in range(124 - 74):
+ labels += [i] * 10 * 11
+ assert casiab_test.labels.tolist() == labels
+
+ assert casiab_test.label_encoder.inverse_transform([0, 2]).tolist() == [
+ '075', '077']
+
def test_casiab_nm():
nm_selector = {'conditions': ClipConditions({r'nm-0\d'})}
@@ -22,3 +45,11 @@ def test_casiab_nm_bg_90():
selector=nm_bg_90_selector,
discard_threshold=0)
assert len(casiab_nm_bg_90) == 74 * (6 + 2) * 1
+
+
+def test_caisab_class_selector():
+ class_selector = {'classes': ClipClasses({'001', '003'})}
+ casiab_class_001_003 = CASIAB(CASIAB_ROOT_DIR,
+ selector=class_selector,
+ discard_threshold=0)
+ assert len(casiab_class_001_003) == 2 * 10 * 11