diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-06 22:02:52 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-06 22:02:52 +0800 |
commit | b8892d3a3838fe6f5e18c9d76f16ea6368d715f2 (patch) | |
tree | e50ecce3955beb23242246a2eea290f22c8e5bd6 /models/rgb_part_net.py | |
parent | 1b3afdead66ac4995b83c7fa953bb041caa92051 (diff) | |
parent | b54596ab5ce41110100214cd76ff50acc65b589a (diff) |
Merge branch 'python3.8' into python3.7
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 39 |
1 files changed, 20 insertions, 19 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 811a711..ffee044 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -59,25 +59,26 @@ class RGBPartNet(nn.Module): if self.training: i_a, i_c, i_p = None, None, None if self.image_log_on: - f_a_mean = f_a.mean(1) - i_a = self.ae.decoder( - f_a_mean, - torch.zeros_like(f_c_mean), - torch.zeros_like(f_p[:, 0]) - ) - i_c = self.ae.decoder( - torch.zeros_like(f_a_mean), - f_c_mean, - torch.zeros_like(f_p[:, 0]) - ) - f_p_size = f_p.size() - i_p = self.ae.decoder( - torch.zeros(f_p_size[0] * f_p_size[1], *f_a.shape[2:], - device=f_a.device), - torch.zeros(f_p_size[0] * f_p_size[1], *f_c.shape[2:], - device=f_c.device), - f_p.view(-1, *f_p_size[2:]) - ).view(x_c1.size()) + with torch.no_grad(): + f_a_mean = f_a.mean(1) + i_a = self.ae.decoder( + f_a_mean, + torch.zeros_like(f_c_mean), + torch.zeros_like(f_p[:, 0]) + ) + i_c = self.ae.decoder( + torch.zeros_like(f_a_mean), + f_c_mean, + torch.zeros_like(f_p[:, 0]) + ) + f_p_size = f_p.size() + i_p = self.ae.decoder( + torch.zeros(f_p_size[0] * f_p_size[1], *f_a.shape[2:], + device=f_a.device), + torch.zeros(f_p_size[0] * f_p_size[1], *f_c.shape[2:], + device=f_c.device), + f_p.view(-1, *f_p_size[2:]) + ).view(x_c1.size()) return x_c, x_p, ae_losses, (i_a, i_c, i_p) else: return x_c, x_p |