summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-18 16:01:26 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-18 17:13:34 +0800
commitfd77ad26d5c4ede79e3406e736fcdaa29eb1c7c9 (patch)
tree163889734f97e4c2c94e607e9dd2312d1b8fc43d /models/rgb_part_net.py
parent2a9507204ae2dd14504556ab5885c4f39bddd89a (diff)
Decode mean appearance feature
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index bf52efe..c3954bc 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -120,8 +120,8 @@ class RGBPartNet(nn.Module):
f_a = f_a_.view(n, t, -1)
x_a = self.ae.decoder(
f_a.mean(1),
- torch.zeros((n * t, self.f_c_dim), device=device),
- torch.zeros((n * t, self.f_p_dim), device=device)
+ torch.zeros((n, self.f_c_dim), device=device),
+ torch.zeros((n, self.f_p_dim), device=device)
)
return x_a