diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-15 12:03:07 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-15 14:06:59 +0800 |
commit | 6fb1c7cb34a65769c018a08324387af419355b32 (patch) | |
tree | c940384179b7f492592fb11f066f0c816ee57ce6 /models/rgb_part_net.py | |
parent | d51312415a32686793d3f0d14eda7fa7cc3990ea (diff) |
Add DataParallel support on new codebase
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 2aa680c..66609fd 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -83,7 +83,7 @@ class RGBPartNet(nn.Module): pn_ba_trip = self.pn_ba_trip( x[self.hpm_num_parts:], y[self.hpm_num_parts:] ) - losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) + losses = (*losses, hpm_ba_trip, pn_ba_trip) return losses, images else: return x.unsqueeze(1).view(-1) |