diff options
Diffstat (limited to 'utils/dataset.py')
-rw-r--r-- | utils/dataset.py | 21 |
1 files changed, 10 insertions, 11 deletions
diff --git a/utils/dataset.py b/utils/dataset.py index f41a9c5..40d7968 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -5,8 +5,8 @@ from typing import Optional, Dict, NewType, Union, List, Set import numpy as np import torch +from PIL import Image from torch.utils import data -from torchvision.io import read_image import torchvision.transforms as transforms from tqdm import tqdm @@ -31,7 +31,6 @@ class CASIAB(data.Dataset): num_input_channels: int = 3, frame_height: int = 64, frame_width: int = 32, - device: torch.device = torch.device('cpu'), cache_on: bool = False ): """ @@ -49,11 +48,12 @@ class CASIAB(data.Dataset): RBG image has 3 channels, grayscale image has 1 channel. :param frame_height: Frame height after transforming. :param frame_width: Frame width after transforming. - :param device: Device used in transforms. :param cache_on: Preload all clips in memory or not, this will increase loading speed, but will add a preload process and - cost a lot of RAM. Loading the entire dataset need about - 7 GB of RAM. + cost a lot of RAM. Loading the entire dataset + (is_train = True,train_size = 124, discard_threshold = 1, + num_input_channels = 3, frame_height = 64, frame_width = 32) + need about 22 GB of RAM. """ super(CASIAB, self).__init__() self.root_dir = root_dir @@ -64,12 +64,12 @@ class CASIAB(data.Dataset): self.num_input_channels = num_input_channels self.frame_height = frame_height self.frame_width = frame_width - self.device = device self.cache_on = cache_on self.frame_transform: transforms.Compose transform_compose_list = [ - transforms.Resize(size=(self.frame_height, self.frame_width)) + transforms.Resize(size=(self.frame_height, self.frame_width)), + transforms.ToTensor() ] if self.num_input_channels == 1: transform_compose_list.insert(0, transforms.Grayscale()) @@ -243,10 +243,9 @@ class CASIAB(data.Dataset): frames = [] for frame_name in frame_names: frame_path = os.path.join(clip_path, frame_name) - frame = read_image(frame_path) - # Transforming using CPU is not efficient - frame = self.frame_transform(frame.to(self.device)) - frames.append(frame.cpu()) + frame = Image.open(frame_path) + frame = self.frame_transform(frame) + frames.append(frame) clip = torch.stack(frames) return clip |