diff options
Diffstat (limited to 'utils/triplet_loss.py')
-rw-r--r-- | utils/triplet_loss.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 0df2188..6025bd3 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -1,3 +1,5 @@ +from typing import Tuple + import torch import torch.nn as nn import torch.nn.functional as F @@ -55,7 +57,7 @@ class JointBatchAllTripletLoss(BatchAllTripletLoss): def __init__( self, hpm_num_parts: int, - margins: tuple[float, float] = (0.2, 0.2) + margins: Tuple[float, float] = (0.2, 0.2) ): super().__init__() self.hpm_num_parts = hpm_num_parts |