summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--requirements.txt4
-rw-r--r--utils/dataset.py21
2 files changed, 13 insertions, 12 deletions
diff --git a/requirements.txt b/requirements.txt
index a0eaf1a..c58a9c3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,6 @@
torch~=1.7.1
torchvision~=0.8.0a0+ecf4e9c
numpy~=1.19.4
-tqdm~=4.54.1 \ No newline at end of file
+tqdm~=4.54.1
+Pillow~=8.0.1
+matplotlib~=3.3.3 \ No newline at end of file
diff --git a/utils/dataset.py b/utils/dataset.py
index f41a9c5..40d7968 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -5,8 +5,8 @@ from typing import Optional, Dict, NewType, Union, List, Set
import numpy as np
import torch
+from PIL import Image
from torch.utils import data
-from torchvision.io import read_image
import torchvision.transforms as transforms
from tqdm import tqdm
@@ -31,7 +31,6 @@ class CASIAB(data.Dataset):
num_input_channels: int = 3,
frame_height: int = 64,
frame_width: int = 32,
- device: torch.device = torch.device('cpu'),
cache_on: bool = False
):
"""
@@ -49,11 +48,12 @@ class CASIAB(data.Dataset):
RBG image has 3 channels, grayscale image has 1 channel.
:param frame_height: Frame height after transforming.
:param frame_width: Frame width after transforming.
- :param device: Device used in transforms.
:param cache_on: Preload all clips in memory or not, this will
increase loading speed, but will add a preload process and
- cost a lot of RAM. Loading the entire dataset need about
- 7 GB of RAM.
+ cost a lot of RAM. Loading the entire dataset
+ (is_train = True,train_size = 124, discard_threshold = 1,
+ num_input_channels = 3, frame_height = 64, frame_width = 32)
+ need about 22 GB of RAM.
"""
super(CASIAB, self).__init__()
self.root_dir = root_dir
@@ -64,12 +64,12 @@ class CASIAB(data.Dataset):
self.num_input_channels = num_input_channels
self.frame_height = frame_height
self.frame_width = frame_width
- self.device = device
self.cache_on = cache_on
self.frame_transform: transforms.Compose
transform_compose_list = [
- transforms.Resize(size=(self.frame_height, self.frame_width))
+ transforms.Resize(size=(self.frame_height, self.frame_width)),
+ transforms.ToTensor()
]
if self.num_input_channels == 1:
transform_compose_list.insert(0, transforms.Grayscale())
@@ -243,10 +243,9 @@ class CASIAB(data.Dataset):
frames = []
for frame_name in frame_names:
frame_path = os.path.join(clip_path, frame_name)
- frame = read_image(frame_path)
- # Transforming using CPU is not efficient
- frame = self.frame_transform(frame.to(self.device))
- frames.append(frame.cpu())
+ frame = Image.open(frame_path)
+ frame = self.frame_transform(frame)
+ frames.append(frame)
clip = torch.stack(frames)
return clip