diff options
Diffstat (limited to 'train.py')
-rw-r--r-- | train.py | 16 |
1 files changed, 4 insertions, 12 deletions
@@ -1,14 +1,9 @@ -import os - from config import config from models import Model from utils.dataset import ClipConditions +from utils.misc import set_visible_cuda -# Set environment variable CUDA device(s) -CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None) -if CUDA_VISIBLE_DEVICES: - os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES - +set_visible_cuda(config['system']) model = Model(config['system'], config['model'], config['hyperparameter']) # 3 models for different conditions @@ -17,8 +12,5 @@ dataset_selectors = { 'bg': {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})}, 'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})}, } -for selector in dataset_selectors.values(): - model.fit( - dict(**config['dataset'], **{'selector': selector}), - config['dataloader'] - ) + +model.fit_all(config['dataset'], dataset_selectors, config['dataloader']) |