diff options
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 65 |
1 files changed, 42 insertions, 23 deletions
diff --git a/models/model.py b/models/model.py index e9714b8..c407d6c 100644 --- a/models/model.py +++ b/models/model.py @@ -2,9 +2,11 @@ from typing import Union, Optional import numpy as np import torch +import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate +from models import RGBPartNet from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration from utils.dataset import CASIAB @@ -22,7 +24,8 @@ class Model: self.curr_iter = self.meta['restore_iter'] self.is_train: bool = True - self.dataset_metadata: Optional[DatasetConfiguration] = None + self.train_size: int = 74 + self.in_channels: int = 3 self.pr: Optional[int] = None self.k: Optional[int] = None @@ -30,6 +33,10 @@ class Model: self._hp_sig: str = self._make_signature(self.hp) self._dataset_sig: str = 'undefined' + self.rbg_pn: Optional[RGBPartNet] = None + self.optimizer: Optional[optim.Adam] = None + self.scheduler: Optional[optim.lr_scheduler.StepLR] = None + @property def signature(self) -> str: return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig, @@ -44,23 +51,6 @@ class Model: else: return 1 - def _batch_splitter( - self, - batch: list[dict[str, Union[np.int64, str, torch.Tensor]]] - ) -> tuple[dict[str, Union[list[str], torch.Tensor]], - dict[str, Union[list[str], torch.Tensor]]]: - """ - Disentanglement need two random conditions, this function will - split pr * k * 2 samples to 2 dicts each containing pr * k - samples. labels and clip data are tensor, and others are list. - """ - _batch = [[], []] - for i in range(0, self.pr * self.k * 2, self.k * 2): - _batch[0] += batch[i:i + self.k] - _batch[1] += batch[i + self.k:i + self.k * 2] - - return default_collate(_batch[0]), default_collate(_batch[1]) - def fit( self, dataset_config: DatasetConfiguration, @@ -69,21 +59,33 @@ class Model: self.is_train = True dataset = self._parse_dataset_config(dataset_config) dataloader = self._parse_dataloader_config(dataset, dataloader_config) - for iter_i, (samples_c1, samples_c2) in enumerate(dataloader): - pass - - if iter_i == 0: + # Prepare for model, optimizer and scheduler + hp = self.hp.copy() + lr, betas = hp.pop('lr', 1e-4), hp.pop('betas', (0.9, 0.999)) + self.rbg_pn = RGBPartNet(self.train_size, self.in_channels, **hp) + self.optimizer = optim.Adam(self.rbg_pn.parameters(), lr, betas) + self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9) + + self.rbg_pn.train() + for iter_i, (x_c1, x_c2) in enumerate(dataloader): + loss = self.rbg_pn(x_c1['clip'], x_c2['clip'], x_c1['label']) + loss.backward() + self.optimizer.step() + self.scheduler.step(iter_i) + + if iter_i == self.meta['total_iter']: break def _parse_dataset_config( self, dataset_config: DatasetConfiguration ) -> Union[CASIAB]: + self.train_size = dataset_config['train_size'] + self.in_channels = dataset_config['num_input_channels'] self._dataset_sig = self._make_signature( dataset_config, popped_keys=['root_dir', 'cache_on'] ) - config: dict = dataset_config.copy() name = config.pop('name') if name == 'CASIA-B': @@ -110,6 +112,23 @@ class Model: config.pop('batch_size') return DataLoader(dataset, **config) + def _batch_splitter( + self, + batch: list[dict[str, Union[np.int64, str, torch.Tensor]]] + ) -> tuple[dict[str, Union[list[str], torch.Tensor]], + dict[str, Union[list[str], torch.Tensor]]]: + """ + Disentanglement need two random conditions, this function will + split pr * k * 2 samples to 2 dicts each containing pr * k + samples. labels and clip data are tensor, and others are list. + """ + _batch = [[], []] + for i in range(0, self.pr * self.k * 2, self.k * 2): + _batch[0] += batch[i:i + self.k] + _batch[1] += batch[i + self.k:i + self.k * 2] + + return default_collate(_batch[0]), default_collate(_batch[1]) + @staticmethod def _make_signature(config: dict, popped_keys: Optional[list] = None) -> str: |