summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-04-06 22:02:40 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-04-06 22:02:40 +0800
commitb54596ab5ce41110100214cd76ff50acc65b589a (patch)
treed418b5cbaa04444ee0b3d1c97a5ac347b4b75815 /models/rgb_part_net.py
parenta7d83c5447d0b81dc5e2a824c7ba638833fead86 (diff)
parent5e8947fbc90e1d67dadae36d32330a280d057267 (diff)
Merge branch 'master' into python3.8
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py39
1 files changed, 20 insertions, 19 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 811a711..ffee044 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -59,25 +59,26 @@ class RGBPartNet(nn.Module):
if self.training:
i_a, i_c, i_p = None, None, None
if self.image_log_on:
- f_a_mean = f_a.mean(1)
- i_a = self.ae.decoder(
- f_a_mean,
- torch.zeros_like(f_c_mean),
- torch.zeros_like(f_p[:, 0])
- )
- i_c = self.ae.decoder(
- torch.zeros_like(f_a_mean),
- f_c_mean,
- torch.zeros_like(f_p[:, 0])
- )
- f_p_size = f_p.size()
- i_p = self.ae.decoder(
- torch.zeros(f_p_size[0] * f_p_size[1], *f_a.shape[2:],
- device=f_a.device),
- torch.zeros(f_p_size[0] * f_p_size[1], *f_c.shape[2:],
- device=f_c.device),
- f_p.view(-1, *f_p_size[2:])
- ).view(x_c1.size())
+ with torch.no_grad():
+ f_a_mean = f_a.mean(1)
+ i_a = self.ae.decoder(
+ f_a_mean,
+ torch.zeros_like(f_c_mean),
+ torch.zeros_like(f_p[:, 0])
+ )
+ i_c = self.ae.decoder(
+ torch.zeros_like(f_a_mean),
+ f_c_mean,
+ torch.zeros_like(f_p[:, 0])
+ )
+ f_p_size = f_p.size()
+ i_p = self.ae.decoder(
+ torch.zeros(f_p_size[0] * f_p_size[1], *f_a.shape[2:],
+ device=f_a.device),
+ torch.zeros(f_p_size[0] * f_p_size[1], *f_c.shape[2:],
+ device=f_c.device),
+ f_p.view(-1, *f_p_size[2:])
+ ).view(x_c1.size())
return x_c, x_p, ae_losses, (i_a, i_c, i_p)
else:
return x_c, x_p