summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py4
-rw-r--r--models/hpm.py54
-rw-r--r--models/layers.py87
-rw-r--r--models/model.py282
-rw-r--r--models/part_net.py149
-rw-r--r--models/rgb_part_net.py121
6 files changed, 77 insertions, 620 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 4fece69..91071dd 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -106,15 +106,13 @@ class Decoder(nn.Module):
self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
- def forward(self, f_appearance, f_canonical, f_pose, is_feature_map=False):
+ def forward(self, f_appearance, f_canonical, f_pose):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.fc(x)
x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0)
x = F.relu(x, inplace=True)
x = self.trans_conv1(x)
x = self.trans_conv2(x)
- if is_feature_map:
- return x
x = self.trans_conv3(x)
x = torch.sigmoid(self.trans_conv4(x))
diff --git a/models/hpm.py b/models/hpm.py
deleted file mode 100644
index 8186b20..0000000
--- a/models/hpm.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import torch
-import torch.nn as nn
-
-from models.layers import HorizontalPyramidPooling
-
-
-class HorizontalPyramidMatching(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int = 128,
- scales: tuple[int, ...] = (1, 2, 4),
- use_avg_pool: bool = True,
- use_max_pool: bool = False,
- ):
- super().__init__()
- self.scales = scales
- self.num_parts = sum(scales)
- self.use_avg_pool = use_avg_pool
- self.use_max_pool = use_max_pool
-
- self.pyramids = nn.ModuleList([
- self._make_pyramid(scale) for scale in scales
- ])
- self.fc_mat = nn.Parameter(
- torch.empty(self.num_parts, in_channels, out_channels)
- )
-
- def _make_pyramid(self, scale: int):
- pyramid = nn.ModuleList([
- HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool)
- for _ in range(scale)
- ])
- return pyramid
-
- def forward(self, x):
- n, c, h, w = x.size()
- feature = []
- for scale, pyramid in zip(self.scales, self.pyramids):
- h_per_hpp = h // scale
- for hpp_index, hpp in enumerate(pyramid):
- h_filter = torch.arange(hpp_index * h_per_hpp,
- (hpp_index + 1) * h_per_hpp)
- x_slice = x[:, :, h_filter, :]
- x_slice = hpp(x_slice)
- x_slice = x_slice.view(n, -1)
- feature.append(x_slice)
- x = torch.stack(feature)
-
- # p, n, c
- x = x @ self.fc_mat
- # p, n, d
-
- return x
diff --git a/models/layers.py b/models/layers.py
index c609698..5306769 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -1,6 +1,5 @@
from typing import Union
-import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -99,89 +98,3 @@ class BasicLinear(nn.Module):
x = self.fc(x)
x = self.bn(x)
return x
-
-
-class FocalConv2d(BasicConv2d):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, tuple[int, int]],
- halving: int,
- **kwargs
- ):
- super().__init__(in_channels, out_channels, kernel_size, **kwargs)
- self.halving = halving
-
- def forward(self, x):
- h = x.size(2)
- split_size = h // 2 ** self.halving
- z = x.split(split_size, dim=2)
- z = torch.cat([self.conv(_) for _ in z], dim=2)
- return F.leaky_relu(z, inplace=True)
-
-
-class FocalConv2dBlock(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_sizes: tuple[int, int],
- paddings: tuple[int, int],
- halving: int,
- use_pool: bool = True,
- **kwargs
- ):
- super().__init__()
- self.use_pool = use_pool
- self.fconv1 = FocalConv2d(in_channels, out_channels, kernel_sizes[0],
- halving, padding=paddings[0], **kwargs)
- self.fconv2 = FocalConv2d(out_channels, out_channels, kernel_sizes[1],
- halving, padding=paddings[1], **kwargs)
- self.max_pool = nn.MaxPool2d(2)
-
- def forward(self, x):
- x = self.fconv1(x)
- x = self.fconv2(x)
- if self.use_pool:
- x = self.max_pool(x)
- return x
-
-
-class BasicConv1d(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, tuple[int]],
- **kwargs
- ):
- super().__init__()
- self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
- bias=False, **kwargs)
-
- def forward(self, x):
- return self.conv(x)
-
-
-class HorizontalPyramidPooling(nn.Module):
- def __init__(
- self,
- use_avg_pool: bool = True,
- use_max_pool: bool = False,
- ):
- super().__init__()
- self.use_avg_pool = use_avg_pool
- self.use_max_pool = use_max_pool
- assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.'
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.max_pool = nn.AdaptiveMaxPool2d(1)
-
- def forward(self, x):
- if self.use_avg_pool and self.use_max_pool:
- x = self.avg_pool(x) + self.max_pool(x)
- elif self.use_avg_pool and not self.use_max_pool:
- x = self.avg_pool(x)
- elif not self.use_avg_pool and self.use_max_pool:
- x = self.max_pool(x)
- return x
diff --git a/models/model.py b/models/model.py
index ceadb92..25c8a4f 100644
--- a/models/model.py
+++ b/models/model.py
@@ -6,22 +6,18 @@ from typing import Union, Optional
import numpy as np
import torch
import torch.nn as nn
-import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
-from models.hpm import HorizontalPyramidMatching
-from models.part_net import PartNet
from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
-from utils.triplet_loss import BatchTripletLoss
class Model:
@@ -62,8 +58,6 @@ class Model:
self.in_size: tuple[int, int] = (64, 48)
self.pr: Optional[int] = None
self.k: Optional[int] = None
- self.num_pairs: Optional[int] = None
- self.num_pos_pairs: Optional[int] = None
self._gallery_dataset_meta: Optional[dict[str, list]] = None
self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None
@@ -73,8 +67,6 @@ class Model:
self._dataset_sig: str = 'undefined'
self.rgb_pn: Optional[RGBPartNet] = None
- self.triplet_loss_hpm: Optional[BatchTripletLoss] = None
- self.triplet_loss_pn: Optional[BatchTripletLoss] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
@@ -166,71 +158,23 @@ class Model:
))
# Prepare for model, optimizer and scheduler
model_hp: dict = self.hp.get('model', {}).copy()
- triplet_is_hard = model_hp.pop('triplet_is_hard', True)
- triplet_is_mean = model_hp.pop('triplet_is_mean', True)
- triplet_margins = model_hp.pop('triplet_margins', None)
optim_hp: dict = self.hp.get('optimizer', {}).copy()
- ae_optim_hp = optim_hp.pop('auto_encoder', {})
- hpm_optim_hp = optim_hp.pop('hpm', {})
- pn_optim_hp = optim_hp.pop('part_net', {})
sched_hp = self.hp.get('scheduler', {})
- ae_sched_hp = sched_hp.get('auto_encoder', {})
- hpm_sched_hp = sched_hp.get('hpm', {})
- pn_sched_hp = sched_hp.get('part_net', {})
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
- # Hard margins
- if triplet_margins:
- self.triplet_loss_hpm = BatchTripletLoss(
- triplet_is_hard, triplet_is_mean, triplet_margins[0]
- )
- self.triplet_loss_pn = BatchTripletLoss(
- triplet_is_hard, triplet_is_mean, triplet_margins[1]
- )
- else: # Soft margins
- self.triplet_loss_hpm = BatchTripletLoss(
- triplet_is_hard, triplet_is_mean, None
- )
- self.triplet_loss_pn = BatchTripletLoss(
- triplet_is_hard, triplet_is_mean, None
- )
-
- self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
- self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
- self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device)
- self.triplet_loss_pn = self.triplet_loss_pn.to(self.device)
-
- self.optimizer = optim.Adam([
- {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
- {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
- {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
- ], **optim_hp)
-
- # Scheduler
+ self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp)
start_step = sched_hp.get('start_step', 15_000)
final_gamma = sched_hp.get('final_gamma', 0.001)
- ae_start_step = ae_sched_hp.get('start_step', start_step)
- ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma)
- ae_all_step = self.total_iter - ae_start_step
- hpm_start_step = hpm_sched_hp.get('start_step', start_step)
- hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma)
- hpm_all_step = self.total_iter - hpm_start_step
- pn_start_step = pn_sched_hp.get('start_step', start_step)
- pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma)
- pn_all_step = self.total_iter - pn_start_step
- self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
- lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step)
- if t > ae_start_step else 1,
- lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step)
- if t > hpm_start_step else 1,
- lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step)
- if t > pn_start_step else 1,
- ])
-
+ all_step = self.total_iter - start_step
+ self.scheduler = optim.lr_scheduler.LambdaLR(
+ self.optimizer,
+ lambda t: final_gamma ** ((t - start_step) / all_step)
+ if t > start_step else 1,
+ )
self.writer = SummaryWriter(self._log_name)
# Set seeds for reproducibility
@@ -259,24 +203,18 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)
- y = batch_c1['label'].to(self.device)
- losses, hpm_result, pn_result = self._classification_loss(
- embed_c, embed_p, ae_losses, y
- )
+ losses, features, images = self.rgb_pn(x_c1, x_c2)
loss = losses.sum()
loss.backward()
self.optimizer.step()
self.scheduler.step()
# Learning rate
- self.writer.add_scalars('Learning rate', dict(zip((
- 'Auto-encoder', 'HPM', 'PartNet'
- ), self.scheduler.get_last_lr())), self.curr_iter)
- # Other stats
- self._write_stat(
- 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses
+ self.writer.add_scalar(
+ 'Learning rate', self.scheduler.get_last_lr()[0], self.curr_iter
)
+ # Other stats
+ self._write_stat('Train', loss, losses)
if self.curr_iter % 100 == 99:
# Write disentangled images
@@ -295,32 +233,33 @@ class Model:
self.writer.add_images(
f'Pose image/batch {i}', p, self.curr_iter
)
-
- # Validation
- embed_c = self._flatten_embedding(embed_c)
- embed_p = self._flatten_embedding(embed_p)
- self._write_embedding('HPM Train', embed_c, x_c1, y)
- self._write_embedding('PartNet Train', embed_p, x_c1, y)
+ f_a, f_c, f_p = features
+ for i, (f_a_i, f_c_i, f_p_i) in enumerate(
+ zip(f_a, f_c, f_p)
+ ):
+ self.writer.add_images(
+ f'Appearance features/Layer {i}',
+ f_a_i[:, :3, :, :], self.curr_iter
+ )
+ self.writer.add_images(
+ f'Canonical features/Layer {i}',
+ f_c_i[:, :3, :, :], self.curr_iter
+ )
+ for j, p in enumerate(f_p_i):
+ self.writer.add_images(
+ f'Pose features/Layer {i}/batch{j}',
+ p[:, :3, :, :], self.curr_iter
+ )
# Calculate losses on testing batch
batch_c1, batch_c2 = next(val_dataloader)
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
with torch.no_grad():
- embed_c, embed_p, ae_losses, _ = self.rgb_pn(x_c1, x_c2)
- y = batch_c1['label'].to(self.device)
- losses, hpm_result, pn_result = self._classification_loss(
- embed_c, embed_p, ae_losses, y
- )
+ losses, _, _ = self.rgb_pn(x_c1, x_c2)
loss = losses.sum()
- self._write_stat(
- 'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses
- )
- embed_c = self._flatten_embedding(embed_c)
- embed_p = self._flatten_embedding(embed_p)
- self._write_embedding('HPM Val', embed_c, x_c1, y)
- self._write_embedding('PartNet Val', embed_p, x_c1, y)
+ self._write_stat('Val', loss, losses)
# Checkpoint
if self.curr_iter % 1000 == 999:
@@ -333,117 +272,15 @@ class Model:
self.writer.close()
- def _classification_loss(self, embed_c, embed_p, ae_losses, y):
- # Duplicate labels for each part
- y_triplet = y.repeat(self.rgb_pn.num_parts, 1)
- hpm_result = self.triplet_loss_hpm(
- embed_c, y_triplet[:self.rgb_pn.hpm.num_parts]
- )
- pn_result = self.triplet_loss_pn(
- embed_p, y_triplet[self.rgb_pn.hpm.num_parts:]
- )
- losses = torch.stack((
- *ae_losses,
- hpm_result.pop('loss').mean(),
- pn_result.pop('loss').mean()
- ))
- return losses, hpm_result, pn_result
-
- def _write_embedding(self, tag, embed, x, y):
- frame = x[:, 0, :, :, :].cpu()
- n, c, h, w = frame.size()
- padding = torch.zeros(n, c, h, (h-w) // 2)
- padded_frame = torch.cat((padding, frame, padding), dim=-1)
- self.writer.add_embedding(
- embed,
- metadata=y.cpu().tolist(),
- label_img=padded_frame,
- global_step=self.curr_iter,
- tag=tag
- )
-
- def _flatten_embedding(self, embed):
- return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1)
-
def _write_stat(
- self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses
+ self, postfix, loss, losses
):
# Write losses to TensorBoard
self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter)
self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip((
'Cross reconstruction loss', 'Canonical consistency loss',
'Pose similarity loss'
- ), losses[:3])), self.curr_iter)
- self.writer.add_scalars(f'Loss/triplet loss {postfix}', {
- 'HPM': losses[3],
- 'PartNet': losses[4]
- }, self.curr_iter)
- # None-zero losses in batch
- if hpm_result['counts'] is not None and pn_result['counts'] is not None:
- self.writer.add_scalars(f'Loss/non-zero counts {postfix}', {
- 'HPM': hpm_result['counts'].mean(),
- 'PartNet': pn_result['counts'].mean()
- }, self.curr_iter)
- # Embedding distance
- mean_hpm_dist = hpm_result['dist'].mean(0)
- self._add_ranked_scalars(
- f'Embedding/HPM distance {postfix}', mean_hpm_dist,
- self.num_pos_pairs, self.num_pairs, self.curr_iter
- )
- mean_pn_dist = pn_result['dist'].mean(0)
- self._add_ranked_scalars(
- f'Embedding/ParNet distance {postfix}', mean_pn_dist,
- self.num_pos_pairs, self.num_pairs, self.curr_iter
- )
- # Embedding norm
- mean_hpm_embedding = embed_c.mean(0)
- mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
- self._add_ranked_scalars(
- f'Embedding/HPM norm {postfix}', mean_hpm_norm,
- self.k, self.pr * self.k, self.curr_iter
- )
- mean_pa_embedding = embed_p.mean(0)
- mean_pa_norm = mean_pa_embedding.norm(dim=-1)
- self._add_ranked_scalars(
- f'Embedding/PartNet norm {postfix}', mean_pa_norm,
- self.k, self.pr * self.k, self.curr_iter
- )
-
- def _add_ranked_scalars(
- self,
- main_tag: str,
- metric: torch.Tensor,
- num_pos: int,
- num_all: int,
- global_step: int
- ):
- rank = metric.argsort()
- pos_ile = 100 - (num_pos - 1) * 100 // num_all
- self.writer.add_scalars(main_tag, {
- '0%-ile': metric[rank[-1]],
- f'{100 - pos_ile}%-ile': metric[rank[-num_pos]],
- '50%-ile': metric[rank[num_all // 2 - 1]],
- f'{pos_ile}%-ile': metric[rank[num_pos - 1]],
- '100%-ile': metric[rank[0]]
- }, global_step)
-
- def predict_all(
- self,
- iters: tuple[int],
- dataset_config: DatasetConfiguration,
- dataset_selectors: dict[
- str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
- ],
- dataloader_config: DataloaderConfiguration,
- ) -> dict[str, torch.Tensor]:
- # Transform data to features
- gallery_samples, probe_samples = self.transform(
- iters, dataset_config, dataset_selectors, dataloader_config
- )
- # Evaluate features
- accuracy = self.evaluate(gallery_samples, probe_samples)
-
- return accuracy
+ ), losses)), self.curr_iter)
def transform(
self,
@@ -466,9 +303,6 @@ class Model:
# Init models
model_hp: dict = self.hp.get('model', {}).copy()
- model_hp.pop('triplet_is_hard', True)
- model_hp.pop('triplet_is_mean', True)
- model_hp.pop('triplet_margins', None)
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
@@ -509,54 +343,6 @@ class Model:
'feature': torch.cat((feature_c, feature_p)).view(-1)
}
- @staticmethod
- def evaluate(
- gallery_samples: dict[str, dict[str, Union[list, torch.Tensor]]],
- probe_samples: dict[str, dict[str, Union[list, torch.Tensor]]],
- num_ranks: int = 5
- ) -> dict[str, torch.Tensor]:
- conditions = list(probe_samples.keys())
- gallery_views_meta = gallery_samples['meta']['views']
- probe_views_meta = probe_samples[conditions[0]]['meta']['views']
- accuracy = {
- condition: torch.empty(
- len(gallery_views_meta), len(probe_views_meta), num_ranks
- )
- for condition in conditions
- }
-
- for condition in conditions:
- gallery_samples_c = gallery_samples[condition]
- (labels_g, _, views_g, features_g) = gallery_samples_c.values()
- views_g = np.asarray(views_g)
- probe_samples_c = probe_samples[condition]
- (labels_p, _, views_p, features_p, _) = probe_samples_c.values()
- views_p = np.asarray(views_p)
- accuracy_c = accuracy[condition]
- for (v_g_i, view_g) in enumerate(gallery_views_meta):
- gallery_view_mask = (views_g == view_g)
- f_g = features_g[gallery_view_mask]
- y_g = labels_g[gallery_view_mask]
- for (v_p_i, view_p) in enumerate(probe_views_meta):
- probe_view_mask = (views_p == view_p)
- f_p = features_p[probe_view_mask]
- y_p = labels_p[probe_view_mask]
- # Euclidean distance
- f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1)
- f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0)
- f_p_times_f_g_sum = f_p @ f_g.T
- dist = torch.sqrt(F.relu(
- f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum
- ))
- # Ranked accuracy
- rank_mask = dist.argsort(1)[:, :num_ranks]
- positive_mat = torch.eq(y_p.unsqueeze(1),
- y_g[rank_mask]).cumsum(1).gt(0)
- positive_counts = positive_mat.sum(0)
- total_counts, _ = dist.size()
- accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts
- return accuracy
-
def _load_pretrained(
self,
iters: tuple[int],
@@ -629,8 +415,6 @@ class Model:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
- elif isinstance(m, (HorizontalPyramidMatching, PartNet)):
- nn.init.xavier_uniform_(m.fc_mat)
def _parse_dataset_config(
self,
diff --git a/models/part_net.py b/models/part_net.py
deleted file mode 100644
index f2236bf..0000000
--- a/models/part_net.py
+++ /dev/null
@@ -1,149 +0,0 @@
-import copy
-
-import torch
-import torch.nn as nn
-
-from models.layers import BasicConv1d, FocalConv2dBlock
-
-
-class FrameLevelPartFeatureExtractor(nn.Module):
-
- def __init__(
- self,
- in_channels: int = 3,
- feature_channels: int = 32,
- kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)),
- paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)),
- halving: tuple[int, ...] = (0, 2, 3)
- ):
- super().__init__()
- num_blocks = len(kernel_sizes)
- out_channels = [feature_channels * 2 ** i for i in range(num_blocks)]
- in_channels = [in_channels] + out_channels[:-1]
- use_pools = [True] * (num_blocks - 1) + [False]
- params = (in_channels, out_channels, kernel_sizes,
- paddings, halving, use_pools)
-
- self.fconv_blocks = nn.ModuleList([
- FocalConv2dBlock(*_params) for _params in zip(*params)
- ])
-
- def forward(self, x):
- # Flatten frames in all batches
- n, t, c, h, w = x.size()
- x = x.view(n * t, c, h, w)
-
- for fconv_block in self.fconv_blocks:
- x = fconv_block(x)
- return x
-
-
-class TemporalFeatureAggregator(nn.Module):
- def __init__(
- self,
- in_channels: int,
- squeeze_ratio: int = 4,
- num_part: int = 16
- ):
- super().__init__()
- hidden_dim = in_channels // squeeze_ratio
- self.num_part = num_part
-
- # MTB1
- conv3x1 = nn.Sequential(
- BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1),
- nn.LeakyReLU(inplace=True),
- BasicConv1d(hidden_dim, in_channels, kernel_size=1, padding=0)
- )
- self.conv1d3x1 = self._parted(conv3x1)
- self.avg_pool3x1 = nn.AvgPool1d(kernel_size=3, stride=1, padding=1)
- self.max_pool3x1 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1)
-
- # MTB2
- conv3x3 = nn.Sequential(
- BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1),
- nn.LeakyReLU(inplace=True),
- BasicConv1d(hidden_dim, in_channels, kernel_size=3, padding=1)
- )
- self.conv1d3x3 = self._parted(conv3x3)
- self.avg_pool3x3 = nn.AvgPool1d(kernel_size=5, stride=1, padding=2)
- self.max_pool3x3 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2)
-
- def _parted(self, module: nn.Module):
- """Duplicate module `part_num` times."""
- return nn.ModuleList([copy.deepcopy(module)
- for _ in range(self.num_part)])
-
- def forward(self, x):
- # p, n, t, c
- x = x.transpose(2, 3)
- p, n, c, t = x.size()
- feature = x.split(1, dim=0)
- feature = [f.squeeze(0) for f in feature]
- x = x.view(-1, c, t)
-
- # MTB1: ConvNet1d & Sigmoid
- logits3x1 = torch.stack(
- [conv(f) for conv, f in zip(self.conv1d3x1, feature)]
- )
- scores3x1 = torch.sigmoid(logits3x1)
- # MTB1: Template Function
- feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x)
- feature3x1 = feature3x1.view(p, n, c, t)
- feature3x1 = feature3x1 * scores3x1
-
- # MTB2: ConvNet1d & Sigmoid
- logits3x3 = torch.stack(
- [conv(f) for conv, f in zip(self.conv1d3x3, feature)]
- )
- scores3x3 = torch.sigmoid(logits3x3)
- # MTB2: Template Function
- feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x)
- feature3x3 = feature3x3.view(p, n, c, t)
- feature3x3 = feature3x3 * scores3x3
-
- # Temporal Pooling
- ret = (feature3x1 + feature3x3).max(-1)[0]
- return ret
-
-
-class PartNet(nn.Module):
- def __init__(
- self,
- in_channels: int = 128,
- embedding_dims: int = 256,
- num_parts: int = 16,
- squeeze_ratio: int = 4,
- ):
- super().__init__()
- self.num_part = num_parts
-
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.max_pool = nn.AdaptiveMaxPool2d(1)
- self.tfa = TemporalFeatureAggregator(
- in_channels, squeeze_ratio, self.num_part
- )
- self.fc_mat = nn.Parameter(
- torch.empty(num_parts, in_channels, embedding_dims)
- )
-
- def forward(self, x):
- n, t, c, h, w = x.size()
- x = x.view(n * t, c, h, w)
- # n * t x c x h x w
-
- # Horizontal Pooling
- _, c, h, w = x.size()
- split_size = h // self.num_part
- x = x.split(split_size, dim=2)
- x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.view(n, t, c) for x_ in x]
- x = torch.stack(x)
-
- # p, n, t, c
- x = self.tfa(x)
-
- # p, n, c
- x = x @ self.fc_mat
- # p, n, d
- return x
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 4a82da3..d3f8ade 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -1,9 +1,8 @@
import torch
import torch.nn as nn
+import torch.nn.functional as F
from models.auto_encoder import AutoEncoder
-from models.hpm import HorizontalPyramidMatching
-from models.part_net import PartNet
class RGBPartNet(nn.Module):
@@ -13,12 +12,6 @@ class RGBPartNet(nn.Module):
ae_in_size: tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
- hpm_scales: tuple[int, ...] = (1, 2, 4),
- hpm_use_avg_pool: bool = True,
- hpm_use_max_pool: bool = True,
- tfa_squeeze_ratio: int = 4,
- tfa_num_parts: int = 16,
- embedding_dims: tuple[int] = (256, 256),
image_log_on: bool = False
):
super().__init__()
@@ -29,94 +22,66 @@ class RGBPartNet(nn.Module):
self.ae = AutoEncoder(
ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
- self.pn_in_channels = ae_feature_channels * 2
- self.hpm = HorizontalPyramidMatching(
- self.pn_in_channels, embedding_dims[0], hpm_scales,
- hpm_use_avg_pool, hpm_use_max_pool
- )
- self.pn = PartNet(self.pn_in_channels, embedding_dims[1],
- tfa_num_parts, tfa_squeeze_ratio)
-
- self.num_parts = self.hpm.num_parts + tfa_num_parts
def forward(self, x_c1, x_c2=None):
- # Step 1: Disentanglement
- # n, t, c, h, w
- ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2)
-
- # Step 2.a: Static Gait Feature Aggregation & HPM
- # n, c, h, w
- x_c = self.hpm(x_c)
- # p, n, d
-
- # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
- # n, t, c, h, w
- x_p = self.pn(x_p)
- # p, n, d
+ losses, features, images = self._disentangle(x_c1, x_c2)
if self.training:
- return x_c, x_p, ae_losses, images
+ losses = torch.stack(losses)
+ return losses, features, images
else:
- return x_c, x_p
+ return features
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()
- device = x_c1_t2.device
if self.training:
x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
- # Decode features
- x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p_ = self._decode_pose_feature(f_p_, n, t, device)
- x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4)
+ f_a = f_a_.view(n, t, -1)
+ f_c = f_c_.view(n, t, -1)
+ f_p = f_p_.view(n, t, -1)
i_a, i_c, i_p = None, None, None
if self.image_log_on:
with torch.no_grad():
- i_a = self._decode_appr_feature(f_a_, n, t, device)
- # Continue decoding canonical features
- i_c = self.ae.decoder.trans_conv3(x_c)
- i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
- i_p_ = self.ae.decoder.trans_conv3(x_p_)
- i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_))
+ x_a, i_a = self._separate_decode(
+ f_a.mean(1),
+ torch.zeros_like(f_c[:, 0, :]),
+ torch.zeros_like(f_p[:, 0, :])
+ )
+ x_c, i_c = self._separate_decode(
+ torch.zeros_like(f_a[:, 0, :]),
+ f_c.mean(1),
+ torch.zeros_like(f_p[:, 0, :]),
+ )
+ x_p_, i_p_ = self._separate_decode(
+ torch.zeros_like(f_a_),
+ torch.zeros_like(f_c_),
+ f_p_
+ )
+ x_p = tuple(_x_p.view(n, t, *_x_p.size()[1:]) for _x_p in x_p_)
i_p = i_p_.view(n, t, c, h, w)
- return (x_c, x_p), losses, (i_a, i_c, i_p)
+ return losses, (x_a, x_c, x_p), (i_a, i_c, i_p)
else: # evaluating
f_c_, f_p_ = self.ae(x_c1_t2)
- x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p_ = self._decode_pose_feature(f_p_, n, t, device)
- x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4)
- return (x_c, x_p), None, None
-
- def _decode_appr_feature(self, f_a_, n, t, device):
- # Decode appearance features
- f_a = f_a_.view(n, t, -1)
- x_a = self.ae.decoder(
- f_a.mean(1),
- torch.zeros((n, self.f_c_dim), device=device),
- torch.zeros((n, self.f_p_dim), device=device)
- )
- return x_a
-
- def _decode_cano_feature(self, f_c_, n, t, device):
- # Decode average canonical features to higher dimension
- f_c = f_c_.view(n, t, -1)
- x_c = self.ae.decoder(
- torch.zeros((n, self.f_a_dim), device=device),
- f_c.mean(1),
- torch.zeros((n, self.f_p_dim), device=device),
- is_feature_map=True
- )
- return x_c
-
- def _decode_pose_feature(self, f_p_, n, t, device):
- # Decode pose features to images
- x_p_ = self.ae.decoder(
- torch.zeros((n * t, self.f_a_dim), device=device),
- torch.zeros((n * t, self.f_c_dim), device=device),
- f_p_,
- is_feature_map=True
+ f_c = f_c_.view(n, t, -1)
+ f_p = f_p_.view(n, t, -1)
+ return (f_c, f_p), None, None
+
+ def _separate_decode(self, f_a, f_c, f_p):
+ x_1 = torch.cat((f_a, f_c, f_p), dim=1)
+ x_1 = self.ae.decoder.fc(x_1).view(
+ -1,
+ self.ae.decoder.feature_channels * 8,
+ self.ae.decoder.h_0,
+ self.ae.decoder.w_0
)
- return x_p_
+ x_1 = F.relu(x_1, inplace=True)
+ x_2 = self.ae.decoder.trans_conv1(x_1)
+ x_3 = self.ae.decoder.trans_conv2(x_2)
+ x_4 = self.ae.decoder.trans_conv3(x_3)
+ image = torch.sigmoid(self.ae.decoder.trans_conv4(x_4))
+ x = (x_1, x_2, x_3, x_4)
+ return x, image