From d5f7cdab1466566d805f9cbf81c05767880886ae Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 26 Dec 2020 20:26:48 +0800 Subject: Add config file and corresponding type hint --- config.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++ utils/configuration.py | 48 ++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 config.py create mode 100644 utils/configuration.py diff --git a/config.py b/config.py new file mode 100644 index 0000000..0d98608 --- /dev/null +++ b/config.py @@ -0,0 +1,67 @@ +import torch + +from utils.configuration import Configuration + +config: Configuration = { + 'system': { + # Device(s) used in training and testing (CPU or CUDA) + 'device': torch.device('cuda'), + # GPU(s) used in training or testing, if CUDA enabled + 'CUDA_VISIBLE_DEVICES': '0', + # Directory used in training or testing for temporary storage + 'save_path': 'runs', + }, + # Dataset settings + 'dataset': { + # Name of dataset (CASIA-B or FVG) + 'name': 'CASIA-B', + # Path to dataset root + 'path': 'dataset/output/CASIA-B', + # The number of subjects for training + 'train_size': 74, + # Number of sampled frames per sequence (Training only) + 'num_sampled_frames': 30, + # Discard clips shorter than `discard_threshold` + 'discard_threshold': 15, + # Number of input channels of model + 'num_input_channels': 3, + # Resolution after resize, height : width should be 2 : 1 + 'frame_size': (64, 32), + # Cache dataset or not + 'cache_on': False, + }, + # Dataloader settings + 'dataloader': { + # Batch size (pr, k) + # `pr` denotes number of persons + # `k` denotes number of sequences per person + 'batch_size': (8, 16), + # Number of workers of Dataloader + 'num_workers': 4, + # Faster data transfer from RAM to GPU if enabled + 'pin_memory': True, + }, + # Hyperparameter tuning + 'hyperparameter': { + # Hidden dimension of FC + 'hidden_dim': 256, + # Initial learning rate of Adam Optimizer + 'lr': 1e-4, + # Betas of Adam Optimizer + 'betas': (0.9, 0.999), + # Batch Hard or Batch Full Triplet loss + # `hard` for BH, `all` for BA + 'hard_or_all': 'all', + # Triplet loss margin + 'margin': 0.2, + }, + # Model metadata + 'model': { + # Model name, used for naming checkpoint + 'name': 'RGB-GaitPart', + # Restoration iteration from checkpoint + 'restore_iter': 0, + # Total iteration for training + 'total_iter': 80000, + }, +} diff --git a/utils/configuration.py b/utils/configuration.py new file mode 100644 index 0000000..32b9bec --- /dev/null +++ b/utils/configuration.py @@ -0,0 +1,48 @@ +from typing import TypedDict, Tuple + +import torch + + +class SystemConfiguration(TypedDict): + device: torch.device + CUDA_VISIBLE_DEVICES: str + save_path: str + + +class DatasetConfiguration(TypedDict): + name: str + path: str + train_size: int + num_sampled_frames: int + discard_threshold: int + num_input_channels: int + frame_size: Tuple[int, int] + cache_on: bool + + +class DataloaderConfiguration(TypedDict): + batch_size: Tuple[int, int] + num_workers: int + pin_memory: bool + + +class HyperparameterConfiguration(TypedDict): + hidden_dim: int + lr: int + betas: Tuple[float, float] + hard_or_all: str + margin: float + + +class ModelConfiguration(TypedDict): + name: str + restore_iter: int + total_iter: int + + +class Configuration(TypedDict): + system: SystemConfiguration + dataset: DatasetConfiguration + dataloader: DataloaderConfiguration + hyperparameter: HyperparameterConfiguration + model: ModelConfiguration -- cgit v1.2.3