summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-12 13:59:05 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-12 13:59:05 +0800
commitd63b267dd15388dd323d9b8672cdb9461b96c885 (patch)
tree5095fc80fb93b946e4cfdee88258ab4fd49a8275 /models
parent08911dcb80ecb769972c2d2659c8ad152bbeb447 (diff)
parentc74df416b00f837ba051f3947be92f76e7afbd88 (diff)
Merge branch 'master' into python3.8
# Conflicts: # models/hpm.py # models/rgb_part_net.py # utils/configuration.py # utils/triplet_loss.py
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py2
-rw-r--r--models/hpm.py25
-rw-r--r--models/layers.py9
-rw-r--r--models/model.py119
-rw-r--r--models/part_net.py18
-rw-r--r--models/rgb_part_net.py32
6 files changed, 103 insertions, 102 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index dbd1da0..7f0eb6c 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -173,7 +173,7 @@ class AutoEncoder(nn.Module):
return (
(f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_),
- torch.stack((xrecon_loss, cano_cons_loss, pose_sim_loss * 10))
+ (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
)
else: # evaluating
return f_c_c1_t2_, f_p_c1_t2_
diff --git a/models/hpm.py b/models/hpm.py
index b49be3a..8320569 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -11,32 +11,26 @@ class HorizontalPyramidMatching(nn.Module):
self,
in_channels: int,
out_channels: int = 128,
- use_1x1conv: bool = False,
scales: Tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
use_max_pool: bool = False,
- **kwargs
):
super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.use_1x1conv = use_1x1conv
self.scales = scales
+ self.num_parts = sum(scales)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
self.pyramids = nn.ModuleList([
- self._make_pyramid(scale, **kwargs) for scale in self.scales
+ self._make_pyramid(scale) for scale in scales
])
+ self.fc_mat = nn.Parameter(
+ torch.empty(self.num_parts, in_channels, out_channels)
+ )
- def _make_pyramid(self, scale: int, **kwargs):
+ def _make_pyramid(self, scale: int):
pyramid = nn.ModuleList([
- HorizontalPyramidPooling(self.in_channels,
- self.out_channels,
- use_1x1conv=self.use_1x1conv,
- use_avg_pool=self.use_avg_pool,
- use_max_pool=self.use_max_pool,
- **kwargs)
+ HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool)
for _ in range(scale)
])
return pyramid
@@ -54,4 +48,9 @@ class HorizontalPyramidMatching(nn.Module):
x_slice = x_slice.view(n, -1)
feature.append(x_slice)
x = torch.stack(feature)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
+
return x
diff --git a/models/layers.py b/models/layers.py
index e30d0c4..a0933e8 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -167,17 +167,10 @@ class BasicConv1d(nn.Module):
class HorizontalPyramidPooling(nn.Module):
def __init__(
self,
- in_channels: int,
- out_channels: int,
- use_1x1conv: bool = False,
use_avg_pool: bool = True,
use_max_pool: bool = False,
- **kwargs
):
super().__init__()
- self.use_1x1conv = use_1x1conv
- if use_1x1conv:
- self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.'
@@ -191,6 +184,4 @@ class HorizontalPyramidPooling(nn.Module):
x = self.avg_pool(x)
elif not self.use_avg_pool and self.use_max_pool:
x = self.max_pool(x)
- if self.use_1x1conv:
- x = self.conv(x)
return x
diff --git a/models/model.py b/models/model.py
index e83cc7f..2608236 100644
--- a/models/model.py
+++ b/models/model.py
@@ -13,13 +13,15 @@ from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
+from models.hpm import HorizontalPyramidMatching
+from models.part_net import PartNet
from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
-from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss
+from utils.triplet_loss import BatchTripletLoss
class Model:
@@ -69,7 +71,8 @@ class Model:
self._dataset_sig: str = 'undefined'
self.rgb_pn: Optional[RGBPartNet] = None
- self.triplet_loss: Optional[JointBatchTripletLoss] = None
+ self.triplet_loss_hpm: Optional[BatchTripletLoss] = None
+ self.triplet_loss_pn: Optional[BatchTripletLoss] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
@@ -149,26 +152,28 @@ class Model:
triplet_margins = model_hp.pop('triplet_margins', None)
optim_hp: Dict = self.hp.get('optimizer', {}).copy()
ae_optim_hp = optim_hp.pop('auto_encoder', {})
- pn_optim_hp = optim_hp.pop('part_net', {})
hpm_optim_hp = optim_hp.pop('hpm', {})
- fc_optim_hp = optim_hp.pop('fc', {})
+ pn_optim_hp = optim_hp.pop('part_net', {})
sched_hp = self.hp.get('scheduler', {})
+ ae_sched_hp = sched_hp.get('auto_encoder', {})
+ hpm_sched_hp = sched_hp.get('hpm', {})
+ pn_sched_hp = sched_hp.get('part_net', {})
+
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
# Hard margins
if triplet_margins:
- # Same margins
- if triplet_margins[0] == triplet_margins[1]:
- self.triplet_loss = BatchTripletLoss(
- triplet_is_hard, triplet_margins[0]
- )
- else: # Different margins
- self.triplet_loss = JointBatchTripletLoss(
- self.rgb_pn.hpm_num_parts,
- triplet_is_hard, triplet_is_mean, triplet_margins
- )
+ self.triplet_loss_hpm = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, triplet_margins[0]
+ )
+ self.triplet_loss_pn = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, triplet_margins[1]
+ )
else: # Soft margins
- self.triplet_loss = BatchTripletLoss(
+ self.triplet_loss_hpm = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, None
+ )
+ self.triplet_loss_pn = BatchTripletLoss(
triplet_is_hard, triplet_is_mean, None
)
@@ -177,25 +182,33 @@ class Model:
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
- self.triplet_loss = self.triplet_loss.to(self.device)
+ self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device)
+ self.triplet_loss_pn = self.triplet_loss_pn.to(self.device)
+
self.optimizer = optim.Adam([
{'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
- {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
{'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
- {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
+ {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
], **optim_hp)
- sched_final_gamma = sched_hp.get('final_gamma', 0.001)
- sched_start_step = sched_hp.get('start_step', 15_000)
- all_step = self.total_iter - sched_start_step
-
- def lr_lambda(epoch):
- if epoch > sched_start_step:
- passed_step = epoch - sched_start_step
- return sched_final_gamma ** (passed_step / all_step)
- else:
- return 1
+
+ start_step = sched_hp.get('start_step', 15_000)
+ final_gamma = sched_hp.get('final_gamma', 0.001)
+ ae_start_step = ae_sched_hp.get('start_step', start_step)
+ ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma)
+ ae_all_step = self.total_iter - ae_start_step
+ hpm_start_step = hpm_sched_hp.get('start_step', start_step)
+ hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma)
+ hpm_all_step = self.total_iter - hpm_start_step
+ pn_start_step = pn_sched_hp.get('start_step', start_step)
+ pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma)
+ pn_all_step = self.total_iter - pn_start_step
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
- lr_lambda, lr_lambda, lr_lambda, lr_lambda
+ lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step)
+ if t > ae_start_step else 1,
+ lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step)
+ if t > hpm_start_step else 1,
+ lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step)
+ if t > pn_start_step else 1,
])
self.writer = SummaryWriter(self._log_name)
@@ -220,7 +233,7 @@ class Model:
running_loss = torch.zeros(5, device=self.device)
print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}")
+ f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}")
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -228,17 +241,20 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+ embedding_c, embedding_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
- y = y.repeat(self.rgb_pn.num_total_parts, 1)
- trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
- losses = torch.cat((
- ae_losses,
- torch.stack((
- trip_loss[:self.rgb_pn.hpm_num_parts].mean(),
- trip_loss[self.rgb_pn.hpm_num_parts:].mean()
- ))
+ y = y.repeat(self.rgb_pn.num_parts, 1)
+ trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm(
+ embedding_c, y[:self.rgb_pn.hpm.num_parts]
+ )
+ trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn(
+ embedding_p, y[self.rgb_pn.hpm.num_parts:]
+ )
+ losses = torch.stack((
+ *ae_losses,
+ trip_loss_hpm.mean(),
+ trip_loss_pn.mean()
))
loss = losses.sum()
loss.backward()
@@ -257,30 +273,30 @@ class Model:
'PartNet': losses[4]
}, self.curr_iter)
# None-zero losses in batch
- if num_non_zero is not None:
+ if hpm_num_non_zero is not None and hpm_num_non_zero is not None:
self.writer.add_scalars('Loss/non-zero counts', {
- 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(),
- 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean()
+ 'HPM': hpm_num_non_zero.mean(),
+ 'PartNet': pn_num_non_zero.mean()
}, self.curr_iter)
# Embedding distance
- mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0)
+ mean_hpm_dist = hpm_dist.mean(0)
self._add_ranked_scalars(
'Embedding/HPM distance', mean_hpm_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
- mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0)
+ mean_pa_dist = pn_dist.mean(0)
self._add_ranked_scalars(
'Embedding/ParNet distance', mean_pa_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
# Embedding norm
- mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0)
+ mean_hpm_embedding = embedding_c.mean(0)
mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/HPM norm', mean_hpm_norm,
self.k, self.pr * self.k, self.curr_iter
)
- mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0)
+ mean_pa_embedding = embedding_p.mean(0)
mean_pa_norm = mean_pa_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/PartNet norm', mean_pa_norm,
@@ -288,10 +304,9 @@ class Model:
)
# Learning rate
lrs = self.scheduler.get_last_lr()
- # Write learning rates
- self.writer.add_scalar(
- 'Learning rate', lrs[0], self.curr_iter
- )
+ self.writer.add_scalars('Learning rate', dict(zip((
+ 'Auto-encoder', 'HPM', 'PartNet'
+ ), lrs)), self.curr_iter)
if self.curr_iter % 100 == 0:
# Write disentangled images
@@ -316,7 +331,7 @@ class Model:
print(f'{hour:02}:{minute:02}:{second:02}',
f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- f'{lrs[0]:.3e}')
+ '{:.3e} {:.3e} {:.3e}'.format(*lrs))
running_loss.zero_()
# Step scheduler
@@ -548,7 +563,7 @@ class Model:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
- elif isinstance(m, RGBPartNet):
+ elif isinstance(m, (HorizontalPyramidMatching, PartNet)):
nn.init.xavier_uniform_(m.fc_mat)
def _parse_dataset_config(
diff --git a/models/part_net.py b/models/part_net.py
index d9b954f..de19c8c 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -112,17 +112,21 @@ class PartNet(nn.Module):
def __init__(
self,
in_channels: int = 128,
+ embedding_dims: int = 256,
+ num_parts: int = 16,
squeeze_ratio: int = 4,
- num_part: int = 16
):
super().__init__()
- self.num_part = num_part
- self.tfa = TemporalFeatureAggregator(
- in_channels, squeeze_ratio, self.num_part
- )
+ self.num_part = num_parts
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
+ self.tfa = TemporalFeatureAggregator(
+ in_channels, squeeze_ratio, self.num_part
+ )
+ self.fc_mat = nn.Parameter(
+ torch.empty(num_parts, in_channels, embedding_dims)
+ )
def forward(self, x):
n, t, c, h, w = x.size()
@@ -139,4 +143,8 @@ class PartNet(nn.Module):
# p, n, t, c
x = self.tfa(x)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
return x
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 310ef25..3a5777e 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -15,39 +15,31 @@ class RGBPartNet(nn.Module):
ae_in_size: Tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
f_a_c_p_dims: Tuple[int, int, int] = (128, 128, 64),
- hpm_use_1x1conv: bool = False,
hpm_scales: Tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
- embedding_dims: int = 256,
+ embedding_dims: tuple[int] = (256, 256),
image_log_on: bool = False
):
super().__init__()
self.h, self.w = ae_in_size
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
- self.hpm_num_parts = sum(hpm_scales)
self.image_log_on = image_log_on
self.ae = AutoEncoder(
ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
self.pn_in_channels = ae_feature_channels * 2
- self.pn = PartNet(
- self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts
- )
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv,
- hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
+ self.pn_in_channels, embedding_dims[0], hpm_scales,
+ hpm_use_avg_pool, hpm_use_max_pool
)
- self.num_total_parts = self.hpm_num_parts + tfa_num_parts
- empty_fc = torch.empty(self.num_total_parts,
- self.pn_in_channels, embedding_dims)
- self.fc_mat = nn.Parameter(empty_fc)
+ self.pn = PartNet(self.pn_in_channels, embedding_dims[1],
+ tfa_num_parts, tfa_squeeze_ratio)
- def fc(self, x):
- return x @ self.fc_mat
+ self.num_parts = self.hpm.num_parts + tfa_num_parts
def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
@@ -57,21 +49,17 @@ class RGBPartNet(nn.Module):
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
x_c = self.hpm(x_c)
- # p, n, c
+ # p, n, d
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
# n, t, c, h, w
x_p = self.pn(x_p)
- # p, n, c
-
- # Step 3: Cat feature map together and fc
- x = torch.cat((x_c, x_p))
- x = self.fc(x)
+ # p, n, d
if self.training:
- return x, ae_losses, images
+ return x_c, x_p, ae_losses, images
else:
- return x.unsqueeze(1).view(-1)
+ return torch.cat((x_c, x_p)).unsqueeze(1).view(-1)
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()