summaryrefslogtreecommitdiff
path: root/config.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-26 20:26:48 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-26 20:26:48 +0800
commitd5f7cdab1466566d805f9cbf81c05767880886ae (patch)
tree99a7d95582518e9ffdf44f67fa9691daaa46ee90 /config.py
parente7bac6ab39b9abbbcbdcf10d565df4863510f0d9 (diff)
Add config file and corresponding type hint
Diffstat (limited to 'config.py')
-rw-r--r--config.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..0d98608
--- /dev/null
+++ b/config.py
@@ -0,0 +1,67 @@
+import torch
+
+from utils.configuration import Configuration
+
+config: Configuration = {
+ 'system': {
+ # Device(s) used in training and testing (CPU or CUDA)
+ 'device': torch.device('cuda'),
+ # GPU(s) used in training or testing, if CUDA enabled
+ 'CUDA_VISIBLE_DEVICES': '0',
+ # Directory used in training or testing for temporary storage
+ 'save_path': 'runs',
+ },
+ # Dataset settings
+ 'dataset': {
+ # Name of dataset (CASIA-B or FVG)
+ 'name': 'CASIA-B',
+ # Path to dataset root
+ 'path': 'dataset/output/CASIA-B',
+ # The number of subjects for training
+ 'train_size': 74,
+ # Number of sampled frames per sequence (Training only)
+ 'num_sampled_frames': 30,
+ # Discard clips shorter than `discard_threshold`
+ 'discard_threshold': 15,
+ # Number of input channels of model
+ 'num_input_channels': 3,
+ # Resolution after resize, height : width should be 2 : 1
+ 'frame_size': (64, 32),
+ # Cache dataset or not
+ 'cache_on': False,
+ },
+ # Dataloader settings
+ 'dataloader': {
+ # Batch size (pr, k)
+ # `pr` denotes number of persons
+ # `k` denotes number of sequences per person
+ 'batch_size': (8, 16),
+ # Number of workers of Dataloader
+ 'num_workers': 4,
+ # Faster data transfer from RAM to GPU if enabled
+ 'pin_memory': True,
+ },
+ # Hyperparameter tuning
+ 'hyperparameter': {
+ # Hidden dimension of FC
+ 'hidden_dim': 256,
+ # Initial learning rate of Adam Optimizer
+ 'lr': 1e-4,
+ # Betas of Adam Optimizer
+ 'betas': (0.9, 0.999),
+ # Batch Hard or Batch Full Triplet loss
+ # `hard` for BH, `all` for BA
+ 'hard_or_all': 'all',
+ # Triplet loss margin
+ 'margin': 0.2,
+ },
+ # Model metadata
+ 'model': {
+ # Model name, used for naming checkpoint
+ 'name': 'RGB-GaitPart',
+ # Restoration iteration from checkpoint
+ 'restore_iter': 0,
+ # Total iteration for training
+ 'total_iter': 80000,
+ },
+}