summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.py32
-rw-r--r--models/__init__.py2
-rw-r--r--models/model.py65
-rw-r--r--models/rgb_part_net.py1
-rw-r--r--utils/configuration.py16
5 files changed, 81 insertions, 35 deletions
diff --git a/config.py b/config.py
index 634b9f3..ad737e8 100644
--- a/config.py
+++ b/config.py
@@ -43,17 +43,35 @@ config: Configuration = {
},
# Hyperparameter tuning
'hyperparameter': {
- # Hidden dimension of FC
- 'hidden_dim': 256,
+ # Auto-encoder feature channels coefficient
+ 'ae_feature_channels': 64,
+ # Appearance, canonical and pose feature dimensions
+ 'f_a_c_p_dims': (128, 128, 64),
+ # HPM pyramid scales, of which sum is number of parts
+ 'hpm_scales': (1, 2, 4),
+ # Global pooling method
+ 'hpm_use_avg_pool': True,
+ 'hpm_use_max_pool': True,
+ # FConv feature channels coefficient
+ 'fpfe_feature_channels': 32,
+ # FConv blocks kernel sizes
+ 'fpfe_kernel_sizes': ((5, 3), (3, 3), (3, 3)),
+ # FConv blocks paddings
+ 'fpfe_paddings': ((2, 1), (1, 1), (1, 1)),
+ # FConv blocks halving
+ 'fpfe_halving': (0, 2, 3),
+ # Attention squeeze ratio
+ 'tfa_squeeze_ratio': 4,
+ # Number of parts after Part Net
+ 'tfa_num_parts': 16,
+ # Embedding dimension for each part
+ 'embedding_dims': 256,
+ # Triplet loss margin
+ 'triplet_margin': 0.2,
# Initial learning rate of Adam Optimizer
'lr': 1e-4,
# Betas of Adam Optimizer
'betas': (0.9, 0.999),
- # Batch Hard or Batch Full Triplet loss
- # `hard` for BH, `all` for BA
- 'hard_or_all': 'all',
- # Triplet loss margin
- 'margin': 0.2,
},
# Model metadata
'model': {
diff --git a/models/__init__.py b/models/__init__.py
index 51c86af..c1b9fe8 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -1,4 +1,4 @@
-from .model import Model
from .auto_encoder import AutoEncoder
from .hpm import HorizontalPyramidMatching
from .part_net import PartNet
+from .rgb_part_net import RGBPartNet
diff --git a/models/model.py b/models/model.py
index e9714b8..c407d6c 100644
--- a/models/model.py
+++ b/models/model.py
@@ -2,9 +2,11 @@ from typing import Union, Optional
import numpy as np
import torch
+import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
+from models import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration
from utils.dataset import CASIAB
@@ -22,7 +24,8 @@ class Model:
self.curr_iter = self.meta['restore_iter']
self.is_train: bool = True
- self.dataset_metadata: Optional[DatasetConfiguration] = None
+ self.train_size: int = 74
+ self.in_channels: int = 3
self.pr: Optional[int] = None
self.k: Optional[int] = None
@@ -30,6 +33,10 @@ class Model:
self._hp_sig: str = self._make_signature(self.hp)
self._dataset_sig: str = 'undefined'
+ self.rbg_pn: Optional[RGBPartNet] = None
+ self.optimizer: Optional[optim.Adam] = None
+ self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
+
@property
def signature(self) -> str:
return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig,
@@ -44,23 +51,6 @@ class Model:
else:
return 1
- def _batch_splitter(
- self,
- batch: list[dict[str, Union[np.int64, str, torch.Tensor]]]
- ) -> tuple[dict[str, Union[list[str], torch.Tensor]],
- dict[str, Union[list[str], torch.Tensor]]]:
- """
- Disentanglement need two random conditions, this function will
- split pr * k * 2 samples to 2 dicts each containing pr * k
- samples. labels and clip data are tensor, and others are list.
- """
- _batch = [[], []]
- for i in range(0, self.pr * self.k * 2, self.k * 2):
- _batch[0] += batch[i:i + self.k]
- _batch[1] += batch[i + self.k:i + self.k * 2]
-
- return default_collate(_batch[0]), default_collate(_batch[1])
-
def fit(
self,
dataset_config: DatasetConfiguration,
@@ -69,21 +59,33 @@ class Model:
self.is_train = True
dataset = self._parse_dataset_config(dataset_config)
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
- for iter_i, (samples_c1, samples_c2) in enumerate(dataloader):
- pass
-
- if iter_i == 0:
+ # Prepare for model, optimizer and scheduler
+ hp = self.hp.copy()
+ lr, betas = hp.pop('lr', 1e-4), hp.pop('betas', (0.9, 0.999))
+ self.rbg_pn = RGBPartNet(self.train_size, self.in_channels, **hp)
+ self.optimizer = optim.Adam(self.rbg_pn.parameters(), lr, betas)
+ self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
+
+ self.rbg_pn.train()
+ for iter_i, (x_c1, x_c2) in enumerate(dataloader):
+ loss = self.rbg_pn(x_c1['clip'], x_c2['clip'], x_c1['label'])
+ loss.backward()
+ self.optimizer.step()
+ self.scheduler.step(iter_i)
+
+ if iter_i == self.meta['total_iter']:
break
def _parse_dataset_config(
self,
dataset_config: DatasetConfiguration
) -> Union[CASIAB]:
+ self.train_size = dataset_config['train_size']
+ self.in_channels = dataset_config['num_input_channels']
self._dataset_sig = self._make_signature(
dataset_config,
popped_keys=['root_dir', 'cache_on']
)
-
config: dict = dataset_config.copy()
name = config.pop('name')
if name == 'CASIA-B':
@@ -110,6 +112,23 @@ class Model:
config.pop('batch_size')
return DataLoader(dataset, **config)
+ def _batch_splitter(
+ self,
+ batch: list[dict[str, Union[np.int64, str, torch.Tensor]]]
+ ) -> tuple[dict[str, Union[list[str], torch.Tensor]],
+ dict[str, Union[list[str], torch.Tensor]]]:
+ """
+ Disentanglement need two random conditions, this function will
+ split pr * k * 2 samples to 2 dicts each containing pr * k
+ samples. labels and clip data are tensor, and others are list.
+ """
+ _batch = [[], []]
+ for i in range(0, self.pr * self.k * 2, self.k * 2):
+ _batch[0] += batch[i:i + self.k]
+ _batch[1] += batch[i + self.k:i + self.k * 2]
+
+ return default_collate(_batch[0]), default_collate(_batch[1])
+
@staticmethod
def _make_signature(config: dict,
popped_keys: Optional[list] = None) -> str:
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 02345d6..0f3b4f4 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -76,7 +76,6 @@ class RGBPartNet(nn.Module):
if self.training:
# TODO Implement Batch All triplet loss function
batch_all_triplet_loss = torch.tensor(0.)
- print(*losses, batch_all_triplet_loss)
loss = torch.sum(torch.stack((*losses, batch_all_triplet_loss)))
return loss
else:
diff --git a/utils/configuration.py b/utils/configuration.py
index 965af94..3e98343 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -30,11 +30,21 @@ class DataloaderConfiguration(TypedDict):
class HyperparameterConfiguration(TypedDict):
- hidden_dim: int
+ ae_feature_channels: int
+ f_a_c_p_dims: tuple[int, int, int]
+ hpm_scales: tuple[int, ...]
+ hpm_use_avg_pool: bool
+ hpm_use_max_pool: bool
+ fpfe_feature_channels: int
+ fpfe_kernel_sizes: tuple[tuple, ...]
+ fpfe_paddings: tuple[tuple, ...]
+ fpfe_halving: tuple[int, ...]
+ tfa_squeeze_ratio: int
+ tfa_num_parts: int
+ embedding_dims: int
+ triplet_margin: float
lr: int
betas: tuple[float, float]
- hard_or_all: str
- margin: float
class ModelConfiguration(TypedDict):