summaryrefslogtreecommitdiff
path: root/models/auto_encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/auto_encoder.py')
-rw-r--r--models/auto_encoder.py86
1 files changed, 56 insertions, 30 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index dc7843a..aa55a42 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -12,42 +12,53 @@ class Encoder(nn.Module):
self,
in_channels: int = 3,
frame_size: tuple[int, int] = (64, 48),
- feature_channels: int = 64,
- output_dims: tuple[int, int, int] = (192, 192, 128)
+ feature_channels: int = 32,
+ output_dims: tuple[int, int, int] = (48, 48, 32)
):
super().__init__()
h_0, w_0 = frame_size
h_1, w_1 = h_0 // 2, w_0 // 2
- h_2, w_2 = h_1 // 2, w_1 // 2
+ h_2, w_2 = h_1 // 4, w_1 // 4
# Appearance features, canonical features, pose features
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = output_dims
# Conv1 in_channels x H x W
# -> feature_map_size x H x W
- self.conv1 = VGGConv2d(in_channels, feature_channels)
+ self.conv1 = VGGConv2d(in_channels,
+ feature_channels,
+ kernel_size=5,
+ padding=2)
+ # Conv2 feature_map_size x H x W
+ # -> feature_map_size x H x W
+ self.conv2 = VGGConv2d(feature_channels, feature_channels)
# MaxPool1 feature_map_size x H x W
# -> feature_map_size x H//2 x W//2
self.max_pool1 = nn.AdaptiveMaxPool2d((h_1, w_1))
- # Conv2 feature_map_size x H//2 x W//2
- # -> feature_map_size*4 x H//2 x W//2
- self.conv2 = VGGConv2d(feature_channels, feature_channels * 4)
- # MaxPool2 feature_map_size*4 x H//2 x W//2
- # -> feature_map_size*4 x H//4 x W//4
+ # Conv3 feature_map_size x H//2 x W//2
+ # -> feature_map_size*2 x H//2 x W//2
+ self.conv3 = VGGConv2d(feature_channels, feature_channels * 2)
+ # Conv4 feature_map_size*2 x H//2 x W//2
+ # -> feature_map_size*2 x H//2 x W//2
+ self.conv4 = VGGConv2d(feature_channels * 2, feature_channels * 2)
+ # MaxPool2 feature_map_size*2 x H//2 x W//2
+ # -> feature_map_size*2 x H//8 x W//8
self.max_pool2 = nn.AdaptiveMaxPool2d((h_2, w_2))
- # Conv3 feature_map_size*4 x H//4 x W//4
- # -> feature_map_size*8 x H//4 x W//4
- self.conv3 = VGGConv2d(feature_channels * 4, feature_channels * 8)
- # Conv4 feature_map_size*8 x H//4 x W//4
- # -> feature_map_size*8 x H//4 x W//4 (for large dataset)
- self.conv4 = VGGConv2d(feature_channels * 8, feature_channels * 8)
+ # Conv5 feature_map_size*2 x H//8 x W//8
+ # -> feature_map_size*4 x H//8 x W//8
+ self.conv5 = VGGConv2d(feature_channels * 2, feature_channels * 4)
+ # Conv6 feature_map_size*4 x H//8 x W//8
+ # -> feature_map_size*4 x H//8 x W//8
+ self.conv6 = VGGConv2d(feature_channels * 4, feature_channels * 4)
def forward(self, x):
x = self.conv1(x)
- x = self.max_pool1(x)
x = self.conv2(x)
- x = self.max_pool2(x)
+ x = self.max_pool1(x)
x = self.conv3(x)
x = self.conv4(x)
+ x = self.max_pool2(x)
+ x = self.conv5(x)
+ x = self.conv6(x)
f_appearance, f_canonical, f_pose = x.split(
(self.f_a_dim, self.f_c_dim, self.f_p_dim), dim=1
)
@@ -59,41 +70,56 @@ class Decoder(nn.Module):
def __init__(
self,
- feature_channels: int = 64,
+ feature_channels: int = 32,
out_channels: int = 3,
):
super().__init__()
self.feature_channels = feature_channels
- # TransConv1 feature_map_size*8 x H x W
+ # TransConv1 feature_map_size*4 x H x W
# -> feature_map_size*4 x H x W
- self.trans_conv1 = DCGANConvTranspose2d(feature_channels * 8,
+ self.trans_conv1 = DCGANConvTranspose2d(feature_channels * 4,
feature_channels * 4,
kernel_size=3,
stride=1,
padding=1)
- # TransConv2 feature_map_size*4 x H x W
+ # TransConv2 feature_map_size*4 x H x W
# -> feature_map_size*2 x H*2 x W*2
self.trans_conv2 = DCGANConvTranspose2d(feature_channels * 4,
feature_channels * 2)
# TransConv3 feature_map_size*2 x H*2 x W*2
- # -> feature_map_size x H*2 x W*2
+ # -> feature_map_size*2 x H*2 x W*2
self.trans_conv3 = DCGANConvTranspose2d(feature_channels * 2,
+ feature_channels * 2,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ # TransConv4 feature_map_size*2 x H*2 x W*2
+ # -> feature_map_size x H*4 x W*4
+ self.trans_conv4 = DCGANConvTranspose2d(feature_channels * 2,
+ feature_channels)
+
+ # TransConv5 feature_map_size x H*4 x W*4
+ # -> feature_map_size x H*4 x W*4
+ self.trans_conv5 = DCGANConvTranspose2d(feature_channels,
feature_channels,
kernel_size=3,
stride=1,
padding=1)
- # TransConv4 feature_map_size x H*2 x W*2
- # -> in_channels x H*4 x W*4
- self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
- is_last_layer=True)
+
+ # TransConv6 feature_map_size x H*4 x W*4
+ # -> out_channels x H*8 x W*8
+ self.trans_conv6 = DCGANConvTranspose2d(feature_channels,
+ out_channels)
def forward(self, f_appearance, f_canonical, f_pose):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.trans_conv1(x)
x = self.trans_conv2(x)
x = self.trans_conv3(x)
- x = torch.sigmoid(self.trans_conv4(x))
+ x = self.trans_conv4(x)
+ x = self.trans_conv5(x)
+ x = torch.sigmoid(self.trans_conv6(x))
return x
@@ -103,8 +129,8 @@ class AutoEncoder(nn.Module):
self,
channels: int = 3,
frame_size: tuple[int, int] = (64, 48),
- feature_channels: int = 64,
- embedding_dims: tuple[int, int, int] = (192, 192, 128)
+ feature_channels: int = 32,
+ embedding_dims: tuple[int, int, int] = (48, 48, 32)
):
super().__init__()
self.embedding_dims = embedding_dims
@@ -118,7 +144,7 @@ class AutoEncoder(nn.Module):
x_c1_t2_ = x_c1_t2.view(n * t, c, h, w)
(f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_) = self.encoder(x_c1_t2_)
- f_size = [torch.Size([n, t, embedding_dim, h // 4, w // 4])
+ f_size = [torch.Size([n, t, embedding_dim, h // 8, w // 8])
for embedding_dim in self.embedding_dims]
f_a_c1_t2 = f_a_c1_t2_.view(f_size[0])
f_c_c1_t2 = f_c_c1_t2_.view(f_size[1])