summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-27 22:14:21 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-27 22:14:21 +0800
commit46391257ff50848efa1aa251ab3f15dc8b7a2d2c (patch)
tree1e04084a9f0e42a7421b951134dd0588ea691c08 /models/model.py
parent9001f7e13d8985b220bd218d8de716bc586dbdcf (diff)
Implement Batch Hard triplet loss and soft margin
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py57
1 files changed, 42 insertions, 15 deletions
diff --git a/models/model.py b/models/model.py
index 90d48e0..79952cb 100644
--- a/models/model.py
+++ b/models/model.py
@@ -18,7 +18,7 @@ from utils.configuration import DataloaderConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
-from utils.triplet_loss import JointBatchAllTripletLoss
+from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss
class Model:
@@ -68,7 +68,7 @@ class Model:
self._dataset_sig: str = 'undefined'
self.rgb_pn: Optional[RGBPartNet] = None
- self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None
+ self.triplet_loss: Optional[JointBatchTripletLoss] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
@@ -143,7 +143,8 @@ class Model:
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
# Prepare for model, optimizer and scheduler
model_hp: dict = self.hp.get('model', {}).copy()
- triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2))
+ triplet_is_hard = model_hp.pop('triplet_is_hard', True)
+ triplet_margins = model_hp.pop('triplet_margins', None)
optim_hp: dict = self.hp.get('optimizer', {}).copy()
start_iter = optim_hp.pop('start_iter', 0)
ae_optim_hp = optim_hp.pop('auto_encoder', {})
@@ -153,12 +154,23 @@ class Model:
sched_hp = self.hp.get('scheduler', {})
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
- self.ba_triplet_loss = JointBatchAllTripletLoss(
- self.rgb_pn.hpm_num_parts, triplet_margins
- )
+ # Hard margins
+ if triplet_margins:
+ # Same margins
+ if triplet_margins[0] == triplet_margins[1]:
+ self.triplet_loss = BatchTripletLoss(
+ triplet_is_hard, triplet_margins[0]
+ )
+ else: # Different margins
+ self.triplet_loss = JointBatchTripletLoss(
+ self.rgb_pn.hpm_num_parts, triplet_is_hard, triplet_margins
+ )
+ else: # Soft margins
+ self.triplet_loss = BatchTripletLoss(triplet_is_hard, None)
+
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
- self.ba_triplet_loss = self.ba_triplet_loss.to(self.device)
+ self.triplet_loss = self.triplet_loss.to(self.device)
self.optimizer = optim.Adam([
{'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
{'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
@@ -200,16 +212,16 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- feature, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+ embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2)
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
y = y.repeat(self.rgb_pn.num_total_parts, 1)
- triplet_loss = self.ba_triplet_loss(feature, y)
+ trip_loss, dist, non_zero_counts = self.triplet_loss(embedding, y)
losses = torch.cat((
ae_losses,
torch.stack((
- triplet_loss[:self.rgb_pn.hpm_num_parts].mean(),
- triplet_loss[self.rgb_pn.hpm_num_parts:].mean()
+ trip_loss[:self.rgb_pn.hpm_num_parts].mean(),
+ trip_loss[self.rgb_pn.hpm_num_parts:].mean()
))
))
loss = losses.sum()
@@ -220,11 +232,26 @@ class Model:
running_loss += losses.detach()
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
- self.writer.add_scalars('Loss/details', dict(zip([
+ self.writer.add_scalars('Loss/disentanglement', dict(zip((
'Cross reconstruction loss', 'Canonical consistency loss',
- 'Pose similarity loss', 'Batch All triplet loss (HPM)',
- 'Batch All triplet loss (PartNet)'
- ], losses)), self.curr_iter)
+ 'Pose similarity loss'
+ ), ae_losses)), self.curr_iter)
+ self.writer.add_scalars('Loss/triplet loss', {
+ 'HPM': losses[3],
+ 'PartNet': losses[4]
+ }, self.curr_iter)
+ self.writer.add_scalars('Loss/non-zero counts', {
+ 'HPM': non_zero_counts[:self.rgb_pn.hpm_num_parts].mean(),
+ 'PartNet': non_zero_counts[self.rgb_pn.hpm_num_parts:].mean()
+ }, self.curr_iter)
+ self.writer.add_scalars('Embedding/distance', {
+ 'HPM': dist[:self.rgb_pn.hpm_num_parts].mean(),
+ 'PartNet': dist[self.rgb_pn.hpm_num_parts].mean()
+ }, self.curr_iter)
+ self.writer.add_scalars('Embedding/2-norm', {
+ 'HPM': embedding[:self.rgb_pn.hpm_num_parts].norm(),
+ 'PartNet': embedding[self.rgb_pn.hpm_num_parts].norm()
+ }, self.curr_iter)
if self.curr_iter % 100 == 0:
lrs = self.scheduler.get_last_lr()