summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-20 14:19:30 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-20 14:19:30 +0800
commit50eedae4f320c446544772fb2b0abcbce1be7590 (patch)
tree93a5b252a6d11b0c677d0700522d928d59916a9e
parent4aa9044122878a8e2b887a8b170c036983431559 (diff)
Separate triplet loss from model
-rw-r--r--models/auto_encoder.py2
-rw-r--r--models/model.py20
-rw-r--r--models/rgb_part_net.py20
-rw-r--r--utils/triplet_loss.py58
4 files changed, 71 insertions, 29 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 2d715db..1ef7494 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -171,7 +171,7 @@ class AutoEncoder(nn.Module):
return (
(f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_),
- (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
+ torch.stack((xrecon_loss, cano_cons_loss, pose_sim_loss * 10))
)
else: # evaluating
return f_c_c1_t2_, f_p_c1_t2_
diff --git a/models/model.py b/models/model.py
index 82d6461..5aa6436 100644
--- a/models/model.py
+++ b/models/model.py
@@ -18,6 +18,7 @@ from utils.configuration import DataloaderConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
+from utils.triplet_loss import JointBatchAllTripletLoss
class Model:
@@ -67,6 +68,7 @@ class Model:
self._dataset_sig: str = 'undefined'
self.rgb_pn: Optional[RGBPartNet] = None
+ self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
@@ -140,7 +142,8 @@ class Model:
dataset = self._parse_dataset_config(dataset_config)
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
# Prepare for model, optimizer and scheduler
- model_hp = self.hp.get('model', {})
+ model_hp: dict = self.hp.get('model', {}).copy()
+ triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2))
optim_hp: dict = self.hp.get('optimizer', {}).copy()
start_iter = optim_hp.pop('start_iter', 0)
ae_optim_hp = optim_hp.pop('auto_encoder', {})
@@ -150,6 +153,9 @@ class Model:
sched_hp = self.hp.get('scheduler', {})
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
+ self.ba_triplet_loss = JointBatchAllTripletLoss(
+ self.rgb_pn.hpm_num_parts, triplet_margins
+ )
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam([
@@ -193,10 +199,18 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
+ feature, ae_losses, images = self.rgb_pn(x_c1, x_c2)
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
- y = y.unsqueeze(1).repeat(1, self.rgb_pn.num_total_parts)
- losses, images = self.rgb_pn(x_c1, x_c2, y)
+ y = y.repeat(self.rgb_pn.num_total_parts, 1)
+ triplet_loss = self.ba_triplet_loss(feature, y)
+ losses = torch.cat((
+ ae_losses,
+ torch.stack((
+ triplet_loss[:self.rgb_pn.hpm_num_parts].mean(),
+ triplet_loss[self.rgb_pn.hpm_num_parts:].mean()
+ ))
+ ))
loss = losses.sum()
loss.backward()
self.optimizer.step()
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 67acac3..408bca0 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -4,7 +4,6 @@ import torch.nn as nn
from models.auto_encoder import AutoEncoder
from models.hpm import HorizontalPyramidMatching
from models.part_net import PartNet
-from utils.triplet_loss import BatchAllTripletLoss
class RGBPartNet(nn.Module):
@@ -25,7 +24,6 @@ class RGBPartNet(nn.Module):
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
embedding_dims: int = 256,
- triplet_margins: tuple[float, float] = (0.2, 0.2),
image_log_on: bool = False
):
super().__init__()
@@ -50,17 +48,13 @@ class RGBPartNet(nn.Module):
out_channels, embedding_dims)
self.fc_mat = nn.Parameter(empty_fc)
- (hpm_margin, pn_margin) = triplet_margins
- self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin)
- self.pn_ba_trip = BatchAllTripletLoss(pn_margin)
-
def fc(self, x):
return x @ self.fc_mat
- def forward(self, x_c1, x_c2=None, y=None):
+ def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
# n, t, c, h, w
- ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2)
+ ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2)
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
@@ -77,15 +71,7 @@ class RGBPartNet(nn.Module):
x = self.fc(x)
if self.training:
- y = y.T
- hpm_ba_trip = self.hpm_ba_trip(
- x[:self.hpm_num_parts], y[:self.hpm_num_parts]
- )
- pn_ba_trip = self.pn_ba_trip(
- x[self.hpm_num_parts:], y[self.hpm_num_parts:]
- )
- losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip))
- return losses, images
+ return x, ae_losses, images
else:
return x.unsqueeze(1).view(-1)
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 954def2..0df2188 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -11,6 +11,25 @@ class BatchAllTripletLoss(nn.Module):
def forward(self, x, y):
p, n, c = x.size()
+ dist = self._batch_distance(x)
+ positive_negative_dist = self._hard_distance(dist, y, p, n)
+ all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1)
+ parted_loss_mean = self._none_zero_parted_mean(all_loss)
+
+ return parted_loss_mean
+
+ @staticmethod
+ def _hard_distance(dist, y, p, n):
+ hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
+ hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
+ all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1)
+ all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1)
+ positive_negative_dist = all_hard_positive - all_hard_negative
+
+ return positive_negative_dist
+
+ @staticmethod
+ def _batch_distance(x):
# Euclidean distance p x n x n
x_squared_sum = torch.sum(x ** 2, dim=2)
x1_squared_sum = x_squared_sum.unsqueeze(2)
@@ -20,17 +39,40 @@ class BatchAllTripletLoss(nn.Module):
F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum)
)
- hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
- hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
- all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1)
- all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1)
- positive_negative_dist = all_hard_positive - all_hard_negative
- all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1)
+ return dist
+ @staticmethod
+ def _none_zero_parted_mean(all_loss):
# Non-zero parted mean
non_zero_counts = (all_loss != 0).sum(1)
parted_loss_mean = all_loss.sum(1) / non_zero_counts
parted_loss_mean[non_zero_counts == 0] = 0
- loss = parted_loss_mean.mean()
- return loss
+ return parted_loss_mean
+
+
+class JointBatchAllTripletLoss(BatchAllTripletLoss):
+ def __init__(
+ self,
+ hpm_num_parts: int,
+ margins: tuple[float, float] = (0.2, 0.2)
+ ):
+ super().__init__()
+ self.hpm_num_parts = hpm_num_parts
+ self.margin_hpm, self.margin_pn = margins
+
+ def forward(self, x, y):
+ p, n, c = x.size()
+
+ dist = self._batch_distance(x)
+ positive_negative_dist = self._hard_distance(dist, y, p, n)
+ hpm_part_loss = F.relu(
+ self.margin_hpm + positive_negative_dist[:self.hpm_num_parts]
+ ).view(self.hpm_num_parts, -1)
+ pn_part_loss = F.relu(
+ self.margin_pn + positive_negative_dist[self.hpm_num_parts:]
+ ).view(p - self.hpm_num_parts, -1)
+ all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1)
+ parted_loss_mean = self._none_zero_parted_mean(all_loss)
+
+ return parted_loss_mean