summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-07 21:37:33 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-07 21:37:33 +0800
commit7825f978f198e56958703f0d08f7ccbd8cef49ca (patch)
tree8510a2d3484143aec84618cbe130f7a9d7ae4596
parent46e68c1d0168816107fd9997e1d948d3f403f5ee (diff)
parentdd12098603ac415904b9a8d512889deb995a8391 (diff)
Merge branch 'master' into python3.8
# Conflicts: # models/model.py
-rw-r--r--models/model.py8
-rw-r--r--train.py14
2 files changed, 18 insertions, 4 deletions
diff --git a/models/model.py b/models/model.py
index 725988a..80d4499 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,5 +1,5 @@
import os
-from typing import Union, Optional, Tuple, List, Dict
+from typing import Union, Optional, Tuple, List, Dict, Set
import numpy as np
import torch
@@ -220,13 +220,15 @@ class Model:
return self._gen_sig(list(_config.values()))
- def _gen_sig(self, values: Union[Tuple, List, str, int, float]) -> str:
+ def _gen_sig(self, values: Union[Tuple, List, Set, str, int, float]) -> str:
strings = []
for v in values:
if isinstance(v, str):
strings.append(v)
- elif isinstance(v, (Tuple, List)):
+ elif isinstance(v, (Tuple, List, Set)):
strings.append(self._gen_sig(v))
+ elif isinstance(v, Dict):
+ strings.append(self._gen_sig(list(v.values())))
else:
strings.append(str(v))
return '_'.join(strings)
diff --git a/train.py b/train.py
index 17cd0f6..d921839 100644
--- a/train.py
+++ b/train.py
@@ -2,6 +2,7 @@ import os
from config import config
from models import Model
+from utils.dataset import ClipConditions
# Set environment variable CUDA device(s)
CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None)
@@ -9,4 +10,15 @@ if CUDA_VISIBLE_DEVICES:
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
model = Model(config['system'], config['model'], config['hyperparameter'])
-model.fit(config['dataset'], config['dataloader'])
+
+# 3 models for different conditions
+dataset_selectors = [
+ {'conditions': ClipConditions({r'nm-0\d'})},
+ {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
+ {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
+]
+for selector in dataset_selectors:
+ model.fit(
+ dict(**config['dataset'], **{'selector': selector}),
+ config['dataloader']
+ )