summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 845a477..80b3e17 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -13,6 +13,7 @@ class RGBPartNet(nn.Module):
def __init__(
self,
ae_in_channels: int = 3,
+ ae_in_size: tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
f_a_c_p_dims: Tuple[int, int, int] = (128, 128, 64),
hpm_use_1x1conv: bool = False,
@@ -35,7 +36,7 @@ class RGBPartNet(nn.Module):
self.image_log_on = image_log_on
self.ae = AutoEncoder(
- ae_in_channels, ae_feature_channels, f_a_c_p_dims
+ ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
self.pn = PartNet(
ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes,
@@ -103,7 +104,7 @@ class RGBPartNet(nn.Module):
i_a, i_c, i_p = None, None, None
if self.image_log_on:
- i_a = self._decode_appr_feature(f_a_, n, t, c, h, w, device)
+ i_a = self._decode_appr_feature(f_a_, n, t, device)
# Continue decoding canonical features
i_c = self.ae.decoder.trans_conv3(x_c)
i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
@@ -117,14 +118,14 @@ class RGBPartNet(nn.Module):
x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
return (x_c, x_p), None, None
- def _decode_appr_feature(self, f_a_, n, t, c, h, w, device):
+ def _decode_appr_feature(self, f_a_, n, t, device):
# Decode appearance features
- x_a_ = self.ae.decoder(
- f_a_,
- torch.zeros((n * t, self.f_c_dim), device=device),
- torch.zeros((n * t, self.f_p_dim), device=device)
+ f_a = f_a_.view(n, t, -1)
+ x_a = self.ae.decoder(
+ f_a.mean(1),
+ torch.zeros((n, self.f_c_dim), device=device),
+ torch.zeros((n, self.f_p_dim), device=device)
)
- x_a = x_a_.view(n, t, c, h, w)
return x_a
def _decode_cano_feature(self, f_c_, n, t, device):