From 2a9507204ae2dd14504556ab5885c4f39bddd89a Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 18 Feb 2021 16:00:32 +0800 Subject: Decode mean appearance feature --- models/rgb_part_net.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'models/rgb_part_net.py') diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 2aa680c..bf52efe 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -101,7 +101,7 @@ class RGBPartNet(nn.Module): i_a, i_c, i_p = None, None, None if self.image_log_on: - i_a = self._decode_appr_feature(f_a_, n, t, c, h, w, device) + i_a = self._decode_appr_feature(f_a_, n, t, device) # Continue decoding canonical features i_c = self.ae.decoder.trans_conv3(x_c) i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c)) @@ -115,14 +115,14 @@ class RGBPartNet(nn.Module): x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device) return (x_c, x_p), None, None - def _decode_appr_feature(self, f_a_, n, t, c, h, w, device): + def _decode_appr_feature(self, f_a_, n, t, device): # Decode appearance features - x_a_ = self.ae.decoder( - f_a_, + f_a = f_a_.view(n, t, -1) + x_a = self.ae.decoder( + f_a.mean(1), torch.zeros((n * t, self.f_c_dim), device=device), torch.zeros((n * t, self.f_p_dim), device=device) ) - x_a = x_a_.view(n, t, c, h, w) return x_a def _decode_cano_feature(self, f_c_, n, t, device): -- cgit v1.2.3 From fd77ad26d5c4ede79e3406e736fcdaa29eb1c7c9 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 18 Feb 2021 16:01:26 +0800 Subject: Decode mean appearance feature --- models/rgb_part_net.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'models/rgb_part_net.py') diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index bf52efe..c3954bc 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -120,8 +120,8 @@ class RGBPartNet(nn.Module): f_a = f_a_.view(n, t, -1) x_a = self.ae.decoder( f_a.mean(1), - torch.zeros((n * t, self.f_c_dim), device=device), - torch.zeros((n * t, self.f_p_dim), device=device) + torch.zeros((n, self.f_c_dim), device=device), + torch.zeros((n, self.f_p_dim), device=device) ) return x_a -- cgit v1.2.3 From 84a3d5991f2f7272d1be54ad6cfe6ce695f915a0 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 18 Feb 2021 18:34:06 +0800 Subject: Implement adjustable input size and change some default configs --- models/rgb_part_net.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'models/rgb_part_net.py') diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index c3954bc..67acac3 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -11,6 +11,7 @@ class RGBPartNet(nn.Module): def __init__( self, ae_in_channels: int = 3, + ae_in_size: tuple[int, int] = (64, 48), ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), hpm_use_1x1conv: bool = False, @@ -33,7 +34,7 @@ class RGBPartNet(nn.Module): self.image_log_on = image_log_on self.ae = AutoEncoder( - ae_in_channels, ae_feature_channels, f_a_c_p_dims + ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) self.pn = PartNet( ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, -- cgit v1.2.3