From 4f81e3f44ed60459ddbd311cb7356932974c800d Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 10 Feb 2021 11:41:49 +0800 Subject: Save scheduler state_dict --- models/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/models/model.py b/models/model.py index 70c43b3..9cb6a8e 100644 --- a/models/model.py +++ b/models/model.py @@ -177,6 +177,7 @@ class Model: print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) + self.scheduler.load_state_dict(checkpoint['sched_state_dict']) # Training start start_time = datetime.now() @@ -249,6 +250,7 @@ class Model: 'iter': self.curr_iter, 'model_state_dict': self.rgb_pn.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), + 'sched_state_dict': self.scheduler.state_dict(), 'loss': loss, }, self._checkpoint_name) -- cgit v1.2.3 From 2f7c09fb4fb985db1cbf6e2bdc6622a2c51ebfc3 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Wed, 10 Feb 2021 13:43:15 +0800 Subject: Implement new sampling technique mentioned in GaitPart[1] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [1]C. Fan et al., “GaitPart: Temporal Part-Based Model for Gait Recognition,” in Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2020, pp. 14225–14233. --- config.py | 2 ++ utils/configuration.py | 1 + utils/dataset.py | 17 ++++++++++++++++- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/config.py b/config.py index 04a22b9..547d1a3 100644 --- a/config.py +++ b/config.py @@ -21,6 +21,8 @@ config: Configuration = { 'train_size': 74, # Number of sampled frames per sequence (Training only) 'num_sampled_frames': 30, + # Truncate clips longer than `truncate_threshold` + 'truncate_threshold': 40, # Discard clips shorter than `discard_threshold` 'discard_threshold': 15, # Number of input channels of model diff --git a/utils/configuration.py b/utils/configuration.py index 4ab1520..435d815 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -15,6 +15,7 @@ class DatasetConfiguration(TypedDict): root_dir: str train_size: int num_sampled_frames: int + truncate_threshold: int discard_threshold: int selector: Optional[dict[str, Union[ClipClasses, ClipConditions, ClipViews]]] num_input_channels: int diff --git a/utils/dataset.py b/utils/dataset.py index bbd42c3..c487988 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -25,6 +25,7 @@ class CASIAB(data.Dataset): is_train: bool = True, train_size: int = 74, num_sampled_frames: int = 30, + truncate_threshold: int = 40, discard_threshold: int = 15, selector: Optional[dict[ str, Union[ClipClasses, ClipConditions, ClipViews] @@ -40,6 +41,8 @@ class CASIAB(data.Dataset): when `is_train` is False, test size will be inferred. :param num_sampled_frames: The number of sampled frames. (Training Only) + :param truncate_threshold: Truncate redundant frames larger + than this threshold. (Training Only) :param discard_threshold: Discard the sample if its number of frames is less than this threshold. :param selector: Restrict output data classes, conditions and @@ -58,6 +61,7 @@ class CASIAB(data.Dataset): self._root_dir = root_dir self._is_train = is_train self._num_sampled_frames = num_sampled_frames + self._truncate_threshold = truncate_threshold self._cache_on = cache_on self._frame_transform: transforms.Compose @@ -264,11 +268,22 @@ class CASIAB(data.Dataset): if self._is_train: num_frames = len(frame_names) - # Sample frames without replace if have enough frames + # 1# Sample with replacement if less than num_sampled_frames if num_frames < self._num_sampled_frames: frame_names = random.choices(frame_names, k=self._num_sampled_frames) + # 2# Sample without replacement if less than 40 + elif self._num_sampled_frames \ + <= num_frames < self._truncate_threshold: + frame_names = random.sample(frame_names, + k=self._num_sampled_frames) + # 3# Cut out an interval with truncate_threshold frames else: + frame_name_0_max = num_frames - self._truncate_threshold + frame_name_0 = random.randint(0, frame_name_0_max) + frame_name_1 = frame_name_0 + self._truncate_threshold + frame_names = sorted(frame_names) + frame_names = frame_names[frame_name_0:frame_name_1] frame_names = random.sample(frame_names, k=self._num_sampled_frames) -- cgit v1.2.3