From 5e8947fbc90e1d67dadae36d32330a280d057267 Mon Sep 17 00:00:00 2001
From: Jordan Gong <jordan.gong@protonmail.com>
Date: Tue, 6 Apr 2021 22:02:00 +0800
Subject: Turn off gradient when decoding images

---
 models/rgb_part_net.py | 39 ++++++++++++++++++++-------------------
 1 file changed, 20 insertions(+), 19 deletions(-)

diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index ecc38c0..b0169e3 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -57,25 +57,26 @@ class RGBPartNet(nn.Module):
         if self.training:
             i_a, i_c, i_p = None, None, None
             if self.image_log_on:
-                f_a_mean = f_a.mean(1)
-                i_a = self.ae.decoder(
-                    f_a_mean,
-                    torch.zeros_like(f_c_mean),
-                    torch.zeros_like(f_p[:, 0])
-                )
-                i_c = self.ae.decoder(
-                    torch.zeros_like(f_a_mean),
-                    f_c_mean,
-                    torch.zeros_like(f_p[:, 0])
-                )
-                f_p_size = f_p.size()
-                i_p = self.ae.decoder(
-                    torch.zeros(f_p_size[0] * f_p_size[1], *f_a.shape[2:],
-                                device=f_a.device),
-                    torch.zeros(f_p_size[0] * f_p_size[1], *f_c.shape[2:],
-                                device=f_c.device),
-                    f_p.view(-1, *f_p_size[2:])
-                ).view(x_c1.size())
+                with torch.no_grad():
+                    f_a_mean = f_a.mean(1)
+                    i_a = self.ae.decoder(
+                        f_a_mean,
+                        torch.zeros_like(f_c_mean),
+                        torch.zeros_like(f_p[:, 0])
+                    )
+                    i_c = self.ae.decoder(
+                        torch.zeros_like(f_a_mean),
+                        f_c_mean,
+                        torch.zeros_like(f_p[:, 0])
+                    )
+                    f_p_size = f_p.size()
+                    i_p = self.ae.decoder(
+                        torch.zeros(f_p_size[0] * f_p_size[1], *f_a.shape[2:],
+                                    device=f_a.device),
+                        torch.zeros(f_p_size[0] * f_p_size[1], *f_c.shape[2:],
+                                    device=f_c.device),
+                        f_p.view(-1, *f_p_size[2:])
+                    ).view(x_c1.size())
             return x_c, x_p, ae_losses, (i_a, i_c, i_p)
         else:
             return x_c, x_p
-- 
cgit v1.2.3