summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-26 20:26:48 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-26 20:26:48 +0800
commitd5f7cdab1466566d805f9cbf81c05767880886ae (patch)
tree99a7d95582518e9ffdf44f67fa9691daaa46ee90
parente7bac6ab39b9abbbcbdcf10d565df4863510f0d9 (diff)
Add config file and corresponding type hint
-rw-r--r--config.py67
-rw-r--r--utils/configuration.py48
2 files changed, 115 insertions, 0 deletions
diff --git a/config.py b/config.py
new file mode 100644
index 0000000..0d98608
--- /dev/null
+++ b/config.py
@@ -0,0 +1,67 @@
+import torch
+
+from utils.configuration import Configuration
+
+config: Configuration = {
+ 'system': {
+ # Device(s) used in training and testing (CPU or CUDA)
+ 'device': torch.device('cuda'),
+ # GPU(s) used in training or testing, if CUDA enabled
+ 'CUDA_VISIBLE_DEVICES': '0',
+ # Directory used in training or testing for temporary storage
+ 'save_path': 'runs',
+ },
+ # Dataset settings
+ 'dataset': {
+ # Name of dataset (CASIA-B or FVG)
+ 'name': 'CASIA-B',
+ # Path to dataset root
+ 'path': 'dataset/output/CASIA-B',
+ # The number of subjects for training
+ 'train_size': 74,
+ # Number of sampled frames per sequence (Training only)
+ 'num_sampled_frames': 30,
+ # Discard clips shorter than `discard_threshold`
+ 'discard_threshold': 15,
+ # Number of input channels of model
+ 'num_input_channels': 3,
+ # Resolution after resize, height : width should be 2 : 1
+ 'frame_size': (64, 32),
+ # Cache dataset or not
+ 'cache_on': False,
+ },
+ # Dataloader settings
+ 'dataloader': {
+ # Batch size (pr, k)
+ # `pr` denotes number of persons
+ # `k` denotes number of sequences per person
+ 'batch_size': (8, 16),
+ # Number of workers of Dataloader
+ 'num_workers': 4,
+ # Faster data transfer from RAM to GPU if enabled
+ 'pin_memory': True,
+ },
+ # Hyperparameter tuning
+ 'hyperparameter': {
+ # Hidden dimension of FC
+ 'hidden_dim': 256,
+ # Initial learning rate of Adam Optimizer
+ 'lr': 1e-4,
+ # Betas of Adam Optimizer
+ 'betas': (0.9, 0.999),
+ # Batch Hard or Batch Full Triplet loss
+ # `hard` for BH, `all` for BA
+ 'hard_or_all': 'all',
+ # Triplet loss margin
+ 'margin': 0.2,
+ },
+ # Model metadata
+ 'model': {
+ # Model name, used for naming checkpoint
+ 'name': 'RGB-GaitPart',
+ # Restoration iteration from checkpoint
+ 'restore_iter': 0,
+ # Total iteration for training
+ 'total_iter': 80000,
+ },
+}
diff --git a/utils/configuration.py b/utils/configuration.py
new file mode 100644
index 0000000..32b9bec
--- /dev/null
+++ b/utils/configuration.py
@@ -0,0 +1,48 @@
+from typing import TypedDict, Tuple
+
+import torch
+
+
+class SystemConfiguration(TypedDict):
+ device: torch.device
+ CUDA_VISIBLE_DEVICES: str
+ save_path: str
+
+
+class DatasetConfiguration(TypedDict):
+ name: str
+ path: str
+ train_size: int
+ num_sampled_frames: int
+ discard_threshold: int
+ num_input_channels: int
+ frame_size: Tuple[int, int]
+ cache_on: bool
+
+
+class DataloaderConfiguration(TypedDict):
+ batch_size: Tuple[int, int]
+ num_workers: int
+ pin_memory: bool
+
+
+class HyperparameterConfiguration(TypedDict):
+ hidden_dim: int
+ lr: int
+ betas: Tuple[float, float]
+ hard_or_all: str
+ margin: float
+
+
+class ModelConfiguration(TypedDict):
+ name: str
+ restore_iter: int
+ total_iter: int
+
+
+class Configuration(TypedDict):
+ system: SystemConfiguration
+ dataset: DatasetConfiguration
+ dataloader: DataloaderConfiguration
+ hyperparameter: HyperparameterConfiguration
+ model: ModelConfiguration