diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-27 22:14:21 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-27 22:14:21 +0800 |
commit | 46391257ff50848efa1aa251ab3f15dc8b7a2d2c (patch) | |
tree | 1e04084a9f0e42a7421b951134dd0588ea691c08 /models | |
parent | 9001f7e13d8985b220bd218d8de716bc586dbdcf (diff) |
Implement Batch Hard triplet loss and soft margin
Diffstat (limited to 'models')
-rw-r--r-- | models/model.py | 57 | ||||
-rw-r--r-- | models/rgb_part_net.py | 6 |
2 files changed, 45 insertions, 18 deletions
diff --git a/models/model.py b/models/model.py index 90d48e0..79952cb 100644 --- a/models/model.py +++ b/models/model.py @@ -18,7 +18,7 @@ from utils.configuration import DataloaderConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler -from utils.triplet_loss import JointBatchAllTripletLoss +from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss class Model: @@ -68,7 +68,7 @@ class Model: self._dataset_sig: str = 'undefined' self.rgb_pn: Optional[RGBPartNet] = None - self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None + self.triplet_loss: Optional[JointBatchTripletLoss] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None @@ -143,7 +143,8 @@ class Model: dataloader = self._parse_dataloader_config(dataset, dataloader_config) # Prepare for model, optimizer and scheduler model_hp: dict = self.hp.get('model', {}).copy() - triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2)) + triplet_is_hard = model_hp.pop('triplet_is_hard', True) + triplet_margins = model_hp.pop('triplet_margins', None) optim_hp: dict = self.hp.get('optimizer', {}).copy() start_iter = optim_hp.pop('start_iter', 0) ae_optim_hp = optim_hp.pop('auto_encoder', {}) @@ -153,12 +154,23 @@ class Model: sched_hp = self.hp.get('scheduler', {}) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) - self.ba_triplet_loss = JointBatchAllTripletLoss( - self.rgb_pn.hpm_num_parts, triplet_margins - ) + # Hard margins + if triplet_margins: + # Same margins + if triplet_margins[0] == triplet_margins[1]: + self.triplet_loss = BatchTripletLoss( + triplet_is_hard, triplet_margins[0] + ) + else: # Different margins + self.triplet_loss = JointBatchTripletLoss( + self.rgb_pn.hpm_num_parts, triplet_is_hard, triplet_margins + ) + else: # Soft margins + self.triplet_loss = BatchTripletLoss(triplet_is_hard, None) + # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) - self.ba_triplet_loss = self.ba_triplet_loss.to(self.device) + self.triplet_loss = self.triplet_loss.to(self.device) self.optimizer = optim.Adam([ {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, @@ -200,16 +212,16 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - feature, ae_losses, images = self.rgb_pn(x_c1, x_c2) + embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part y = y.repeat(self.rgb_pn.num_total_parts, 1) - triplet_loss = self.ba_triplet_loss(feature, y) + trip_loss, dist, non_zero_counts = self.triplet_loss(embedding, y) losses = torch.cat(( ae_losses, torch.stack(( - triplet_loss[:self.rgb_pn.hpm_num_parts].mean(), - triplet_loss[self.rgb_pn.hpm_num_parts:].mean() + trip_loss[:self.rgb_pn.hpm_num_parts].mean(), + trip_loss[self.rgb_pn.hpm_num_parts:].mean() )) )) loss = losses.sum() @@ -220,11 +232,26 @@ class Model: running_loss += losses.detach() # Write losses to TensorBoard self.writer.add_scalar('Loss/all', loss, self.curr_iter) - self.writer.add_scalars('Loss/details', dict(zip([ + self.writer.add_scalars('Loss/disentanglement', dict(zip(( 'Cross reconstruction loss', 'Canonical consistency loss', - 'Pose similarity loss', 'Batch All triplet loss (HPM)', - 'Batch All triplet loss (PartNet)' - ], losses)), self.curr_iter) + 'Pose similarity loss' + ), ae_losses)), self.curr_iter) + self.writer.add_scalars('Loss/triplet loss', { + 'HPM': losses[3], + 'PartNet': losses[4] + }, self.curr_iter) + self.writer.add_scalars('Loss/non-zero counts', { + 'HPM': non_zero_counts[:self.rgb_pn.hpm_num_parts].mean(), + 'PartNet': non_zero_counts[self.rgb_pn.hpm_num_parts:].mean() + }, self.curr_iter) + self.writer.add_scalars('Embedding/distance', { + 'HPM': dist[:self.rgb_pn.hpm_num_parts].mean(), + 'PartNet': dist[self.rgb_pn.hpm_num_parts].mean() + }, self.curr_iter) + self.writer.add_scalars('Embedding/2-norm', { + 'HPM': embedding[:self.rgb_pn.hpm_num_parts].norm(), + 'PartNet': embedding[self.rgb_pn.hpm_num_parts].norm() + }, self.curr_iter) if self.curr_iter % 100 == 0: lrs = self.scheduler.get_last_lr() diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 4367c62..8a0f3a7 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -79,7 +79,7 @@ class RGBPartNet(nn.Module): ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) # Decode features x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + x_p_ = self._decode_pose_feature(f_p_, n, t, device) x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) i_a, i_c, i_p = None, None, None @@ -98,7 +98,7 @@ class RGBPartNet(nn.Module): else: # evaluating f_c_, f_p_ = self.ae(x_c1_t2) x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + x_p_ = self._decode_pose_feature(f_p_, n, t, device) x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4) return (x_c, x_p), None, None @@ -123,7 +123,7 @@ class RGBPartNet(nn.Module): ) return x_c - def _decode_pose_feature(self, f_p_, n, t, c, h, w, device): + def _decode_pose_feature(self, f_p_, n, t, device): # Decode pose features to images x_p_ = self.ae.decoder( torch.zeros((n * t, self.f_a_dim), device=device), |