diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 4 | ||||
-rw-r--r-- | models/hpm.py | 54 | ||||
-rw-r--r-- | models/layers.py | 87 | ||||
-rw-r--r-- | models/model.py | 282 | ||||
-rw-r--r-- | models/part_net.py | 149 | ||||
-rw-r--r-- | models/rgb_part_net.py | 121 |
6 files changed, 77 insertions, 620 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 4fece69..91071dd 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -106,15 +106,13 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose, is_feature_map=False): + def forward(self, f_appearance, f_canonical, f_pose): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0) x = F.relu(x, inplace=True) x = self.trans_conv1(x) x = self.trans_conv2(x) - if is_feature_map: - return x x = self.trans_conv3(x) x = torch.sigmoid(self.trans_conv4(x)) diff --git a/models/hpm.py b/models/hpm.py deleted file mode 100644 index 8186b20..0000000 --- a/models/hpm.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -import torch.nn as nn - -from models.layers import HorizontalPyramidPooling - - -class HorizontalPyramidMatching(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int = 128, - scales: tuple[int, ...] = (1, 2, 4), - use_avg_pool: bool = True, - use_max_pool: bool = False, - ): - super().__init__() - self.scales = scales - self.num_parts = sum(scales) - self.use_avg_pool = use_avg_pool - self.use_max_pool = use_max_pool - - self.pyramids = nn.ModuleList([ - self._make_pyramid(scale) for scale in scales - ]) - self.fc_mat = nn.Parameter( - torch.empty(self.num_parts, in_channels, out_channels) - ) - - def _make_pyramid(self, scale: int): - pyramid = nn.ModuleList([ - HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool) - for _ in range(scale) - ]) - return pyramid - - def forward(self, x): - n, c, h, w = x.size() - feature = [] - for scale, pyramid in zip(self.scales, self.pyramids): - h_per_hpp = h // scale - for hpp_index, hpp in enumerate(pyramid): - h_filter = torch.arange(hpp_index * h_per_hpp, - (hpp_index + 1) * h_per_hpp) - x_slice = x[:, :, h_filter, :] - x_slice = hpp(x_slice) - x_slice = x_slice.view(n, -1) - feature.append(x_slice) - x = torch.stack(feature) - - # p, n, c - x = x @ self.fc_mat - # p, n, d - - return x diff --git a/models/layers.py b/models/layers.py index c609698..5306769 100644 --- a/models/layers.py +++ b/models/layers.py @@ -1,6 +1,5 @@ from typing import Union -import torch import torch.nn as nn import torch.nn.functional as F @@ -99,89 +98,3 @@ class BasicLinear(nn.Module): x = self.fc(x) x = self.bn(x) return x - - -class FocalConv2d(BasicConv2d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, tuple[int, int]], - halving: int, - **kwargs - ): - super().__init__(in_channels, out_channels, kernel_size, **kwargs) - self.halving = halving - - def forward(self, x): - h = x.size(2) - split_size = h // 2 ** self.halving - z = x.split(split_size, dim=2) - z = torch.cat([self.conv(_) for _ in z], dim=2) - return F.leaky_relu(z, inplace=True) - - -class FocalConv2dBlock(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_sizes: tuple[int, int], - paddings: tuple[int, int], - halving: int, - use_pool: bool = True, - **kwargs - ): - super().__init__() - self.use_pool = use_pool - self.fconv1 = FocalConv2d(in_channels, out_channels, kernel_sizes[0], - halving, padding=paddings[0], **kwargs) - self.fconv2 = FocalConv2d(out_channels, out_channels, kernel_sizes[1], - halving, padding=paddings[1], **kwargs) - self.max_pool = nn.MaxPool2d(2) - - def forward(self, x): - x = self.fconv1(x) - x = self.fconv2(x) - if self.use_pool: - x = self.max_pool(x) - return x - - -class BasicConv1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, tuple[int]], - **kwargs - ): - super().__init__() - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, - bias=False, **kwargs) - - def forward(self, x): - return self.conv(x) - - -class HorizontalPyramidPooling(nn.Module): - def __init__( - self, - use_avg_pool: bool = True, - use_max_pool: bool = False, - ): - super().__init__() - self.use_avg_pool = use_avg_pool - self.use_max_pool = use_max_pool - assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.' - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.max_pool = nn.AdaptiveMaxPool2d(1) - - def forward(self, x): - if self.use_avg_pool and self.use_max_pool: - x = self.avg_pool(x) + self.max_pool(x) - elif self.use_avg_pool and not self.use_max_pool: - x = self.avg_pool(x) - elif not self.use_avg_pool and self.use_max_pool: - x = self.max_pool(x) - return x diff --git a/models/model.py b/models/model.py index ceadb92..25c8a4f 100644 --- a/models/model.py +++ b/models/model.py @@ -6,22 +6,18 @@ from typing import Union, Optional import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from models.hpm import HorizontalPyramidMatching -from models.part_net import PartNet from models.rgb_part_net import RGBPartNet from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler -from utils.triplet_loss import BatchTripletLoss class Model: @@ -62,8 +58,6 @@ class Model: self.in_size: tuple[int, int] = (64, 48) self.pr: Optional[int] = None self.k: Optional[int] = None - self.num_pairs: Optional[int] = None - self.num_pos_pairs: Optional[int] = None self._gallery_dataset_meta: Optional[dict[str, list]] = None self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None @@ -73,8 +67,6 @@ class Model: self._dataset_sig: str = 'undefined' self.rgb_pn: Optional[RGBPartNet] = None - self.triplet_loss_hpm: Optional[BatchTripletLoss] = None - self.triplet_loss_pn: Optional[BatchTripletLoss] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None @@ -166,71 +158,23 @@ class Model: )) # Prepare for model, optimizer and scheduler model_hp: dict = self.hp.get('model', {}).copy() - triplet_is_hard = model_hp.pop('triplet_is_hard', True) - triplet_is_mean = model_hp.pop('triplet_is_mean', True) - triplet_margins = model_hp.pop('triplet_margins', None) optim_hp: dict = self.hp.get('optimizer', {}).copy() - ae_optim_hp = optim_hp.pop('auto_encoder', {}) - hpm_optim_hp = optim_hp.pop('hpm', {}) - pn_optim_hp = optim_hp.pop('part_net', {}) sched_hp = self.hp.get('scheduler', {}) - ae_sched_hp = sched_hp.get('auto_encoder', {}) - hpm_sched_hp = sched_hp.get('hpm', {}) - pn_sched_hp = sched_hp.get('part_net', {}) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) - # Hard margins - if triplet_margins: - self.triplet_loss_hpm = BatchTripletLoss( - triplet_is_hard, triplet_is_mean, triplet_margins[0] - ) - self.triplet_loss_pn = BatchTripletLoss( - triplet_is_hard, triplet_is_mean, triplet_margins[1] - ) - else: # Soft margins - self.triplet_loss_hpm = BatchTripletLoss( - triplet_is_hard, triplet_is_mean, None - ) - self.triplet_loss_pn = BatchTripletLoss( - triplet_is_hard, triplet_is_mean, None - ) - - self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 - self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) - self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device) - self.triplet_loss_pn = self.triplet_loss_pn.to(self.device) - - self.optimizer = optim.Adam([ - {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, - ], **optim_hp) - - # Scheduler + self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp) start_step = sched_hp.get('start_step', 15_000) final_gamma = sched_hp.get('final_gamma', 0.001) - ae_start_step = ae_sched_hp.get('start_step', start_step) - ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma) - ae_all_step = self.total_iter - ae_start_step - hpm_start_step = hpm_sched_hp.get('start_step', start_step) - hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma) - hpm_all_step = self.total_iter - hpm_start_step - pn_start_step = pn_sched_hp.get('start_step', start_step) - pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma) - pn_all_step = self.total_iter - pn_start_step - self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ - lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step) - if t > ae_start_step else 1, - lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step) - if t > hpm_start_step else 1, - lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step) - if t > pn_start_step else 1, - ]) - + all_step = self.total_iter - start_step + self.scheduler = optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda t: final_gamma ** ((t - start_step) / all_step) + if t > start_step else 1, + ) self.writer = SummaryWriter(self._log_name) # Set seeds for reproducibility @@ -259,24 +203,18 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2) - y = batch_c1['label'].to(self.device) - losses, hpm_result, pn_result = self._classification_loss( - embed_c, embed_p, ae_losses, y - ) + losses, features, images = self.rgb_pn(x_c1, x_c2) loss = losses.sum() loss.backward() self.optimizer.step() self.scheduler.step() # Learning rate - self.writer.add_scalars('Learning rate', dict(zip(( - 'Auto-encoder', 'HPM', 'PartNet' - ), self.scheduler.get_last_lr())), self.curr_iter) - # Other stats - self._write_stat( - 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses + self.writer.add_scalar( + 'Learning rate', self.scheduler.get_last_lr()[0], self.curr_iter ) + # Other stats + self._write_stat('Train', loss, losses) if self.curr_iter % 100 == 99: # Write disentangled images @@ -295,32 +233,33 @@ class Model: self.writer.add_images( f'Pose image/batch {i}', p, self.curr_iter ) - - # Validation - embed_c = self._flatten_embedding(embed_c) - embed_p = self._flatten_embedding(embed_p) - self._write_embedding('HPM Train', embed_c, x_c1, y) - self._write_embedding('PartNet Train', embed_p, x_c1, y) + f_a, f_c, f_p = features + for i, (f_a_i, f_c_i, f_p_i) in enumerate( + zip(f_a, f_c, f_p) + ): + self.writer.add_images( + f'Appearance features/Layer {i}', + f_a_i[:, :3, :, :], self.curr_iter + ) + self.writer.add_images( + f'Canonical features/Layer {i}', + f_c_i[:, :3, :, :], self.curr_iter + ) + for j, p in enumerate(f_p_i): + self.writer.add_images( + f'Pose features/Layer {i}/batch{j}', + p[:, :3, :, :], self.curr_iter + ) # Calculate losses on testing batch batch_c1, batch_c2 = next(val_dataloader) x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) with torch.no_grad(): - embed_c, embed_p, ae_losses, _ = self.rgb_pn(x_c1, x_c2) - y = batch_c1['label'].to(self.device) - losses, hpm_result, pn_result = self._classification_loss( - embed_c, embed_p, ae_losses, y - ) + losses, _, _ = self.rgb_pn(x_c1, x_c2) loss = losses.sum() - self._write_stat( - 'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses - ) - embed_c = self._flatten_embedding(embed_c) - embed_p = self._flatten_embedding(embed_p) - self._write_embedding('HPM Val', embed_c, x_c1, y) - self._write_embedding('PartNet Val', embed_p, x_c1, y) + self._write_stat('Val', loss, losses) # Checkpoint if self.curr_iter % 1000 == 999: @@ -333,117 +272,15 @@ class Model: self.writer.close() - def _classification_loss(self, embed_c, embed_p, ae_losses, y): - # Duplicate labels for each part - y_triplet = y.repeat(self.rgb_pn.num_parts, 1) - hpm_result = self.triplet_loss_hpm( - embed_c, y_triplet[:self.rgb_pn.hpm.num_parts] - ) - pn_result = self.triplet_loss_pn( - embed_p, y_triplet[self.rgb_pn.hpm.num_parts:] - ) - losses = torch.stack(( - *ae_losses, - hpm_result.pop('loss').mean(), - pn_result.pop('loss').mean() - )) - return losses, hpm_result, pn_result - - def _write_embedding(self, tag, embed, x, y): - frame = x[:, 0, :, :, :].cpu() - n, c, h, w = frame.size() - padding = torch.zeros(n, c, h, (h-w) // 2) - padded_frame = torch.cat((padding, frame, padding), dim=-1) - self.writer.add_embedding( - embed, - metadata=y.cpu().tolist(), - label_img=padded_frame, - global_step=self.curr_iter, - tag=tag - ) - - def _flatten_embedding(self, embed): - return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1) - def _write_stat( - self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses + self, postfix, loss, losses ): # Write losses to TensorBoard self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter) self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip(( 'Cross reconstruction loss', 'Canonical consistency loss', 'Pose similarity loss' - ), losses[:3])), self.curr_iter) - self.writer.add_scalars(f'Loss/triplet loss {postfix}', { - 'HPM': losses[3], - 'PartNet': losses[4] - }, self.curr_iter) - # None-zero losses in batch - if hpm_result['counts'] is not None and pn_result['counts'] is not None: - self.writer.add_scalars(f'Loss/non-zero counts {postfix}', { - 'HPM': hpm_result['counts'].mean(), - 'PartNet': pn_result['counts'].mean() - }, self.curr_iter) - # Embedding distance - mean_hpm_dist = hpm_result['dist'].mean(0) - self._add_ranked_scalars( - f'Embedding/HPM distance {postfix}', mean_hpm_dist, - self.num_pos_pairs, self.num_pairs, self.curr_iter - ) - mean_pn_dist = pn_result['dist'].mean(0) - self._add_ranked_scalars( - f'Embedding/ParNet distance {postfix}', mean_pn_dist, - self.num_pos_pairs, self.num_pairs, self.curr_iter - ) - # Embedding norm - mean_hpm_embedding = embed_c.mean(0) - mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) - self._add_ranked_scalars( - f'Embedding/HPM norm {postfix}', mean_hpm_norm, - self.k, self.pr * self.k, self.curr_iter - ) - mean_pa_embedding = embed_p.mean(0) - mean_pa_norm = mean_pa_embedding.norm(dim=-1) - self._add_ranked_scalars( - f'Embedding/PartNet norm {postfix}', mean_pa_norm, - self.k, self.pr * self.k, self.curr_iter - ) - - def _add_ranked_scalars( - self, - main_tag: str, - metric: torch.Tensor, - num_pos: int, - num_all: int, - global_step: int - ): - rank = metric.argsort() - pos_ile = 100 - (num_pos - 1) * 100 // num_all - self.writer.add_scalars(main_tag, { - '0%-ile': metric[rank[-1]], - f'{100 - pos_ile}%-ile': metric[rank[-num_pos]], - '50%-ile': metric[rank[num_all // 2 - 1]], - f'{pos_ile}%-ile': metric[rank[num_pos - 1]], - '100%-ile': metric[rank[0]] - }, global_step) - - def predict_all( - self, - iters: tuple[int], - dataset_config: DatasetConfiguration, - dataset_selectors: dict[ - str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] - ], - dataloader_config: DataloaderConfiguration, - ) -> dict[str, torch.Tensor]: - # Transform data to features - gallery_samples, probe_samples = self.transform( - iters, dataset_config, dataset_selectors, dataloader_config - ) - # Evaluate features - accuracy = self.evaluate(gallery_samples, probe_samples) - - return accuracy + ), losses)), self.curr_iter) def transform( self, @@ -466,9 +303,6 @@ class Model: # Init models model_hp: dict = self.hp.get('model', {}).copy() - model_hp.pop('triplet_is_hard', True) - model_hp.pop('triplet_is_mean', True) - model_hp.pop('triplet_margins', None) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) @@ -509,54 +343,6 @@ class Model: 'feature': torch.cat((feature_c, feature_p)).view(-1) } - @staticmethod - def evaluate( - gallery_samples: dict[str, dict[str, Union[list, torch.Tensor]]], - probe_samples: dict[str, dict[str, Union[list, torch.Tensor]]], - num_ranks: int = 5 - ) -> dict[str, torch.Tensor]: - conditions = list(probe_samples.keys()) - gallery_views_meta = gallery_samples['meta']['views'] - probe_views_meta = probe_samples[conditions[0]]['meta']['views'] - accuracy = { - condition: torch.empty( - len(gallery_views_meta), len(probe_views_meta), num_ranks - ) - for condition in conditions - } - - for condition in conditions: - gallery_samples_c = gallery_samples[condition] - (labels_g, _, views_g, features_g) = gallery_samples_c.values() - views_g = np.asarray(views_g) - probe_samples_c = probe_samples[condition] - (labels_p, _, views_p, features_p, _) = probe_samples_c.values() - views_p = np.asarray(views_p) - accuracy_c = accuracy[condition] - for (v_g_i, view_g) in enumerate(gallery_views_meta): - gallery_view_mask = (views_g == view_g) - f_g = features_g[gallery_view_mask] - y_g = labels_g[gallery_view_mask] - for (v_p_i, view_p) in enumerate(probe_views_meta): - probe_view_mask = (views_p == view_p) - f_p = features_p[probe_view_mask] - y_p = labels_p[probe_view_mask] - # Euclidean distance - f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1) - f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0) - f_p_times_f_g_sum = f_p @ f_g.T - dist = torch.sqrt(F.relu( - f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum - )) - # Ranked accuracy - rank_mask = dist.argsort(1)[:, :num_ranks] - positive_mat = torch.eq(y_p.unsqueeze(1), - y_g[rank_mask]).cumsum(1).gt(0) - positive_counts = positive_mat.sum(0) - total_counts, _ = dist.size() - accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts - return accuracy - def _load_pretrained( self, iters: tuple[int], @@ -629,8 +415,6 @@ class Model: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) - elif isinstance(m, (HorizontalPyramidMatching, PartNet)): - nn.init.xavier_uniform_(m.fc_mat) def _parse_dataset_config( self, diff --git a/models/part_net.py b/models/part_net.py deleted file mode 100644 index f2236bf..0000000 --- a/models/part_net.py +++ /dev/null @@ -1,149 +0,0 @@ -import copy - -import torch -import torch.nn as nn - -from models.layers import BasicConv1d, FocalConv2dBlock - - -class FrameLevelPartFeatureExtractor(nn.Module): - - def __init__( - self, - in_channels: int = 3, - feature_channels: int = 32, - kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), - paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), - halving: tuple[int, ...] = (0, 2, 3) - ): - super().__init__() - num_blocks = len(kernel_sizes) - out_channels = [feature_channels * 2 ** i for i in range(num_blocks)] - in_channels = [in_channels] + out_channels[:-1] - use_pools = [True] * (num_blocks - 1) + [False] - params = (in_channels, out_channels, kernel_sizes, - paddings, halving, use_pools) - - self.fconv_blocks = nn.ModuleList([ - FocalConv2dBlock(*_params) for _params in zip(*params) - ]) - - def forward(self, x): - # Flatten frames in all batches - n, t, c, h, w = x.size() - x = x.view(n * t, c, h, w) - - for fconv_block in self.fconv_blocks: - x = fconv_block(x) - return x - - -class TemporalFeatureAggregator(nn.Module): - def __init__( - self, - in_channels: int, - squeeze_ratio: int = 4, - num_part: int = 16 - ): - super().__init__() - hidden_dim = in_channels // squeeze_ratio - self.num_part = num_part - - # MTB1 - conv3x1 = nn.Sequential( - BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1), - nn.LeakyReLU(inplace=True), - BasicConv1d(hidden_dim, in_channels, kernel_size=1, padding=0) - ) - self.conv1d3x1 = self._parted(conv3x1) - self.avg_pool3x1 = nn.AvgPool1d(kernel_size=3, stride=1, padding=1) - self.max_pool3x1 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) - - # MTB2 - conv3x3 = nn.Sequential( - BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1), - nn.LeakyReLU(inplace=True), - BasicConv1d(hidden_dim, in_channels, kernel_size=3, padding=1) - ) - self.conv1d3x3 = self._parted(conv3x3) - self.avg_pool3x3 = nn.AvgPool1d(kernel_size=5, stride=1, padding=2) - self.max_pool3x3 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2) - - def _parted(self, module: nn.Module): - """Duplicate module `part_num` times.""" - return nn.ModuleList([copy.deepcopy(module) - for _ in range(self.num_part)]) - - def forward(self, x): - # p, n, t, c - x = x.transpose(2, 3) - p, n, c, t = x.size() - feature = x.split(1, dim=0) - feature = [f.squeeze(0) for f in feature] - x = x.view(-1, c, t) - - # MTB1: ConvNet1d & Sigmoid - logits3x1 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x1, feature)] - ) - scores3x1 = torch.sigmoid(logits3x1) - # MTB1: Template Function - feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x) - feature3x1 = feature3x1.view(p, n, c, t) - feature3x1 = feature3x1 * scores3x1 - - # MTB2: ConvNet1d & Sigmoid - logits3x3 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x3, feature)] - ) - scores3x3 = torch.sigmoid(logits3x3) - # MTB2: Template Function - feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x) - feature3x3 = feature3x3.view(p, n, c, t) - feature3x3 = feature3x3 * scores3x3 - - # Temporal Pooling - ret = (feature3x1 + feature3x3).max(-1)[0] - return ret - - -class PartNet(nn.Module): - def __init__( - self, - in_channels: int = 128, - embedding_dims: int = 256, - num_parts: int = 16, - squeeze_ratio: int = 4, - ): - super().__init__() - self.num_part = num_parts - - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.max_pool = nn.AdaptiveMaxPool2d(1) - self.tfa = TemporalFeatureAggregator( - in_channels, squeeze_ratio, self.num_part - ) - self.fc_mat = nn.Parameter( - torch.empty(num_parts, in_channels, embedding_dims) - ) - - def forward(self, x): - n, t, c, h, w = x.size() - x = x.view(n * t, c, h, w) - # n * t x c x h x w - - # Horizontal Pooling - _, c, h, w = x.size() - split_size = h // self.num_part - x = x.split(split_size, dim=2) - x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(n, t, c) for x_ in x] - x = torch.stack(x) - - # p, n, t, c - x = self.tfa(x) - - # p, n, c - x = x @ self.fc_mat - # p, n, d - return x diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 4a82da3..d3f8ade 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -1,9 +1,8 @@ import torch import torch.nn as nn +import torch.nn.functional as F from models.auto_encoder import AutoEncoder -from models.hpm import HorizontalPyramidMatching -from models.part_net import PartNet class RGBPartNet(nn.Module): @@ -13,12 +12,6 @@ class RGBPartNet(nn.Module): ae_in_size: tuple[int, int] = (64, 48), ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), - hpm_scales: tuple[int, ...] = (1, 2, 4), - hpm_use_avg_pool: bool = True, - hpm_use_max_pool: bool = True, - tfa_squeeze_ratio: int = 4, - tfa_num_parts: int = 16, - embedding_dims: tuple[int] = (256, 256), image_log_on: bool = False ): super().__init__() @@ -29,94 +22,66 @@ class RGBPartNet(nn.Module): self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) - self.pn_in_channels = ae_feature_channels * 2 - self.hpm = HorizontalPyramidMatching( - self.pn_in_channels, embedding_dims[0], hpm_scales, - hpm_use_avg_pool, hpm_use_max_pool - ) - self.pn = PartNet(self.pn_in_channels, embedding_dims[1], - tfa_num_parts, tfa_squeeze_ratio) - - self.num_parts = self.hpm.num_parts + tfa_num_parts def forward(self, x_c1, x_c2=None): - # Step 1: Disentanglement - # n, t, c, h, w - ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2) - - # Step 2.a: Static Gait Feature Aggregation & HPM - # n, c, h, w - x_c = self.hpm(x_c) - # p, n, d - - # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) - # n, t, c, h, w - x_p = self.pn(x_p) - # p, n, d + losses, features, images = self._disentangle(x_c1, x_c2) if self.training: - return x_c, x_p, ae_losses, images + losses = torch.stack(losses) + return losses, features, images else: - return x_c, x_p + return features def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() - device = x_c1_t2.device if self.training: x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) - # Decode features - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, device) - x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) + f_a = f_a_.view(n, t, -1) + f_c = f_c_.view(n, t, -1) + f_p = f_p_.view(n, t, -1) i_a, i_c, i_p = None, None, None if self.image_log_on: with torch.no_grad(): - i_a = self._decode_appr_feature(f_a_, n, t, device) - # Continue decoding canonical features - i_c = self.ae.decoder.trans_conv3(x_c) - i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c)) - i_p_ = self.ae.decoder.trans_conv3(x_p_) - i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_)) + x_a, i_a = self._separate_decode( + f_a.mean(1), + torch.zeros_like(f_c[:, 0, :]), + torch.zeros_like(f_p[:, 0, :]) + ) + x_c, i_c = self._separate_decode( + torch.zeros_like(f_a[:, 0, :]), + f_c.mean(1), + torch.zeros_like(f_p[:, 0, :]), + ) + x_p_, i_p_ = self._separate_decode( + torch.zeros_like(f_a_), + torch.zeros_like(f_c_), + f_p_ + ) + x_p = tuple(_x_p.view(n, t, *_x_p.size()[1:]) for _x_p in x_p_) i_p = i_p_.view(n, t, c, h, w) - return (x_c, x_p), losses, (i_a, i_c, i_p) + return losses, (x_a, x_c, x_p), (i_a, i_c, i_p) else: # evaluating f_c_, f_p_ = self.ae(x_c1_t2) - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, device) - x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) - return (x_c, x_p), None, None - - def _decode_appr_feature(self, f_a_, n, t, device): - # Decode appearance features - f_a = f_a_.view(n, t, -1) - x_a = self.ae.decoder( - f_a.mean(1), - torch.zeros((n, self.f_c_dim), device=device), - torch.zeros((n, self.f_p_dim), device=device) - ) - return x_a - - def _decode_cano_feature(self, f_c_, n, t, device): - # Decode average canonical features to higher dimension - f_c = f_c_.view(n, t, -1) - x_c = self.ae.decoder( - torch.zeros((n, self.f_a_dim), device=device), - f_c.mean(1), - torch.zeros((n, self.f_p_dim), device=device), - is_feature_map=True - ) - return x_c - - def _decode_pose_feature(self, f_p_, n, t, device): - # Decode pose features to images - x_p_ = self.ae.decoder( - torch.zeros((n * t, self.f_a_dim), device=device), - torch.zeros((n * t, self.f_c_dim), device=device), - f_p_, - is_feature_map=True + f_c = f_c_.view(n, t, -1) + f_p = f_p_.view(n, t, -1) + return (f_c, f_p), None, None + + def _separate_decode(self, f_a, f_c, f_p): + x_1 = torch.cat((f_a, f_c, f_p), dim=1) + x_1 = self.ae.decoder.fc(x_1).view( + -1, + self.ae.decoder.feature_channels * 8, + self.ae.decoder.h_0, + self.ae.decoder.w_0 ) - return x_p_ + x_1 = F.relu(x_1, inplace=True) + x_2 = self.ae.decoder.trans_conv1(x_1) + x_3 = self.ae.decoder.trans_conv2(x_2) + x_4 = self.ae.decoder.trans_conv3(x_3) + image = torch.sigmoid(self.ae.decoder.trans_conv4(x_4)) + x = (x_1, x_2, x_3, x_4) + return x, image |