diff options
-rw-r--r-- | config.py | 2 | ||||
-rw-r--r-- | models/model.py | 10 | ||||
-rw-r--r-- | utils/configuration.py | 1 | ||||
-rw-r--r-- | utils/triplet_loss.py | 43 |
4 files changed, 38 insertions, 18 deletions
@@ -65,6 +65,8 @@ config: Configuration = { 'embedding_dims': 256, # Batch Hard or Batch All 'triplet_is_hard': True, + # Use non-zero mean or sum + 'triplet_is_mean': True, # Triplet loss margins for HPM and PartNet, None for soft margin 'triplet_margins': None, }, diff --git a/models/model.py b/models/model.py index 18896ae..34cb816 100644 --- a/models/model.py +++ b/models/model.py @@ -146,6 +146,7 @@ class Model: # Prepare for model, optimizer and scheduler model_hp: dict = self.hp.get('model', {}).copy() triplet_is_hard = model_hp.pop('triplet_is_hard', True) + triplet_is_mean = model_hp.pop('triplet_is_mean', True) triplet_margins = model_hp.pop('triplet_margins', None) optim_hp: dict = self.hp.get('optimizer', {}).copy() start_iter = optim_hp.pop('start_iter', 0) @@ -165,10 +166,13 @@ class Model: ) else: # Different margins self.triplet_loss = JointBatchTripletLoss( - self.rgb_pn.hpm_num_parts, triplet_is_hard, triplet_margins + self.rgb_pn.hpm_num_parts, + triplet_is_hard, triplet_is_mean, triplet_margins ) else: # Soft margins - self.triplet_loss = BatchTripletLoss(triplet_is_hard, None) + self.triplet_loss = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, None + ) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) @@ -243,7 +247,7 @@ class Model: 'PartNet': losses[4] }, self.curr_iter) # None-zero losses in batch - if num_non_zero: + if num_non_zero is not None: self.writer.add_scalars('Loss/non-zero counts', { 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() diff --git a/utils/configuration.py b/utils/configuration.py index 20aec76..31eb243 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -44,6 +44,7 @@ class ModelHPConfiguration(TypedDict): tfa_num_parts: int embedding_dims: int triplet_is_hard: bool + triplet_is_mean: bool triplet_margins: tuple[float, float] diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 52d676e..db0cf0f 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -9,10 +9,12 @@ class BatchTripletLoss(nn.Module): def __init__( self, is_hard: bool = True, + is_mean: bool = True, margin: Optional[float] = 0.2, ): super().__init__() self.is_hard = is_hard + self.is_mean = is_mean self.margin = margin def forward(self, x, y): @@ -27,13 +29,20 @@ class BatchTripletLoss(nn.Module): positive_negative_dist = self._all_distance(dist, y, p, n) if self.margin: - all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - loss_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) - return loss_mean, flat_dist, non_zero_counts + losses = F.relu(self.margin + positive_negative_dist).view(p, -1) + non_zero_counts = (losses != 0).sum(1).float() + if self.is_mean: + loss_metric = self._none_zero_mean(losses, non_zero_counts) + else: # is_sum + loss_metric = losses.sum(1) + return loss_metric, flat_dist, non_zero_counts else: # Soft margin - all_loss = F.softplus(positive_negative_dist).view(p, -1) - loss_mean = all_loss.mean(1) - return loss_mean, flat_dist, None + losses = F.softplus(positive_negative_dist).view(p, -1) + if self.is_mean: + loss_metric = losses.mean(1) + else: # is_sum + loss_metric = losses.sum(1) + return loss_metric, flat_dist, None @staticmethod def _batch_distance(x): @@ -68,13 +77,11 @@ class BatchTripletLoss(nn.Module): return positive_negative_dist @staticmethod - def _none_zero_parted_mean(all_loss): + def _none_zero_mean(losses, non_zero_counts): # Non-zero parted mean - non_zero_counts = (all_loss != 0).sum(1).float() - non_zero_mean = all_loss.sum(1) / non_zero_counts + non_zero_mean = losses.sum(1) / non_zero_counts non_zero_mean[non_zero_counts == 0] = 0 - - return non_zero_mean, non_zero_counts + return non_zero_mean class JointBatchTripletLoss(BatchTripletLoss): @@ -82,9 +89,10 @@ class JointBatchTripletLoss(BatchTripletLoss): self, hpm_num_parts: int, is_hard: bool = True, + is_mean: bool = True, margins: tuple[float, float] = (0.2, 0.2) ): - super().__init__(is_hard) + super().__init__(is_hard, is_mean) self.hpm_num_parts = hpm_num_parts self.margin_hpm, self.margin_pn = margins @@ -103,7 +111,12 @@ class JointBatchTripletLoss(BatchTripletLoss): pn_part_loss = F.relu( self.margin_pn + positive_negative_dist[self.hpm_num_parts:] ) - all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - non_zero_mean, non_zero_counts = self._none_zero_parted_mean(all_loss) + losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) + + non_zero_counts = (losses != 0).sum(1).float() + if self.is_mean: + loss_metric = self._none_zero_mean(losses, non_zero_counts) + else: # is_sum + loss_metric = losses.sum(1) - return non_zero_mean, dist, non_zero_counts + return loss_metric, dist, non_zero_counts |