summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-18 16:00:32 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-18 16:00:32 +0800
commit2a9507204ae2dd14504556ab5885c4f39bddd89a (patch)
tree112351f0844b4a7af24d92a91b327f2fa74f1b16 /models
parent641d19c1ebdf44486de139fadeff3276aecdf284 (diff)
Decode mean appearance feature
Diffstat (limited to 'models')
-rw-r--r--models/model.py8
-rw-r--r--models/rgb_part_net.py10
2 files changed, 9 insertions, 9 deletions
diff --git a/models/model.py b/models/model.py
index 0c3e5eb..139fd59 100644
--- a/models/model.py
+++ b/models/model.py
@@ -223,16 +223,16 @@ class Model:
if self.image_log_on:
i_a, i_c, i_p = images
self.writer.add_images(
+ 'Appearance image', i_a, self.curr_iter
+ )
+ self.writer.add_images(
'Canonical image', i_c, self.curr_iter
)
- for (i, (o, a, p)) in enumerate(zip(x_c1, i_a, i_p)):
+ for i, (o, p) in enumerate(zip(x_c1, i_p)):
self.writer.add_images(
f'Original image/batch {i}', o, self.curr_iter
)
self.writer.add_images(
- f'Appearance image/batch {i}', a, self.curr_iter
- )
- self.writer.add_images(
f'Pose image/batch {i}', p, self.curr_iter
)
time_used = datetime.now() - start_time
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 2aa680c..bf52efe 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -101,7 +101,7 @@ class RGBPartNet(nn.Module):
i_a, i_c, i_p = None, None, None
if self.image_log_on:
- i_a = self._decode_appr_feature(f_a_, n, t, c, h, w, device)
+ i_a = self._decode_appr_feature(f_a_, n, t, device)
# Continue decoding canonical features
i_c = self.ae.decoder.trans_conv3(x_c)
i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
@@ -115,14 +115,14 @@ class RGBPartNet(nn.Module):
x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
return (x_c, x_p), None, None
- def _decode_appr_feature(self, f_a_, n, t, c, h, w, device):
+ def _decode_appr_feature(self, f_a_, n, t, device):
# Decode appearance features
- x_a_ = self.ae.decoder(
- f_a_,
+ f_a = f_a_.view(n, t, -1)
+ x_a = self.ae.decoder(
+ f_a.mean(1),
torch.zeros((n * t, self.f_c_dim), device=device),
torch.zeros((n * t, self.f_p_dim), device=device)
)
- x_a = x_a_.view(n, t, c, h, w)
return x_a
def _decode_cano_feature(self, f_c_, n, t, device):