diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-12 11:29:02 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-12 11:29:02 +0800 |
commit | 966d4431c037b0c4641aa2a5fc22f05be064b331 (patch) | |
tree | 0239ba89d31857a7f086acf627fc1bbf167855a9 /train.py | |
parent | 7825f978f198e56958703f0d08f7ccbd8cef49ca (diff) | |
parent | 36cf502afe9b93efe31c244030270b0a62e644b8 (diff) |
Merge branch 'master' into python3.8
# Conflicts:
# models/model.py
Diffstat (limited to 'train.py')
-rw-r--r-- | train.py | 26 |
1 files changed, 9 insertions, 17 deletions
@@ -1,24 +1,16 @@ -import os - from config import config from models import Model from utils.dataset import ClipConditions +from utils.misc import set_visible_cuda -# Set environment variable CUDA device(s) -CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None) -if CUDA_VISIBLE_DEVICES: - os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES - +set_visible_cuda(config['system']) model = Model(config['system'], config['model'], config['hyperparameter']) # 3 models for different conditions -dataset_selectors = [ - {'conditions': ClipConditions({r'nm-0\d'})}, - {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})}, - {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})}, -] -for selector in dataset_selectors: - model.fit( - dict(**config['dataset'], **{'selector': selector}), - config['dataloader'] - ) +dataset_selectors = { + 'nm': {'conditions': ClipConditions({r'nm-0\d'})}, + 'bg': {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})}, + 'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})}, +} + +model.fit_all(config['dataset'], dataset_selectors, config['dataloader']) |