diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-11 23:59:30 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-11 23:59:30 +0800 |
commit | 72a53806746bc7ffa2f3939721e34b5cfdb7330a (patch) | |
tree | 36c549aa32ed9e160381e47de6dbec045f6085cc /train.py | |
parent | 7188d71b2b6faf3da527c8d0ade9a32ec4893dc5 (diff) |
Add evaluation script, code review and fix some bugs
1. Add new `train_all` method for one shot calling
2. Print time used in 1k iterations
3. Correct label dimension in predict function
4. Transpose distance matrix for convenient indexing
5. Sort dictionary before generate signature
6. Extract visible CUDA setting function
Diffstat (limited to 'train.py')
-rw-r--r-- | train.py | 16 |
1 files changed, 4 insertions, 12 deletions
@@ -1,14 +1,9 @@ -import os - from config import config from models import Model from utils.dataset import ClipConditions +from utils.misc import set_visible_cuda -# Set environment variable CUDA device(s) -CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None) -if CUDA_VISIBLE_DEVICES: - os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES - +set_visible_cuda(config['system']) model = Model(config['system'], config['model'], config['hyperparameter']) # 3 models for different conditions @@ -17,8 +12,5 @@ dataset_selectors = { 'bg': {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})}, 'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})}, } -for selector in dataset_selectors.values(): - model.fit( - dict(**config['dataset'], **{'selector': selector}), - config['dataloader'] - ) + +model.fit_all(config['dataset'], dataset_selectors, config['dataloader']) |