summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-07 20:54:44 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-07 20:54:44 +0800
commitdd12098603ac415904b9a8d512889deb995a8391 (patch)
tree368aba05057e8a4f1fafef7e4de65a48bd85b992 /train.py
parent4a284084c253b9114fc02e1782962556ff113761 (diff)
Train different models in different conditions
Diffstat (limited to 'train.py')
-rw-r--r--train.py14
1 files changed, 13 insertions, 1 deletions
diff --git a/train.py b/train.py
index 17cd0f6..d921839 100644
--- a/train.py
+++ b/train.py
@@ -2,6 +2,7 @@ import os
from config import config
from models import Model
+from utils.dataset import ClipConditions
# Set environment variable CUDA device(s)
CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None)
@@ -9,4 +10,15 @@ if CUDA_VISIBLE_DEVICES:
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
model = Model(config['system'], config['model'], config['hyperparameter'])
-model.fit(config['dataset'], config['dataloader'])
+
+# 3 models for different conditions
+dataset_selectors = [
+ {'conditions': ClipConditions({r'nm-0\d'})},
+ {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
+ {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
+]
+for selector in dataset_selectors:
+ model.fit(
+ dict(**config['dataset'], **{'selector': selector}),
+ config['dataloader']
+ )