summaryrefslogtreecommitdiff
path: root/utils
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-08 18:11:25 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-08 18:25:42 +0800
commit99ddd7c142a4ec97cb8bd14b204651790b3cf4ee (patch)
treea4ccbd08a7155e90df63aba60eb93ab2b7969c9b /utils
parent507e1d163aaa6ea4be23e7f08ff6ce0ef58c830b (diff)
Code refactoring, modifications and new features
1. Decode features outside of auto-encoder 2. Turn off HPM 1x1 conv by default 3. Change canonical feature map size from `feature_channels * 8 x 4 x 2` to `feature_channels * 2 x 16 x 8` 4. Use mean of canonical embeddings instead of mean of static features 5. Calculate static and dynamic loss separately 6. Calculate mean of parts in triplet loss instead of sum of parts 7. Add switch to log disentangled images 8. Change default configuration
Diffstat (limited to 'utils')
-rw-r--r--utils/configuration.py4
-rw-r--r--utils/triplet_loss.py2
2 files changed, 4 insertions, 2 deletions
diff --git a/utils/configuration.py b/utils/configuration.py
index c4c4b4d..4ab1520 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -7,6 +7,7 @@ class SystemConfiguration(TypedDict):
disable_acc: bool
CUDA_VISIBLE_DEVICES: str
save_dir: str
+ image_log_on: bool
class DatasetConfiguration(TypedDict):
@@ -31,6 +32,7 @@ class ModelHPConfiguration(TypedDict):
ae_feature_channels: int
f_a_c_p_dims: tuple[int, int, int]
hpm_scales: tuple[int, ...]
+ hpm_use_1x1conv: bool
hpm_use_avg_pool: bool
hpm_use_max_pool: bool
fpfe_feature_channels: int
@@ -40,7 +42,7 @@ class ModelHPConfiguration(TypedDict):
tfa_squeeze_ratio: int
tfa_num_parts: int
embedding_dims: int
- triplet_margin: float
+ triplet_margins: tuple[float, float]
class SubOptimizerHPConfiguration(TypedDict):
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 8c143d6..d573ef4 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -34,5 +34,5 @@ class BatchAllTripletLoss(nn.Module):
parted_loss_mean = all_loss.sum(1) / non_zero_counts
parted_loss_mean[non_zero_counts == 0] = 0
- loss = parted_loss_mean.sum()
+ loss = parted_loss_mean.mean()
return loss