summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-12 13:56:17 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-12 13:56:17 +0800
commitc74df416b00f837ba051f3947be92f76e7afbd88 (patch)
tree02983df94008bbb427c2066c5f619e0ffdefe1c5
parent1b8d1614168ce6590c5e029c7f1007ac9b17048c (diff)
Code refactoring
1. Separate FCs and triplet losses for HPM and PartNet 2. Remove FC-equivalent 1x1 conv layers in HPM 3. Support adjustable learning rate schedulers
-rw-r--r--config.py21
-rw-r--r--models/auto_encoder.py2
-rw-r--r--models/hpm.py25
-rw-r--r--models/layers.py9
-rw-r--r--models/model.py119
-rw-r--r--models/part_net.py18
-rw-r--r--models/rgb_part_net.py32
-rw-r--r--test/part_net.py2
-rw-r--r--utils/configuration.py20
-rw-r--r--utils/triplet_loss.py40
10 files changed, 127 insertions, 161 deletions
diff --git a/config.py b/config.py
index d6de788..8abeba3 100644
--- a/config.py
+++ b/config.py
@@ -50,19 +50,17 @@ config: Configuration = {
'ae_feature_channels': 64,
# Appearance, canonical and pose feature dimensions
'f_a_c_p_dims': (192, 192, 96),
- # Use 1x1 convolution in dimensionality reduction
- 'hpm_use_1x1conv': False,
# HPM pyramid scales, of which sum is number of parts
'hpm_scales': (1, 2, 4, 8),
# Global pooling method
'hpm_use_avg_pool': True,
'hpm_use_max_pool': True,
- # Attention squeeze ratio
- 'tfa_squeeze_ratio': 4,
# Number of parts after Part Net
'tfa_num_parts': 16,
- # Embedding dimension for each part
- 'embedding_dims': 256,
+ # Attention squeeze ratio
+ 'tfa_squeeze_ratio': 4,
+ # Embedding dimensions for each part
+ 'embedding_dims': (256, 256),
# Batch Hard or Batch All
'triplet_is_hard': True,
# Use non-zero mean or sum
@@ -91,9 +89,14 @@ config: Configuration = {
},
'scheduler': {
# Step start to decay
- 'start_step': 15_000,
+ 'start_step': 500,
# Multiplicative factor of decay in the end
- 'final_gamma': 0.001,
+ 'final_gamma': 0.01,
+
+ # Local parameters (override global ones)
+ 'hpm': {
+ 'final_gamma': 0.001
+ }
}
},
# Model metadata
@@ -107,6 +110,6 @@ config: Configuration = {
# Restoration iteration (multiple models, e.g. nm, bg and cl)
'restore_iters': (0, 0, 0),
# Total iteration for training (multiple models)
- 'total_iters': (25_000, 25_000, 25_000),
+ 'total_iters': (30_000, 40_000, 60_000),
},
}
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index e6a3e60..4fece69 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -171,7 +171,7 @@ class AutoEncoder(nn.Module):
return (
(f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_),
- torch.stack((xrecon_loss, cano_cons_loss, pose_sim_loss * 10))
+ (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
)
else: # evaluating
return f_c_c1_t2_, f_p_c1_t2_
diff --git a/models/hpm.py b/models/hpm.py
index 9879cfb..8186b20 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -9,32 +9,26 @@ class HorizontalPyramidMatching(nn.Module):
self,
in_channels: int,
out_channels: int = 128,
- use_1x1conv: bool = False,
scales: tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
use_max_pool: bool = False,
- **kwargs
):
super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.use_1x1conv = use_1x1conv
self.scales = scales
+ self.num_parts = sum(scales)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
self.pyramids = nn.ModuleList([
- self._make_pyramid(scale, **kwargs) for scale in self.scales
+ self._make_pyramid(scale) for scale in scales
])
+ self.fc_mat = nn.Parameter(
+ torch.empty(self.num_parts, in_channels, out_channels)
+ )
- def _make_pyramid(self, scale: int, **kwargs):
+ def _make_pyramid(self, scale: int):
pyramid = nn.ModuleList([
- HorizontalPyramidPooling(self.in_channels,
- self.out_channels,
- use_1x1conv=self.use_1x1conv,
- use_avg_pool=self.use_avg_pool,
- use_max_pool=self.use_max_pool,
- **kwargs)
+ HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool)
for _ in range(scale)
])
return pyramid
@@ -52,4 +46,9 @@ class HorizontalPyramidMatching(nn.Module):
x_slice = x_slice.view(n, -1)
feature.append(x_slice)
x = torch.stack(feature)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
+
return x
diff --git a/models/layers.py b/models/layers.py
index f1d72b6..c609698 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -167,17 +167,10 @@ class BasicConv1d(nn.Module):
class HorizontalPyramidPooling(nn.Module):
def __init__(
self,
- in_channels: int,
- out_channels: int,
- use_1x1conv: bool = False,
use_avg_pool: bool = True,
use_max_pool: bool = False,
- **kwargs
):
super().__init__()
- self.use_1x1conv = use_1x1conv
- if use_1x1conv:
- self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.'
@@ -191,6 +184,4 @@ class HorizontalPyramidPooling(nn.Module):
x = self.avg_pool(x)
elif not self.use_avg_pool and self.use_max_pool:
x = self.max_pool(x)
- if self.use_1x1conv:
- x = self.conv(x)
return x
diff --git a/models/model.py b/models/model.py
index cb455f3..adea626 100644
--- a/models/model.py
+++ b/models/model.py
@@ -13,13 +13,15 @@ from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
+from models.hpm import HorizontalPyramidMatching
+from models.part_net import PartNet
from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
-from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss
+from utils.triplet_loss import BatchTripletLoss
class Model:
@@ -69,7 +71,8 @@ class Model:
self._dataset_sig: str = 'undefined'
self.rgb_pn: Optional[RGBPartNet] = None
- self.triplet_loss: Optional[JointBatchTripletLoss] = None
+ self.triplet_loss_hpm: Optional[BatchTripletLoss] = None
+ self.triplet_loss_pn: Optional[BatchTripletLoss] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
@@ -149,26 +152,28 @@ class Model:
triplet_margins = model_hp.pop('triplet_margins', None)
optim_hp: dict = self.hp.get('optimizer', {}).copy()
ae_optim_hp = optim_hp.pop('auto_encoder', {})
- pn_optim_hp = optim_hp.pop('part_net', {})
hpm_optim_hp = optim_hp.pop('hpm', {})
- fc_optim_hp = optim_hp.pop('fc', {})
+ pn_optim_hp = optim_hp.pop('part_net', {})
sched_hp = self.hp.get('scheduler', {})
+ ae_sched_hp = sched_hp.get('auto_encoder', {})
+ hpm_sched_hp = sched_hp.get('hpm', {})
+ pn_sched_hp = sched_hp.get('part_net', {})
+
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
# Hard margins
if triplet_margins:
- # Same margins
- if triplet_margins[0] == triplet_margins[1]:
- self.triplet_loss = BatchTripletLoss(
- triplet_is_hard, triplet_margins[0]
- )
- else: # Different margins
- self.triplet_loss = JointBatchTripletLoss(
- self.rgb_pn.hpm_num_parts,
- triplet_is_hard, triplet_is_mean, triplet_margins
- )
+ self.triplet_loss_hpm = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, triplet_margins[0]
+ )
+ self.triplet_loss_pn = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, triplet_margins[1]
+ )
else: # Soft margins
- self.triplet_loss = BatchTripletLoss(
+ self.triplet_loss_hpm = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, None
+ )
+ self.triplet_loss_pn = BatchTripletLoss(
triplet_is_hard, triplet_is_mean, None
)
@@ -177,25 +182,33 @@ class Model:
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
- self.triplet_loss = self.triplet_loss.to(self.device)
+ self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device)
+ self.triplet_loss_pn = self.triplet_loss_pn.to(self.device)
+
self.optimizer = optim.Adam([
{'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
- {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
{'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
- {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
+ {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
], **optim_hp)
- sched_final_gamma = sched_hp.get('final_gamma', 0.001)
- sched_start_step = sched_hp.get('start_step', 15_000)
- all_step = self.total_iter - sched_start_step
-
- def lr_lambda(epoch):
- if epoch > sched_start_step:
- passed_step = epoch - sched_start_step
- return sched_final_gamma ** (passed_step / all_step)
- else:
- return 1
+
+ start_step = sched_hp.get('start_step', 15_000)
+ final_gamma = sched_hp.get('final_gamma', 0.001)
+ ae_start_step = ae_sched_hp.get('start_step', start_step)
+ ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma)
+ ae_all_step = self.total_iter - ae_start_step
+ hpm_start_step = hpm_sched_hp.get('start_step', start_step)
+ hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma)
+ hpm_all_step = self.total_iter - hpm_start_step
+ pn_start_step = pn_sched_hp.get('start_step', start_step)
+ pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma)
+ pn_all_step = self.total_iter - pn_start_step
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
- lr_lambda, lr_lambda, lr_lambda, lr_lambda
+ lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step)
+ if t > ae_start_step else 1,
+ lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step)
+ if t > hpm_start_step else 1,
+ lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step)
+ if t > pn_start_step else 1,
])
self.writer = SummaryWriter(self._log_name)
@@ -220,7 +233,7 @@ class Model:
running_loss = torch.zeros(5, device=self.device)
print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}")
+ f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}")
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -228,17 +241,20 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+ embedding_c, embedding_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
- y = y.repeat(self.rgb_pn.num_total_parts, 1)
- trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
- losses = torch.cat((
- ae_losses,
- torch.stack((
- trip_loss[:self.rgb_pn.hpm_num_parts].mean(),
- trip_loss[self.rgb_pn.hpm_num_parts:].mean()
- ))
+ y = y.repeat(self.rgb_pn.num_parts, 1)
+ trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm(
+ embedding_c, y[:self.rgb_pn.hpm.num_parts]
+ )
+ trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn(
+ embedding_p, y[self.rgb_pn.hpm.num_parts:]
+ )
+ losses = torch.stack((
+ *ae_losses,
+ trip_loss_hpm.mean(),
+ trip_loss_pn.mean()
))
loss = losses.sum()
loss.backward()
@@ -257,30 +273,30 @@ class Model:
'PartNet': losses[4]
}, self.curr_iter)
# None-zero losses in batch
- if num_non_zero is not None:
+ if hpm_num_non_zero is not None and hpm_num_non_zero is not None:
self.writer.add_scalars('Loss/non-zero counts', {
- 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(),
- 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean()
+ 'HPM': hpm_num_non_zero.mean(),
+ 'PartNet': pn_num_non_zero.mean()
}, self.curr_iter)
# Embedding distance
- mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0)
+ mean_hpm_dist = hpm_dist.mean(0)
self._add_ranked_scalars(
'Embedding/HPM distance', mean_hpm_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
- mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0)
+ mean_pa_dist = pn_dist.mean(0)
self._add_ranked_scalars(
'Embedding/ParNet distance', mean_pa_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
# Embedding norm
- mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0)
+ mean_hpm_embedding = embedding_c.mean(0)
mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/HPM norm', mean_hpm_norm,
self.k, self.pr * self.k, self.curr_iter
)
- mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0)
+ mean_pa_embedding = embedding_p.mean(0)
mean_pa_norm = mean_pa_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/PartNet norm', mean_pa_norm,
@@ -288,10 +304,9 @@ class Model:
)
# Learning rate
lrs = self.scheduler.get_last_lr()
- # Write learning rates
- self.writer.add_scalar(
- 'Learning rate', lrs[0], self.curr_iter
- )
+ self.writer.add_scalars('Learning rate', dict(zip((
+ 'Auto-encoder', 'HPM', 'PartNet'
+ ), lrs)), self.curr_iter)
if self.curr_iter % 100 == 0:
# Write disentangled images
@@ -316,7 +331,7 @@ class Model:
print(f'{hour:02}:{minute:02}:{second:02}',
f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- f'{lrs[0]:.3e}')
+ '{:.3e} {:.3e} {:.3e}'.format(*lrs))
running_loss.zero_()
# Step scheduler
@@ -548,7 +563,7 @@ class Model:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
- elif isinstance(m, RGBPartNet):
+ elif isinstance(m, (HorizontalPyramidMatching, PartNet)):
nn.init.xavier_uniform_(m.fc_mat)
def _parse_dataset_config(
diff --git a/models/part_net.py b/models/part_net.py
index 29cf9cd..f2236bf 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -111,17 +111,21 @@ class PartNet(nn.Module):
def __init__(
self,
in_channels: int = 128,
+ embedding_dims: int = 256,
+ num_parts: int = 16,
squeeze_ratio: int = 4,
- num_part: int = 16
):
super().__init__()
- self.num_part = num_part
- self.tfa = TemporalFeatureAggregator(
- in_channels, squeeze_ratio, self.num_part
- )
+ self.num_part = num_parts
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
+ self.tfa = TemporalFeatureAggregator(
+ in_channels, squeeze_ratio, self.num_part
+ )
+ self.fc_mat = nn.Parameter(
+ torch.empty(num_parts, in_channels, embedding_dims)
+ )
def forward(self, x):
n, t, c, h, w = x.size()
@@ -138,4 +142,8 @@ class PartNet(nn.Module):
# p, n, t, c
x = self.tfa(x)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
return x
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 8a0f3a7..c38a567 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -13,39 +13,31 @@ class RGBPartNet(nn.Module):
ae_in_size: tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
- hpm_use_1x1conv: bool = False,
hpm_scales: tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
- embedding_dims: int = 256,
+ embedding_dims: tuple[int] = (256, 256),
image_log_on: bool = False
):
super().__init__()
self.h, self.w = ae_in_size
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
- self.hpm_num_parts = sum(hpm_scales)
self.image_log_on = image_log_on
self.ae = AutoEncoder(
ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
self.pn_in_channels = ae_feature_channels * 2
- self.pn = PartNet(
- self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts
- )
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv,
- hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
+ self.pn_in_channels, embedding_dims[0], hpm_scales,
+ hpm_use_avg_pool, hpm_use_max_pool
)
- self.num_total_parts = self.hpm_num_parts + tfa_num_parts
- empty_fc = torch.empty(self.num_total_parts,
- self.pn_in_channels, embedding_dims)
- self.fc_mat = nn.Parameter(empty_fc)
+ self.pn = PartNet(self.pn_in_channels, embedding_dims[1],
+ tfa_num_parts, tfa_squeeze_ratio)
- def fc(self, x):
- return x @ self.fc_mat
+ self.num_parts = self.hpm.num_parts + tfa_num_parts
def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
@@ -55,21 +47,17 @@ class RGBPartNet(nn.Module):
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
x_c = self.hpm(x_c)
- # p, n, c
+ # p, n, d
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
# n, t, c, h, w
x_p = self.pn(x_p)
- # p, n, c
-
- # Step 3: Cat feature map together and fc
- x = torch.cat((x_c, x_p))
- x = self.fc(x)
+ # p, n, d
if self.training:
- return x, ae_losses, images
+ return x_c, x_p, ae_losses, images
else:
- return x.unsqueeze(1).view(-1)
+ return torch.cat((x_c, x_p)).unsqueeze(1).view(-1)
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()
diff --git a/test/part_net.py b/test/part_net.py
index 25e92ae..fada2c4 100644
--- a/test/part_net.py
+++ b/test/part_net.py
@@ -64,7 +64,7 @@ def test_custom_part_net():
paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
halving=(1, 1, 3, 3),
squeeze_ratio=8,
- num_part=8)
+ num_parts=8)
x = torch.rand(T, N, 1, H, W)
x = pa(x)
diff --git a/utils/configuration.py b/utils/configuration.py
index 0f8d9ff..f6ac182 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -33,16 +33,11 @@ class ModelHPConfiguration(TypedDict):
ae_feature_channels: int
f_a_c_p_dims: tuple[int, int, int]
hpm_scales: tuple[int, ...]
- hpm_use_1x1conv: bool
hpm_use_avg_pool: bool
hpm_use_max_pool: bool
- fpfe_feature_channels: int
- fpfe_kernel_sizes: tuple[tuple, ...]
- fpfe_paddings: tuple[tuple, ...]
- fpfe_halving: tuple[int, ...]
- tfa_squeeze_ratio: int
tfa_num_parts: int
- embedding_dims: int
+ tfa_squeeze_ratio: int
+ embedding_dims: tuple[int]
triplet_is_hard: bool
triplet_is_mean: bool
triplet_margins: tuple[float, float]
@@ -63,14 +58,21 @@ class OptimizerHPConfiguration(TypedDict):
weight_decay: float
amsgrad: bool
auto_encoder: SubOptimizerHPConfiguration
- part_net: SubOptimizerHPConfiguration
hpm: SubOptimizerHPConfiguration
- fc: SubOptimizerHPConfiguration
+ part_net: SubOptimizerHPConfiguration
+
+
+class SubSchedulerHPConfiguration(TypedDict):
+ start_step: int
+ final_gamma: float
class SchedulerHPConfiguration(TypedDict):
start_step: int
final_gamma: float
+ auto_encoder: SubSchedulerHPConfiguration
+ hpm: SubSchedulerHPConfiguration
+ part_net: SubSchedulerHPConfiguration
class HyperparameterConfiguration(TypedDict):
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index e05b69d..03fff21 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -85,43 +85,3 @@ class BatchTripletLoss(nn.Module):
non_zero_mean = losses.sum(1) / non_zero_counts
non_zero_mean[non_zero_counts == 0] = 0
return non_zero_mean
-
-
-class JointBatchTripletLoss(BatchTripletLoss):
- def __init__(
- self,
- hpm_num_parts: int,
- is_hard: bool = True,
- is_mean: bool = True,
- margins: tuple[float, float] = (0.2, 0.2)
- ):
- super().__init__(is_hard, is_mean)
- self.hpm_num_parts = hpm_num_parts
- self.margin_hpm, self.margin_pn = margins
-
- def forward(self, x, y):
- p, n, c = x.size()
- dist = self._batch_distance(x)
- flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device)
- flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]]
-
- if self.is_hard:
- positive_negative_dist = self._hard_distance(dist, y, p, n)
- else: # is_all
- positive_negative_dist = self._all_distance(dist, y, p, n)
-
- hpm_part_loss = F.relu(
- self.margin_hpm + positive_negative_dist[:self.hpm_num_parts]
- )
- pn_part_loss = F.relu(
- self.margin_pn + positive_negative_dist[self.hpm_num_parts:]
- )
- losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1)
-
- non_zero_counts = (losses != 0).sum(1).float()
- if self.is_mean:
- loss_metric = self._none_zero_mean(losses, non_zero_counts)
- else: # is_sum
- loss_metric = losses.sum(1)
-
- return loss_metric, flat_dist, non_zero_counts