diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/configuration.py | 26 | ||||
-rw-r--r-- | utils/triplet_loss.py | 36 |
2 files changed, 1 insertions, 61 deletions
diff --git a/utils/configuration.py b/utils/configuration.py index b9e6d92..340815b 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -31,27 +31,7 @@ class DataloaderConfiguration(TypedDict): class ModelHPConfiguration(TypedDict): ae_feature_channels: int - f_a_c_p_dims: Tuple[int, int, int] - hpm_scales: Tuple[int, ...] - hpm_use_1x1conv: bool - hpm_use_avg_pool: bool - hpm_use_max_pool: bool - fpfe_feature_channels: int - fpfe_kernel_sizes: Tuple[Tuple, ...] - fpfe_paddings: Tuple[Tuple, ...] - fpfe_halving: Tuple[int, ...] - tfa_squeeze_ratio: int - tfa_num_parts: int - embedding_dims: int - triplet_margins: Tuple[float, float] - - -class SubOptimizerHPConfiguration(TypedDict): - lr: int - betas: Tuple[float, float] - eps: float - weight_decay: float - amsgrad: bool + f_a_c_p_dims: tuple[int, int, int] class OptimizerHPConfiguration(TypedDict): @@ -61,10 +41,6 @@ class OptimizerHPConfiguration(TypedDict): eps: float weight_decay: float amsgrad: bool - auto_encoder: SubOptimizerHPConfiguration - part_net: SubOptimizerHPConfiguration - hpm: SubOptimizerHPConfiguration - fc: SubOptimizerHPConfiguration class SchedulerHPConfiguration(TypedDict): diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py deleted file mode 100644 index 954def2..0000000 --- a/utils/triplet_loss.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class BatchAllTripletLoss(nn.Module): - def __init__(self, margin: float = 0.2): - super().__init__() - self.margin = margin - - def forward(self, x, y): - p, n, c = x.size() - - # Euclidean distance p x n x n - x_squared_sum = torch.sum(x ** 2, dim=2) - x1_squared_sum = x_squared_sum.unsqueeze(2) - x2_squared_sum = x_squared_sum.unsqueeze(1) - x1_times_x2_sum = x @ x.transpose(1, 2) - dist = torch.sqrt( - F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) - ) - - hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) - hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) - all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1) - all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1) - positive_negative_dist = all_hard_positive - all_hard_negative - all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) - - # Non-zero parted mean - non_zero_counts = (all_loss != 0).sum(1) - parted_loss_mean = all_loss.sum(1) / non_zero_counts - parted_loss_mean[non_zero_counts == 0] = 0 - - loss = parted_loss_mean.mean() - return loss |