diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-18 20:16:56 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-18 20:16:56 +0800 |
commit | 8ed106fd05007124dff421603e8afb93aa2bbdbc (patch) | |
tree | b855ff396d3acc0551cff14738d849e32bb289ac /libs/datautils.py | |
parent | b475ecfa28c603010f550b0a8ad9204a5840b65f (diff) |
Implement multi-crop dataset wrapper
Diffstat (limited to 'libs/datautils.py')
-rw-r--r-- | libs/datautils.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/libs/datautils.py b/libs/datautils.py index 843f669..feae481 100644 --- a/libs/datautils.py +++ b/libs/datautils.py @@ -1,5 +1,8 @@ +from typing import Optional + import numpy as np import torch +from torch.utils.data import Dataset from torchvision.transforms import transforms @@ -65,3 +68,50 @@ class RandomGaussianBlur(object): img, sigma=np.random.uniform(*self.sigma_range) ) return img + + +class MultiCropDatasetWrapper(Dataset): + """ + Modified from Facebook SwAV at: https://github.com/facebookresearch/swav/blob/06b1b7cbaf6ba2a792300d79c7299db98b93b7f9/src/multicropdataset.py#L18 + """ + + def __init__( + self, + dataset: Dataset, + n_crops: list[int], + crop_sizes: list[tuple[int, int]], + crop_scale_ranges: list[tuple[float, float]], + other_transforms: Optional[transforms.Compose] = None, + ): + assert len(crop_sizes) == len(n_crops) + assert len(crop_scale_ranges) == len(n_crops) + + if hasattr(dataset, 'transform') and dataset.transform is not None: + raise AttributeError('Please pass the transform to wrapper.') + + self.dataset = dataset + + trans = [] + for crop_size, crop_scale_range, n_crop in zip( + crop_sizes, crop_scale_ranges, n_crops + ): + rand_resize_crop = transforms.RandomResizedCrop( + crop_size, scale=crop_scale_range + ) + if other_transforms is not None: + trans_i = transforms.Compose([ + rand_resize_crop, other_transforms + ]) + else: + trans_i = rand_resize_crop + trans += [trans_i] * n_crop + self.transform = trans + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + img, target = self.dataset[index] + multi_crops = list(map(lambda trans: trans(img), self.transform)) + + return multi_crops, target |