diff options
-rw-r--r-- | requirements.txt | 4 | ||||
-rw-r--r-- | utils/dataset.py | 21 |
2 files changed, 13 insertions, 12 deletions
diff --git a/requirements.txt b/requirements.txt index a0eaf1a..c58a9c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ torch~=1.7.1 torchvision~=0.8.0a0+ecf4e9c numpy~=1.19.4 -tqdm~=4.54.1
\ No newline at end of file +tqdm~=4.54.1 +Pillow~=8.0.1 +matplotlib~=3.3.3
\ No newline at end of file diff --git a/utils/dataset.py b/utils/dataset.py index f41a9c5..40d7968 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -5,8 +5,8 @@ from typing import Optional, Dict, NewType, Union, List, Set import numpy as np import torch +from PIL import Image from torch.utils import data -from torchvision.io import read_image import torchvision.transforms as transforms from tqdm import tqdm @@ -31,7 +31,6 @@ class CASIAB(data.Dataset): num_input_channels: int = 3, frame_height: int = 64, frame_width: int = 32, - device: torch.device = torch.device('cpu'), cache_on: bool = False ): """ @@ -49,11 +48,12 @@ class CASIAB(data.Dataset): RBG image has 3 channels, grayscale image has 1 channel. :param frame_height: Frame height after transforming. :param frame_width: Frame width after transforming. - :param device: Device used in transforms. :param cache_on: Preload all clips in memory or not, this will increase loading speed, but will add a preload process and - cost a lot of RAM. Loading the entire dataset need about - 7 GB of RAM. + cost a lot of RAM. Loading the entire dataset + (is_train = True,train_size = 124, discard_threshold = 1, + num_input_channels = 3, frame_height = 64, frame_width = 32) + need about 22 GB of RAM. """ super(CASIAB, self).__init__() self.root_dir = root_dir @@ -64,12 +64,12 @@ class CASIAB(data.Dataset): self.num_input_channels = num_input_channels self.frame_height = frame_height self.frame_width = frame_width - self.device = device self.cache_on = cache_on self.frame_transform: transforms.Compose transform_compose_list = [ - transforms.Resize(size=(self.frame_height, self.frame_width)) + transforms.Resize(size=(self.frame_height, self.frame_width)), + transforms.ToTensor() ] if self.num_input_channels == 1: transform_compose_list.insert(0, transforms.Grayscale()) @@ -243,10 +243,9 @@ class CASIAB(data.Dataset): frames = [] for frame_name in frame_names: frame_path = os.path.join(clip_path, frame_name) - frame = read_image(frame_path) - # Transforming using CPU is not efficient - frame = self.frame_transform(frame.to(self.device)) - frames.append(frame.cpu()) + frame = Image.open(frame_path) + frame = self.frame_transform(frame) + frames.append(frame) clip = torch.stack(frames) return clip |