summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/rgb_part_net.py15
-rw-r--r--test/rgb_part_net.py65
-rw-r--r--utils/triplet_loss.py35
3 files changed, 109 insertions, 6 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 5012765..a58be39 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -5,6 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from models import AutoEncoder, HorizontalPyramidMatching, PartNet
+from utils.triplet_loss import BatchAllTripletLoss
class RGBPartNet(nn.Module):
@@ -24,7 +25,7 @@ class RGBPartNet(nn.Module):
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
embedding_dims: int = 256,
- triplet_margin: int = 0.2
+ triplet_margin: float = 0.2
):
super().__init__()
self.ae = AutoEncoder(
@@ -43,8 +44,10 @@ class RGBPartNet(nn.Module):
empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
self.fc_mat = nn.Parameter(empty_fc)
+ self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin)
+
def fc(self, x):
- return torch.matmul(x, self.fc_mat)
+ return x @ self.fc_mat
def forward(self, x_c1, x_c2, y=None):
# Step 0: Swap batch_size and time dimensions for next step
@@ -72,10 +75,10 @@ class RGBPartNet(nn.Module):
x = self.fc(x)
if self.training:
- # TODO Implement Batch All triplet loss function
- batch_all_triplet_loss = torch.tensor(0.)
- loss = torch.sum(torch.stack((*losses, batch_all_triplet_loss)))
- return loss
+ batch_all_triplet_loss = self.ba_triplet_loss(x, y)
+ losses = (*losses, batch_all_triplet_loss)
+ loss = torch.sum(torch.stack(losses))
+ return loss, (loss.item() for loss in losses)
else:
return x
diff --git a/test/rgb_part_net.py b/test/rgb_part_net.py
new file mode 100644
index 0000000..1d754a0
--- /dev/null
+++ b/test/rgb_part_net.py
@@ -0,0 +1,65 @@
+import torch
+
+from models import RGBPartNet
+
+P, K = 2, 4
+N, T, C, H, W = P * K, 10, 3, 64, 32
+
+
+def rand_x1_x2_y(n, t, c, h, w):
+ x1 = torch.rand(n, t, c, h, w)
+ x2 = torch.rand(n, t, c, h, w)
+ y = []
+ for p in range(P):
+ y += [p] * K
+ y = torch.as_tensor(y)
+ return x1, x2, y
+
+
+def test_default_rgb_part_net():
+ rgb_pa = RGBPartNet()
+ x1, x2, y = rand_x1_x2_y(N, T, C, H, W)
+
+ rgb_pa.train()
+ loss, metrics = rgb_pa(x1, x2, y)
+ _, _, _, _ = metrics
+ assert tuple(loss.size()) == ()
+ assert isinstance(_, float)
+
+ rgb_pa.eval()
+ x = rgb_pa(x1, x2)
+ assert tuple(x.size()) == (23, N, 256)
+
+
+def test_custom_rgb_part_net():
+ hpm_scales = (1, 2, 4, 8)
+ tfa_num_parts = 8
+ embedding_dims = 1024
+ rgb_pa = RGBPartNet(num_class=10,
+ ae_in_channels=1,
+ ae_feature_channels=32,
+ f_a_c_p_dims=(64, 64, 32),
+ hpm_scales=hpm_scales,
+ hpm_use_avg_pool=True,
+ hpm_use_max_pool=False,
+ fpfe_feature_channels=64,
+ fpfe_kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)),
+ fpfe_paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
+ fpfe_halving=(1, 1, 3, 3),
+ tfa_squeeze_ratio=8,
+ tfa_num_parts=tfa_num_parts,
+ embedding_dims=1024,
+ triplet_margin=0.4)
+ x1, x2, y = rand_x1_x2_y(N, T, 1, H, W)
+
+ rgb_pa.train()
+ loss, metrics = rgb_pa(x1, x2, y)
+ _, _, _, _ = metrics
+ assert tuple(loss.size()) == ()
+ assert isinstance(_, float)
+
+ rgb_pa.eval()
+ x = rgb_pa(x1, x2)
+ assert tuple(x.size()) == (
+ sum(hpm_scales) + tfa_num_parts, N, embedding_dims
+ )
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
new file mode 100644
index 0000000..242be45
--- /dev/null
+++ b/utils/triplet_loss.py
@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BatchAllTripletLoss(nn.Module):
+ def __init__(self, margin: float = 0.2):
+ super().__init__()
+ self.margin = margin
+
+ def forward(self, x, y):
+ # Duplicate labels for each part
+ p, n, c = x.size()
+ y = y.repeat(p, 1)
+
+ # Euclidean distance p x n x n
+ x_squared_sum = torch.sum(x ** 2, dim=2)
+ x1_squared_sum = x_squared_sum.unsqueeze(1)
+ x2_squared_sum = x_squared_sum.unsqueeze(2)
+ x1_times_x2_sum = x @ x.transpose(1, 2)
+ dist = torch.sqrt(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum)
+
+ hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
+ hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
+ all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1)
+ all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1)
+ positive_negative_dist = all_hard_positive - all_hard_negative
+ all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1)
+
+ # Non-zero parted mean
+ parted_loss_mean = all_loss.sum(1) / (all_loss != 0).sum(1)
+ parted_loss_mean[parted_loss_mean == float('Inf')] = 0
+
+ loss = parted_loss_mean.mean()
+ return loss