summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-12 11:29:02 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-12 11:29:02 +0800
commit966d4431c037b0c4641aa2a5fc22f05be064b331 (patch)
tree0239ba89d31857a7f086acf627fc1bbf167855a9 /train.py
parent7825f978f198e56958703f0d08f7ccbd8cef49ca (diff)
parent36cf502afe9b93efe31c244030270b0a62e644b8 (diff)
Merge branch 'master' into python3.8
# Conflicts: # models/model.py
Diffstat (limited to 'train.py')
-rw-r--r--train.py26
1 files changed, 9 insertions, 17 deletions
diff --git a/train.py b/train.py
index d921839..d91dcd0 100644
--- a/train.py
+++ b/train.py
@@ -1,24 +1,16 @@
-import os
-
from config import config
from models import Model
from utils.dataset import ClipConditions
+from utils.misc import set_visible_cuda
-# Set environment variable CUDA device(s)
-CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None)
-if CUDA_VISIBLE_DEVICES:
- os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
-
+set_visible_cuda(config['system'])
model = Model(config['system'], config['model'], config['hyperparameter'])
# 3 models for different conditions
-dataset_selectors = [
- {'conditions': ClipConditions({r'nm-0\d'})},
- {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
- {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
-]
-for selector in dataset_selectors:
- model.fit(
- dict(**config['dataset'], **{'selector': selector}),
- config['dataloader']
- )
+dataset_selectors = {
+ 'nm': {'conditions': ClipConditions({r'nm-0\d'})},
+ 'bg': {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
+ 'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
+}
+
+model.fit_all(config['dataset'], dataset_selectors, config['dataloader'])