aboutsummaryrefslogtreecommitdiff
path: root/supervised/datautils.py
blob: 196fca7cf9ad06fde58d16146ad327cfb51be3d1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
from torchvision.transforms import transforms


def color_distortion(s=1.0):
    # s is the strength of color distortion.
    color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([
        rnd_color_jitter,
        rnd_gray
    ])
    return color_distort