summaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-08 18:31:52 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-08 18:31:52 +0800
commitd380e04df37593e414bd5641db100613fb2ad882 (patch)
tree1e3b3ea55a464d59d790711372bbca42cb203d0a /utils
parenta040400d7caa267d4bfbe8e5520568806f92b3d4 (diff)
parent99ddd7c142a4ec97cb8bd14b204651790b3cf4ee (diff)
Merge branch 'master' into python3.8
# Conflicts: # models/hpm.py # models/layers.py # models/model.py # models/rgb_part_net.py # utils/configuration.py
Diffstat (limited to 'utils')
-rw-r--r--utils/configuration.py5
-rw-r--r--utils/triplet_loss.py2
2 files changed, 5 insertions, 2 deletions
diff --git a/utils/configuration.py b/utils/configuration.py
index ef6b757..f44bcf0 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -7,6 +7,7 @@ class SystemConfiguration(TypedDict):
disable_acc: bool
CUDA_VISIBLE_DEVICES: str
save_dir: str
+ image_log_on: bool
class DatasetConfiguration(TypedDict):
@@ -31,6 +32,7 @@ class ModelHPConfiguration(TypedDict):
ae_feature_channels: int
f_a_c_p_dims: Tuple[int, int, int]
hpm_scales: Tuple[int, ...]
+ hpm_use_1x1conv: bool
hpm_use_avg_pool: bool
hpm_use_max_pool: bool
fpfe_feature_channels: int
@@ -40,7 +42,7 @@ class ModelHPConfiguration(TypedDict):
tfa_squeeze_ratio: int
tfa_num_parts: int
embedding_dims: int
- triplet_margin: float
+ triplet_margins: Tuple[float, float]
class SubOptimizerHPConfiguration(TypedDict):
@@ -52,6 +54,7 @@ class SubOptimizerHPConfiguration(TypedDict):
class OptimizerHPConfiguration(TypedDict):
+ start_iter: int
lr: int
betas: Tuple[float, float]
eps: float
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 8c143d6..d573ef4 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -34,5 +34,5 @@ class BatchAllTripletLoss(nn.Module):
parted_loss_mean = all_loss.sum(1) / non_zero_counts
parted_loss_mean[non_zero_counts == 0] = 0
- loss = parted_loss_mean.sum()
+ loss = parted_loss_mean.mean()
return loss