summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py65
1 files changed, 42 insertions, 23 deletions
diff --git a/models/model.py b/models/model.py
index e9714b8..c407d6c 100644
--- a/models/model.py
+++ b/models/model.py
@@ -2,9 +2,11 @@ from typing import Union, Optional
import numpy as np
import torch
+import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
+from models import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration
from utils.dataset import CASIAB
@@ -22,7 +24,8 @@ class Model:
self.curr_iter = self.meta['restore_iter']
self.is_train: bool = True
- self.dataset_metadata: Optional[DatasetConfiguration] = None
+ self.train_size: int = 74
+ self.in_channels: int = 3
self.pr: Optional[int] = None
self.k: Optional[int] = None
@@ -30,6 +33,10 @@ class Model:
self._hp_sig: str = self._make_signature(self.hp)
self._dataset_sig: str = 'undefined'
+ self.rbg_pn: Optional[RGBPartNet] = None
+ self.optimizer: Optional[optim.Adam] = None
+ self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
+
@property
def signature(self) -> str:
return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig,
@@ -44,23 +51,6 @@ class Model:
else:
return 1
- def _batch_splitter(
- self,
- batch: list[dict[str, Union[np.int64, str, torch.Tensor]]]
- ) -> tuple[dict[str, Union[list[str], torch.Tensor]],
- dict[str, Union[list[str], torch.Tensor]]]:
- """
- Disentanglement need two random conditions, this function will
- split pr * k * 2 samples to 2 dicts each containing pr * k
- samples. labels and clip data are tensor, and others are list.
- """
- _batch = [[], []]
- for i in range(0, self.pr * self.k * 2, self.k * 2):
- _batch[0] += batch[i:i + self.k]
- _batch[1] += batch[i + self.k:i + self.k * 2]
-
- return default_collate(_batch[0]), default_collate(_batch[1])
-
def fit(
self,
dataset_config: DatasetConfiguration,
@@ -69,21 +59,33 @@ class Model:
self.is_train = True
dataset = self._parse_dataset_config(dataset_config)
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
- for iter_i, (samples_c1, samples_c2) in enumerate(dataloader):
- pass
-
- if iter_i == 0:
+ # Prepare for model, optimizer and scheduler
+ hp = self.hp.copy()
+ lr, betas = hp.pop('lr', 1e-4), hp.pop('betas', (0.9, 0.999))
+ self.rbg_pn = RGBPartNet(self.train_size, self.in_channels, **hp)
+ self.optimizer = optim.Adam(self.rbg_pn.parameters(), lr, betas)
+ self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
+
+ self.rbg_pn.train()
+ for iter_i, (x_c1, x_c2) in enumerate(dataloader):
+ loss = self.rbg_pn(x_c1['clip'], x_c2['clip'], x_c1['label'])
+ loss.backward()
+ self.optimizer.step()
+ self.scheduler.step(iter_i)
+
+ if iter_i == self.meta['total_iter']:
break
def _parse_dataset_config(
self,
dataset_config: DatasetConfiguration
) -> Union[CASIAB]:
+ self.train_size = dataset_config['train_size']
+ self.in_channels = dataset_config['num_input_channels']
self._dataset_sig = self._make_signature(
dataset_config,
popped_keys=['root_dir', 'cache_on']
)
-
config: dict = dataset_config.copy()
name = config.pop('name')
if name == 'CASIA-B':
@@ -110,6 +112,23 @@ class Model:
config.pop('batch_size')
return DataLoader(dataset, **config)
+ def _batch_splitter(
+ self,
+ batch: list[dict[str, Union[np.int64, str, torch.Tensor]]]
+ ) -> tuple[dict[str, Union[list[str], torch.Tensor]],
+ dict[str, Union[list[str], torch.Tensor]]]:
+ """
+ Disentanglement need two random conditions, this function will
+ split pr * k * 2 samples to 2 dicts each containing pr * k
+ samples. labels and clip data are tensor, and others are list.
+ """
+ _batch = [[], []]
+ for i in range(0, self.pr * self.k * 2, self.k * 2):
+ _batch[0] += batch[i:i + self.k]
+ _batch[1] += batch[i + self.k:i + self.k * 2]
+
+ return default_collate(_batch[0]), default_collate(_batch[1])
+
@staticmethod
def _make_signature(config: dict,
popped_keys: Optional[list] = None) -> str: