summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-10 13:45:38 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-10 13:45:38 +0800
commitba705ac60f463b034782ba03796dbfa2776a5631 (patch)
tree984a50254777583fd384bd4d72acf12843ad8de0
parent58ef39d75098bce92654492e09edf1e83033d0c8 (diff)
parent2f7c09fb4fb985db1cbf6e2bdc6622a2c51ebfc3 (diff)
Merge branch 'master' into python3.8
-rw-r--r--config.py2
-rw-r--r--models/model.py2
-rw-r--r--utils/configuration.py1
-rw-r--r--utils/dataset.py17
4 files changed, 21 insertions, 1 deletions
diff --git a/config.py b/config.py
index 04a22b9..547d1a3 100644
--- a/config.py
+++ b/config.py
@@ -21,6 +21,8 @@ config: Configuration = {
'train_size': 74,
# Number of sampled frames per sequence (Training only)
'num_sampled_frames': 30,
+ # Truncate clips longer than `truncate_threshold`
+ 'truncate_threshold': 40,
# Discard clips shorter than `discard_threshold`
'discard_threshold': 15,
# Number of input channels of model
diff --git a/models/model.py b/models/model.py
index b5daa54..a8843c5 100644
--- a/models/model.py
+++ b/models/model.py
@@ -177,6 +177,7 @@ class Model:
print('{0:5d} loss: {1:.3f}'.format(iter_, loss))
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
+ self.scheduler.load_state_dict(checkpoint['sched_state_dict'])
# Training start
start_time = datetime.now()
@@ -249,6 +250,7 @@ class Model:
'iter': self.curr_iter,
'model_state_dict': self.rgb_pn.state_dict(),
'optim_state_dict': self.optimizer.state_dict(),
+ 'sched_state_dict': self.scheduler.state_dict(),
'loss': loss,
}, self._checkpoint_name)
diff --git a/utils/configuration.py b/utils/configuration.py
index f44bcf0..b9e6d92 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -15,6 +15,7 @@ class DatasetConfiguration(TypedDict):
root_dir: str
train_size: int
num_sampled_frames: int
+ truncate_threshold: int
discard_threshold: int
selector: Optional[Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]]
num_input_channels: int
diff --git a/utils/dataset.py b/utils/dataset.py
index 6b2a991..72cf050 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -25,6 +25,7 @@ class CASIAB(data.Dataset):
is_train: bool = True,
train_size: int = 74,
num_sampled_frames: int = 30,
+ truncate_threshold: int = 40,
discard_threshold: int = 15,
selector: Optional[Dict[
str, Union[ClipClasses, ClipConditions, ClipViews]
@@ -40,6 +41,8 @@ class CASIAB(data.Dataset):
when `is_train` is False, test size will be inferred.
:param num_sampled_frames: The number of sampled frames.
(Training Only)
+ :param truncate_threshold: Truncate redundant frames larger
+ than this threshold. (Training Only)
:param discard_threshold: Discard the sample if its number of
frames is less than this threshold.
:param selector: Restrict output data classes, conditions and
@@ -58,6 +61,7 @@ class CASIAB(data.Dataset):
self._root_dir = root_dir
self._is_train = is_train
self._num_sampled_frames = num_sampled_frames
+ self._truncate_threshold = truncate_threshold
self._cache_on = cache_on
self._frame_transform: transforms.Compose
@@ -264,11 +268,22 @@ class CASIAB(data.Dataset):
if self._is_train:
num_frames = len(frame_names)
- # Sample frames without replace if have enough frames
+ # 1# Sample with replacement if less than num_sampled_frames
if num_frames < self._num_sampled_frames:
frame_names = random.choices(frame_names,
k=self._num_sampled_frames)
+ # 2# Sample without replacement if less than 40
+ elif self._num_sampled_frames \
+ <= num_frames < self._truncate_threshold:
+ frame_names = random.sample(frame_names,
+ k=self._num_sampled_frames)
+ # 3# Cut out an interval with truncate_threshold frames
else:
+ frame_name_0_max = num_frames - self._truncate_threshold
+ frame_name_0 = random.randint(0, frame_name_0_max)
+ frame_name_1 = frame_name_0 + self._truncate_threshold
+ frame_names = sorted(frame_names)
+ frame_names = frame_names[frame_name_0:frame_name_1]
frame_names = random.sample(frame_names,
k=self._num_sampled_frames)