summaryrefslogtreecommitdiff
path: root/utils/triplet_loss.py
diff options
context:
space:
mode:
Diffstat (limited to 'utils/triplet_loss.py')
-rw-r--r--utils/triplet_loss.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index e05b69d..ae899ec 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, Tuple
import torch
import torch.nn as nn
@@ -93,7 +93,7 @@ class JointBatchTripletLoss(BatchTripletLoss):
hpm_num_parts: int,
is_hard: bool = True,
is_mean: bool = True,
- margins: tuple[float, float] = (0.2, 0.2)
+ margins: Tuple[float, float] = (0.2, 0.2)
):
super().__init__(is_hard, is_mean)
self.hpm_num_parts = hpm_num_parts