summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.py2
-rw-r--r--models/model.py2
-rw-r--r--utils/dataset.py17
3 files changed, 20 insertions, 1 deletions
diff --git a/config.py b/config.py
index 45a9068..2470826 100644
--- a/config.py
+++ b/config.py
@@ -19,6 +19,8 @@ config = {
'train_size': 74,
# Number of sampled frames per sequence (Training only)
'num_sampled_frames': 30,
+ # Truncate clips longer than `truncate_threshold`
+ 'truncate_threshold': 40,
# Discard clips shorter than `discard_threshold`
'discard_threshold': 15,
# Number of input channels of model
diff --git a/models/model.py b/models/model.py
index 5f079b8..eae15e3 100644
--- a/models/model.py
+++ b/models/model.py
@@ -174,6 +174,7 @@ class Model:
print('{0:5d} loss: {1:.3f}'.format(iter_, loss))
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
+ self.scheduler.load_state_dict(checkpoint['sched_state_dict'])
# Training start
start_time = datetime.now()
@@ -246,6 +247,7 @@ class Model:
'iter': self.curr_iter,
'model_state_dict': self.rgb_pn.state_dict(),
'optim_state_dict': self.optimizer.state_dict(),
+ 'sched_state_dict': self.scheduler.state_dict(),
'loss': loss,
}, self._checkpoint_name)
diff --git a/utils/dataset.py b/utils/dataset.py
index 6b2a991..72cf050 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -25,6 +25,7 @@ class CASIAB(data.Dataset):
is_train: bool = True,
train_size: int = 74,
num_sampled_frames: int = 30,
+ truncate_threshold: int = 40,
discard_threshold: int = 15,
selector: Optional[Dict[
str, Union[ClipClasses, ClipConditions, ClipViews]
@@ -40,6 +41,8 @@ class CASIAB(data.Dataset):
when `is_train` is False, test size will be inferred.
:param num_sampled_frames: The number of sampled frames.
(Training Only)
+ :param truncate_threshold: Truncate redundant frames larger
+ than this threshold. (Training Only)
:param discard_threshold: Discard the sample if its number of
frames is less than this threshold.
:param selector: Restrict output data classes, conditions and
@@ -58,6 +61,7 @@ class CASIAB(data.Dataset):
self._root_dir = root_dir
self._is_train = is_train
self._num_sampled_frames = num_sampled_frames
+ self._truncate_threshold = truncate_threshold
self._cache_on = cache_on
self._frame_transform: transforms.Compose
@@ -264,11 +268,22 @@ class CASIAB(data.Dataset):
if self._is_train:
num_frames = len(frame_names)
- # Sample frames without replace if have enough frames
+ # 1# Sample with replacement if less than num_sampled_frames
if num_frames < self._num_sampled_frames:
frame_names = random.choices(frame_names,
k=self._num_sampled_frames)
+ # 2# Sample without replacement if less than 40
+ elif self._num_sampled_frames \
+ <= num_frames < self._truncate_threshold:
+ frame_names = random.sample(frame_names,
+ k=self._num_sampled_frames)
+ # 3# Cut out an interval with truncate_threshold frames
else:
+ frame_name_0_max = num_frames - self._truncate_threshold
+ frame_name_0 = random.randint(0, frame_name_0_max)
+ frame_name_1 = frame_name_0 + self._truncate_threshold
+ frame_names = sorted(frame_names)
+ frame_names = frame_names[frame_name_0:frame_name_1]
frame_names = random.sample(frame_names,
k=self._num_sampled_frames)