summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py107
1 files changed, 104 insertions, 3 deletions
diff --git a/models/model.py b/models/model.py
index cb0e756..ca1497f 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,15 +1,47 @@
-from typing import Union
+from typing import Union, Optional
import torch
+from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
+from utils.configuration import DataloaderConfiguration, \
+ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration
+from utils.dataset import CASIAB
+from utils.sampler import TripletSampler
+
class Model:
def __init__(
self,
- batch_size: tuple[int, int]
+ model_config: ModelConfiguration,
+ hyperparameter_config: HyperparameterConfiguration
):
- (self.pr, self.k) = batch_size
+ self.meta = model_config
+ self.hp = hyperparameter_config
+ self.curr_iter = self.meta['restore_iter']
+
+ self.is_train: bool = True
+ self.dataset_metadata: Optional[DatasetConfiguration] = None
+ self.pr: Optional[int] = None
+ self.k: Optional[int] = None
+
+ self._model_sig: str = self._make_signature(self.meta, ['restore_iter'])
+ self._hp_sig: str = self._make_signature(self.hp)
+ self._dataset_sig: str = 'undefined'
+
+ @property
+ def signature(self) -> str:
+ return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig,
+ self._dataset_sig, str(self.batch_size)))
+
+ @property
+ def batch_size(self) -> int:
+ if self.is_train:
+ if self.pr and self.k:
+ return self.pr * self.k
+ raise AttributeError('No dataset loaded')
+ else:
+ return 1
def _batch_splitter(
self,
@@ -30,3 +62,72 @@ class Model:
default_collate(batch[i + self.k:i + self.k * 2])))
return _batch
+
+ def fit(
+ self,
+ dataset_config: DatasetConfiguration,
+ dataloader_config: DataloaderConfiguration,
+ ):
+ self.is_train = True
+ dataset = self._parse_dataset_config(dataset_config)
+ dataloader = self._parse_dataloader_config(dataset, dataloader_config)
+ for iter_i, samples_batched in enumerate(dataloader):
+ for sub_i, (subject_c1, subject_c2) in enumerate(samples_batched):
+ pass
+
+ if sub_i == 0:
+ break
+ if iter_i == 0:
+ break
+
+ def _parse_dataset_config(
+ self,
+ dataset_config: DatasetConfiguration
+ ) -> Union[CASIAB]:
+ self._dataset_sig = self._make_signature(
+ dataset_config,
+ popped_keys=['root_dir', 'cache_on']
+ )
+
+ config: dict = dataset_config.copy()
+ name = config.pop('name')
+ if name == 'CASIA-B':
+ return CASIAB(**config, is_train=self.is_train)
+ elif name == 'FVG':
+ # TODO
+ pass
+ raise ValueError('Invalid dataset: {0}'.format(name))
+
+ def _parse_dataloader_config(
+ self,
+ dataset: Union[CASIAB],
+ dataloader_config: DataloaderConfiguration
+ ) -> DataLoader:
+ config: dict = dataloader_config.copy()
+ if self.is_train:
+ (self.pr, self.k) = config.pop('batch_size')
+ triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
+ return DataLoader(dataset,
+ batch_sampler=triplet_sampler,
+ collate_fn=self._batch_splitter,
+ **config)
+ else: # is_test
+ config.pop('batch_size')
+ return DataLoader(dataset, **config)
+
+ @staticmethod
+ def _make_signature(config: dict,
+ popped_keys: Optional[list] = None) -> str:
+ _config = config.copy()
+ for (key, value) in config.items():
+ if popped_keys and key in popped_keys:
+ _config.pop(key)
+ continue
+ if isinstance(value, str):
+ pass
+ elif isinstance(value, (tuple, list)):
+ _config[key] = '_'.join([str(v) for v in value])
+ else:
+ _config[key] = str(value)
+
+ return '_'.join(_config.values())