diff options
Diffstat (limited to 'test/dataset.py')
-rw-r--r-- | test/dataset.py | 33 |
1 files changed, 32 insertions, 1 deletions
diff --git a/test/dataset.py b/test/dataset.py index bfb8563..e0fc59a 100644 --- a/test/dataset.py +++ b/test/dataset.py @@ -1,4 +1,4 @@ -from utils.dataset import CASIAB, ClipConditions, ClipViews +from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses CASIAB_ROOT_DIR = '../data/CASIA-B-MRCNN/SEG' @@ -7,6 +7,29 @@ def test_casiab(): casiab = CASIAB(CASIAB_ROOT_DIR, discard_threshold=0) assert len(casiab) == 74 * 10 * 11 + labels = [] + for i in range(74): + labels += [i] * 10 * 11 + assert casiab.labels.tolist() == labels + + assert casiab.metadata['labels'] == [i for i in range(74)] + + assert casiab.label_encoder.inverse_transform([0, 2]).tolist() == ['001', + '003'] + + +def test_casiab_test(): + casiab_test = CASIAB(CASIAB_ROOT_DIR, is_train=False, discard_threshold=0) + assert len(casiab_test) == (124 - 74) * 10 * 11 + + labels = [] + for i in range(124 - 74): + labels += [i] * 10 * 11 + assert casiab_test.labels.tolist() == labels + + assert casiab_test.label_encoder.inverse_transform([0, 2]).tolist() == [ + '075', '077'] + def test_casiab_nm(): nm_selector = {'conditions': ClipConditions({r'nm-0\d'})} @@ -22,3 +45,11 @@ def test_casiab_nm_bg_90(): selector=nm_bg_90_selector, discard_threshold=0) assert len(casiab_nm_bg_90) == 74 * (6 + 2) * 1 + + +def test_caisab_class_selector(): + class_selector = {'classes': ClipClasses({'001', '003'})} + casiab_class_001_003 = CASIAB(CASIAB_ROOT_DIR, + selector=class_selector, + discard_threshold=0) + assert len(casiab_class_001_003) == 2 * 10 * 11 |