diff options
-rw-r--r-- | models/rgb_part_net.py | 15 | ||||
-rw-r--r-- | test/rgb_part_net.py | 65 | ||||
-rw-r--r-- | utils/triplet_loss.py | 35 |
3 files changed, 109 insertions, 6 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 5012765..a58be39 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from models import AutoEncoder, HorizontalPyramidMatching, PartNet +from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): @@ -24,7 +25,7 @@ class RGBPartNet(nn.Module): tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, embedding_dims: int = 256, - triplet_margin: int = 0.2 + triplet_margin: float = 0.2 ): super().__init__() self.ae = AutoEncoder( @@ -43,8 +44,10 @@ class RGBPartNet(nn.Module): empty_fc = torch.empty(total_parts, out_channels, embedding_dims) self.fc_mat = nn.Parameter(empty_fc) + self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin) + def fc(self, x): - return torch.matmul(x, self.fc_mat) + return x @ self.fc_mat def forward(self, x_c1, x_c2, y=None): # Step 0: Swap batch_size and time dimensions for next step @@ -72,10 +75,10 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - # TODO Implement Batch All triplet loss function - batch_all_triplet_loss = torch.tensor(0.) - loss = torch.sum(torch.stack((*losses, batch_all_triplet_loss))) - return loss + batch_all_triplet_loss = self.ba_triplet_loss(x, y) + losses = (*losses, batch_all_triplet_loss) + loss = torch.sum(torch.stack(losses)) + return loss, (loss.item() for loss in losses) else: return x diff --git a/test/rgb_part_net.py b/test/rgb_part_net.py new file mode 100644 index 0000000..1d754a0 --- /dev/null +++ b/test/rgb_part_net.py @@ -0,0 +1,65 @@ +import torch + +from models import RGBPartNet + +P, K = 2, 4 +N, T, C, H, W = P * K, 10, 3, 64, 32 + + +def rand_x1_x2_y(n, t, c, h, w): + x1 = torch.rand(n, t, c, h, w) + x2 = torch.rand(n, t, c, h, w) + y = [] + for p in range(P): + y += [p] * K + y = torch.as_tensor(y) + return x1, x2, y + + +def test_default_rgb_part_net(): + rgb_pa = RGBPartNet() + x1, x2, y = rand_x1_x2_y(N, T, C, H, W) + + rgb_pa.train() + loss, metrics = rgb_pa(x1, x2, y) + _, _, _, _ = metrics + assert tuple(loss.size()) == () + assert isinstance(_, float) + + rgb_pa.eval() + x = rgb_pa(x1, x2) + assert tuple(x.size()) == (23, N, 256) + + +def test_custom_rgb_part_net(): + hpm_scales = (1, 2, 4, 8) + tfa_num_parts = 8 + embedding_dims = 1024 + rgb_pa = RGBPartNet(num_class=10, + ae_in_channels=1, + ae_feature_channels=32, + f_a_c_p_dims=(64, 64, 32), + hpm_scales=hpm_scales, + hpm_use_avg_pool=True, + hpm_use_max_pool=False, + fpfe_feature_channels=64, + fpfe_kernel_sizes=((5, 3), (3, 3), (3, 3), (3, 3)), + fpfe_paddings=((2, 1), (1, 1), (1, 1), (1, 1)), + fpfe_halving=(1, 1, 3, 3), + tfa_squeeze_ratio=8, + tfa_num_parts=tfa_num_parts, + embedding_dims=1024, + triplet_margin=0.4) + x1, x2, y = rand_x1_x2_y(N, T, 1, H, W) + + rgb_pa.train() + loss, metrics = rgb_pa(x1, x2, y) + _, _, _, _ = metrics + assert tuple(loss.size()) == () + assert isinstance(_, float) + + rgb_pa.eval() + x = rgb_pa(x1, x2) + assert tuple(x.size()) == ( + sum(hpm_scales) + tfa_num_parts, N, embedding_dims + ) diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py new file mode 100644 index 0000000..242be45 --- /dev/null +++ b/utils/triplet_loss.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BatchAllTripletLoss(nn.Module): + def __init__(self, margin: float = 0.2): + super().__init__() + self.margin = margin + + def forward(self, x, y): + # Duplicate labels for each part + p, n, c = x.size() + y = y.repeat(p, 1) + + # Euclidean distance p x n x n + x_squared_sum = torch.sum(x ** 2, dim=2) + x1_squared_sum = x_squared_sum.unsqueeze(1) + x2_squared_sum = x_squared_sum.unsqueeze(2) + x1_times_x2_sum = x @ x.transpose(1, 2) + dist = torch.sqrt(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) + + hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) + hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) + all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1) + all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1) + positive_negative_dist = all_hard_positive - all_hard_negative + all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1) + + # Non-zero parted mean + parted_loss_mean = all_loss.sum(1) / (all_loss != 0).sum(1) + parted_loss_mean[parted_loss_mean == float('Inf')] = 0 + + loss = parted_loss_mean.mean() + return loss |