diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-07 21:37:33 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-07 21:37:33 +0800 |
commit | 7825f978f198e56958703f0d08f7ccbd8cef49ca (patch) | |
tree | 8510a2d3484143aec84618cbe130f7a9d7ae4596 | |
parent | 46e68c1d0168816107fd9997e1d948d3f403f5ee (diff) | |
parent | dd12098603ac415904b9a8d512889deb995a8391 (diff) |
Merge branch 'master' into python3.8
# Conflicts:
# models/model.py
-rw-r--r-- | models/model.py | 8 | ||||
-rw-r--r-- | train.py | 14 |
2 files changed, 18 insertions, 4 deletions
diff --git a/models/model.py b/models/model.py index 725988a..80d4499 100644 --- a/models/model.py +++ b/models/model.py @@ -1,5 +1,5 @@ import os -from typing import Union, Optional, Tuple, List, Dict +from typing import Union, Optional, Tuple, List, Dict, Set import numpy as np import torch @@ -220,13 +220,15 @@ class Model: return self._gen_sig(list(_config.values())) - def _gen_sig(self, values: Union[Tuple, List, str, int, float]) -> str: + def _gen_sig(self, values: Union[Tuple, List, Set, str, int, float]) -> str: strings = [] for v in values: if isinstance(v, str): strings.append(v) - elif isinstance(v, (Tuple, List)): + elif isinstance(v, (Tuple, List, Set)): strings.append(self._gen_sig(v)) + elif isinstance(v, Dict): + strings.append(self._gen_sig(list(v.values()))) else: strings.append(str(v)) return '_'.join(strings) @@ -2,6 +2,7 @@ import os from config import config from models import Model +from utils.dataset import ClipConditions # Set environment variable CUDA device(s) CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None) @@ -9,4 +10,15 @@ if CUDA_VISIBLE_DEVICES: os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES model = Model(config['system'], config['model'], config['hyperparameter']) -model.fit(config['dataset'], config['dataloader']) + +# 3 models for different conditions +dataset_selectors = [ + {'conditions': ClipConditions({r'nm-0\d'})}, + {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})}, + {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})}, +] +for selector in dataset_selectors: + model.fit( + dict(**config['dataset'], **{'selector': selector}), + config['dataloader'] + ) |