diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-12 13:56:17 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-12 13:56:17 +0800 | 
| commit | c74df416b00f837ba051f3947be92f76e7afbd88 (patch) | |
| tree | 02983df94008bbb427c2066c5f619e0ffdefe1c5 | |
| parent | 1b8d1614168ce6590c5e029c7f1007ac9b17048c (diff) | |
Code refactoring
1. Separate FCs and triplet losses for HPM and PartNet
2. Remove FC-equivalent 1x1 conv layers in HPM
3. Support adjustable learning rate schedulers
| -rw-r--r-- | config.py | 21 | ||||
| -rw-r--r-- | models/auto_encoder.py | 2 | ||||
| -rw-r--r-- | models/hpm.py | 25 | ||||
| -rw-r--r-- | models/layers.py | 9 | ||||
| -rw-r--r-- | models/model.py | 119 | ||||
| -rw-r--r-- | models/part_net.py | 18 | ||||
| -rw-r--r-- | models/rgb_part_net.py | 32 | ||||
| -rw-r--r-- | test/part_net.py | 2 | ||||
| -rw-r--r-- | utils/configuration.py | 20 | ||||
| -rw-r--r-- | utils/triplet_loss.py | 40 | 
10 files changed, 127 insertions, 161 deletions
@@ -50,19 +50,17 @@ config: Configuration = {              'ae_feature_channels': 64,              # Appearance, canonical and pose feature dimensions              'f_a_c_p_dims': (192, 192, 96), -            # Use 1x1 convolution in dimensionality reduction -            'hpm_use_1x1conv': False,              # HPM pyramid scales, of which sum is number of parts              'hpm_scales': (1, 2, 4, 8),              # Global pooling method              'hpm_use_avg_pool': True,              'hpm_use_max_pool': True, -            # Attention squeeze ratio -            'tfa_squeeze_ratio': 4,              # Number of parts after Part Net              'tfa_num_parts': 16, -            # Embedding dimension for each part -            'embedding_dims': 256, +            # Attention squeeze ratio +            'tfa_squeeze_ratio': 4, +            # Embedding dimensions for each part +            'embedding_dims': (256, 256),              # Batch Hard or Batch All              'triplet_is_hard': True,              # Use non-zero mean or sum @@ -91,9 +89,14 @@ config: Configuration = {          },          'scheduler': {              # Step start to decay -            'start_step': 15_000, +            'start_step': 500,              # Multiplicative factor of decay in the end -            'final_gamma': 0.001, +            'final_gamma': 0.01, + +            # Local parameters (override global ones) +            'hpm': { +                'final_gamma': 0.001 +            }          }      },      # Model metadata @@ -107,6 +110,6 @@ config: Configuration = {          # Restoration iteration (multiple models, e.g. nm, bg and cl)          'restore_iters': (0, 0, 0),          # Total iteration for training (multiple models) -        'total_iters': (25_000, 25_000, 25_000), +        'total_iters': (30_000, 40_000, 60_000),      },  } diff --git a/models/auto_encoder.py b/models/auto_encoder.py index e6a3e60..4fece69 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -171,7 +171,7 @@ class AutoEncoder(nn.Module):              return (                  (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_), -                torch.stack((xrecon_loss, cano_cons_loss, pose_sim_loss * 10)) +                (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)              )          else:  # evaluating              return f_c_c1_t2_, f_p_c1_t2_ diff --git a/models/hpm.py b/models/hpm.py index 9879cfb..8186b20 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -9,32 +9,26 @@ class HorizontalPyramidMatching(nn.Module):              self,              in_channels: int,              out_channels: int = 128, -            use_1x1conv: bool = False,              scales: tuple[int, ...] = (1, 2, 4),              use_avg_pool: bool = True,              use_max_pool: bool = False, -            **kwargs      ):          super().__init__() -        self.in_channels = in_channels -        self.out_channels = out_channels -        self.use_1x1conv = use_1x1conv          self.scales = scales +        self.num_parts = sum(scales)          self.use_avg_pool = use_avg_pool          self.use_max_pool = use_max_pool          self.pyramids = nn.ModuleList([ -            self._make_pyramid(scale, **kwargs) for scale in self.scales +            self._make_pyramid(scale) for scale in scales          ]) +        self.fc_mat = nn.Parameter( +            torch.empty(self.num_parts, in_channels, out_channels) +        ) -    def _make_pyramid(self, scale: int, **kwargs): +    def _make_pyramid(self, scale: int):          pyramid = nn.ModuleList([ -            HorizontalPyramidPooling(self.in_channels, -                                     self.out_channels, -                                     use_1x1conv=self.use_1x1conv, -                                     use_avg_pool=self.use_avg_pool, -                                     use_max_pool=self.use_max_pool, -                                     **kwargs) +            HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool)              for _ in range(scale)          ])          return pyramid @@ -52,4 +46,9 @@ class HorizontalPyramidMatching(nn.Module):                  x_slice = x_slice.view(n, -1)                  feature.append(x_slice)          x = torch.stack(feature) + +        # p, n, c +        x = x @ self.fc_mat +        # p, n, d +          return x diff --git a/models/layers.py b/models/layers.py index f1d72b6..c609698 100644 --- a/models/layers.py +++ b/models/layers.py @@ -167,17 +167,10 @@ class BasicConv1d(nn.Module):  class HorizontalPyramidPooling(nn.Module):      def __init__(              self, -            in_channels: int, -            out_channels: int, -            use_1x1conv: bool = False,              use_avg_pool: bool = True,              use_max_pool: bool = False, -            **kwargs      ):          super().__init__() -        self.use_1x1conv = use_1x1conv -        if use_1x1conv: -            self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs)          self.use_avg_pool = use_avg_pool          self.use_max_pool = use_max_pool          assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.' @@ -191,6 +184,4 @@ class HorizontalPyramidPooling(nn.Module):              x = self.avg_pool(x)          elif not self.use_avg_pool and self.use_max_pool:              x = self.max_pool(x) -        if self.use_1x1conv: -            x = self.conv(x)          return x diff --git a/models/model.py b/models/model.py index cb455f3..adea626 100644 --- a/models/model.py +++ b/models/model.py @@ -13,13 +13,15 @@ from torch.utils.data.dataloader import default_collate  from torch.utils.tensorboard import SummaryWriter  from tqdm import tqdm +from models.hpm import HorizontalPyramidMatching +from models.part_net import PartNet  from models.rgb_part_net import RGBPartNet  from utils.configuration import DataloaderConfiguration, \      HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \      SystemConfiguration  from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses  from utils.sampler import TripletSampler -from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss +from utils.triplet_loss import BatchTripletLoss  class Model: @@ -69,7 +71,8 @@ class Model:          self._dataset_sig: str = 'undefined'          self.rgb_pn: Optional[RGBPartNet] = None -        self.triplet_loss: Optional[JointBatchTripletLoss] = None +        self.triplet_loss_hpm: Optional[BatchTripletLoss] = None +        self.triplet_loss_pn: Optional[BatchTripletLoss] = None          self.optimizer: Optional[optim.Adam] = None          self.scheduler: Optional[optim.lr_scheduler.StepLR] = None          self.writer: Optional[SummaryWriter] = None @@ -149,26 +152,28 @@ class Model:          triplet_margins = model_hp.pop('triplet_margins', None)          optim_hp: dict = self.hp.get('optimizer', {}).copy()          ae_optim_hp = optim_hp.pop('auto_encoder', {}) -        pn_optim_hp = optim_hp.pop('part_net', {})          hpm_optim_hp = optim_hp.pop('hpm', {}) -        fc_optim_hp = optim_hp.pop('fc', {}) +        pn_optim_hp = optim_hp.pop('part_net', {})          sched_hp = self.hp.get('scheduler', {}) +        ae_sched_hp = sched_hp.get('auto_encoder', {}) +        hpm_sched_hp = sched_hp.get('hpm', {}) +        pn_sched_hp = sched_hp.get('part_net', {}) +          self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,                                   image_log_on=self.image_log_on)          # Hard margins          if triplet_margins: -            # Same margins -            if triplet_margins[0] == triplet_margins[1]: -                self.triplet_loss = BatchTripletLoss( -                    triplet_is_hard, triplet_margins[0] -                ) -            else:  # Different margins -                self.triplet_loss = JointBatchTripletLoss( -                    self.rgb_pn.hpm_num_parts, -                    triplet_is_hard, triplet_is_mean, triplet_margins -                ) +            self.triplet_loss_hpm = BatchTripletLoss( +                triplet_is_hard, triplet_is_mean, triplet_margins[0] +            ) +            self.triplet_loss_pn = BatchTripletLoss( +                triplet_is_hard, triplet_is_mean, triplet_margins[1] +            )          else:  # Soft margins -            self.triplet_loss = BatchTripletLoss( +            self.triplet_loss_hpm = BatchTripletLoss( +                triplet_is_hard, triplet_is_mean, None +            ) +            self.triplet_loss_pn = BatchTripletLoss(                  triplet_is_hard, triplet_is_mean, None              ) @@ -177,25 +182,33 @@ class Model:          # Try to accelerate computation using CUDA or others          self.rgb_pn = self.rgb_pn.to(self.device) -        self.triplet_loss = self.triplet_loss.to(self.device) +        self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device) +        self.triplet_loss_pn = self.triplet_loss_pn.to(self.device) +          self.optimizer = optim.Adam([              {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, -            {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},              {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, -            {'params': self.rgb_pn.fc_mat, **fc_optim_hp} +            {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},          ], **optim_hp) -        sched_final_gamma = sched_hp.get('final_gamma', 0.001) -        sched_start_step = sched_hp.get('start_step', 15_000) -        all_step = self.total_iter - sched_start_step - -        def lr_lambda(epoch): -            if epoch > sched_start_step: -                passed_step = epoch - sched_start_step -                return sched_final_gamma ** (passed_step / all_step) -            else: -                return 1 + +        start_step = sched_hp.get('start_step', 15_000) +        final_gamma = sched_hp.get('final_gamma', 0.001) +        ae_start_step = ae_sched_hp.get('start_step', start_step) +        ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma) +        ae_all_step = self.total_iter - ae_start_step +        hpm_start_step = hpm_sched_hp.get('start_step', start_step) +        hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma) +        hpm_all_step = self.total_iter - hpm_start_step +        pn_start_step = pn_sched_hp.get('start_step', start_step) +        pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma) +        pn_all_step = self.total_iter - pn_start_step          self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ -            lr_lambda, lr_lambda, lr_lambda, lr_lambda +            lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step) +            if t > ae_start_step else 1, +            lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step) +            if t > hpm_start_step else 1, +            lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step) +            if t > pn_start_step else 1,          ])          self.writer = SummaryWriter(self._log_name) @@ -220,7 +233,7 @@ class Model:          running_loss = torch.zeros(5, device=self.device)          print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",                f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", -              f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}") +              f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}")          for (batch_c1, batch_c2) in dataloader:              self.curr_iter += 1              # Zero the parameter gradients @@ -228,17 +241,20 @@ class Model:              # forward + backward + optimize              x_c1 = batch_c1['clip'].to(self.device)              x_c2 = batch_c2['clip'].to(self.device) -            embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2) +            embedding_c, embedding_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)              y = batch_c1['label'].to(self.device)              # Duplicate labels for each part -            y = y.repeat(self.rgb_pn.num_total_parts, 1) -            trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) -            losses = torch.cat(( -                ae_losses, -                torch.stack(( -                    trip_loss[:self.rgb_pn.hpm_num_parts].mean(), -                    trip_loss[self.rgb_pn.hpm_num_parts:].mean() -                )) +            y = y.repeat(self.rgb_pn.num_parts, 1) +            trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( +                embedding_c, y[:self.rgb_pn.hpm.num_parts] +            ) +            trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( +                embedding_p, y[self.rgb_pn.hpm.num_parts:] +            ) +            losses = torch.stack(( +                *ae_losses, +                trip_loss_hpm.mean(), +                trip_loss_pn.mean()              ))              loss = losses.sum()              loss.backward() @@ -257,30 +273,30 @@ class Model:                  'PartNet': losses[4]              }, self.curr_iter)              # None-zero losses in batch -            if num_non_zero is not None: +            if hpm_num_non_zero is not None and hpm_num_non_zero is not None:                  self.writer.add_scalars('Loss/non-zero counts', { -                    'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), -                    'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() +                    'HPM': hpm_num_non_zero.mean(), +                    'PartNet': pn_num_non_zero.mean()                  }, self.curr_iter)              # Embedding distance -            mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) +            mean_hpm_dist = hpm_dist.mean(0)              self._add_ranked_scalars(                  'Embedding/HPM distance', mean_hpm_dist,                  num_pos_pairs, num_pairs, self.curr_iter              ) -            mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) +            mean_pa_dist = pn_dist.mean(0)              self._add_ranked_scalars(                  'Embedding/ParNet distance', mean_pa_dist,                  num_pos_pairs, num_pairs, self.curr_iter              )              # Embedding norm -            mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) +            mean_hpm_embedding = embedding_c.mean(0)              mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)              self._add_ranked_scalars(                  'Embedding/HPM norm', mean_hpm_norm,                  self.k, self.pr * self.k, self.curr_iter              ) -            mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0) +            mean_pa_embedding = embedding_p.mean(0)              mean_pa_norm = mean_pa_embedding.norm(dim=-1)              self._add_ranked_scalars(                  'Embedding/PartNet norm', mean_pa_norm, @@ -288,10 +304,9 @@ class Model:              )              # Learning rate              lrs = self.scheduler.get_last_lr() -            # Write learning rates -            self.writer.add_scalar( -                'Learning rate', lrs[0], self.curr_iter -            ) +            self.writer.add_scalars('Learning rate', dict(zip(( +                'Auto-encoder', 'HPM', 'PartNet' +            ), lrs)), self.curr_iter)              if self.curr_iter % 100 == 0:                  # Write disentangled images @@ -316,7 +331,7 @@ class Model:                  print(f'{hour:02}:{minute:02}:{second:02}',                        f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',                        '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), -                      f'{lrs[0]:.3e}') +                      '{:.3e} {:.3e} {:.3e}'.format(*lrs))                  running_loss.zero_()              # Step scheduler @@ -548,7 +563,7 @@ class Model:              nn.init.zeros_(m.bias)          elif isinstance(m, nn.Linear):              nn.init.xavier_uniform_(m.weight) -        elif isinstance(m, RGBPartNet): +        elif isinstance(m, (HorizontalPyramidMatching, PartNet)):              nn.init.xavier_uniform_(m.fc_mat)      def _parse_dataset_config( diff --git a/models/part_net.py b/models/part_net.py index 29cf9cd..f2236bf 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -111,17 +111,21 @@ class PartNet(nn.Module):      def __init__(              self,              in_channels: int = 128, +            embedding_dims: int = 256, +            num_parts: int = 16,              squeeze_ratio: int = 4, -            num_part: int = 16      ):          super().__init__() -        self.num_part = num_part -        self.tfa = TemporalFeatureAggregator( -            in_channels, squeeze_ratio, self.num_part -        ) +        self.num_part = num_parts          self.avg_pool = nn.AdaptiveAvgPool2d(1)          self.max_pool = nn.AdaptiveMaxPool2d(1) +        self.tfa = TemporalFeatureAggregator( +            in_channels, squeeze_ratio, self.num_part +        ) +        self.fc_mat = nn.Parameter( +            torch.empty(num_parts, in_channels, embedding_dims) +        )      def forward(self, x):          n, t, c, h, w = x.size() @@ -138,4 +142,8 @@ class PartNet(nn.Module):          # p, n, t, c          x = self.tfa(x) + +        # p, n, c +        x = x @ self.fc_mat +        # p, n, d          return x diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 8a0f3a7..c38a567 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,39 +13,31 @@ class RGBPartNet(nn.Module):              ae_in_size: tuple[int, int] = (64, 48),              ae_feature_channels: int = 64,              f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), -            hpm_use_1x1conv: bool = False,              hpm_scales: tuple[int, ...] = (1, 2, 4),              hpm_use_avg_pool: bool = True,              hpm_use_max_pool: bool = True,              tfa_squeeze_ratio: int = 4,              tfa_num_parts: int = 16, -            embedding_dims: int = 256, +            embedding_dims: tuple[int] = (256, 256),              image_log_on: bool = False      ):          super().__init__()          self.h, self.w = ae_in_size          (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims -        self.hpm_num_parts = sum(hpm_scales)          self.image_log_on = image_log_on          self.ae = AutoEncoder(              ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims          )          self.pn_in_channels = ae_feature_channels * 2 -        self.pn = PartNet( -            self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts -        )          self.hpm = HorizontalPyramidMatching( -            ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv, -            hpm_scales, hpm_use_avg_pool, hpm_use_max_pool +            self.pn_in_channels, embedding_dims[0], hpm_scales, +            hpm_use_avg_pool, hpm_use_max_pool          ) -        self.num_total_parts = self.hpm_num_parts + tfa_num_parts -        empty_fc = torch.empty(self.num_total_parts, -                               self.pn_in_channels, embedding_dims) -        self.fc_mat = nn.Parameter(empty_fc) +        self.pn = PartNet(self.pn_in_channels, embedding_dims[1], +                          tfa_num_parts, tfa_squeeze_ratio) -    def fc(self, x): -        return x @ self.fc_mat +        self.num_parts = self.hpm.num_parts + tfa_num_parts      def forward(self, x_c1, x_c2=None):          # Step 1: Disentanglement @@ -55,21 +47,17 @@ class RGBPartNet(nn.Module):          # Step 2.a: Static Gait Feature Aggregation & HPM          # n, c, h, w          x_c = self.hpm(x_c) -        # p, n, c +        # p, n, d          # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)          # n, t, c, h, w          x_p = self.pn(x_p) -        # p, n, c - -        # Step 3: Cat feature map together and fc -        x = torch.cat((x_c, x_p)) -        x = self.fc(x) +        # p, n, d          if self.training: -            return x, ae_losses, images +            return x_c, x_p, ae_losses, images          else: -            return x.unsqueeze(1).view(-1) +            return torch.cat((x_c, x_p)).unsqueeze(1).view(-1)      def _disentangle(self, x_c1_t2, x_c2_t2=None):          n, t, c, h, w = x_c1_t2.size() diff --git a/test/part_net.py b/test/part_net.py index 25e92ae..fada2c4 100644 --- a/test/part_net.py +++ b/test/part_net.py @@ -64,7 +64,7 @@ def test_custom_part_net():                   paddings=((2, 1), (1, 1), (1, 1), (1, 1)),                   halving=(1, 1, 3, 3),                   squeeze_ratio=8, -                 num_part=8) +                 num_parts=8)      x = torch.rand(T, N, 1, H, W)      x = pa(x) diff --git a/utils/configuration.py b/utils/configuration.py index 0f8d9ff..f6ac182 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -33,16 +33,11 @@ class ModelHPConfiguration(TypedDict):      ae_feature_channels: int      f_a_c_p_dims: tuple[int, int, int]      hpm_scales: tuple[int, ...] -    hpm_use_1x1conv: bool      hpm_use_avg_pool: bool      hpm_use_max_pool: bool -    fpfe_feature_channels: int -    fpfe_kernel_sizes: tuple[tuple, ...] -    fpfe_paddings: tuple[tuple, ...] -    fpfe_halving: tuple[int, ...] -    tfa_squeeze_ratio: int      tfa_num_parts: int -    embedding_dims: int +    tfa_squeeze_ratio: int +    embedding_dims: tuple[int]      triplet_is_hard: bool      triplet_is_mean: bool      triplet_margins: tuple[float, float] @@ -63,14 +58,21 @@ class OptimizerHPConfiguration(TypedDict):      weight_decay: float      amsgrad: bool      auto_encoder: SubOptimizerHPConfiguration -    part_net: SubOptimizerHPConfiguration      hpm: SubOptimizerHPConfiguration -    fc: SubOptimizerHPConfiguration +    part_net: SubOptimizerHPConfiguration + + +class SubSchedulerHPConfiguration(TypedDict): +    start_step: int +    final_gamma: float  class SchedulerHPConfiguration(TypedDict):      start_step: int      final_gamma: float +    auto_encoder: SubSchedulerHPConfiguration +    hpm: SubSchedulerHPConfiguration +    part_net: SubSchedulerHPConfiguration  class HyperparameterConfiguration(TypedDict): diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index e05b69d..03fff21 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -85,43 +85,3 @@ class BatchTripletLoss(nn.Module):          non_zero_mean = losses.sum(1) / non_zero_counts          non_zero_mean[non_zero_counts == 0] = 0          return non_zero_mean - - -class JointBatchTripletLoss(BatchTripletLoss): -    def __init__( -            self, -            hpm_num_parts: int, -            is_hard: bool = True, -            is_mean: bool = True, -            margins: tuple[float, float] = (0.2, 0.2) -    ): -        super().__init__(is_hard, is_mean) -        self.hpm_num_parts = hpm_num_parts -        self.margin_hpm, self.margin_pn = margins - -    def forward(self, x, y): -        p, n, c = x.size() -        dist = self._batch_distance(x) -        flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) -        flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] - -        if self.is_hard: -            positive_negative_dist = self._hard_distance(dist, y, p, n) -        else:  # is_all -            positive_negative_dist = self._all_distance(dist, y, p, n) - -        hpm_part_loss = F.relu( -            self.margin_hpm + positive_negative_dist[:self.hpm_num_parts] -        ) -        pn_part_loss = F.relu( -            self.margin_pn + positive_negative_dist[self.hpm_num_parts:] -        ) -        losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - -        non_zero_counts = (losses != 0).sum(1).float() -        if self.is_mean: -            loss_metric = self._none_zero_mean(losses, non_zero_counts) -        else:  # is_sum -            loss_metric = losses.sum(1) - -        return loss_metric, flat_dist, non_zero_counts  | 
