aboutsummaryrefslogtreecommitdiff
path: root/libs/datautils.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-03-18 20:16:56 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-03-18 20:16:56 +0800
commit8ed106fd05007124dff421603e8afb93aa2bbdbc (patch)
treeb855ff396d3acc0551cff14738d849e32bb289ac /libs/datautils.py
parentb475ecfa28c603010f550b0a8ad9204a5840b65f (diff)
Implement multi-crop dataset wrapper
Diffstat (limited to 'libs/datautils.py')
-rw-r--r--libs/datautils.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/libs/datautils.py b/libs/datautils.py
index 843f669..feae481 100644
--- a/libs/datautils.py
+++ b/libs/datautils.py
@@ -1,5 +1,8 @@
+from typing import Optional
+
import numpy as np
import torch
+from torch.utils.data import Dataset
from torchvision.transforms import transforms
@@ -65,3 +68,50 @@ class RandomGaussianBlur(object):
img, sigma=np.random.uniform(*self.sigma_range)
)
return img
+
+
+class MultiCropDatasetWrapper(Dataset):
+ """
+ Modified from Facebook SwAV at: https://github.com/facebookresearch/swav/blob/06b1b7cbaf6ba2a792300d79c7299db98b93b7f9/src/multicropdataset.py#L18
+ """
+
+ def __init__(
+ self,
+ dataset: Dataset,
+ n_crops: list[int],
+ crop_sizes: list[tuple[int, int]],
+ crop_scale_ranges: list[tuple[float, float]],
+ other_transforms: Optional[transforms.Compose] = None,
+ ):
+ assert len(crop_sizes) == len(n_crops)
+ assert len(crop_scale_ranges) == len(n_crops)
+
+ if hasattr(dataset, 'transform') and dataset.transform is not None:
+ raise AttributeError('Please pass the transform to wrapper.')
+
+ self.dataset = dataset
+
+ trans = []
+ for crop_size, crop_scale_range, n_crop in zip(
+ crop_sizes, crop_scale_ranges, n_crops
+ ):
+ rand_resize_crop = transforms.RandomResizedCrop(
+ crop_size, scale=crop_scale_range
+ )
+ if other_transforms is not None:
+ trans_i = transforms.Compose([
+ rand_resize_crop, other_transforms
+ ])
+ else:
+ trans_i = rand_resize_crop
+ trans += [trans_i] * n_crop
+ self.transform = trans
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getitem__(self, index):
+ img, target = self.dataset[index]
+ multi_crops = list(map(lambda trans: trans(img), self.transform))
+
+ return multi_crops, target