summaryrefslogtreecommitdiff
path: root/utils/triplet_loss.py
diff options
context:
space:
mode:
Diffstat (limited to 'utils/triplet_loss.py')
-rw-r--r--utils/triplet_loss.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 0df2188..6025bd3 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -1,3 +1,5 @@
+from typing import Tuple
+
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -55,7 +57,7 @@ class JointBatchAllTripletLoss(BatchAllTripletLoss):
def __init__(
self,
hpm_num_parts: int,
- margins: tuple[float, float] = (0.2, 0.2)
+ margins: Tuple[float, float] = (0.2, 0.2)
):
super().__init__()
self.hpm_num_parts = hpm_num_parts