summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-25 12:23:52 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-25 12:23:52 +0800
commit6aad65a708f518073c07b3f896d0107a17198d14 (patch)
tree1769491b1f54a7509663a7793dea99db2cb748bf
parent1bf22d4d440b1c5843a7a968ad7dc3f434c84771 (diff)
parentf065f227b0bcdac61db9240f7df6ea2b748a89b7 (diff)
Merge branch 'python3.8' into python3.7
# Conflicts: # utils/configuration.py
-rw-r--r--config.py6
-rw-r--r--models/model.py8
2 files changed, 7 insertions, 7 deletions
diff --git a/config.py b/config.py
index b9df470..25846a2 100644
--- a/config.py
+++ b/config.py
@@ -7,7 +7,9 @@ config = {
# Directory used in training or testing for temporary storage
'save_dir': 'runs',
# Recorde disentangled image or not
- 'image_log_on': False
+ 'image_log_on': False,
+ # The number of subjects for validating (Part of testing set)
+ 'val_size': 10,
},
# Dataset settings
'dataset': {
@@ -17,8 +19,6 @@ config = {
'root_dir': 'data/CASIA-B-MRCNN-V2/SEG',
# The number of subjects for training
'train_size': 74,
- # The number of subjects for validating (Part of testing set)
- 'val_size': 10,
# Number of sampled frames per sequence (Training only)
'num_sampled_frames': 30,
# Truncate clips longer than `truncate_threshold`
diff --git a/models/model.py b/models/model.py
index 7eaaaf0..78a9c0f 100644
--- a/models/model.py
+++ b/models/model.py
@@ -76,6 +76,7 @@ class Model:
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
self.image_log_on = system_config.get('image_log_on', False)
+ self.val_size = system_config.get('val_size', 10)
self.CASIAB_GALLERY_SELECTOR = {
'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
@@ -144,13 +145,12 @@ class Model:
self.is_train = True
# Validation dataset
# (the first `val_size` subjects from evaluation set)
- val_size = dataset_config.pop('val_size', 10)
val_dataset_config = copy.deepcopy(dataset_config)
train_size = dataset_config.get('train_size', 74)
- val_dataset_config['train_size'] = train_size + val_size
+ val_dataset_config['train_size'] = train_size + self.val_size
val_dataset_config['selector']['classes'] = ClipClasses({
str(c).zfill(3)
- for c in range(train_size + 1, train_size + val_size + 1)
+ for c in range(train_size + 1, train_size + self.val_size + 1)
})
val_dataset = self._parse_dataset_config(val_dataset_config)
val_dataloader = iter(self._parse_dataloader_config(
@@ -566,7 +566,7 @@ class Model:
for (iter_, total_iter, (condition, selector)) in zip(
iters, self.total_iters, dataset_selectors.items()
):
- self.curr_iter = iter_
+ self.curr_iter = iter_ - 1
self.total_iter = total_iter
self._dataset_sig = self._make_signature(
dict(**dataset_config, **selector),