summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..17cd0f6
--- /dev/null
+++ b/train.py
@@ -0,0 +1,12 @@
+import os
+
+from config import config
+from models import Model
+
+# Set environment variable CUDA device(s)
+CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None)
+if CUDA_VISIBLE_DEVICES:
+ os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
+
+model = Model(config['system'], config['model'], config['hyperparameter'])
+model.fit(config['dataset'], config['dataloader'])