summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/model.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/models/model.py b/models/model.py
index 18896ae..34cb816 100644
--- a/models/model.py
+++ b/models/model.py
@@ -146,6 +146,7 @@ class Model:
# Prepare for model, optimizer and scheduler
model_hp: dict = self.hp.get('model', {}).copy()
triplet_is_hard = model_hp.pop('triplet_is_hard', True)
+ triplet_is_mean = model_hp.pop('triplet_is_mean', True)
triplet_margins = model_hp.pop('triplet_margins', None)
optim_hp: dict = self.hp.get('optimizer', {}).copy()
start_iter = optim_hp.pop('start_iter', 0)
@@ -165,10 +166,13 @@ class Model:
)
else: # Different margins
self.triplet_loss = JointBatchTripletLoss(
- self.rgb_pn.hpm_num_parts, triplet_is_hard, triplet_margins
+ self.rgb_pn.hpm_num_parts,
+ triplet_is_hard, triplet_is_mean, triplet_margins
)
else: # Soft margins
- self.triplet_loss = BatchTripletLoss(triplet_is_hard, None)
+ self.triplet_loss = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, None
+ )
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
@@ -243,7 +247,7 @@ class Model:
'PartNet': losses[4]
}, self.curr_iter)
# None-zero losses in batch
- if num_non_zero:
+ if num_non_zero is not None:
self.writer.add_scalars('Loss/non-zero counts', {
'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(),
'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean()