summaryrefslogtreecommitdiff
path: root/utils/dataset.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-21 20:00:51 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-21 20:00:51 +0800
commit315495ffe14986cfce90e92203cc317eef5ba5cf (patch)
treed261a573117d405a5b3faabd0dc549e95a1cde7b /utils/dataset.py
parent0d4c20a8104a77f2f61f5ce0ac46c3a26d61b9c1 (diff)
Change image loading technique
1. Use Pillow.Image.open instead of torchvision.io.read_image to read image 2. Transforming PIL images instead of tensors which performs better and device option is removed 3. Images are normalized now
Diffstat (limited to 'utils/dataset.py')
-rw-r--r--utils/dataset.py21
1 files changed, 10 insertions, 11 deletions
diff --git a/utils/dataset.py b/utils/dataset.py
index f41a9c5..40d7968 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -5,8 +5,8 @@ from typing import Optional, Dict, NewType, Union, List, Set
import numpy as np
import torch
+from PIL import Image
from torch.utils import data
-from torchvision.io import read_image
import torchvision.transforms as transforms
from tqdm import tqdm
@@ -31,7 +31,6 @@ class CASIAB(data.Dataset):
num_input_channels: int = 3,
frame_height: int = 64,
frame_width: int = 32,
- device: torch.device = torch.device('cpu'),
cache_on: bool = False
):
"""
@@ -49,11 +48,12 @@ class CASIAB(data.Dataset):
RBG image has 3 channels, grayscale image has 1 channel.
:param frame_height: Frame height after transforming.
:param frame_width: Frame width after transforming.
- :param device: Device used in transforms.
:param cache_on: Preload all clips in memory or not, this will
increase loading speed, but will add a preload process and
- cost a lot of RAM. Loading the entire dataset need about
- 7 GB of RAM.
+ cost a lot of RAM. Loading the entire dataset
+ (is_train = True,train_size = 124, discard_threshold = 1,
+ num_input_channels = 3, frame_height = 64, frame_width = 32)
+ need about 22 GB of RAM.
"""
super(CASIAB, self).__init__()
self.root_dir = root_dir
@@ -64,12 +64,12 @@ class CASIAB(data.Dataset):
self.num_input_channels = num_input_channels
self.frame_height = frame_height
self.frame_width = frame_width
- self.device = device
self.cache_on = cache_on
self.frame_transform: transforms.Compose
transform_compose_list = [
- transforms.Resize(size=(self.frame_height, self.frame_width))
+ transforms.Resize(size=(self.frame_height, self.frame_width)),
+ transforms.ToTensor()
]
if self.num_input_channels == 1:
transform_compose_list.insert(0, transforms.Grayscale())
@@ -243,10 +243,9 @@ class CASIAB(data.Dataset):
frames = []
for frame_name in frame_names:
frame_path = os.path.join(clip_path, frame_name)
- frame = read_image(frame_path)
- # Transforming using CPU is not efficient
- frame = self.frame_transform(frame.to(self.device))
- frames.append(frame.cpu())
+ frame = Image.open(frame_path)
+ frame = self.frame_transform(frame)
+ frames.append(frame)
clip = torch.stack(frames)
return clip