diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/hpm.py | 55 | ||||
-rw-r--r-- | models/layers.py | 96 | ||||
-rw-r--r-- | models/model.py | 114 | ||||
-rw-r--r-- | models/part_net.py | 151 | ||||
-rw-r--r-- | models/rgb_part_net.py | 63 |
5 files changed, 16 insertions, 463 deletions
diff --git a/models/hpm.py b/models/hpm.py deleted file mode 100644 index 9879cfb..0000000 --- a/models/hpm.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -import torch.nn as nn - -from models.layers import HorizontalPyramidPooling - - -class HorizontalPyramidMatching(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int = 128, - use_1x1conv: bool = False, - scales: tuple[int, ...] = (1, 2, 4), - use_avg_pool: bool = True, - use_max_pool: bool = False, - **kwargs - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.use_1x1conv = use_1x1conv - self.scales = scales - self.use_avg_pool = use_avg_pool - self.use_max_pool = use_max_pool - - self.pyramids = nn.ModuleList([ - self._make_pyramid(scale, **kwargs) for scale in self.scales - ]) - - def _make_pyramid(self, scale: int, **kwargs): - pyramid = nn.ModuleList([ - HorizontalPyramidPooling(self.in_channels, - self.out_channels, - use_1x1conv=self.use_1x1conv, - use_avg_pool=self.use_avg_pool, - use_max_pool=self.use_max_pool, - **kwargs) - for _ in range(scale) - ]) - return pyramid - - def forward(self, x): - n, c, h, w = x.size() - feature = [] - for scale, pyramid in zip(self.scales, self.pyramids): - h_per_hpp = h // scale - for hpp_index, hpp in enumerate(pyramid): - h_filter = torch.arange(hpp_index * h_per_hpp, - (hpp_index + 1) * h_per_hpp) - x_slice = x[:, :, h_filter, :] - x_slice = hpp(x_slice) - x_slice = x_slice.view(n, -1) - feature.append(x_slice) - x = torch.stack(feature) - return x diff --git a/models/layers.py b/models/layers.py index ef53a95..1b4640f 100644 --- a/models/layers.py +++ b/models/layers.py @@ -1,6 +1,5 @@ from typing import Union -import torch import torch.nn as nn import torch.nn.functional as F @@ -97,98 +96,3 @@ class BasicLinear(nn.Module): x = self.fc(x) x = self.bn(x) return x - - -class FocalConv2d(BasicConv2d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, tuple[int, int]], - halving: int, - **kwargs - ): - super().__init__(in_channels, out_channels, kernel_size, **kwargs) - self.halving = halving - - def forward(self, x): - h = x.size(2) - split_size = h // 2 ** self.halving - z = x.split(split_size, dim=2) - z = torch.cat([self.conv(_) for _ in z], dim=2) - return F.leaky_relu(z, inplace=True) - - -class FocalConv2dBlock(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_sizes: tuple[int, int], - paddings: tuple[int, int], - halving: int, - use_pool: bool = True, - **kwargs - ): - super().__init__() - self.use_pool = use_pool - self.fconv1 = FocalConv2d(in_channels, out_channels, kernel_sizes[0], - halving, padding=paddings[0], **kwargs) - self.fconv2 = FocalConv2d(out_channels, out_channels, kernel_sizes[1], - halving, padding=paddings[1], **kwargs) - self.max_pool = nn.MaxPool2d(2) - - def forward(self, x): - x = self.fconv1(x) - x = self.fconv2(x) - if self.use_pool: - x = self.max_pool(x) - return x - - -class BasicConv1d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, tuple[int]], - **kwargs - ): - super().__init__() - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, - bias=False, **kwargs) - - def forward(self, x): - return self.conv(x) - - -class HorizontalPyramidPooling(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - use_1x1conv: bool = False, - use_avg_pool: bool = True, - use_max_pool: bool = False, - **kwargs - ): - super().__init__() - self.use_1x1conv = use_1x1conv - if use_1x1conv: - self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs) - self.use_avg_pool = use_avg_pool - self.use_max_pool = use_max_pool - assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.' - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.max_pool = nn.AdaptiveMaxPool2d(1) - - def forward(self, x): - if self.use_avg_pool and self.use_max_pool: - x = self.avg_pool(x) + self.max_pool(x) - elif self.use_avg_pool and not self.use_max_pool: - x = self.avg_pool(x) - elif not self.use_avg_pool and self.use_max_pool: - x = self.max_pool(x) - if self.use_1x1conv: - x = self.conv(x) - return x diff --git a/models/model.py b/models/model.py index 82d6461..3f24936 100644 --- a/models/model.py +++ b/models/model.py @@ -5,7 +5,6 @@ from typing import Union, Optional import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate @@ -142,29 +141,16 @@ class Model: # Prepare for model, optimizer and scheduler model_hp = self.hp.get('model', {}) optim_hp: dict = self.hp.get('optimizer', {}).copy() - start_iter = optim_hp.pop('start_iter', 0) - ae_optim_hp = optim_hp.pop('auto_encoder', {}) - pn_optim_hp = optim_hp.pop('part_net', {}) - hpm_optim_hp = optim_hp.pop('hpm', {}) - fc_optim_hp = optim_hp.pop('fc', {}) sched_hp = self.hp.get('scheduler', {}) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) - self.optimizer = optim.Adam([ - {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, - {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.fc_mat, **fc_optim_hp} - ], **optim_hp) + self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp) sched_gamma = sched_hp.get('gamma', 0.9) sched_step_size = sched_hp.get('step_size', 500) self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ lambda epoch: sched_gamma ** (epoch // sched_step_size), - lambda epoch: 0 if epoch < start_iter else 1, - lambda epoch: 0 if epoch < start_iter else 1, - lambda epoch: 0 if epoch < start_iter else 1, ]) self.writer = SummaryWriter(self._log_name) @@ -182,10 +168,10 @@ class Model: # Training start start_time = datetime.now() - running_loss = torch.zeros(5, device=self.device) + running_loss = torch.zeros(3, device=self.device) print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}") + f"{'LR':^9}") for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -193,10 +179,7 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - y = batch_c1['label'].to(self.device) - # Duplicate labels for each part - y = y.unsqueeze(1).repeat(1, self.rgb_pn.num_total_parts) - losses, images = self.rgb_pn(x_c1, x_c2, y) + losses, images = self.rgb_pn(x_c1, x_c2) loss = losses.sum() loss.backward() self.optimizer.step() @@ -206,19 +189,16 @@ class Model: # Write losses to TensorBoard self.writer.add_scalar('Loss/all', loss, self.curr_iter) self.writer.add_scalars('Loss/details', dict(zip([ - 'Cross reconstruction loss', 'Canonical consistency loss', - 'Pose similarity loss', 'Batch All triplet loss (HPM)', - 'Batch All triplet loss (PartNet)' + 'Cross reconstruction loss', + 'Canonical consistency loss', + 'Pose similarity loss' ], losses)), self.curr_iter) if self.curr_iter % 100 == 0: - lrs = self.scheduler.get_last_lr() + lr = self.scheduler.get_last_lr()[0] # Write learning rates self.writer.add_scalar( - 'Learning rate/Auto-encoder', lrs[0], self.curr_iter - ) - self.writer.add_scalar( - 'Learning rate/Others', lrs[1], self.curr_iter + 'Learning rate/Auto-encoder', lr, self.curr_iter ) # Write disentangled images if self.image_log_on: @@ -241,8 +221,8 @@ class Model: hour, minute = divmod(remaining_minute, 60) print(f'{hour:02}:{minute:02}:{second:02}', f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', - '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), - '{:.3e} {:.3e}'.format(lrs[0], lrs[1])) + '{:f} {:f} {:f}'.format(*running_loss / 100), + f'{lr:.3e}') running_loss.zero_() # Step scheduler @@ -261,24 +241,6 @@ class Model: self.writer.close() break - def predict_all( - self, - iters: tuple[int], - dataset_config: DatasetConfiguration, - dataset_selectors: dict[ - str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] - ], - dataloader_config: DataloaderConfiguration, - ) -> dict[str, torch.Tensor]: - # Transform data to features - gallery_samples, probe_samples = self.transform( - iters, dataset_config, dataset_selectors, dataloader_config - ) - # Evaluate features - accuracy = self.evaluate(gallery_samples, probe_samples) - - return accuracy - def transform( self, iters: tuple[int], @@ -329,61 +291,13 @@ class Model: def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): label = sample.pop('label').item() clip = sample.pop('clip').to(self.device) - feature = self.rgb_pn(clip).detach() + x_c, x_p = self.rgb_pn(clip).detach() return { **{'label': label}, **sample, - **{'feature': feature} - } - - def evaluate( - self, - gallery_samples: dict[str, Union[list[str], torch.Tensor]], - probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]], - num_ranks: int = 5 - ) -> dict[str, torch.Tensor]: - probe_conditions = self._probe_datasets_meta.keys() - gallery_views_meta = self._gallery_dataset_meta['views'] - probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] - accuracy = { - condition: torch.empty( - len(gallery_views_meta), len(probe_views_meta), num_ranks - ) - for condition in self._probe_datasets_meta.keys() + **{'cano_feature': x_c, 'pose_feature': x_p} } - (labels_g, _, views_g, features_g) = gallery_samples.values() - views_g = np.asarray(views_g) - for (v_g_i, view_g) in enumerate(gallery_views_meta): - gallery_view_mask = (views_g == view_g) - f_g = features_g[gallery_view_mask] - y_g = labels_g[gallery_view_mask] - for condition in probe_conditions: - probe_samples_c = probe_samples[condition] - accuracy_c = accuracy[condition] - (labels_p, _, views_p, features_p) = probe_samples_c.values() - views_p = np.asarray(views_p) - for (v_p_i, view_p) in enumerate(probe_views_meta): - probe_view_mask = (views_p == view_p) - f_p = features_p[probe_view_mask] - y_p = labels_p[probe_view_mask] - # Euclidean distance - f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1) - f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0) - f_p_times_f_g_sum = f_p @ f_g.T - dist = torch.sqrt(F.relu( - f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum - )) - # Ranked accuracy - rank_mask = dist.argsort(1)[:, :num_ranks] - positive_mat = torch.eq(y_p.unsqueeze(1), - y_g[rank_mask]).cumsum(1).gt(0) - positive_counts = positive_mat.sum(0) - total_counts, _ = dist.size() - accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts - - return accuracy - def _load_pretrained( self, iters: tuple[int], @@ -452,8 +366,6 @@ class Model: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) - elif isinstance(m, RGBPartNet): - nn.init.xavier_uniform_(m.fc_mat) def _parse_dataset_config( self, diff --git a/models/part_net.py b/models/part_net.py deleted file mode 100644 index 62a2bac..0000000 --- a/models/part_net.py +++ /dev/null @@ -1,151 +0,0 @@ -import copy - -import torch -import torch.nn as nn - -from models.layers import BasicConv1d, FocalConv2dBlock - - -class FrameLevelPartFeatureExtractor(nn.Module): - - def __init__( - self, - in_channels: int = 3, - feature_channels: int = 32, - kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), - paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), - halving: tuple[int, ...] = (0, 2, 3) - ): - super().__init__() - num_blocks = len(kernel_sizes) - out_channels = [feature_channels * 2 ** i for i in range(num_blocks)] - in_channels = [in_channels] + out_channels[:-1] - use_pools = [True] * (num_blocks - 1) + [False] - params = (in_channels, out_channels, kernel_sizes, - paddings, halving, use_pools) - - self.fconv_blocks = nn.ModuleList([ - FocalConv2dBlock(*_params) for _params in zip(*params) - ]) - - def forward(self, x): - # Flatten frames in all batches - n, t, c, h, w = x.size() - x = x.view(n * t, c, h, w) - - for fconv_block in self.fconv_blocks: - x = fconv_block(x) - return x - - -class TemporalFeatureAggregator(nn.Module): - def __init__( - self, - in_channels: int, - squeeze_ratio: int = 4, - num_part: int = 16 - ): - super().__init__() - hidden_dim = in_channels // squeeze_ratio - self.num_part = num_part - - # MTB1 - conv3x1 = nn.Sequential( - BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1), - nn.LeakyReLU(inplace=True), - BasicConv1d(hidden_dim, in_channels, kernel_size=1, padding=0) - ) - self.conv1d3x1 = self._parted(conv3x1) - self.avg_pool3x1 = nn.AvgPool1d(kernel_size=3, stride=1, padding=1) - self.max_pool3x1 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) - - # MTB2 - conv3x3 = nn.Sequential( - BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1), - nn.LeakyReLU(inplace=True), - BasicConv1d(hidden_dim, in_channels, kernel_size=3, padding=1) - ) - self.conv1d3x3 = self._parted(conv3x3) - self.avg_pool3x3 = nn.AvgPool1d(kernel_size=5, stride=1, padding=2) - self.max_pool3x3 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2) - - def _parted(self, module: nn.Module): - """Duplicate module `part_num` times.""" - return nn.ModuleList([copy.deepcopy(module) - for _ in range(self.num_part)]) - - def forward(self, x): - # p, n, t, c - x = x.transpose(2, 3) - p, n, c, t = x.size() - feature = x.split(1, dim=0) - feature = [f.squeeze(0) for f in feature] - x = x.view(-1, c, t) - - # MTB1: ConvNet1d & Sigmoid - logits3x1 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x1, feature)] - ) - scores3x1 = torch.sigmoid(logits3x1) - # MTB1: Template Function - feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x) - feature3x1 = feature3x1.view(p, n, c, t) - feature3x1 = feature3x1 * scores3x1 - - # MTB2: ConvNet1d & Sigmoid - logits3x3 = torch.stack( - [conv(f) for conv, f in zip(self.conv1d3x3, feature)] - ) - scores3x3 = torch.sigmoid(logits3x3) - # MTB2: Template Function - feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x) - feature3x3 = feature3x3.view(p, n, c, t) - feature3x3 = feature3x3 * scores3x3 - - # Temporal Pooling - ret = (feature3x1 + feature3x3).max(-1)[0] - return ret - - -class PartNet(nn.Module): - def __init__( - self, - in_channels: int = 3, - feature_channels: int = 32, - kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), - paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), - halving: tuple[int, ...] = (0, 2, 3), - squeeze_ratio: int = 4, - num_part: int = 16 - ): - super().__init__() - self.num_part = num_part - self.fpfe = FrameLevelPartFeatureExtractor( - in_channels, feature_channels, kernel_sizes, paddings, halving - ) - - num_fconv_blocks = len(self.fpfe.fconv_blocks) - self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1) - self.tfa = TemporalFeatureAggregator( - self.tfa_in_channels, squeeze_ratio, self.num_part - ) - - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.max_pool = nn.AdaptiveMaxPool2d(1) - - def forward(self, x): - n, t, _, _, _ = x.size() - x = self.fpfe(x) - # n * t x c x h x w - - # Horizontal Pooling - _, c, h, w = x.size() - split_size = h // self.num_part - x = x.split(split_size, dim=2) - x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x] - x = [x_.view(n, t, c) for x_ in x] - x = torch.stack(x) - - # p, n, t, c - x = self.tfa(x) - return x diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 67acac3..f18d675 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -2,9 +2,6 @@ import torch import torch.nn as nn from models.auto_encoder import AutoEncoder -from models.hpm import HorizontalPyramidMatching -from models.part_net import PartNet -from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): @@ -14,80 +11,26 @@ class RGBPartNet(nn.Module): ae_in_size: tuple[int, int] = (64, 48), ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), - hpm_use_1x1conv: bool = False, - hpm_scales: tuple[int, ...] = (1, 2, 4), - hpm_use_avg_pool: bool = True, - hpm_use_max_pool: bool = True, - fpfe_feature_channels: int = 32, - fpfe_kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)), - fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)), - fpfe_halving: tuple[int, ...] = (0, 2, 3), - tfa_squeeze_ratio: int = 4, - tfa_num_parts: int = 16, - embedding_dims: int = 256, - triplet_margins: tuple[float, float] = (0.2, 0.2), image_log_on: bool = False ): super().__init__() (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims - self.hpm_num_parts = sum(hpm_scales) self.image_log_on = image_log_on self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) - self.pn = PartNet( - ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, - fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_parts - ) - out_channels = self.pn.tfa_in_channels - self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 2, out_channels, hpm_use_1x1conv, - hpm_scales, hpm_use_avg_pool, hpm_use_max_pool - ) - self.num_total_parts = self.hpm_num_parts + tfa_num_parts - empty_fc = torch.empty(self.num_total_parts, - out_channels, embedding_dims) - self.fc_mat = nn.Parameter(empty_fc) - - (hpm_margin, pn_margin) = triplet_margins - self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) - self.pn_ba_trip = BatchAllTripletLoss(pn_margin) - - def fc(self, x): - return x @ self.fc_mat - def forward(self, x_c1, x_c2=None, y=None): + def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement # n, t, c, h, w ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2) - # Step 2.a: Static Gait Feature Aggregation & HPM - # n, c, h, w - x_c = self.hpm(x_c) - # p, n, c - - # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) - # n, t, c, h, w - x_p = self.pn(x_p) - # p, n, c - - # Step 3: Cat feature map together and fc - x = torch.cat((x_c, x_p)) - x = self.fc(x) - if self.training: - y = y.T - hpm_ba_trip = self.hpm_ba_trip( - x[:self.hpm_num_parts], y[:self.hpm_num_parts] - ) - pn_ba_trip = self.pn_ba_trip( - x[self.hpm_num_parts:], y[self.hpm_num_parts:] - ) - losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) + losses = torch.stack(losses) return losses, images else: - return x.unsqueeze(1).view(-1) + return x_c, x_p def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() |