summaryrefslogtreecommitdiff
path: root/config.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-04-03 23:07:23 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-04-03 23:07:23 +0800
commit258efcafe4d34ed5ffeebcaab9389f75a17e4717 (patch)
tree0f4ffe75990b63b8e17956eeec269e3589852769 /config.py
parent4049566103a00aa6d5a0b1f73569bdc5435714ca (diff)
parentf6f133fa7b926ce0c7d28bbf0ba4de41b3708d4a (diff)
Merge branch 'disentangling_only' into disentangling_only_py3.8
# Conflicts: # models/model.py
Diffstat (limited to 'config.py')
-rw-r--r--config.py23
1 files changed, 15 insertions, 8 deletions
diff --git a/config.py b/config.py
index afd40d5..dc4e0ba 100644
--- a/config.py
+++ b/config.py
@@ -9,7 +9,9 @@ config: Configuration = {
# Directory used in training or testing for temporary storage
'save_dir': 'runs/dis_only',
# Recorde disentangled image or not
- 'image_log_on': True
+ 'image_log_on': True,
+ # The number of subjects for validating (Part of testing set)
+ 'val_size': 10,
},
# Dataset settings
'dataset': {
@@ -37,7 +39,7 @@ config: Configuration = {
# Batch size (pr, k)
# `pr` denotes number of persons
# `k` denotes number of sequences per person
- 'batch_size': (2, 2),
+ 'batch_size': (4, 6),
# Number of workers of Dataloader
'num_workers': 4,
# Faster data transfer from RAM to GPU if enabled
@@ -61,15 +63,20 @@ config: Configuration = {
# Term added to the denominator
# 'eps': 1e-8,
# Weight decay (L2 penalty)
- # 'weight_decay': 0,
+ 'weight_decay': 0.001,
# Use AMSGrad or not
# 'amsgrad': False,
},
'scheduler': {
- # Period of learning rate decay
- 'step_size': 500,
- # Multiplicative factor of decay
- 'gamma': 0.9,
+ # Step start to decay
+ 'start_step': 500,
+ # Multiplicative factor of decay in the end
+ 'final_gamma': 0.01,
+
+ # Local parameters (override global ones)
+ # 'hpm': {
+ # 'final_gamma': 0.001
+ # }
}
},
# Model metadata
@@ -83,6 +90,6 @@ config: Configuration = {
# Restoration iteration (multiple models, e.g. nm, bg and cl)
'restore_iters': (0, 0, 0),
# Total iteration for training (multiple models)
- 'total_iters': (80_000, 80_000, 80_000),
+ 'total_iters': (30_000, 40_000, 60_000),
},
}