summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/hpm.py25
-rw-r--r--models/layers.py9
-rw-r--r--models/model.py144
-rw-r--r--models/part_net.py18
-rw-r--r--models/rgb_part_net.py32
5 files changed, 114 insertions, 114 deletions
diff --git a/models/hpm.py b/models/hpm.py
index 9879cfb..8186b20 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -9,32 +9,26 @@ class HorizontalPyramidMatching(nn.Module):
self,
in_channels: int,
out_channels: int = 128,
- use_1x1conv: bool = False,
scales: tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
use_max_pool: bool = False,
- **kwargs
):
super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.use_1x1conv = use_1x1conv
self.scales = scales
+ self.num_parts = sum(scales)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
self.pyramids = nn.ModuleList([
- self._make_pyramid(scale, **kwargs) for scale in self.scales
+ self._make_pyramid(scale) for scale in scales
])
+ self.fc_mat = nn.Parameter(
+ torch.empty(self.num_parts, in_channels, out_channels)
+ )
- def _make_pyramid(self, scale: int, **kwargs):
+ def _make_pyramid(self, scale: int):
pyramid = nn.ModuleList([
- HorizontalPyramidPooling(self.in_channels,
- self.out_channels,
- use_1x1conv=self.use_1x1conv,
- use_avg_pool=self.use_avg_pool,
- use_max_pool=self.use_max_pool,
- **kwargs)
+ HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool)
for _ in range(scale)
])
return pyramid
@@ -52,4 +46,9 @@ class HorizontalPyramidMatching(nn.Module):
x_slice = x_slice.view(n, -1)
feature.append(x_slice)
x = torch.stack(feature)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
+
return x
diff --git a/models/layers.py b/models/layers.py
index f1d72b6..c609698 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -167,17 +167,10 @@ class BasicConv1d(nn.Module):
class HorizontalPyramidPooling(nn.Module):
def __init__(
self,
- in_channels: int,
- out_channels: int,
- use_1x1conv: bool = False,
use_avg_pool: bool = True,
use_max_pool: bool = False,
- **kwargs
):
super().__init__()
- self.use_1x1conv = use_1x1conv
- if use_1x1conv:
- self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.'
@@ -191,6 +184,4 @@ class HorizontalPyramidPooling(nn.Module):
x = self.avg_pool(x)
elif not self.use_avg_pool and self.use_max_pool:
x = self.max_pool(x)
- if self.use_1x1conv:
- x = self.conv(x)
return x
diff --git a/models/model.py b/models/model.py
index f515e05..bf79564 100644
--- a/models/model.py
+++ b/models/model.py
@@ -13,13 +13,15 @@ from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
+from models.hpm import HorizontalPyramidMatching
+from models.part_net import PartNet
from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
-from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss
+from utils.triplet_loss import BatchTripletLoss
class Model:
@@ -69,7 +71,8 @@ class Model:
self._dataset_sig: str = 'undefined'
self.rgb_pn: Optional[RGBPartNet] = None
- self.triplet_loss: Optional[JointBatchTripletLoss] = None
+ self.triplet_loss_hpm: Optional[BatchTripletLoss] = None
+ self.triplet_loss_pn: Optional[BatchTripletLoss] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
@@ -149,26 +152,28 @@ class Model:
triplet_margins = model_hp.pop('triplet_margins', None)
optim_hp: dict = self.hp.get('optimizer', {}).copy()
ae_optim_hp = optim_hp.pop('auto_encoder', {})
- pn_optim_hp = optim_hp.pop('part_net', {})
hpm_optim_hp = optim_hp.pop('hpm', {})
- fc_optim_hp = optim_hp.pop('fc', {})
+ pn_optim_hp = optim_hp.pop('part_net', {})
sched_hp = self.hp.get('scheduler', {})
+ ae_sched_hp = sched_hp.get('auto_encoder', {})
+ hpm_sched_hp = sched_hp.get('hpm', {})
+ pn_sched_hp = sched_hp.get('part_net', {})
+
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
# Hard margins
if triplet_margins:
- # Same margins
- if triplet_margins[0] == triplet_margins[1]:
- self.triplet_loss = BatchTripletLoss(
- triplet_is_hard, triplet_margins[0]
- )
- else: # Different margins
- self.triplet_loss = JointBatchTripletLoss(
- self.rgb_pn.module.hpm_num_parts,
- triplet_is_hard, triplet_is_mean, triplet_margins
- )
+ self.triplet_loss_hpm = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, triplet_margins[0]
+ )
+ self.triplet_loss_pn = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, triplet_margins[1]
+ )
else: # Soft margins
- self.triplet_loss = BatchTripletLoss(
+ self.triplet_loss_hpm = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, None
+ )
+ self.triplet_loss_pn = BatchTripletLoss(
triplet_is_hard, triplet_is_mean, None
)
@@ -179,26 +184,34 @@ class Model:
# Try to accelerate computation using CUDA or others
self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
- self.triplet_loss = nn.DataParallel(self.triplet_loss)
- self.triplet_loss = self.triplet_loss.to(self.device)
+ self.triplet_loss_hpm = nn.DataParallel(self.triplet_loss_hpm)
+ self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device)
+ self.triplet_loss_pn = nn.DataParallel(self.triplet_loss_pn)
+ self.triplet_loss_pn = self.triplet_loss_pn.to(self.device)
self.optimizer = optim.Adam([
{'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp},
- {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
{'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp},
- {'params': self.rgb_pn.module.fc_mat, **fc_optim_hp}
+ {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
], **optim_hp)
- sched_final_gamma = sched_hp.get('final_gamma', 0.001)
- sched_start_step = sched_hp.get('start_step', 15_000)
- all_step = self.total_iter - sched_start_step
-
- def lr_lambda(epoch):
- if epoch > sched_start_step:
- passed_step = epoch - sched_start_step
- return sched_final_gamma ** (passed_step / all_step)
- else:
- return 1
+
+ start_step = sched_hp.get('start_step', 15_000)
+ final_gamma = sched_hp.get('final_gamma', 0.001)
+ ae_start_step = ae_sched_hp.get('start_step', start_step)
+ ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma)
+ ae_all_step = self.total_iter - ae_start_step
+ hpm_start_step = hpm_sched_hp.get('start_step', start_step)
+ hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma)
+ hpm_all_step = self.total_iter - hpm_start_step
+ pn_start_step = pn_sched_hp.get('start_step', start_step)
+ pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma)
+ pn_all_step = self.total_iter - pn_start_step
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
- lr_lambda, lr_lambda, lr_lambda, lr_lambda
+ lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step)
+ if t > ae_start_step else 1,
+ lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step)
+ if t > hpm_start_step else 1,
+ lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step)
+ if t > pn_start_step else 1,
])
self.writer = SummaryWriter(self._log_name)
@@ -223,7 +236,7 @@ class Model:
running_loss = torch.zeros(5, device=self.device)
print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}")
+ f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}")
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -231,7 +244,7 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embedding, images, feature_for_loss = self.rgb_pn(x_c1, x_c2)
+ embed_c, embed_p, images, feature_for_loss = self.rgb_pn(x_c1, x_c2)
x_c1_pred = feature_for_loss[0]
xrecon_loss = torch.stack([
F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :])
@@ -249,13 +262,16 @@ class Model:
) * 10
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
- y = y.repeat(self.rgb_pn.module.num_total_parts, 1)
- embedding = embedding.transpose(0, 1)
- triplet_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
- hpm_loss = triplet_loss[:self.rgb_pn.module.hpm_num_parts].mean()
- pn_loss = triplet_loss[self.rgb_pn.module.hpm_num_parts:].mean()
+ y = y.repeat(self.rgb_pn.module.num_parts, 1)
+ trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm(
+ embed_c.transpose(0, 1), y[:self.rgb_pn.module.hpm.num_parts]
+ )
+ trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn(
+ embed_p.transpose(0, 1), y[self.rgb_pn.module.hpm.num_parts:]
+ )
losses = torch.stack((
- xrecon_loss, cano_cons_loss, pose_sim_loss, hpm_loss, pn_loss
+ xrecon_loss, cano_cons_loss, pose_sim_loss,
+ trip_loss_hpm.mean(), trip_loss_pn.mean()
))
loss = losses.sum()
loss.backward()
@@ -271,37 +287,34 @@ class Model:
'Pose similarity loss': pose_sim_loss
}, self.curr_iter)
self.writer.add_scalars('Loss/triplet loss', {
- 'HPM': hpm_loss, 'PartNet': pn_loss
+ 'HPM': losses[3],
+ 'PartNet': losses[4]
}, self.curr_iter)
# None-zero losses in batch
- if num_non_zero is not None:
+ if hpm_num_non_zero is not None and hpm_num_non_zero is not None:
self.writer.add_scalars('Loss/non-zero counts', {
- 'HPM': num_non_zero[
- :self.rgb_pn.module.hpm_num_parts].mean(),
- 'PartNet': num_non_zero[
- self.rgb_pn.module.hpm_num_parts:].mean()
+ 'HPM': hpm_num_non_zero.mean(),
+ 'PartNet': pn_num_non_zero.mean()
}, self.curr_iter)
# Embedding distance
- mean_hpm_dist = dist[:self.rgb_pn.module.hpm_num_parts].mean(0)
+ mean_hpm_dist = hpm_dist.mean(0)
self._add_ranked_scalars(
'Embedding/HPM distance', mean_hpm_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
- mean_pa_dist = dist[self.rgb_pn.module.hpm_num_parts:].mean(0)
+ mean_pa_dist = pn_dist.mean(0)
self._add_ranked_scalars(
'Embedding/ParNet distance', mean_pa_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
# Embedding norm
- mean_hpm_embedding = embedding[
- :self.rgb_pn.module.hpm_num_parts].mean(0)
+ mean_hpm_embedding = embed_c.mean(0)
mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/HPM norm', mean_hpm_norm,
self.k, self.pr * self.k, self.curr_iter
)
- mean_pa_embedding = embedding[
- self.rgb_pn.module.hpm_num_parts:].mean(0)
+ mean_pa_embedding = embed_p.mean(0)
mean_pa_norm = mean_pa_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/PartNet norm', mean_pa_norm,
@@ -309,10 +322,9 @@ class Model:
)
# Learning rate
lrs = self.scheduler.get_last_lr()
- # Write learning rates
- self.writer.add_scalar(
- 'Learning rate', lrs[0], self.curr_iter
- )
+ self.writer.add_scalars('Learning rate', dict(zip((
+ 'Auto-encoder', 'HPM', 'PartNet'
+ ), lrs)), self.curr_iter)
if self.curr_iter % 100 == 0:
# Write disentangled images
@@ -337,7 +349,7 @@ class Model:
print(f'{hour:02}:{minute:02}:{second:02}',
f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- f'{lrs[0]:.3e}')
+ '{:.3e} {:.3e} {:.3e}'.format(*lrs))
running_loss.zero_()
# Step scheduler
@@ -432,13 +444,16 @@ class Model:
unit='clips'):
gallery_samples_c.append(self._get_eval_sample(sample))
gallery_samples[condition] = default_collate(gallery_samples_c)
+ gallery_samples['meta'] = self._gallery_dataset_meta
# Probe
probe_samples_c = []
for sample in tqdm(probe_dataloader,
desc=f'Transforming probe {condition}',
unit='clips'):
probe_samples_c.append(self._get_eval_sample(sample))
- probe_samples[condition] = default_collate(probe_samples_c)
+ probe_samples_c = default_collate(probe_samples_c)
+ probe_samples_c['meta'] = self._probe_datasets_meta[condition]
+ probe_samples[condition] = probe_samples_c
return gallery_samples, probe_samples
@@ -453,15 +468,15 @@ class Model:
**{'feature': feature}
}
+ @staticmethod
def evaluate(
- self,
- gallery_samples: dict[str, Union[list[str], torch.Tensor]],
- probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]],
+ gallery_samples: dict[str, dict[str, Union[list, torch.Tensor]]],
+ probe_samples: dict[str, dict[str, Union[list, torch.Tensor]]],
num_ranks: int = 5
) -> dict[str, torch.Tensor]:
- conditions = gallery_samples.keys()
- gallery_views_meta = self._gallery_dataset_meta['views']
- probe_views_meta = list(self._probe_datasets_meta.values())[0]['views']
+ conditions = list(probe_samples.keys())
+ gallery_views_meta = gallery_samples['meta']['views']
+ probe_views_meta = probe_samples[conditions[0]]['meta']['views']
accuracy = {
condition: torch.empty(
len(gallery_views_meta), len(probe_views_meta), num_ranks
@@ -474,7 +489,7 @@ class Model:
(labels_g, _, views_g, features_g) = gallery_samples_c.values()
views_g = np.asarray(views_g)
probe_samples_c = probe_samples[condition]
- (labels_p, _, views_p, features_p) = probe_samples_c.values()
+ (labels_p, _, views_p, features_p, _) = probe_samples_c.values()
views_p = np.asarray(views_p)
accuracy_c = accuracy[condition]
for (v_g_i, view_g) in enumerate(gallery_views_meta):
@@ -499,7 +514,6 @@ class Model:
positive_counts = positive_mat.sum(0)
total_counts, _ = dist.size()
accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts
-
return accuracy
def _load_pretrained(
@@ -570,7 +584,7 @@ class Model:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
- elif isinstance(m, RGBPartNet):
+ elif isinstance(m, (HorizontalPyramidMatching, PartNet)):
nn.init.xavier_uniform_(m.fc_mat)
def _parse_dataset_config(
diff --git a/models/part_net.py b/models/part_net.py
index 29cf9cd..f2236bf 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -111,17 +111,21 @@ class PartNet(nn.Module):
def __init__(
self,
in_channels: int = 128,
+ embedding_dims: int = 256,
+ num_parts: int = 16,
squeeze_ratio: int = 4,
- num_part: int = 16
):
super().__init__()
- self.num_part = num_part
- self.tfa = TemporalFeatureAggregator(
- in_channels, squeeze_ratio, self.num_part
- )
+ self.num_part = num_parts
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
+ self.tfa = TemporalFeatureAggregator(
+ in_channels, squeeze_ratio, self.num_part
+ )
+ self.fc_mat = nn.Parameter(
+ torch.empty(num_parts, in_channels, embedding_dims)
+ )
def forward(self, x):
n, t, c, h, w = x.size()
@@ -138,4 +142,8 @@ class PartNet(nn.Module):
# p, n, t, c
x = self.tfa(x)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
return x
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 7785bb7..fdeed17 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -13,39 +13,31 @@ class RGBPartNet(nn.Module):
ae_in_size: tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
- hpm_use_1x1conv: bool = False,
hpm_scales: tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
- embedding_dims: int = 256,
+ embedding_dims: tuple[int] = (256, 256),
image_log_on: bool = False
):
super().__init__()
self.h, self.w = ae_in_size
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
- self.hpm_num_parts = sum(hpm_scales)
self.image_log_on = image_log_on
self.ae = AutoEncoder(
ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
self.pn_in_channels = ae_feature_channels * 2
- self.pn = PartNet(
- self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts
- )
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv,
- hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
+ self.pn_in_channels, embedding_dims[0], hpm_scales,
+ hpm_use_avg_pool, hpm_use_max_pool
)
- self.num_total_parts = self.hpm_num_parts + tfa_num_parts
- empty_fc = torch.empty(self.num_total_parts,
- self.pn_in_channels, embedding_dims)
- self.fc_mat = nn.Parameter(empty_fc)
+ self.pn = PartNet(self.pn_in_channels, embedding_dims[1],
+ tfa_num_parts, tfa_squeeze_ratio)
- def fc(self, x):
- return x @ self.fc_mat
+ self.num_parts = self.hpm.num_parts + tfa_num_parts
def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
@@ -55,21 +47,17 @@ class RGBPartNet(nn.Module):
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
x_c = self.hpm(x_c)
- # p, n, c
+ # p, n, d
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
# n, t, c, h, w
x_p = self.pn(x_p)
- # p, n, c
-
- # Step 3: Cat feature map together and fc
- x = torch.cat((x_c, x_p))
- x = self.fc(x)
+ # p, n, d
if self.training:
- return x.transpose(0, 1), images, f_loss
+ return x_c.transpose(0, 1), x_p.transpose(0, 1), images, f_loss
else:
- return x.unsqueeze(1).view(-1)
+ return torch.cat((x_c, x_p)).unsqueeze(1).view(-1)
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()