diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-08 18:31:52 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-08 18:31:52 +0800 |
commit | d380e04df37593e414bd5641db100613fb2ad882 (patch) | |
tree | 1e3b3ea55a464d59d790711372bbca42cb203d0a /utils | |
parent | a040400d7caa267d4bfbe8e5520568806f92b3d4 (diff) | |
parent | 99ddd7c142a4ec97cb8bd14b204651790b3cf4ee (diff) |
Merge branch 'master' into python3.8
# Conflicts:
# models/hpm.py
# models/layers.py
# models/model.py
# models/rgb_part_net.py
# utils/configuration.py
Diffstat (limited to 'utils')
-rw-r--r-- | utils/configuration.py | 5 | ||||
-rw-r--r-- | utils/triplet_loss.py | 2 |
2 files changed, 5 insertions, 2 deletions
diff --git a/utils/configuration.py b/utils/configuration.py index ef6b757..f44bcf0 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -7,6 +7,7 @@ class SystemConfiguration(TypedDict): disable_acc: bool CUDA_VISIBLE_DEVICES: str save_dir: str + image_log_on: bool class DatasetConfiguration(TypedDict): @@ -31,6 +32,7 @@ class ModelHPConfiguration(TypedDict): ae_feature_channels: int f_a_c_p_dims: Tuple[int, int, int] hpm_scales: Tuple[int, ...] + hpm_use_1x1conv: bool hpm_use_avg_pool: bool hpm_use_max_pool: bool fpfe_feature_channels: int @@ -40,7 +42,7 @@ class ModelHPConfiguration(TypedDict): tfa_squeeze_ratio: int tfa_num_parts: int embedding_dims: int - triplet_margin: float + triplet_margins: Tuple[float, float] class SubOptimizerHPConfiguration(TypedDict): @@ -52,6 +54,7 @@ class SubOptimizerHPConfiguration(TypedDict): class OptimizerHPConfiguration(TypedDict): + start_iter: int lr: int betas: Tuple[float, float] eps: float diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 8c143d6..d573ef4 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -34,5 +34,5 @@ class BatchAllTripletLoss(nn.Module): parted_loss_mean = all_loss.sum(1) / non_zero_counts parted_loss_mean[non_zero_counts == 0] = 0 - loss = parted_loss_mean.sum() + loss = parted_loss_mean.mean() return loss |