summaryrefslogtreecommitdiff
path: root/models/auto_encoder.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-08 18:31:52 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-08 18:31:52 +0800
commitd380e04df37593e414bd5641db100613fb2ad882 (patch)
tree1e3b3ea55a464d59d790711372bbca42cb203d0a /models/auto_encoder.py
parenta040400d7caa267d4bfbe8e5520568806f92b3d4 (diff)
parent99ddd7c142a4ec97cb8bd14b204651790b3cf4ee (diff)
Merge branch 'master' into python3.8
# Conflicts: # models/hpm.py # models/layers.py # models/model.py # models/rgb_part_net.py # utils/configuration.py
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r--models/auto_encoder.py36
1 files changed, 7 insertions, 29 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index befd2d3..69dae4e 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -97,15 +97,14 @@ class Decoder(nn.Module):
self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
- def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False):
+ def forward(self, f_appearance, f_canonical, f_pose, cano_only=False):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.fc(x)
x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True)
- # Decode canonical features without transpose convolutions
- if no_trans_conv:
- return x
x = self.trans_conv1(x)
x = self.trans_conv2(x)
+ if cano_only:
+ return x
x = self.trans_conv3(x)
x = torch.sigmoid(self.trans_conv4(x))
@@ -115,7 +114,6 @@ class Decoder(nn.Module):
class AutoEncoder(nn.Module):
def __init__(
self,
- num_class: int = 74,
channels: int = 3,
feature_channels: int = 64,
embedding_dims: Tuple[int, int, int] = (128, 128, 64)
@@ -124,27 +122,10 @@ class AutoEncoder(nn.Module):
self.encoder = Encoder(channels, feature_channels, embedding_dims)
self.decoder = Decoder(embedding_dims, feature_channels, channels)
- f_c_dim = embedding_dims[1]
- self.classifier = nn.Sequential(
- nn.LeakyReLU(0.2, inplace=True),
- BasicLinear(f_c_dim, num_class)
- )
-
- def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None, y=None):
+ def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None):
# x_c1_t2 is the frame for later module
(f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
- with torch.no_grad():
- # Decode canonical features for HPM
- x_c_c1_t2 = self.decoder(
- torch.zeros_like(f_a_c1_t2), f_c_c1_t2, torch.zeros_like(f_p_c1_t2),
- no_trans_conv=True
- )
- # Decode pose features for Part Net
- x_p_c1_t2 = self.decoder(
- torch.zeros_like(f_a_c1_t2), torch.zeros_like(f_c_c1_t2), f_p_c1_t2
- )
-
if self.training:
# t1 is random time step, c2 is another condition
(f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
@@ -152,16 +133,13 @@ class AutoEncoder(nn.Module):
x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)
xrecon_loss_t2 = F.mse_loss(x_c1_t2, x_c1_t2_)
-
- y_ = self.classifier(f_c_c1_t2.contiguous())
cano_cons_loss_t2 = (F.mse_loss(f_c_c1_t1, f_c_c1_t2)
- + F.mse_loss(f_c_c1_t2, f_c_c2_t2)
- + F.cross_entropy(y_, y))
+ + F.mse_loss(f_c_c1_t2, f_c_c2_t2))
return (
- (x_c_c1_t2, x_p_c1_t2),
+ (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2),
(f_p_c1_t2, f_p_c2_t2),
(xrecon_loss_t2, cano_cons_loss_t2)
)
else: # evaluating
- return x_c_c1_t2, x_p_c1_t2
+ return f_c_c1_t2, f_p_c1_t2