diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-19 22:39:49 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-19 22:39:49 +0800 |
commit | d12dd6b04a4e7c2b1ee43ab6f36f25d0c35ca364 (patch) | |
tree | 71b5209ce4b5cfb1d09b89fe133028bbfa481dc9 | |
parent | 4aa9044122878a8e2b887a8b170c036983431559 (diff) |
New branch with auto-encoder only
-rw-r--r-- | config.py | 38 | ||||
-rw-r--r-- | eval.py | 30 | ||||
-rw-r--r-- | models/hpm.py | 55 | ||||
-rw-r--r-- | models/layers.py | 96 | ||||
-rw-r--r-- | models/model.py | 114 | ||||
-rw-r--r-- | models/part_net.py | 151 | ||||
-rw-r--r-- | models/rgb_part_net.py | 63 | ||||
-rw-r--r-- | test/hpm.py | 23 | ||||
-rw-r--r-- | test/part_net.py | 71 | ||||
-rw-r--r-- | utils/configuration.py | 24 | ||||
-rw-r--r-- | utils/triplet_loss.py | 36 |
11 files changed, 20 insertions, 681 deletions
@@ -7,9 +7,9 @@ config: Configuration = { # GPU(s) used in training or testing if available 'CUDA_VISIBLE_DEVICES': '0', # Directory used in training or testing for temporary storage - 'save_dir': 'runs', + 'save_dir': 'runs/dis_only', # Recorde disentangled image or not - 'image_log_on': False + 'image_log_on': True }, # Dataset settings 'dataset': { @@ -37,7 +37,7 @@ config: Configuration = { # Batch size (pr, k) # `pr` denotes number of persons # `k` denotes number of sequences per person - 'batch_size': (4, 8), + 'batch_size': (2, 2), # Number of workers of Dataloader 'num_workers': 4, # Faster data transfer from RAM to GPU if enabled @@ -49,35 +49,10 @@ config: Configuration = { # Auto-encoder feature channels coefficient 'ae_feature_channels': 64, # Appearance, canonical and pose feature dimensions - 'f_a_c_p_dims': (128, 128, 64), - # Use 1x1 convolution in dimensionality reduction - 'hpm_use_1x1conv': False, - # HPM pyramid scales, of which sum is number of parts - 'hpm_scales': (1, 2, 4), - # Global pooling method - 'hpm_use_avg_pool': True, - 'hpm_use_max_pool': False, - # FConv feature channels coefficient - 'fpfe_feature_channels': 32, - # FConv blocks kernel sizes - 'fpfe_kernel_sizes': ((5, 3), (3, 3), (3, 3)), - # FConv blocks paddings - 'fpfe_paddings': ((2, 1), (1, 1), (1, 1)), - # FConv blocks halving - 'fpfe_halving': (0, 2, 3), - # Attention squeeze ratio - 'tfa_squeeze_ratio': 4, - # Number of parts after Part Net - 'tfa_num_parts': 16, - # Embedding dimension for each part - 'embedding_dims': 256, - # Triplet loss margins for HPM and PartNet - 'triplet_margins': (0.2, 0.2), + 'f_a_c_p_dims': (192, 192, 96), }, 'optimizer': { # Global parameters - # Iteration start to optimize non-disentangling parts - # 'start_iter': 0, # Initial learning rate of Adam Optimizer 'lr': 1e-4, # Coefficients used for computing running averages of @@ -89,11 +64,6 @@ config: Configuration = { # 'weight_decay': 0, # Use AMSGrad or not # 'amsgrad': False, - - # Local parameters (override global ones) - 'auto_encoder': { - 'weight_decay': 0.001 - }, }, 'scheduler': { # Period of learning rate decay diff --git a/eval.py b/eval.py deleted file mode 100644 index c0505f8..0000000 --- a/eval.py +++ /dev/null @@ -1,30 +0,0 @@ -import numpy as np - -from config import config -from models import Model -from utils.dataset import ClipConditions -from utils.misc import set_visible_cuda - -set_visible_cuda(config['system']) -model = Model(config['system'], config['model'], config['hyperparameter']) - -dataset_selectors = { - 'nm': {'conditions': ClipConditions({r'nm-0\d'})}, - 'bg': {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})}, - 'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})}, -} - -accuracy = model.predict_all(config['model']['total_iters'], config['dataset'], - dataset_selectors, config['dataloader']) -rank = 5 -np.set_printoptions(formatter={'float': '{:5.2f}'.format}) -for n in range(rank): - print(f'===Rank-{n + 1} Accuracy===') - for (condition, accuracy_c) in accuracy.items(): - acc_excl_identical_view = accuracy_c[:, :, n].fill_diagonal_(0) - num_gallery_views = (acc_excl_identical_view != 0).sum(0) - acc_each_angle = acc_excl_identical_view.sum(0) / num_gallery_views - print('{0}: {1} mean: {2:5.2f}'.format( - condition, acc_each_angle.cpu().numpy() * 100, - acc_each_angle.mean() * 100) - ) diff --git a/models/hpm.py b/models/hpm.py deleted file mode 100644 index 9879cfb..0000000 --- a/models/hpm.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -import torch.nn as nn - -from models.layers import HorizontalPyramidPooling - - -class HorizontalPyramidMatching(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int = 128, - use_1x1conv: bool = False, - scales: tuple[int, ...] = (1, 2, 4), - use_avg_pool: bool = True, - use_max_pool: bool = False, - **kwargs - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.use_1x1conv = use_1x1conv - self.scales = scales - self.use_avg_pool = use_avg_pool - self.use_max_pool = use_max_pool - - self.pyramids = nn.ModuleList([ - self._make_pyramid(scale, **kwargs) for scale in self.scales - ]) - - def _make_pyramid(self, scale: int, **kwargs): - pyramid = nn.ModuleList([ - HorizontalPyramidPooling(self.in_channels, - self.out_channels, - use_1x1conv=self.use_1x1conv, - use_avg_pool=self.use_avg_pool, - use_max_pool=self.use_max_pool, - **kwargs) - for _ in range(scale) - ]) - return pyramid - - def forward(self, x): - n, c, h, w = x.size() - feature = [] - for scale, pyramid in zip(self.scales, self.pyramids): - h_per_hpp = h // scale - for hpp_index, hpp in enumerate(pyramid): - h_filter = torch.arange(hpp_index * h_per_hpp, - (hpp_index + 1) * h_per_hpp) - x_slice = x[:, :, h_filter, :] - x_slice = hpp(x_slice) - x_slice = x_slice.view(n, -1) - feature.append(x_slice) - x = torch.stack(feature) - return x diff --git a/models/layers.py b/models/layers.py index ef53a95..1b4640f 100644 --- a/models/layers.py +++ b/models/layers.py @@ -1,6 +1,5 @@ from typing import Union -import torch import torch.nn as nn import torch.nn.functional as F @@ -97,98 +96,3 @@ class BasicLinear(nn.Module): x = self.fc(x) x = self.bn(x) return x - - -class FocalConv2d(BasicConv2d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, tuple[int, int]], - halving: int, - **kwargs - ): - super().__init__(in_channels, out_channels, kernel_size, **kwargs) - self.halving = halving - - def forward(self, x): - h = x.size(2) - split_size = h // 2 ** self.halving - z = x.split(split_size, dim=2) - z = torch.cat([self.conv(_) for _ in z], dim=2) - return F.leaky_relu(z, inplace=True) - - -class FocalConv2dBlock(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_sizes: tuple[int, int], - paddings: tuple[int, int], - halving: int, - use_pool: bool = True, - **kwargs - ): - super().__init__() - self.use_pool = use_pool - self.fconv1 = FocalConv2d(in_channels, out_channels, kernel_sizes[0], - halving, padding=paddings[0], **kwargs) - self.fconv2 = FocalConv2d(out_channels, out_channels, kernel_sizes[1], - halving, padding=paddings[1], **kwargs) - self.max_pool = nn.MaxPool2d(2) - - def forward(self, x): - x = self.fconv1(x) - x = self.fconv2(x) - if self.use_pool: - x = self.max_pool(x) - return x - - -class BasicConv1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, tuple[int]], - **kwargs - ): - super().__init__() - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, - bias=False, **kwargs) - - def forward(self, x): - return self.conv(x) - - -class HorizontalPyramidPooling(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - use_1x1conv: bool = False, - use_avg_pool: bool = True, - use_max_pool: bool = False, - **kwargs - ): - super().__init__() - self.use_1x1conv = use_1x1conv - if use_1x1conv: - self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs) - self.use_avg_pool = use_avg_pool - self.use_max_pool = use_max_pool - assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.' - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.max_pool = nn.AdaptiveMaxPool2d(1) - - def forward(self, x): - if self.use_avg_pool and self.use_max_pool: - x = self.avg_pool(x) + self.max_pool(x) - elif self.use_avg_pool and not self.use_max_pool: - x = self.avg_pool(x) - elif not self.use_avg_pool and self.use_max_pool: - x = self.max_pool(x) - if self.use_1x1conv: - x = self.conv(x) - return x diff --git a/models/model.py b/models/model.py index 82d6461..3f24936 100644 --- a/models/model.py +++ b/models/model.py @@ -5,7 +5,6 @@ from typing import Union, Optional import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate @@ -142,29 +141,16 @@ class Model: # Prepare for model, optimizer and scheduler model_hp = self.hp.get('model', {}) optim_hp: dict = self.hp.get('optimizer', {}).copy() - start_iter = optim_hp.pop('start_iter', 0) - ae_optim_hp = optim_hp.pop('auto_encoder', {}) - pn_optim_hp = optim_hp.pop('part_net', {}) - hpm_optim_hp = optim_hp.pop('hpm', {}) - fc_optim_hp = optim_hp.pop('fc', {}) sched_hp = self.hp.get('scheduler', {}) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) - self.optimizer = optim.Adam([ - {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, - {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.fc_mat, **fc_optim_hp} - ], **optim_hp) + self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp) sched_gamma = sched_hp.get('gamma', 0.9) sched_step_size = sched_hp.get('step_size', 500) self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ lambda epoch: sched_gamma ** (epoch // sched_step_size), - lambda epoch: 0 if epoch < start_iter else 1, - lambda epoch: 0 if epoch < start_iter else 1, - lambda epoch: 0 if epoch < start_iter else 1, ]) self.writer = SummaryWriter(self._log_name) @@ -182,10 +168,10 @@ class Model: # Training start start_time = datetime.now() - running_loss = torch.zeros(5, device=self.device) + running_loss = torch.zeros(3, device=self.device) print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}") + f"{'LR':^9}") for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -193,10 +179,7 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - y = batch_c1['label'].to(self.device) - # Duplicate labels for each part - y = y.unsqueeze(1).repeat(1, self.rgb_pn.num_total_parts) - losses, images = self.rgb_pn(x_c1, x_c2, y) + losses, images = self.rgb_pn(x_c1, x_c2) loss = losses.sum() loss.backward() self.optimizer.step() @@ -206,19 +189,16 @@ class Model: # Write losses to TensorBoard self.writer.add_scalar('Loss/all', loss, self.curr_iter) self.writer.add_scalars('Loss/details', dict(zip([ - 'Cross reconstruction loss', 'Canonical consistency loss', - 'Pose similarity loss', 'Batch All triplet loss (HPM)', - 'Batch All triplet loss (PartNet)' + 'Cross reconstruction loss', + 'Canonical consistency loss', + 'Pose similarity loss' ], losses)), self.curr_iter) if self.curr_iter % 100 == 0: - lrs = self.scheduler.get_last_lr() + lr = self.scheduler.get_last_lr()[0] # Write learning rates self.writer.add_scalar( - 'Learning rate/Auto-encoder', lrs[0], self.curr_iter - ) - self.writer.add_scalar( - 'Learning rate/Others', lrs[1], self.curr_iter + 'Learning rate/Auto-encoder', lr, self.curr_iter ) # Write disentangled images if self.image_log_on: @@ -241,8 +221,8 @@ class Model: hour, minute = divmod(remaining_minute, 60) print(f'{hour:02}:{minute:02}:{second:02}', f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', - '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), - '{:.3e} {:.3e}'.format(lrs[0], lrs[1])) + '{:f} {:f} {:f}'.format(*running_loss / 100), + f'{lr:.3e}') running_loss.zero_() # Step scheduler @@ -261,24 +241,6 @@ class Model: self.writer.close() break - def predict_all( - self, - iters: tuple[int], - dataset_config: DatasetConfiguration, - dataset_selectors: dict[ - str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] - ], - dataloader_config: DataloaderConfiguration, - ) -> dict[str, torch.Tensor]: - # Transform data to features - gallery_samples, probe_samples = self.transform( - iters, dataset_config, dataset_selectors, dataloader_config - ) - # Evaluate features - accuracy = self.evaluate(gallery_samples, probe_samples) - - return accuracy - def transform( self, iters: tuple[int], @@ -329,61 +291,13 @@ class Model: def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): label = sample.pop('label').item() clip = sample.pop('clip').to(self.device) - feature = self.rgb_pn(clip).detach() + x_c, x_p = self.rgb_pn(clip).detach() return { **{'label': label}, **sample, - **{'feature': feature} - } - - def evaluate( - self, - gallery_samples: dict[str, Union[list[str], torch.Tensor]], - probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]], - num_ranks: int = 5 - ) -> dict[str, torch.Tensor]: - probe_conditions = self._probe_datasets_meta.keys() - gallery_views_meta = self._gallery_dataset_meta['views'] - probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] - accuracy = { - condition: torch.empty( - len(gallery_views_meta), len(probe_views_meta), num_ranks - ) - for condition in self._probe_datasets_meta.keys() + **{'cano_feature': x_c, 'pose_feature': x_p} } - (labels_g, _, views_g, features_g) = gallery_samples.values() - views_g = np.asarray(views_g) - for (v_g_i, view_g) in enumerate(gallery_views_meta): - gallery_view_mask = (views_g == view_g) - f_g = features_g[gallery_view_mask] - y_g = labels_g[gallery_view_mask] - for condition in probe_conditions: - probe_samples_c = probe_samples[condition] - accuracy_c = accuracy[condition] - (labels_p, _, views_p, features_p) = probe_samples_c.values() - views_p = np.asarray(views_p) - for (v_p_i, view_p) in enumerate(probe_views_meta): - probe_view_mask = (views_p == view_p) - f_p = features_p[probe_view_mask] - y_p = labels_p[probe_view_mask] - # Euclidean distance - f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1) - f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0) - f_p_times_f_g_sum = f_p @ f_g.T - dist = torch.sqrt(F.relu( - f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum - )) - # Ranked accuracy - rank_mask = dist.argsort(1)[:, :num_ranks] - positive_mat = torch.eq(y_p.unsqueeze(1), - y_g[rank_mask]).cumsum(1).gt(0) - positive_counts = positive_mat.sum(0) - total_counts, _ = dist.size() - accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts - - return accuracy - def _load_pretrained( self, iters: tuple[int], @@ -452,8 +366,6 @@ class Model: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) - elif isinstance(m, RGBPartNet): - nn.init.xavier_uniform_(m.fc_mat) def _parse_dataset_config( self, diff --git a/models/part_net.py b/models/part_net.py deleted file mode 100644 index 62a2bac..0000000 --- a/models/part_net.py +++ /dev/null @@ -1,151 +0,0 @@ -import copy - -import torch -import torch.nn as nn - -from models.layers import BasicConv1d, FocalConv2dBlock - - -class FrameLevelPartFeatureExtractor(nn.Module): - - def __init__( - self, - in_channels: int = 3, - feature_channels: int = 32, - kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), - paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), - halving: tuple[int, ...] = (0, 2, 3) - ): - super().__init__() - num_blocks = len(kernel_sizes) - out_channels = [feature_channels * 2 ** i for i in range(num_blocks)] - in_channels = [in_channels] + out_channels[:-1] - use_pools = [True] * (num_blocks - 1) + [False] - params = (in_channels, out_channels, kernel_sizes, - paddings, halving, use_pools) - - self.fconv_blocks = nn.ModuleList([ - FocalConv2dBlock(*_params) for _params in zip(*params) - ]) - - def forward(self, x): - # Flatten frames in all batches - n, t, c, h, w = x.size() - x = x.view(n * t, c, h, w) - - for fconv_block in self.fconv_blocks: - x = fconv_block(x) - return x - - -class TemporalFeatureAggregator(nn.Module): - def __init__( - self, - in_channels: int, - squeeze_ratio: int = 4, - num_part: int = 16 - ): - super().__init__() - hidden_dim = in_channels // squeeze_ratio - self.num_part = num_part - - # MTB1 - conv3x1 = nn.Sequential( - BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1), - nn.LeakyReLU(inplace=True), - BasicConv1d(hidden_dim, in_channels, kernel_size=1, padding=0) - ) - self.conv1d3x1 = self._parted(conv3x1) - self.avg_pool3x1 = nn.AvgPool1d(kernel_size=3, stride=1, padding=1) - self.max_pool3x1 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) - - # MTB2 - conv3x3 = nn.Sequential( - BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1), - nn.LeakyReLU(inplace=True), - BasicConv1d(hidden_dim, in_channels, kernel_size=3, padding=1) - ) - self.conv1d3x3 = self._parted(conv3x3) - self.avg_pool3x3 = nn.AvgPool1d(kernel_size=5, stride=1, padding=2) - self.max_pool3x3 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2) - - def _parted(self, module: nn.Module): - """Duplicate module `part_num` times.""" - return nn.ModuleList([copy.deepcopy(module) - for _ in range(self.num_part)]) - - def forward(self, x): - # p, n, t, c - x = x.transpose(2, 3) - p, n, c, t = x.size() - feature = x.split(1, dim=0) - feature = [f.squeeze(0) for f in feature] - x = x.view(-1, c, t) - - # MTB1: ConvNet1d & Sigmoid - logits3x1 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x1, feature)] - ) - scores3x1 = torch.sigmoid(logits3x1) - # MTB1: Template Function - feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x) - feature3x1 = feature3x1.view(p, n, c, t) - feature3x1 = feature3x1 * scores3x1 - - # MTB2: ConvNet1d & Sigmoid - logits3x3 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x3, feature)] - ) - scores3x3 = torch.sigmoid(logits3x3) - # MTB2: Template Function - feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x) - feature3x3 = feature3x3.view(p, n, c, t) - feature3x3 = feature3x3 * scores3x3 - - # Temporal Pooling - ret = (feature3x1 + feature3x3).max(-1)[0] - return ret - - -class PartNet(nn.Module): - def __init__( - self, - in_channels: int = 3, - feature_channels: int = 32, - kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), - paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), - halving: tuple[int, ...] = (0, 2, 3), - squeeze_ratio: int = 4, - num_part: int = 16 - ): - super().__init__() - self.num_part = num_part - self.fpfe = FrameLevelPartFeatureExtractor( - in_channels, feature_channels, kernel_sizes, paddings, halving - ) - - num_fconv_blocks = len(self.fpfe.fconv_blocks) - self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) - self.tfa = TemporalFeatureAggregator( - self.tfa_in_channels, squeeze_ratio, self.num_part - ) - - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.max_pool = nn.AdaptiveMaxPool2d(1) - - def forward(self, x): - n, t, _, _, _ = x.size() - x = self.fpfe(x) - # n * t x c x h x w - - # Horizontal Pooling - _, c, h, w = x.size() - split_size = h // self.num_part - x = x.split(split_size, dim=2) - x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(n, t, c) for x_ in x] - x = torch.stack(x) - - # p, n, t, c - x = self.tfa(x) - return x diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 67acac3..f18d675 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -2,9 +2,6 @@ import torch import torch.nn as nn from models.auto_encoder import AutoEncoder -from models.hpm import HorizontalPyramidMatching -from models.part_net import PartNet -from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): @@ -14,80 +11,26 @@ class RGBPartNet(nn.Module): ae_in_size: tuple[int, int] = (64, 48), ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), - hpm_use_1x1conv: bool = False, - hpm_scales: tuple[int, ...] = (1, 2, 4), - hpm_use_avg_pool: bool = True, - hpm_use_max_pool: bool = True, - fpfe_feature_channels: int = 32, - fpfe_kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), - fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), - fpfe_halving: tuple[int, ...] = (0, 2, 3), - tfa_squeeze_ratio: int = 4, - tfa_num_parts: int = 16, - embedding_dims: int = 256, - triplet_margins: tuple[float, float] = (0.2, 0.2), image_log_on: bool = False ): super().__init__() (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims - self.hpm_num_parts = sum(hpm_scales) self.image_log_on = image_log_on self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) - self.pn = PartNet( - ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, - fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_parts - ) - out_channels = self.pn.tfa_in_channels - self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 2, out_channels, hpm_use_1x1conv, - hpm_scales, hpm_use_avg_pool, hpm_use_max_pool - ) - self.num_total_parts = self.hpm_num_parts + tfa_num_parts - empty_fc = torch.empty(self.num_total_parts, - out_channels, embedding_dims) - self.fc_mat = nn.Parameter(empty_fc) - - (hpm_margin, pn_margin) = triplet_margins - self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) - self.pn_ba_trip = BatchAllTripletLoss(pn_margin) - - def fc(self, x): - return x @ self.fc_mat - def forward(self, x_c1, x_c2=None, y=None): + def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement # n, t, c, h, w ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2) - # Step 2.a: Static Gait Feature Aggregation & HPM - # n, c, h, w - x_c = self.hpm(x_c) - # p, n, c - - # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) - # n, t, c, h, w - x_p = self.pn(x_p) - # p, n, c - - # Step 3: Cat feature map together and fc - x = torch.cat((x_c, x_p)) - x = self.fc(x) - if self.training: - y = y.T - hpm_ba_trip = self.hpm_ba_trip( - x[:self.hpm_num_parts], y[:self.hpm_num_parts] - ) - pn_ba_trip = self.pn_ba_trip( - x[self.hpm_num_parts:], y[self.hpm_num_parts:] - ) - losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) + losses = torch.stack(losses) return losses, images else: - return x.unsqueeze(1).view(-1) + return x_c, x_p def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() diff --git a/test/hpm.py b/test/hpm.py deleted file mode 100644 index 0aefbb8..0000000 --- a/test/hpm.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch - -from models.hpm import HorizontalPyramidMatching - -T, N, C, H, W = 15, 4, 256, 32, 16 - - -def test_default_hpm(): - hpm = HorizontalPyramidMatching(in_channels=C) - x = torch.rand(T, N, C, H, W) - x = hpm(x) - assert tuple(x.size()) == (1 + 2 + 4, T, N, 128) - - -def test_custom_hpm(): - hpm = HorizontalPyramidMatching(in_channels=2048, - out_channels=256, - scales=(1, 2, 4, 8), - use_avg_pool=True, - use_max_pool=False) - x = torch.rand(T, N, 2048, H, W) - x = hpm(x) - assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256) diff --git a/test/part_net.py b/test/part_net.py deleted file mode 100644 index 25e92ae..0000000 --- a/test/part_net.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch - -from models.part_net import FrameLevelPartFeatureExtractor, \ - TemporalFeatureAggregator, PartNet - -T, N, C, H, W = 15, 4, 3, 64, 32 - - -def test_default_fpfe(): - fpfe = FrameLevelPartFeatureExtractor() - x = torch.rand(T, N, C, H, W) - x = fpfe(x) - - assert tuple(x.size()) == (T * N, 32 * 4, 16, 8) - - -def test_custom_fpfe(): - feature_channels = 64 - fpfe = FrameLevelPartFeatureExtractor( - in_channels=1, - feature_channels=feature_channels, - kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)), - paddings=((2, 1), (1, 1), (1, 1), (1, 1)), - halving=(1, 1, 3, 3) - ) - x = torch.rand(T, N, 1, H, W) - x = fpfe(x) - - assert tuple(x.size()) == (T * N, feature_channels * 8, 8, 4) - - -def test_default_tfa(): - in_channels = 32 * 4 - tfa = TemporalFeatureAggregator(in_channels) - x = torch.rand(16, T, N, in_channels) - x = tfa(x) - - assert tuple(x.size()) == (16, N, in_channels) - - -def test_custom_tfa(): - in_channels = 64 * 8 - num_part = 8 - tfa = TemporalFeatureAggregator(in_channels=in_channels, - squeeze_ratio=8, num_part=num_part) - x = torch.rand(num_part, T, N, in_channels) - x = tfa(x) - - assert tuple(x.size()) == (num_part, N, in_channels) - - -def test_default_part_net(): - pa = PartNet() - x = torch.rand(T, N, C, H, W) - x = pa(x) - - assert tuple(x.size()) == (16, N, 32 * 4) - - -def test_custom_part_net(): - feature_channels = 64 - pa = PartNet(in_channels=1, feature_channels=feature_channels, - kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)), - paddings=((2, 1), (1, 1), (1, 1), (1, 1)), - halving=(1, 1, 3, 3), - squeeze_ratio=8, - num_part=8) - x = torch.rand(T, N, 1, H, W) - x = pa(x) - - assert tuple(x.size()) == (8, N, pa.tfa_in_channels) diff --git a/utils/configuration.py b/utils/configuration.py index 435d815..1b7c8d3 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -32,26 +32,6 @@ class DataloaderConfiguration(TypedDict): class ModelHPConfiguration(TypedDict): ae_feature_channels: int f_a_c_p_dims: tuple[int, int, int] - hpm_scales: tuple[int, ...] - hpm_use_1x1conv: bool - hpm_use_avg_pool: bool - hpm_use_max_pool: bool - fpfe_feature_channels: int - fpfe_kernel_sizes: tuple[tuple, ...] - fpfe_paddings: tuple[tuple, ...] - fpfe_halving: tuple[int, ...] - tfa_squeeze_ratio: int - tfa_num_parts: int - embedding_dims: int - triplet_margins: tuple[float, float] - - -class SubOptimizerHPConfiguration(TypedDict): - lr: int - betas: tuple[float, float] - eps: float - weight_decay: float - amsgrad: bool class OptimizerHPConfiguration(TypedDict): @@ -61,10 +41,6 @@ class OptimizerHPConfiguration(TypedDict): eps: float weight_decay: float amsgrad: bool - auto_encoder: SubOptimizerHPConfiguration - part_net: SubOptimizerHPConfiguration - hpm: SubOptimizerHPConfiguration - fc: SubOptimizerHPConfiguration class SchedulerHPConfiguration(TypedDict): diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py deleted file mode 100644 index 954def2..0000000 --- a/utils/triplet_loss.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class BatchAllTripletLoss(nn.Module): - def __init__(self, margin: float = 0.2): - super().__init__() - self.margin = margin - - def forward(self, x, y): - p, n, c = x.size() - - # Euclidean distance p x n x n - x_squared_sum = torch.sum(x ** 2, dim=2) - x1_squared_sum = x_squared_sum.unsqueeze(2) - x2_squared_sum = x_squared_sum.unsqueeze(1) - x1_times_x2_sum = x @ x.transpose(1, 2) - dist = torch.sqrt( - F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) - ) - - hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) - hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) - all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1) - all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1) - positive_negative_dist = all_hard_positive - all_hard_negative - all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - - # Non-zero parted mean - non_zero_counts = (all_loss != 0).sum(1) - parted_loss_mean = all_loss.sum(1) / non_zero_counts - parted_loss_mean[non_zero_counts == 0] = 0 - - loss = parted_loss_mean.mean() - return loss |