summaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-12 20:30:25 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-12 20:34:00 +0800
commit30b475c0a27e0f848743abf0f909607defc6a3ee (patch)
treeaaab163d3d76a835c32ce5014ce62637550d0b0d /utils
parent3d8fc322623ba61610fd206b9f52b406e85cae61 (diff)
parente83ae0bcb5c763636fd522c2712a3c8aef558f3c (diff)
Merge branch 'data_parallel' into data_parallel_py3.8
# Conflicts: # models/hpm.py # models/model.py # models/rgb_part_net.py # utils/configuration.py # utils/triplet_loss.py
Diffstat (limited to 'utils')
-rw-r--r--utils/configuration.py20
-rw-r--r--utils/triplet_loss.py42
2 files changed, 12 insertions, 50 deletions
diff --git a/utils/configuration.py b/utils/configuration.py
index 376ae0f..8ee08f2 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -33,16 +33,11 @@ class ModelHPConfiguration(TypedDict):
ae_feature_channels: int
f_a_c_p_dims: Tuple[int, int, int]
hpm_scales: Tuple[int, ...]
- hpm_use_1x1conv: bool
hpm_use_avg_pool: bool
hpm_use_max_pool: bool
- fpfe_feature_channels: int
- fpfe_kernel_sizes: Tuple[Tuple, ...]
- fpfe_paddings: Tuple[Tuple, ...]
- fpfe_halving: Tuple[int, ...]
- tfa_squeeze_ratio: int
tfa_num_parts: int
- embedding_dims: int
+ tfa_squeeze_ratio: int
+ embedding_dims: Tuple[int]
triplet_is_hard: bool
triplet_is_mean: bool
triplet_margins: Tuple[float, float]
@@ -63,14 +58,21 @@ class OptimizerHPConfiguration(TypedDict):
weight_decay: float
amsgrad: bool
auto_encoder: SubOptimizerHPConfiguration
- part_net: SubOptimizerHPConfiguration
hpm: SubOptimizerHPConfiguration
- fc: SubOptimizerHPConfiguration
+ part_net: SubOptimizerHPConfiguration
+
+
+class SubSchedulerHPConfiguration(TypedDict):
+ start_step: int
+ final_gamma: float
class SchedulerHPConfiguration(TypedDict):
start_step: int
final_gamma: float
+ auto_encoder: SubSchedulerHPConfiguration
+ hpm: SubSchedulerHPConfiguration
+ part_net: SubSchedulerHPConfiguration
class HyperparameterConfiguration(TypedDict):
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index ae899ec..03fff21 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple
+from typing import Optional
import torch
import torch.nn as nn
@@ -85,43 +85,3 @@ class BatchTripletLoss(nn.Module):
non_zero_mean = losses.sum(1) / non_zero_counts
non_zero_mean[non_zero_counts == 0] = 0
return non_zero_mean
-
-
-class JointBatchTripletLoss(BatchTripletLoss):
- def __init__(
- self,
- hpm_num_parts: int,
- is_hard: bool = True,
- is_mean: bool = True,
- margins: Tuple[float, float] = (0.2, 0.2)
- ):
- super().__init__(is_hard, is_mean)
- self.hpm_num_parts = hpm_num_parts
- self.margin_hpm, self.margin_pn = margins
-
- def forward(self, x, y):
- p, n, c = x.size()
- dist = self._batch_distance(x)
- flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device)
- flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]]
-
- if self.is_hard:
- positive_negative_dist = self._hard_distance(dist, y, p, n)
- else: # is_all
- positive_negative_dist = self._all_distance(dist, y, p, n)
-
- hpm_part_loss = F.relu(
- self.margin_hpm + positive_negative_dist[:self.hpm_num_parts]
- )
- pn_part_loss = F.relu(
- self.margin_pn + positive_negative_dist[self.hpm_num_parts:]
- )
- losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1)
-
- non_zero_counts = (losses != 0).sum(1).float()
- if self.is_mean:
- loss_metric = self._none_zero_mean(losses, non_zero_counts)
- else: # is_sum
- loss_metric = losses.sum(1)
-
- return loss_metric, flat_dist, non_zero_counts