diff options
Diffstat (limited to 'train.py')
-rw-r--r-- | train.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/train.py b/train.py new file mode 100644 index 0000000..17cd0f6 --- /dev/null +++ b/train.py @@ -0,0 +1,12 @@ +import os + +from config import config +from models import Model + +# Set environment variable CUDA device(s) +CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None) +if CUDA_VISIBLE_DEVICES: + os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES + +model = Model(config['system'], config['model'], config['hyperparameter']) +model.fit(config['dataset'], config['dataloader']) |