summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-19 22:39:49 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-19 22:39:49 +0800
commitd12dd6b04a4e7c2b1ee43ab6f36f25d0c35ca364 (patch)
tree71b5209ce4b5cfb1d09b89fe133028bbfa481dc9
parent4aa9044122878a8e2b887a8b170c036983431559 (diff)
New branch with auto-encoder only
-rw-r--r--config.py38
-rw-r--r--eval.py30
-rw-r--r--models/hpm.py55
-rw-r--r--models/layers.py96
-rw-r--r--models/model.py114
-rw-r--r--models/part_net.py151
-rw-r--r--models/rgb_part_net.py63
-rw-r--r--test/hpm.py23
-rw-r--r--test/part_net.py71
-rw-r--r--utils/configuration.py24
-rw-r--r--utils/triplet_loss.py36
11 files changed, 20 insertions, 681 deletions
diff --git a/config.py b/config.py
index 424bf5b..afd40d5 100644
--- a/config.py
+++ b/config.py
@@ -7,9 +7,9 @@ config: Configuration = {
# GPU(s) used in training or testing if available
'CUDA_VISIBLE_DEVICES': '0',
# Directory used in training or testing for temporary storage
- 'save_dir': 'runs',
+ 'save_dir': 'runs/dis_only',
# Recorde disentangled image or not
- 'image_log_on': False
+ 'image_log_on': True
},
# Dataset settings
'dataset': {
@@ -37,7 +37,7 @@ config: Configuration = {
# Batch size (pr, k)
# `pr` denotes number of persons
# `k` denotes number of sequences per person
- 'batch_size': (4, 8),
+ 'batch_size': (2, 2),
# Number of workers of Dataloader
'num_workers': 4,
# Faster data transfer from RAM to GPU if enabled
@@ -49,35 +49,10 @@ config: Configuration = {
# Auto-encoder feature channels coefficient
'ae_feature_channels': 64,
# Appearance, canonical and pose feature dimensions
- 'f_a_c_p_dims': (128, 128, 64),
- # Use 1x1 convolution in dimensionality reduction
- 'hpm_use_1x1conv': False,
- # HPM pyramid scales, of which sum is number of parts
- 'hpm_scales': (1, 2, 4),
- # Global pooling method
- 'hpm_use_avg_pool': True,
- 'hpm_use_max_pool': False,
- # FConv feature channels coefficient
- 'fpfe_feature_channels': 32,
- # FConv blocks kernel sizes
- 'fpfe_kernel_sizes': ((5, 3), (3, 3), (3, 3)),
- # FConv blocks paddings
- 'fpfe_paddings': ((2, 1), (1, 1), (1, 1)),
- # FConv blocks halving
- 'fpfe_halving': (0, 2, 3),
- # Attention squeeze ratio
- 'tfa_squeeze_ratio': 4,
- # Number of parts after Part Net
- 'tfa_num_parts': 16,
- # Embedding dimension for each part
- 'embedding_dims': 256,
- # Triplet loss margins for HPM and PartNet
- 'triplet_margins': (0.2, 0.2),
+ 'f_a_c_p_dims': (192, 192, 96),
},
'optimizer': {
# Global parameters
- # Iteration start to optimize non-disentangling parts
- # 'start_iter': 0,
# Initial learning rate of Adam Optimizer
'lr': 1e-4,
# Coefficients used for computing running averages of
@@ -89,11 +64,6 @@ config: Configuration = {
# 'weight_decay': 0,
# Use AMSGrad or not
# 'amsgrad': False,
-
- # Local parameters (override global ones)
- 'auto_encoder': {
- 'weight_decay': 0.001
- },
},
'scheduler': {
# Period of learning rate decay
diff --git a/eval.py b/eval.py
deleted file mode 100644
index c0505f8..0000000
--- a/eval.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import numpy as np
-
-from config import config
-from models import Model
-from utils.dataset import ClipConditions
-from utils.misc import set_visible_cuda
-
-set_visible_cuda(config['system'])
-model = Model(config['system'], config['model'], config['hyperparameter'])
-
-dataset_selectors = {
- 'nm': {'conditions': ClipConditions({r'nm-0\d'})},
- 'bg': {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
- 'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
-}
-
-accuracy = model.predict_all(config['model']['total_iters'], config['dataset'],
- dataset_selectors, config['dataloader'])
-rank = 5
-np.set_printoptions(formatter={'float': '{:5.2f}'.format})
-for n in range(rank):
- print(f'===Rank-{n + 1} Accuracy===')
- for (condition, accuracy_c) in accuracy.items():
- acc_excl_identical_view = accuracy_c[:, :, n].fill_diagonal_(0)
- num_gallery_views = (acc_excl_identical_view != 0).sum(0)
- acc_each_angle = acc_excl_identical_view.sum(0) / num_gallery_views
- print('{0}: {1} mean: {2:5.2f}'.format(
- condition, acc_each_angle.cpu().numpy() * 100,
- acc_each_angle.mean() * 100)
- )
diff --git a/models/hpm.py b/models/hpm.py
deleted file mode 100644
index 9879cfb..0000000
--- a/models/hpm.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import torch
-import torch.nn as nn
-
-from models.layers import HorizontalPyramidPooling
-
-
-class HorizontalPyramidMatching(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int = 128,
- use_1x1conv: bool = False,
- scales: tuple[int, ...] = (1, 2, 4),
- use_avg_pool: bool = True,
- use_max_pool: bool = False,
- **kwargs
- ):
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.use_1x1conv = use_1x1conv
- self.scales = scales
- self.use_avg_pool = use_avg_pool
- self.use_max_pool = use_max_pool
-
- self.pyramids = nn.ModuleList([
- self._make_pyramid(scale, **kwargs) for scale in self.scales
- ])
-
- def _make_pyramid(self, scale: int, **kwargs):
- pyramid = nn.ModuleList([
- HorizontalPyramidPooling(self.in_channels,
- self.out_channels,
- use_1x1conv=self.use_1x1conv,
- use_avg_pool=self.use_avg_pool,
- use_max_pool=self.use_max_pool,
- **kwargs)
- for _ in range(scale)
- ])
- return pyramid
-
- def forward(self, x):
- n, c, h, w = x.size()
- feature = []
- for scale, pyramid in zip(self.scales, self.pyramids):
- h_per_hpp = h // scale
- for hpp_index, hpp in enumerate(pyramid):
- h_filter = torch.arange(hpp_index * h_per_hpp,
- (hpp_index + 1) * h_per_hpp)
- x_slice = x[:, :, h_filter, :]
- x_slice = hpp(x_slice)
- x_slice = x_slice.view(n, -1)
- feature.append(x_slice)
- x = torch.stack(feature)
- return x
diff --git a/models/layers.py b/models/layers.py
index ef53a95..1b4640f 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -1,6 +1,5 @@
from typing import Union
-import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -97,98 +96,3 @@ class BasicLinear(nn.Module):
x = self.fc(x)
x = self.bn(x)
return x
-
-
-class FocalConv2d(BasicConv2d):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, tuple[int, int]],
- halving: int,
- **kwargs
- ):
- super().__init__(in_channels, out_channels, kernel_size, **kwargs)
- self.halving = halving
-
- def forward(self, x):
- h = x.size(2)
- split_size = h // 2 ** self.halving
- z = x.split(split_size, dim=2)
- z = torch.cat([self.conv(_) for _ in z], dim=2)
- return F.leaky_relu(z, inplace=True)
-
-
-class FocalConv2dBlock(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_sizes: tuple[int, int],
- paddings: tuple[int, int],
- halving: int,
- use_pool: bool = True,
- **kwargs
- ):
- super().__init__()
- self.use_pool = use_pool
- self.fconv1 = FocalConv2d(in_channels, out_channels, kernel_sizes[0],
- halving, padding=paddings[0], **kwargs)
- self.fconv2 = FocalConv2d(out_channels, out_channels, kernel_sizes[1],
- halving, padding=paddings[1], **kwargs)
- self.max_pool = nn.MaxPool2d(2)
-
- def forward(self, x):
- x = self.fconv1(x)
- x = self.fconv2(x)
- if self.use_pool:
- x = self.max_pool(x)
- return x
-
-
-class BasicConv1d(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: Union[int, tuple[int]],
- **kwargs
- ):
- super().__init__()
- self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
- bias=False, **kwargs)
-
- def forward(self, x):
- return self.conv(x)
-
-
-class HorizontalPyramidPooling(nn.Module):
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- use_1x1conv: bool = False,
- use_avg_pool: bool = True,
- use_max_pool: bool = False,
- **kwargs
- ):
- super().__init__()
- self.use_1x1conv = use_1x1conv
- if use_1x1conv:
- self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs)
- self.use_avg_pool = use_avg_pool
- self.use_max_pool = use_max_pool
- assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.'
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.max_pool = nn.AdaptiveMaxPool2d(1)
-
- def forward(self, x):
- if self.use_avg_pool and self.use_max_pool:
- x = self.avg_pool(x) + self.max_pool(x)
- elif self.use_avg_pool and not self.use_max_pool:
- x = self.avg_pool(x)
- elif not self.use_avg_pool and self.use_max_pool:
- x = self.max_pool(x)
- if self.use_1x1conv:
- x = self.conv(x)
- return x
diff --git a/models/model.py b/models/model.py
index 82d6461..3f24936 100644
--- a/models/model.py
+++ b/models/model.py
@@ -5,7 +5,6 @@ from typing import Union, Optional
import numpy as np
import torch
import torch.nn as nn
-import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
@@ -142,29 +141,16 @@ class Model:
# Prepare for model, optimizer and scheduler
model_hp = self.hp.get('model', {})
optim_hp: dict = self.hp.get('optimizer', {}).copy()
- start_iter = optim_hp.pop('start_iter', 0)
- ae_optim_hp = optim_hp.pop('auto_encoder', {})
- pn_optim_hp = optim_hp.pop('part_net', {})
- hpm_optim_hp = optim_hp.pop('hpm', {})
- fc_optim_hp = optim_hp.pop('fc', {})
sched_hp = self.hp.get('scheduler', {})
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
- self.optimizer = optim.Adam([
- {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
- {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
- {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
- {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
- ], **optim_hp)
+ self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp)
sched_gamma = sched_hp.get('gamma', 0.9)
sched_step_size = sched_hp.get('step_size', 500)
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
lambda epoch: sched_gamma ** (epoch // sched_step_size),
- lambda epoch: 0 if epoch < start_iter else 1,
- lambda epoch: 0 if epoch < start_iter else 1,
- lambda epoch: 0 if epoch < start_iter else 1,
])
self.writer = SummaryWriter(self._log_name)
@@ -182,10 +168,10 @@ class Model:
# Training start
start_time = datetime.now()
- running_loss = torch.zeros(5, device=self.device)
+ running_loss = torch.zeros(3, device=self.device)
print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
+ f"{'LR':^9}")
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -193,10 +179,7 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- y = batch_c1['label'].to(self.device)
- # Duplicate labels for each part
- y = y.unsqueeze(1).repeat(1, self.rgb_pn.num_total_parts)
- losses, images = self.rgb_pn(x_c1, x_c2, y)
+ losses, images = self.rgb_pn(x_c1, x_c2)
loss = losses.sum()
loss.backward()
self.optimizer.step()
@@ -206,19 +189,16 @@ class Model:
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
self.writer.add_scalars('Loss/details', dict(zip([
- 'Cross reconstruction loss', 'Canonical consistency loss',
- 'Pose similarity loss', 'Batch All triplet loss (HPM)',
- 'Batch All triplet loss (PartNet)'
+ 'Cross reconstruction loss',
+ 'Canonical consistency loss',
+ 'Pose similarity loss'
], losses)), self.curr_iter)
if self.curr_iter % 100 == 0:
- lrs = self.scheduler.get_last_lr()
+ lr = self.scheduler.get_last_lr()[0]
# Write learning rates
self.writer.add_scalar(
- 'Learning rate/Auto-encoder', lrs[0], self.curr_iter
- )
- self.writer.add_scalar(
- 'Learning rate/Others', lrs[1], self.curr_iter
+ 'Learning rate/Auto-encoder', lr, self.curr_iter
)
# Write disentangled images
if self.image_log_on:
@@ -241,8 +221,8 @@ class Model:
hour, minute = divmod(remaining_minute, 60)
print(f'{hour:02}:{minute:02}:{second:02}',
f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
- '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- '{:.3e} {:.3e}'.format(lrs[0], lrs[1]))
+ '{:f} {:f} {:f}'.format(*running_loss / 100),
+ f'{lr:.3e}')
running_loss.zero_()
# Step scheduler
@@ -261,24 +241,6 @@ class Model:
self.writer.close()
break
- def predict_all(
- self,
- iters: tuple[int],
- dataset_config: DatasetConfiguration,
- dataset_selectors: dict[
- str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
- ],
- dataloader_config: DataloaderConfiguration,
- ) -> dict[str, torch.Tensor]:
- # Transform data to features
- gallery_samples, probe_samples = self.transform(
- iters, dataset_config, dataset_selectors, dataloader_config
- )
- # Evaluate features
- accuracy = self.evaluate(gallery_samples, probe_samples)
-
- return accuracy
-
def transform(
self,
iters: tuple[int],
@@ -329,61 +291,13 @@ class Model:
def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]):
label = sample.pop('label').item()
clip = sample.pop('clip').to(self.device)
- feature = self.rgb_pn(clip).detach()
+ x_c, x_p = self.rgb_pn(clip).detach()
return {
**{'label': label},
**sample,
- **{'feature': feature}
- }
-
- def evaluate(
- self,
- gallery_samples: dict[str, Union[list[str], torch.Tensor]],
- probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]],
- num_ranks: int = 5
- ) -> dict[str, torch.Tensor]:
- probe_conditions = self._probe_datasets_meta.keys()
- gallery_views_meta = self._gallery_dataset_meta['views']
- probe_views_meta = list(self._probe_datasets_meta.values())[0]['views']
- accuracy = {
- condition: torch.empty(
- len(gallery_views_meta), len(probe_views_meta), num_ranks
- )
- for condition in self._probe_datasets_meta.keys()
+ **{'cano_feature': x_c, 'pose_feature': x_p}
}
- (labels_g, _, views_g, features_g) = gallery_samples.values()
- views_g = np.asarray(views_g)
- for (v_g_i, view_g) in enumerate(gallery_views_meta):
- gallery_view_mask = (views_g == view_g)
- f_g = features_g[gallery_view_mask]
- y_g = labels_g[gallery_view_mask]
- for condition in probe_conditions:
- probe_samples_c = probe_samples[condition]
- accuracy_c = accuracy[condition]
- (labels_p, _, views_p, features_p) = probe_samples_c.values()
- views_p = np.asarray(views_p)
- for (v_p_i, view_p) in enumerate(probe_views_meta):
- probe_view_mask = (views_p == view_p)
- f_p = features_p[probe_view_mask]
- y_p = labels_p[probe_view_mask]
- # Euclidean distance
- f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1)
- f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0)
- f_p_times_f_g_sum = f_p @ f_g.T
- dist = torch.sqrt(F.relu(
- f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum
- ))
- # Ranked accuracy
- rank_mask = dist.argsort(1)[:, :num_ranks]
- positive_mat = torch.eq(y_p.unsqueeze(1),
- y_g[rank_mask]).cumsum(1).gt(0)
- positive_counts = positive_mat.sum(0)
- total_counts, _ = dist.size()
- accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts
-
- return accuracy
-
def _load_pretrained(
self,
iters: tuple[int],
@@ -452,8 +366,6 @@ class Model:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
- elif isinstance(m, RGBPartNet):
- nn.init.xavier_uniform_(m.fc_mat)
def _parse_dataset_config(
self,
diff --git a/models/part_net.py b/models/part_net.py
deleted file mode 100644
index 62a2bac..0000000
--- a/models/part_net.py
+++ /dev/null
@@ -1,151 +0,0 @@
-import copy
-
-import torch
-import torch.nn as nn
-
-from models.layers import BasicConv1d, FocalConv2dBlock
-
-
-class FrameLevelPartFeatureExtractor(nn.Module):
-
- def __init__(
- self,
- in_channels: int = 3,
- feature_channels: int = 32,
- kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)),
- paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)),
- halving: tuple[int, ...] = (0, 2, 3)
- ):
- super().__init__()
- num_blocks = len(kernel_sizes)
- out_channels = [feature_channels * 2 ** i for i in range(num_blocks)]
- in_channels = [in_channels] + out_channels[:-1]
- use_pools = [True] * (num_blocks - 1) + [False]
- params = (in_channels, out_channels, kernel_sizes,
- paddings, halving, use_pools)
-
- self.fconv_blocks = nn.ModuleList([
- FocalConv2dBlock(*_params) for _params in zip(*params)
- ])
-
- def forward(self, x):
- # Flatten frames in all batches
- n, t, c, h, w = x.size()
- x = x.view(n * t, c, h, w)
-
- for fconv_block in self.fconv_blocks:
- x = fconv_block(x)
- return x
-
-
-class TemporalFeatureAggregator(nn.Module):
- def __init__(
- self,
- in_channels: int,
- squeeze_ratio: int = 4,
- num_part: int = 16
- ):
- super().__init__()
- hidden_dim = in_channels // squeeze_ratio
- self.num_part = num_part
-
- # MTB1
- conv3x1 = nn.Sequential(
- BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1),
- nn.LeakyReLU(inplace=True),
- BasicConv1d(hidden_dim, in_channels, kernel_size=1, padding=0)
- )
- self.conv1d3x1 = self._parted(conv3x1)
- self.avg_pool3x1 = nn.AvgPool1d(kernel_size=3, stride=1, padding=1)
- self.max_pool3x1 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1)
-
- # MTB2
- conv3x3 = nn.Sequential(
- BasicConv1d(in_channels, hidden_dim, kernel_size=3, padding=1),
- nn.LeakyReLU(inplace=True),
- BasicConv1d(hidden_dim, in_channels, kernel_size=3, padding=1)
- )
- self.conv1d3x3 = self._parted(conv3x3)
- self.avg_pool3x3 = nn.AvgPool1d(kernel_size=5, stride=1, padding=2)
- self.max_pool3x3 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2)
-
- def _parted(self, module: nn.Module):
- """Duplicate module `part_num` times."""
- return nn.ModuleList([copy.deepcopy(module)
- for _ in range(self.num_part)])
-
- def forward(self, x):
- # p, n, t, c
- x = x.transpose(2, 3)
- p, n, c, t = x.size()
- feature = x.split(1, dim=0)
- feature = [f.squeeze(0) for f in feature]
- x = x.view(-1, c, t)
-
- # MTB1: ConvNet1d & Sigmoid
- logits3x1 = torch.stack(
- [conv(f) for conv, f in zip(self.conv1d3x1, feature)]
- )
- scores3x1 = torch.sigmoid(logits3x1)
- # MTB1: Template Function
- feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x)
- feature3x1 = feature3x1.view(p, n, c, t)
- feature3x1 = feature3x1 * scores3x1
-
- # MTB2: ConvNet1d & Sigmoid
- logits3x3 = torch.stack(
- [conv(f) for conv, f in zip(self.conv1d3x3, feature)]
- )
- scores3x3 = torch.sigmoid(logits3x3)
- # MTB2: Template Function
- feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x)
- feature3x3 = feature3x3.view(p, n, c, t)
- feature3x3 = feature3x3 * scores3x3
-
- # Temporal Pooling
- ret = (feature3x1 + feature3x3).max(-1)[0]
- return ret
-
-
-class PartNet(nn.Module):
- def __init__(
- self,
- in_channels: int = 3,
- feature_channels: int = 32,
- kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)),
- paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)),
- halving: tuple[int, ...] = (0, 2, 3),
- squeeze_ratio: int = 4,
- num_part: int = 16
- ):
- super().__init__()
- self.num_part = num_part
- self.fpfe = FrameLevelPartFeatureExtractor(
- in_channels, feature_channels, kernel_sizes, paddings, halving
- )
-
- num_fconv_blocks = len(self.fpfe.fconv_blocks)
- self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1)
- self.tfa = TemporalFeatureAggregator(
- self.tfa_in_channels, squeeze_ratio, self.num_part
- )
-
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.max_pool = nn.AdaptiveMaxPool2d(1)
-
- def forward(self, x):
- n, t, _, _, _ = x.size()
- x = self.fpfe(x)
- # n * t x c x h x w
-
- # Horizontal Pooling
- _, c, h, w = x.size()
- split_size = h // self.num_part
- x = x.split(split_size, dim=2)
- x = [self.avg_pool(x_) + self.max_pool(x_) for x_ in x]
- x = [x_.view(n, t, c) for x_ in x]
- x = torch.stack(x)
-
- # p, n, t, c
- x = self.tfa(x)
- return x
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 67acac3..f18d675 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -2,9 +2,6 @@ import torch
import torch.nn as nn
from models.auto_encoder import AutoEncoder
-from models.hpm import HorizontalPyramidMatching
-from models.part_net import PartNet
-from utils.triplet_loss import BatchAllTripletLoss
class RGBPartNet(nn.Module):
@@ -14,80 +11,26 @@ class RGBPartNet(nn.Module):
ae_in_size: tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
- hpm_use_1x1conv: bool = False,
- hpm_scales: tuple[int, ...] = (1, 2, 4),
- hpm_use_avg_pool: bool = True,
- hpm_use_max_pool: bool = True,
- fpfe_feature_channels: int = 32,
- fpfe_kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)),
- fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)),
- fpfe_halving: tuple[int, ...] = (0, 2, 3),
- tfa_squeeze_ratio: int = 4,
- tfa_num_parts: int = 16,
- embedding_dims: int = 256,
- triplet_margins: tuple[float, float] = (0.2, 0.2),
image_log_on: bool = False
):
super().__init__()
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
- self.hpm_num_parts = sum(hpm_scales)
self.image_log_on = image_log_on
self.ae = AutoEncoder(
ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
- self.pn = PartNet(
- ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes,
- fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_parts
- )
- out_channels = self.pn.tfa_in_channels
- self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 2, out_channels, hpm_use_1x1conv,
- hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
- )
- self.num_total_parts = self.hpm_num_parts + tfa_num_parts
- empty_fc = torch.empty(self.num_total_parts,
- out_channels, embedding_dims)
- self.fc_mat = nn.Parameter(empty_fc)
-
- (hpm_margin, pn_margin) = triplet_margins
- self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin)
- self.pn_ba_trip = BatchAllTripletLoss(pn_margin)
-
- def fc(self, x):
- return x @ self.fc_mat
- def forward(self, x_c1, x_c2=None, y=None):
+ def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
# n, t, c, h, w
((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2)
- # Step 2.a: Static Gait Feature Aggregation & HPM
- # n, c, h, w
- x_c = self.hpm(x_c)
- # p, n, c
-
- # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
- # n, t, c, h, w
- x_p = self.pn(x_p)
- # p, n, c
-
- # Step 3: Cat feature map together and fc
- x = torch.cat((x_c, x_p))
- x = self.fc(x)
-
if self.training:
- y = y.T
- hpm_ba_trip = self.hpm_ba_trip(
- x[:self.hpm_num_parts], y[:self.hpm_num_parts]
- )
- pn_ba_trip = self.pn_ba_trip(
- x[self.hpm_num_parts:], y[self.hpm_num_parts:]
- )
- losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip))
+ losses = torch.stack(losses)
return losses, images
else:
- return x.unsqueeze(1).view(-1)
+ return x_c, x_p
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()
diff --git a/test/hpm.py b/test/hpm.py
deleted file mode 100644
index 0aefbb8..0000000
--- a/test/hpm.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import torch
-
-from models.hpm import HorizontalPyramidMatching
-
-T, N, C, H, W = 15, 4, 256, 32, 16
-
-
-def test_default_hpm():
- hpm = HorizontalPyramidMatching(in_channels=C)
- x = torch.rand(T, N, C, H, W)
- x = hpm(x)
- assert tuple(x.size()) == (1 + 2 + 4, T, N, 128)
-
-
-def test_custom_hpm():
- hpm = HorizontalPyramidMatching(in_channels=2048,
- out_channels=256,
- scales=(1, 2, 4, 8),
- use_avg_pool=True,
- use_max_pool=False)
- x = torch.rand(T, N, 2048, H, W)
- x = hpm(x)
- assert tuple(x.size()) == (1 + 2 + 4 + 8, T, N, 256)
diff --git a/test/part_net.py b/test/part_net.py
deleted file mode 100644
index 25e92ae..0000000
--- a/test/part_net.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import torch
-
-from models.part_net import FrameLevelPartFeatureExtractor, \
- TemporalFeatureAggregator, PartNet
-
-T, N, C, H, W = 15, 4, 3, 64, 32
-
-
-def test_default_fpfe():
- fpfe = FrameLevelPartFeatureExtractor()
- x = torch.rand(T, N, C, H, W)
- x = fpfe(x)
-
- assert tuple(x.size()) == (T * N, 32 * 4, 16, 8)
-
-
-def test_custom_fpfe():
- feature_channels = 64
- fpfe = FrameLevelPartFeatureExtractor(
- in_channels=1,
- feature_channels=feature_channels,
- kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)),
- paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
- halving=(1, 1, 3, 3)
- )
- x = torch.rand(T, N, 1, H, W)
- x = fpfe(x)
-
- assert tuple(x.size()) == (T * N, feature_channels * 8, 8, 4)
-
-
-def test_default_tfa():
- in_channels = 32 * 4
- tfa = TemporalFeatureAggregator(in_channels)
- x = torch.rand(16, T, N, in_channels)
- x = tfa(x)
-
- assert tuple(x.size()) == (16, N, in_channels)
-
-
-def test_custom_tfa():
- in_channels = 64 * 8
- num_part = 8
- tfa = TemporalFeatureAggregator(in_channels=in_channels,
- squeeze_ratio=8, num_part=num_part)
- x = torch.rand(num_part, T, N, in_channels)
- x = tfa(x)
-
- assert tuple(x.size()) == (num_part, N, in_channels)
-
-
-def test_default_part_net():
- pa = PartNet()
- x = torch.rand(T, N, C, H, W)
- x = pa(x)
-
- assert tuple(x.size()) == (16, N, 32 * 4)
-
-
-def test_custom_part_net():
- feature_channels = 64
- pa = PartNet(in_channels=1, feature_channels=feature_channels,
- kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)),
- paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
- halving=(1, 1, 3, 3),
- squeeze_ratio=8,
- num_part=8)
- x = torch.rand(T, N, 1, H, W)
- x = pa(x)
-
- assert tuple(x.size()) == (8, N, pa.tfa_in_channels)
diff --git a/utils/configuration.py b/utils/configuration.py
index 435d815..1b7c8d3 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -32,26 +32,6 @@ class DataloaderConfiguration(TypedDict):
class ModelHPConfiguration(TypedDict):
ae_feature_channels: int
f_a_c_p_dims: tuple[int, int, int]
- hpm_scales: tuple[int, ...]
- hpm_use_1x1conv: bool
- hpm_use_avg_pool: bool
- hpm_use_max_pool: bool
- fpfe_feature_channels: int
- fpfe_kernel_sizes: tuple[tuple, ...]
- fpfe_paddings: tuple[tuple, ...]
- fpfe_halving: tuple[int, ...]
- tfa_squeeze_ratio: int
- tfa_num_parts: int
- embedding_dims: int
- triplet_margins: tuple[float, float]
-
-
-class SubOptimizerHPConfiguration(TypedDict):
- lr: int
- betas: tuple[float, float]
- eps: float
- weight_decay: float
- amsgrad: bool
class OptimizerHPConfiguration(TypedDict):
@@ -61,10 +41,6 @@ class OptimizerHPConfiguration(TypedDict):
eps: float
weight_decay: float
amsgrad: bool
- auto_encoder: SubOptimizerHPConfiguration
- part_net: SubOptimizerHPConfiguration
- hpm: SubOptimizerHPConfiguration
- fc: SubOptimizerHPConfiguration
class SchedulerHPConfiguration(TypedDict):
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
deleted file mode 100644
index 954def2..0000000
--- a/utils/triplet_loss.py
+++ /dev/null
@@ -1,36 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class BatchAllTripletLoss(nn.Module):
- def __init__(self, margin: float = 0.2):
- super().__init__()
- self.margin = margin
-
- def forward(self, x, y):
- p, n, c = x.size()
-
- # Euclidean distance p x n x n
- x_squared_sum = torch.sum(x ** 2, dim=2)
- x1_squared_sum = x_squared_sum.unsqueeze(2)
- x2_squared_sum = x_squared_sum.unsqueeze(1)
- x1_times_x2_sum = x @ x.transpose(1, 2)
- dist = torch.sqrt(
- F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum)
- )
-
- hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
- hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
- all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1)
- all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1)
- positive_negative_dist = all_hard_positive - all_hard_negative
- all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1)
-
- # Non-zero parted mean
- non_zero_counts = (all_loss != 0).sum(1)
- parted_loss_mean = all_loss.sum(1) / non_zero_counts
- parted_loss_mean[non_zero_counts == 0] = 0
-
- loss = parted_loss_mean.mean()
- return loss