diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-06 22:02:00 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-06 22:02:00 +0800 | 
| commit | 5e8947fbc90e1d67dadae36d32330a280d057267 (patch) | |
| tree | f32519b4e6564ddfcdeefb697432646fedc82131 | |
| parent | 7f6c65fa954a43c9d248219525cad35fe1b8a046 (diff) | |
Turn off gradient when decoding images
| -rw-r--r-- | models/rgb_part_net.py | 39 | 
1 files changed, 20 insertions, 19 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index ecc38c0..b0169e3 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -57,25 +57,26 @@ class RGBPartNet(nn.Module):          if self.training:              i_a, i_c, i_p = None, None, None              if self.image_log_on: -                f_a_mean = f_a.mean(1) -                i_a = self.ae.decoder( -                    f_a_mean, -                    torch.zeros_like(f_c_mean), -                    torch.zeros_like(f_p[:, 0]) -                ) -                i_c = self.ae.decoder( -                    torch.zeros_like(f_a_mean), -                    f_c_mean, -                    torch.zeros_like(f_p[:, 0]) -                ) -                f_p_size = f_p.size() -                i_p = self.ae.decoder( -                    torch.zeros(f_p_size[0] * f_p_size[1], *f_a.shape[2:], -                                device=f_a.device), -                    torch.zeros(f_p_size[0] * f_p_size[1], *f_c.shape[2:], -                                device=f_c.device), -                    f_p.view(-1, *f_p_size[2:]) -                ).view(x_c1.size()) +                with torch.no_grad(): +                    f_a_mean = f_a.mean(1) +                    i_a = self.ae.decoder( +                        f_a_mean, +                        torch.zeros_like(f_c_mean), +                        torch.zeros_like(f_p[:, 0]) +                    ) +                    i_c = self.ae.decoder( +                        torch.zeros_like(f_a_mean), +                        f_c_mean, +                        torch.zeros_like(f_p[:, 0]) +                    ) +                    f_p_size = f_p.size() +                    i_p = self.ae.decoder( +                        torch.zeros(f_p_size[0] * f_p_size[1], *f_a.shape[2:], +                                    device=f_a.device), +                        torch.zeros(f_p_size[0] * f_p_size[1], *f_c.shape[2:], +                                    device=f_c.device), +                        f_p.view(-1, *f_p_size[2:]) +                    ).view(x_c1.size())              return x_c, x_p, ae_losses, (i_a, i_c, i_p)          else:              return x_c, x_p  | 
