summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2020-12-28 15:45:31 +0800
committerJordan Gong <jordan.gong@protonmail.com>2020-12-28 15:45:31 +0800
commit6e94fdb587656074dc2e65a80e51b8446f834b41 (patch)
tree182541e36e549455d937d05832cf83400d28f9c6
parentff04b0eb6f1e632d4487a04f6e5b4d8398accb16 (diff)
Wrap the auto-encoder, return 3 losses at t2
-rw-r--r--models/auto_encoder.py42
1 files changed, 40 insertions, 2 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 601348e..feec5e2 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -67,9 +67,9 @@ class Decoder(nn.Module):
def __init__(
self,
- out_channels: int = 3,
+ input_dims: tuple[int, int, int] = (128, 128, 64),
feature_channels: int = 64,
- input_dims: tuple[int, int, int] = (128, 128, 64)
+ out_channels: int = 3,
):
super().__init__()
self.feature_channels = feature_channels
@@ -105,3 +105,41 @@ class Decoder(nn.Module):
x = torch.sigmoid(self.trans_conv4(x))
return x
+
+
+class AutoEncoder(nn.Module):
+ def __init__(
+ self,
+ num_class: int = 74,
+ channels: int = 3,
+ feature_channels: int = 64,
+ embedding_dims: tuple[int, int, int] = (128, 128, 64)
+ ):
+ super().__init__()
+ self.encoder = Encoder(channels, feature_channels, embedding_dims)
+ self.decoder = Decoder(embedding_dims, feature_channels, channels)
+
+ f_c_dim = embedding_dims[1]
+ self.classifier = nn.Sequential(
+ nn.LeakyReLU(0.2, inplace=True),
+ BasicLinear(f_c_dim, num_class)
+ )
+
+ self.mse_loss = nn.MSELoss()
+ self.xent_loss = nn.CrossEntropyLoss()
+
+ def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y):
+ # t2 is random time step
+ (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
+ (_, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
+ (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)
+
+ x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)
+ xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_)
+
+ y_ = self.classifier(f_c_c1_t2)
+ cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2)
+ + self.mse_loss(f_c_c1_t2, f_c_c2_t2)
+ + self.xent_loss(y, y_))
+
+ return xrecon_loss_t2, (f_p_c1_t2, f_p_c2_t2), cano_cons_loss_t2