from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses

CASIAB_ROOT_DIR = '../data/CASIA-B-MRCNN/SEG'


def test_casiab():
    casiab = CASIAB(CASIAB_ROOT_DIR, discard_threshold=0)
    assert len(casiab) == 74 * 10 * 11

    labels = []
    for i in range(74):
        labels += [i] * 10 * 11
    assert casiab.labels.tolist() == labels

    assert casiab.metadata['labels'] == [i for i in range(74)]

    assert casiab.label_encoder.inverse_transform([0, 2]).tolist() == ['001',
                                                                       '003']


def test_casiab_test():
    casiab_test = CASIAB(CASIAB_ROOT_DIR, is_train=False, discard_threshold=0)
    assert len(casiab_test) == (124 - 74) * 10 * 11

    labels = []
    for i in range(124 - 74):
        labels += [i] * 10 * 11
    assert casiab_test.labels.tolist() == labels

    assert casiab_test.label_encoder.inverse_transform([0, 2]).tolist() == [
        '075', '077']


def test_casiab_nm():
    nm_selector = {'conditions': ClipConditions({r'nm-0\d'})}
    casiab_nm = CASIAB(CASIAB_ROOT_DIR, selector=nm_selector,
                       discard_threshold=0)
    assert len(casiab_nm) == 74 * 6 * 11


def test_casiab_nm_bg_90():
    nm_bg_90_selector = {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'}),
                         'views': ClipViews({'090'})}
    casiab_nm_bg_90 = CASIAB(CASIAB_ROOT_DIR,
                             selector=nm_bg_90_selector,
                             discard_threshold=0)
    assert len(casiab_nm_bg_90) == 74 * (6 + 2) * 1


def test_caisab_class_selector():
    class_selector = {'classes': ClipClasses({'001', '003'})}
    casiab_class_001_003 = CASIAB(CASIAB_ROOT_DIR,
                                  selector=class_selector,
                                  discard_threshold=0)
    assert len(casiab_class_001_003) == 2 * 10 * 11