diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-28 15:45:31 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-28 15:45:31 +0800 |
commit | 6e94fdb587656074dc2e65a80e51b8446f834b41 (patch) | |
tree | 182541e36e549455d937d05832cf83400d28f9c6 /models | |
parent | ff04b0eb6f1e632d4487a04f6e5b4d8398accb16 (diff) |
Wrap the auto-encoder, return 3 losses at t2
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 42 |
1 files changed, 40 insertions, 2 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 601348e..feec5e2 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -67,9 +67,9 @@ class Decoder(nn.Module): def __init__( self, - out_channels: int = 3, + input_dims: tuple[int, int, int] = (128, 128, 64), feature_channels: int = 64, - input_dims: tuple[int, int, int] = (128, 128, 64) + out_channels: int = 3, ): super().__init__() self.feature_channels = feature_channels @@ -105,3 +105,41 @@ class Decoder(nn.Module): x = torch.sigmoid(self.trans_conv4(x)) return x + + +class AutoEncoder(nn.Module): + def __init__( + self, + num_class: int = 74, + channels: int = 3, + feature_channels: int = 64, + embedding_dims: tuple[int, int, int] = (128, 128, 64) + ): + super().__init__() + self.encoder = Encoder(channels, feature_channels, embedding_dims) + self.decoder = Decoder(embedding_dims, feature_channels, channels) + + f_c_dim = embedding_dims[1] + self.classifier = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + BasicLinear(f_c_dim, num_class) + ) + + self.mse_loss = nn.MSELoss() + self.xent_loss = nn.CrossEntropyLoss() + + def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y): + # t2 is random time step + (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) + (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) + (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2) + + x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2) + xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_) + + y_ = self.classifier(f_c_c1_t2) + cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2) + + self.mse_loss(f_c_c1_t2, f_c_c2_t2) + + self.xent_loss(y, y_)) + + return xrecon_loss_t2, (f_p_c1_t2, f_p_c2_t2), cano_cons_loss_t2 |