diff options
Diffstat (limited to 'utils/dataset.py')
-rw-r--r-- | utils/dataset.py | 13 |
1 files changed, 5 insertions, 8 deletions
diff --git a/utils/dataset.py b/utils/dataset.py index 40d7968..ecdd2d9 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -1,7 +1,7 @@ import os import random import re -from typing import Optional, Dict, NewType, Union, List, Set +from typing import Optional, Dict, NewType, Union, List, Set, Tuple import numpy as np import torch @@ -29,8 +29,7 @@ class CASIAB(data.Dataset): str, Union[ClipLabels, ClipConditions, ClipLabels] ]] = None, num_input_channels: int = 3, - frame_height: int = 64, - frame_width: int = 32, + frame_size: Tuple[int, int] = (64, 32), cache_on: bool = False ): """ @@ -46,8 +45,7 @@ class CASIAB(data.Dataset): views. :param num_input_channels: The number of input channel(s), RBG image has 3 channels, grayscale image has 1 channel. - :param frame_height: Frame height after transforming. - :param frame_width: Frame width after transforming. + :param frame_size: Frame height and width after transforming. :param cache_on: Preload all clips in memory or not, this will increase loading speed, but will add a preload process and cost a lot of RAM. Loading the entire dataset @@ -62,13 +60,12 @@ class CASIAB(data.Dataset): self.num_sampled_frames = num_sampled_frames self.discard_threshold = discard_threshold self.num_input_channels = num_input_channels - self.frame_height = frame_height - self.frame_width = frame_width + self.frame_size = frame_size self.cache_on = cache_on self.frame_transform: transforms.Compose transform_compose_list = [ - transforms.Resize(size=(self.frame_height, self.frame_width)), + transforms.Resize(size=self.frame_size), transforms.ToTensor() ] if self.num_input_channels == 1: |