diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-22 20:50:30 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-22 20:50:30 +0800 |
commit | 4a4506fde7867df4d4d446c8992621fb0da03939 (patch) | |
tree | c2cedb06e538c6c66cc8fe9dd7f47e00b66cc98e /utils | |
parent | f9f45c8d32c6f99c314f1ca93140598d1005c8fb (diff) | |
parent | b6e5972b64cc61fc967cf3d098fc629d781adce4 (diff) |
Merge branch 'master' into python3.8
# Conflicts:
# models/model.py
Diffstat (limited to 'utils')
-rw-r--r-- | utils/configuration.py | 1 | ||||
-rw-r--r-- | utils/triplet_loss.py | 9 |
2 files changed, 8 insertions, 2 deletions
diff --git a/utils/configuration.py b/utils/configuration.py index 8ee08f2..959791b 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -14,6 +14,7 @@ class DatasetConfiguration(TypedDict): name: str root_dir: str train_size: int + val_size: int num_sampled_frames: int truncate_threshold: int discard_threshold: int diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 03fff21..5e3a97a 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -28,6 +28,7 @@ class BatchTripletLoss(nn.Module): else: # is_all positive_negative_dist = self._all_distance(dist, y, p, n) + non_zero_counts = None if self.margin: losses = F.relu(self.margin + positive_negative_dist).view(p, -1) non_zero_counts = (losses != 0).sum(1).float() @@ -35,14 +36,18 @@ class BatchTripletLoss(nn.Module): loss_metric = self._none_zero_mean(losses, non_zero_counts) else: # is_sum loss_metric = losses.sum(1) - return loss_metric, flat_dist, non_zero_counts else: # Soft margin losses = F.softplus(positive_negative_dist).view(p, -1) if self.is_mean: loss_metric = losses.mean(1) else: # is_sum loss_metric = losses.sum(1) - return loss_metric, flat_dist, None + + return { + 'loss': loss_metric, + 'dist': flat_dist, + 'counts': non_zero_counts + } @staticmethod def _batch_distance(x): |