diff options
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 107 |
1 files changed, 104 insertions, 3 deletions
diff --git a/models/model.py b/models/model.py index cb0e756..ca1497f 100644 --- a/models/model.py +++ b/models/model.py @@ -1,15 +1,47 @@ -from typing import Union +from typing import Union, Optional import torch +from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate +from utils.configuration import DataloaderConfiguration, \ + HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration +from utils.dataset import CASIAB +from utils.sampler import TripletSampler + class Model: def __init__( self, - batch_size: tuple[int, int] + model_config: ModelConfiguration, + hyperparameter_config: HyperparameterConfiguration ): - (self.pr, self.k) = batch_size + self.meta = model_config + self.hp = hyperparameter_config + self.curr_iter = self.meta['restore_iter'] + + self.is_train: bool = True + self.dataset_metadata: Optional[DatasetConfiguration] = None + self.pr: Optional[int] = None + self.k: Optional[int] = None + + self._model_sig: str = self._make_signature(self.meta, ['restore_iter']) + self._hp_sig: str = self._make_signature(self.hp) + self._dataset_sig: str = 'undefined' + + @property + def signature(self) -> str: + return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig, + self._dataset_sig, str(self.batch_size))) + + @property + def batch_size(self) -> int: + if self.is_train: + if self.pr and self.k: + return self.pr * self.k + raise AttributeError('No dataset loaded') + else: + return 1 def _batch_splitter( self, @@ -30,3 +62,72 @@ class Model: default_collate(batch[i + self.k:i + self.k * 2]))) return _batch + + def fit( + self, + dataset_config: DatasetConfiguration, + dataloader_config: DataloaderConfiguration, + ): + self.is_train = True + dataset = self._parse_dataset_config(dataset_config) + dataloader = self._parse_dataloader_config(dataset, dataloader_config) + for iter_i, samples_batched in enumerate(dataloader): + for sub_i, (subject_c1, subject_c2) in enumerate(samples_batched): + pass + + if sub_i == 0: + break + if iter_i == 0: + break + + def _parse_dataset_config( + self, + dataset_config: DatasetConfiguration + ) -> Union[CASIAB]: + self._dataset_sig = self._make_signature( + dataset_config, + popped_keys=['root_dir', 'cache_on'] + ) + + config: dict = dataset_config.copy() + name = config.pop('name') + if name == 'CASIA-B': + return CASIAB(**config, is_train=self.is_train) + elif name == 'FVG': + # TODO + pass + raise ValueError('Invalid dataset: {0}'.format(name)) + + def _parse_dataloader_config( + self, + dataset: Union[CASIAB], + dataloader_config: DataloaderConfiguration + ) -> DataLoader: + config: dict = dataloader_config.copy() + if self.is_train: + (self.pr, self.k) = config.pop('batch_size') + triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) + return DataLoader(dataset, + batch_sampler=triplet_sampler, + collate_fn=self._batch_splitter, + **config) + else: # is_test + config.pop('batch_size') + return DataLoader(dataset, **config) + + @staticmethod + def _make_signature(config: dict, + popped_keys: Optional[list] = None) -> str: + _config = config.copy() + for (key, value) in config.items(): + if popped_keys and key in popped_keys: + _config.pop(key) + continue + if isinstance(value, str): + pass + elif isinstance(value, (tuple, list)): + _config[key] = '_'.join([str(v) for v in value]) + else: + _config[key] = str(value) + + return '_'.join(_config.values()) |