summaryrefslogtreecommitdiff
path: root/models/auto_encoder.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-02 19:10:08 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-02 19:10:08 +0800
commit02aaefaba26b6842d2feb403edfd71aaa75904da (patch)
tree18744854e8d80e0239c0b2f3e7eaf39bc0a7974e /models/auto_encoder.py
parentde8561d1d053730c5af03e1d06850efb60865d3c (diff)
Correct feature dims after disentanglement and HPM backbone removal
1. Features used in HPM is decoded canonical embedding without transpose convolution 2. Decode pose embedding to image for Part Net 3. Backbone seems to be redundant, we can use feature map given by auto-decoder
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r--models/auto_encoder.py31
1 files changed, 25 insertions, 6 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index c84061c..234111a 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -95,10 +95,13 @@ class Decoder(nn.Module):
self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
- def forward(self, f_appearance, f_canonical, f_pose):
+ def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.fc(x)
x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True)
+ # Decode canonical features without transpose convolutions
+ if no_trans_conv:
+ return x
x = self.trans_conv1(x)
x = self.trans_conv2(x)
x = self.trans_conv3(x)
@@ -131,16 +134,32 @@ class AutoEncoder(nn.Module):
def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y):
# t1 is random time step
(f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
- (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
(_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)
x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)
xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_)
- y_ = self.classifier(f_c_c1_t2)
+ y_ = self.classifier(f_c_c1_t2.contiguous())
cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2)
+ self.mse_loss(f_c_c1_t2, f_c_c2_t2)
- + self.xent_loss(y, y_))
+ + self.xent_loss(y_, y))
- return ((f_c_c1_t2, f_p_c1_t2, f_p_c2_t2),
- xrecon_loss_t2, cano_cons_loss_t2)
+ f_a_size, f_c_size, f_p_size = (
+ f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size()
+ )
+ # Decode canonical features for HPM
+ x_c_c1_t2 = self.decoder(
+ torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size),
+ no_trans_conv=True
+ )
+ # Decode pose features for Part Net
+ x_p_c1_t2 = self.decoder(
+ torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2
+ )
+
+ return (
+ (x_c_c1_t2, x_p_c1_t2),
+ (f_p_c1_t2, f_p_c2_t2),
+ (xrecon_loss_t2, cano_cons_loss_t2)
+ )