blob: 196fca7cf9ad06fde58d16146ad327cfb51be3d1 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
|
from torchvision.transforms import transforms
def color_distortion(s=1.0):
# s is the strength of color distortion.
color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
rnd_gray = transforms.RandomGrayscale(p=0.2)
color_distort = transforms.Compose([
rnd_color_jitter,
rnd_gray
])
return color_distort
|