From 04c9d3210ff659bbe00dedb2d193a748e7a97b54 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 21 Jan 2021 23:32:53 +0800 Subject: Print average losses after 100 iters --- models/rgb_part_net.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'models/rgb_part_net.py') diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index f39b40b..e707c26 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -80,9 +80,8 @@ class RGBPartNet(nn.Module): if self.training: batch_all_triplet_loss = self.ba_triplet_loss(x, y) - losses = (*losses, batch_all_triplet_loss) - loss = torch.sum(torch.stack(losses)) - return loss, [loss.item() for loss in losses] + losses = torch.stack((*losses, batch_all_triplet_loss)) + return losses else: return x.unsqueeze(1).view(-1) -- cgit v1.2.3