summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.py6
-rw-r--r--models/model.py57
-rw-r--r--models/rgb_part_net.py6
-rw-r--r--utils/triplet_loss.py84
4 files changed, 104 insertions, 49 deletions
diff --git a/config.py b/config.py
index a7d5371..282a9b4 100644
--- a/config.py
+++ b/config.py
@@ -61,8 +61,10 @@ config = {
'tfa_num_parts': 16,
# Embedding dimension for each part
'embedding_dims': 256,
- # Triplet loss margins for HPM and PartNet
- 'triplet_margins': (1.5, 1.5),
+ # Batch Hard or Batch All
+ 'triplet_is_hard': True,
+ # Triplet loss margins for HPM and PartNet, None for soft margin
+ 'triplet_margins': None,
},
'optimizer': {
# Global parameters
diff --git a/models/model.py b/models/model.py
index 1e22483..f4654f3 100644
--- a/models/model.py
+++ b/models/model.py
@@ -15,7 +15,7 @@ from tqdm import tqdm
from models.rgb_part_net import RGBPartNet
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
-from utils.triplet_loss import JointBatchAllTripletLoss
+from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss
class Model:
@@ -65,7 +65,7 @@ class Model:
self._dataset_sig: str = 'undefined'
self.rgb_pn: Optional[RGBPartNet] = None
- self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None
+ self.triplet_loss: Optional[JointBatchTripletLoss] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
@@ -140,7 +140,8 @@ class Model:
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
# Prepare for model, optimizer and scheduler
model_hp: Dict = self.hp.get('model', {}).copy()
- triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2))
+ triplet_is_hard = model_hp.pop('triplet_is_hard', True)
+ triplet_margins = model_hp.pop('triplet_margins', None)
optim_hp: Dict = self.hp.get('optimizer', {}).copy()
start_iter = optim_hp.pop('start_iter', 0)
ae_optim_hp = optim_hp.pop('auto_encoder', {})
@@ -150,12 +151,23 @@ class Model:
sched_hp = self.hp.get('scheduler', {})
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
- self.ba_triplet_loss = JointBatchAllTripletLoss(
- self.rgb_pn.hpm_num_parts, triplet_margins
- )
+ # Hard margins
+ if triplet_margins:
+ # Same margins
+ if triplet_margins[0] == triplet_margins[1]:
+ self.triplet_loss = BatchTripletLoss(
+ triplet_is_hard, triplet_margins[0]
+ )
+ else: # Different margins
+ self.triplet_loss = JointBatchTripletLoss(
+ self.rgb_pn.hpm_num_parts, triplet_is_hard, triplet_margins
+ )
+ else: # Soft margins
+ self.triplet_loss = BatchTripletLoss(triplet_is_hard, None)
+
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
- self.ba_triplet_loss = self.ba_triplet_loss.to(self.device)
+ self.triplet_loss = self.triplet_loss.to(self.device)
self.optimizer = optim.Adam([
{'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
{'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
@@ -197,16 +209,16 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- feature, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+ embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2)
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
y = y.repeat(self.rgb_pn.num_total_parts, 1)
- triplet_loss = self.ba_triplet_loss(feature, y)
+ trip_loss, dist, non_zero_counts = self.triplet_loss(embedding, y)
losses = torch.cat((
ae_losses,
torch.stack((
- triplet_loss[:self.rgb_pn.hpm_num_parts].mean(),
- triplet_loss[self.rgb_pn.hpm_num_parts:].mean()
+ trip_loss[:self.rgb_pn.hpm_num_parts].mean(),
+ trip_loss[self.rgb_pn.hpm_num_parts:].mean()
))
))
loss = losses.sum()
@@ -217,11 +229,26 @@ class Model:
running_loss += losses.detach()
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
- self.writer.add_scalars('Loss/details', dict(zip([
+ self.writer.add_scalars('Loss/disentanglement', dict(zip((
'Cross reconstruction loss', 'Canonical consistency loss',
- 'Pose similarity loss', 'Batch All triplet loss (HPM)',
- 'Batch All triplet loss (PartNet)'
- ], losses)), self.curr_iter)
+ 'Pose similarity loss'
+ ), ae_losses)), self.curr_iter)
+ self.writer.add_scalars('Loss/triplet loss', {
+ 'HPM': losses[3],
+ 'PartNet': losses[4]
+ }, self.curr_iter)
+ self.writer.add_scalars('Loss/non-zero counts', {
+ 'HPM': non_zero_counts[:self.rgb_pn.hpm_num_parts].mean(),
+ 'PartNet': non_zero_counts[self.rgb_pn.hpm_num_parts:].mean()
+ }, self.curr_iter)
+ self.writer.add_scalars('Embedding/distance', {
+ 'HPM': dist[:self.rgb_pn.hpm_num_parts].mean(),
+ 'PartNet': dist[self.rgb_pn.hpm_num_parts].mean()
+ }, self.curr_iter)
+ self.writer.add_scalars('Embedding/2-norm', {
+ 'HPM': embedding[:self.rgb_pn.hpm_num_parts].norm(),
+ 'PartNet': embedding[self.rgb_pn.hpm_num_parts].norm()
+ }, self.curr_iter)
if self.curr_iter % 100 == 0:
lrs = self.scheduler.get_last_lr()
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 4d7ba7f..310ef25 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -81,7 +81,7 @@ class RGBPartNet(nn.Module):
((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
# Decode features
x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ x_p_ = self._decode_pose_feature(f_p_, n, t, device)
x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4)
i_a, i_c, i_p = None, None, None
@@ -100,7 +100,7 @@ class RGBPartNet(nn.Module):
else: # evaluating
f_c_, f_p_ = self.ae(x_c1_t2)
x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ x_p_ = self._decode_pose_feature(f_p_, n, t, device)
x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4)
return (x_c, x_p), None, None
@@ -125,7 +125,7 @@ class RGBPartNet(nn.Module):
)
return x_c
- def _decode_pose_feature(self, f_p_, n, t, c, h, w, device):
+ def _decode_pose_feature(self, f_p_, n, t, device):
# Decode pose features to images
x_p_ = self.ae.decoder(
torch.zeros((n * t, self.f_a_dim), device=device),
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 6025bd3..22ac2ab 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -1,34 +1,36 @@
-from typing import Tuple
+from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
-class BatchAllTripletLoss(nn.Module):
- def __init__(self, margin: float = 0.2):
+class BatchTripletLoss(nn.Module):
+ def __init__(
+ self,
+ is_hard: bool = True,
+ margin: Optional[float] = 0.2,
+ ):
super().__init__()
+ self.is_hard = is_hard
self.margin = margin
def forward(self, x, y):
p, n, c = x.size()
-
dist = self._batch_distance(x)
- positive_negative_dist = self._hard_distance(dist, y, p, n)
- all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1)
- parted_loss_mean = self._none_zero_parted_mean(all_loss)
- return parted_loss_mean
+ if self.is_hard:
+ positive_negative_dist = self._hard_distance(dist, y, p, n)
+ else: # is_all
+ positive_negative_dist = self._all_distance(dist, y, p, n)
- @staticmethod
- def _hard_distance(dist, y, p, n):
- hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
- hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
- all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1)
- all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1)
- positive_negative_dist = all_hard_positive - all_hard_negative
+ if self.margin:
+ all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1)
+ else:
+ all_loss = F.softplus(positive_negative_dist).view(p, -1)
+ non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss)
- return positive_negative_dist
+ return non_zero_mean, dist.mean((1, 2)), non_zero_counts
@staticmethod
def _batch_distance(x):
@@ -40,41 +42,65 @@ class BatchAllTripletLoss(nn.Module):
dist = torch.sqrt(
F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum)
)
-
return dist
@staticmethod
+ def _hard_distance(dist, y, p, n):
+ positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
+ negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
+ hard_positive = dist[positive_mask].view(p, n, -1).max(-1).values
+ hard_negative = dist[negative_mask].view(p, n, -1).min(-1).values
+ positive_negative_dist = hard_positive - hard_negative
+
+ return positive_negative_dist
+
+ @staticmethod
+ def _all_distance(dist, y, p, n):
+ positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
+ negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
+ all_positive = dist[positive_mask].view(p, n, -1, 1)
+ all_negative = dist[negative_mask].view(p, n, 1, -1)
+ positive_negative_dist = all_positive - all_negative
+
+ return positive_negative_dist
+
+ @staticmethod
def _none_zero_parted_mean(all_loss):
# Non-zero parted mean
- non_zero_counts = (all_loss != 0).sum(1)
- parted_loss_mean = all_loss.sum(1) / non_zero_counts
- parted_loss_mean[non_zero_counts == 0] = 0
+ non_zero_counts = (all_loss != 0).sum(1).float()
+ non_zero_mean = all_loss.sum(1) / non_zero_counts
+ non_zero_mean[non_zero_counts == 0] = 0
- return parted_loss_mean
+ return non_zero_mean, non_zero_counts
-class JointBatchAllTripletLoss(BatchAllTripletLoss):
+class JointBatchTripletLoss(BatchTripletLoss):
def __init__(
self,
hpm_num_parts: int,
+ is_hard: bool = True,
margins: Tuple[float, float] = (0.2, 0.2)
):
- super().__init__()
+ super().__init__(is_hard)
self.hpm_num_parts = hpm_num_parts
self.margin_hpm, self.margin_pn = margins
def forward(self, x, y):
p, n, c = x.size()
-
dist = self._batch_distance(x)
- positive_negative_dist = self._hard_distance(dist, y, p, n)
+
+ if self.is_hard:
+ positive_negative_dist = self._hard_distance(dist, y, p, n)
+ else: # is_all
+ positive_negative_dist = self._all_distance(dist, y, p, n)
+
hpm_part_loss = F.relu(
self.margin_hpm + positive_negative_dist[:self.hpm_num_parts]
- ).view(self.hpm_num_parts, -1)
+ )
pn_part_loss = F.relu(
self.margin_pn + positive_negative_dist[self.hpm_num_parts:]
- ).view(p - self.hpm_num_parts, -1)
+ )
all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1)
- parted_loss_mean = self._none_zero_parted_mean(all_loss)
+ non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss)
- return parted_loss_mean
+ return non_zero_mean, dist.mean((1, 2)), non_zero_counts