summaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
Diffstat (limited to 'utils')
-rw-r--r--utils/dataset.py13
1 files changed, 5 insertions, 8 deletions
diff --git a/utils/dataset.py b/utils/dataset.py
index 40d7968..ecdd2d9 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -1,7 +1,7 @@
import os
import random
import re
-from typing import Optional, Dict, NewType, Union, List, Set
+from typing import Optional, Dict, NewType, Union, List, Set, Tuple
import numpy as np
import torch
@@ -29,8 +29,7 @@ class CASIAB(data.Dataset):
str, Union[ClipLabels, ClipConditions, ClipLabels]
]] = None,
num_input_channels: int = 3,
- frame_height: int = 64,
- frame_width: int = 32,
+ frame_size: Tuple[int, int] = (64, 32),
cache_on: bool = False
):
"""
@@ -46,8 +45,7 @@ class CASIAB(data.Dataset):
views.
:param num_input_channels: The number of input channel(s),
RBG image has 3 channels, grayscale image has 1 channel.
- :param frame_height: Frame height after transforming.
- :param frame_width: Frame width after transforming.
+ :param frame_size: Frame height and width after transforming.
:param cache_on: Preload all clips in memory or not, this will
increase loading speed, but will add a preload process and
cost a lot of RAM. Loading the entire dataset
@@ -62,13 +60,12 @@ class CASIAB(data.Dataset):
self.num_sampled_frames = num_sampled_frames
self.discard_threshold = discard_threshold
self.num_input_channels = num_input_channels
- self.frame_height = frame_height
- self.frame_width = frame_width
+ self.frame_size = frame_size
self.cache_on = cache_on
self.frame_transform: transforms.Compose
transform_compose_list = [
- transforms.Resize(size=(self.frame_height, self.frame_width)),
+ transforms.Resize(size=self.frame_size),
transforms.ToTensor()
]
if self.num_input_channels == 1: