summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-07 18:37:43 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-07 18:37:43 +0800
commit4a284084c253b9114fc02e1782962556ff113761 (patch)
treed6ceff8da68b224186d84772ee6153353675bcfe /train.py
parenta27af5dfd58e7b48cf3bd063fa2b4b51ed1e0277 (diff)
Add typical training script and some bug fixes
1. Resolve deprecated scheduler stepping issue 2. Make losses in the same scale(replace mean with sum in separate triplet loss, enlarge pose similarity loss 10x) 3. Add ReLU when compute distance in triplet loss 4. Remove classes except Model from `models` package init
Diffstat (limited to 'train.py')
-rw-r--r--train.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..17cd0f6
--- /dev/null
+++ b/train.py
@@ -0,0 +1,12 @@
+import os
+
+from config import config
+from models import Model
+
+# Set environment variable CUDA device(s)
+CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None)
+if CUDA_VISIBLE_DEVICES:
+ os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
+
+model = Model(config['system'], config['model'], config['hyperparameter'])
+model.fit(config['dataset'], config['dataloader'])