diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-27 22:14:21 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-27 22:14:21 +0800 | 
| commit | 46391257ff50848efa1aa251ab3f15dc8b7a2d2c (patch) | |
| tree | 1e04084a9f0e42a7421b951134dd0588ea691c08 /models/model.py | |
| parent | 9001f7e13d8985b220bd218d8de716bc586dbdcf (diff) | |
Implement Batch Hard triplet loss and soft margin
Diffstat (limited to 'models/model.py')
| -rw-r--r-- | models/model.py | 57 | 
1 files changed, 42 insertions, 15 deletions
diff --git a/models/model.py b/models/model.py index 90d48e0..79952cb 100644 --- a/models/model.py +++ b/models/model.py @@ -18,7 +18,7 @@ from utils.configuration import DataloaderConfiguration, \      SystemConfiguration  from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses  from utils.sampler import TripletSampler -from utils.triplet_loss import JointBatchAllTripletLoss +from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss  class Model: @@ -68,7 +68,7 @@ class Model:          self._dataset_sig: str = 'undefined'          self.rgb_pn: Optional[RGBPartNet] = None -        self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None +        self.triplet_loss: Optional[JointBatchTripletLoss] = None          self.optimizer: Optional[optim.Adam] = None          self.scheduler: Optional[optim.lr_scheduler.StepLR] = None          self.writer: Optional[SummaryWriter] = None @@ -143,7 +143,8 @@ class Model:          dataloader = self._parse_dataloader_config(dataset, dataloader_config)          # Prepare for model, optimizer and scheduler          model_hp: dict = self.hp.get('model', {}).copy() -        triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2)) +        triplet_is_hard = model_hp.pop('triplet_is_hard', True) +        triplet_margins = model_hp.pop('triplet_margins', None)          optim_hp: dict = self.hp.get('optimizer', {}).copy()          start_iter = optim_hp.pop('start_iter', 0)          ae_optim_hp = optim_hp.pop('auto_encoder', {}) @@ -153,12 +154,23 @@ class Model:          sched_hp = self.hp.get('scheduler', {})          self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,                                   image_log_on=self.image_log_on) -        self.ba_triplet_loss = JointBatchAllTripletLoss( -            self.rgb_pn.hpm_num_parts, triplet_margins -        ) +        # Hard margins +        if triplet_margins: +            # Same margins +            if triplet_margins[0] == triplet_margins[1]: +                self.triplet_loss = BatchTripletLoss( +                    triplet_is_hard, triplet_margins[0] +                ) +            else:  # Different margins +                self.triplet_loss = JointBatchTripletLoss( +                    self.rgb_pn.hpm_num_parts, triplet_is_hard, triplet_margins +                ) +        else:  # Soft margins +            self.triplet_loss = BatchTripletLoss(triplet_is_hard, None) +          # Try to accelerate computation using CUDA or others          self.rgb_pn = self.rgb_pn.to(self.device) -        self.ba_triplet_loss = self.ba_triplet_loss.to(self.device) +        self.triplet_loss = self.triplet_loss.to(self.device)          self.optimizer = optim.Adam([              {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},              {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, @@ -200,16 +212,16 @@ class Model:              # forward + backward + optimize              x_c1 = batch_c1['clip'].to(self.device)              x_c2 = batch_c2['clip'].to(self.device) -            feature, ae_losses, images = self.rgb_pn(x_c1, x_c2) +            embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2)              y = batch_c1['label'].to(self.device)              # Duplicate labels for each part              y = y.repeat(self.rgb_pn.num_total_parts, 1) -            triplet_loss = self.ba_triplet_loss(feature, y) +            trip_loss, dist, non_zero_counts = self.triplet_loss(embedding, y)              losses = torch.cat((                  ae_losses,                  torch.stack(( -                    triplet_loss[:self.rgb_pn.hpm_num_parts].mean(), -                    triplet_loss[self.rgb_pn.hpm_num_parts:].mean() +                    trip_loss[:self.rgb_pn.hpm_num_parts].mean(), +                    trip_loss[self.rgb_pn.hpm_num_parts:].mean()                  ))              ))              loss = losses.sum() @@ -220,11 +232,26 @@ class Model:              running_loss += losses.detach()              # Write losses to TensorBoard              self.writer.add_scalar('Loss/all', loss, self.curr_iter) -            self.writer.add_scalars('Loss/details', dict(zip([ +            self.writer.add_scalars('Loss/disentanglement', dict(zip((                  'Cross reconstruction loss', 'Canonical consistency loss', -                'Pose similarity loss', 'Batch All triplet loss (HPM)', -                'Batch All triplet loss (PartNet)' -            ], losses)), self.curr_iter) +                'Pose similarity loss' +            ), ae_losses)), self.curr_iter) +            self.writer.add_scalars('Loss/triplet loss', { +                'HPM': losses[3], +                'PartNet': losses[4] +            }, self.curr_iter) +            self.writer.add_scalars('Loss/non-zero counts', { +                'HPM': non_zero_counts[:self.rgb_pn.hpm_num_parts].mean(), +                'PartNet': non_zero_counts[self.rgb_pn.hpm_num_parts:].mean() +            }, self.curr_iter) +            self.writer.add_scalars('Embedding/distance', { +                'HPM': dist[:self.rgb_pn.hpm_num_parts].mean(), +                'PartNet': dist[self.rgb_pn.hpm_num_parts].mean() +            }, self.curr_iter) +            self.writer.add_scalars('Embedding/2-norm', { +                'HPM': embedding[:self.rgb_pn.hpm_num_parts].norm(), +                'PartNet': embedding[self.rgb_pn.hpm_num_parts].norm() +            }, self.curr_iter)              if self.curr_iter % 100 == 0:                  lrs = self.scheduler.get_last_lr()  | 
