From 6fb1c7cb34a65769c018a08324387af419355b32 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 15 Feb 2021 12:03:07 +0800 Subject: Add DataParallel support on new codebase --- models/rgb_part_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'models/rgb_part_net.py') diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 2aa680c..66609fd 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -83,7 +83,7 @@ class RGBPartNet(nn.Module): pn_ba_trip = self.pn_ba_trip( x[self.hpm_num_parts:], y[self.hpm_num_parts:] ) - losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) + losses = (*losses, hpm_ba_trip, pn_ba_trip) return losses, images else: return x.unsqueeze(1).view(-1) -- cgit v1.2.3