summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py119
-rw-r--r--models/layers.py16
2 files changed, 89 insertions, 46 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 4a8aaa5..701f299 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -2,34 +2,48 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
-from models.layers import VGGConv2d, DCGANConvTranspose2d
+from models.layers import VGGConv2d, DCGANConvTranspose2d, BasicLinear
class Encoder(nn.Module):
- def __init__(self, in_channels: int, opt):
- super(Encoder, self).__init__()
- self.opt = opt
- self.em_dim = opt.em_dim
- nf = 64
+ """Squeeze input feature map to lower dimension"""
- # Cx[HxW]
- # Conv1 3x64x32 -> 64x64x32
- self.conv1 = VGGConv2d(in_channels, nf)
- # MaxPool1 64x64x32 -> 64x32x16
+ def __init__(
+ self,
+ in_channels: int = 3,
+ feature_channels: int = 64,
+ output_dims: tuple[int, int, int] = (128, 128, 64)
+ ):
+ super().__init__()
+ self.feature_channels = feature_channels
+ # Appearance features, canonical features, pose features
+ (self.f_a_dim, self.f_c_dim, self.f_p_dim) = output_dims
+
+ # Conv1 in_channels x 64 x 32
+ # -> feature_map_size x 64 x 32
+ self.conv1 = VGGConv2d(in_channels, feature_channels)
+ # MaxPool1 feature_map_size x 64 x 32
+ # -> feature_map_size x 32 x 16
self.max_pool1 = nn.AdaptiveMaxPool2d((32, 16))
- # Conv2 64x32x16 -> 256x32x16
- self.conv2 = VGGConv2d(nf, nf * 4)
- # MaxPool2 256x32x16 -> 256x16x8
+ # Conv2 feature_map_size x 32 x 16
+ # -> (feature_map_size*4) x 32 x 16
+ self.conv2 = VGGConv2d(feature_channels, feature_channels * 4)
+ # MaxPool2 (feature_map_size*4) x 32 x 16
+ # -> (feature_map_size*4) x 16 x 8
self.max_pool2 = nn.AdaptiveMaxPool2d((16, 8))
- # Conv3 256x16x8 -> 512x16x8
- self.conv3 = VGGConv2d(nf * 4, nf * 8)
- # Conv4 512x16x8 -> 512x16x8 (for large dataset)
- self.conv4 = VGGConv2d(nf * 8, nf * 8)
- # MaxPool3 512x16x8 -> 512x4x2
+ # Conv3 (feature_map_size*4) x 16 x 8
+ # -> (feature_map_size*8) x 16 x 8
+ self.conv3 = VGGConv2d(feature_channels * 4, feature_channels * 8)
+ # Conv4 (feature_map_size*8) x 16 x 8
+ # -> (feature_map_size*8) x 16 x 8 (for large dataset)
+ self.conv4 = VGGConv2d(feature_channels * 8, feature_channels * 8)
+ # MaxPool3 (feature_map_size*8) x 16 x 8
+ # -> (feature_map_size*8) x 4 x 2
self.max_pool3 = nn.AdaptiveMaxPool2d((4, 2))
- # FC 512*4*2 -> 320
- self.fc = nn.Linear(nf * 8 * 2 * 4, self.em_dim, bias=False)
- self.bn_fc = nn.BatchNorm1d(self.em_dim)
+
+ embedding_dim = sum(output_dims)
+ # FC (feature_map_size*8) * 4 * 2 -> 320
+ self.fc = BasicLinear(feature_channels * 8 * 2 * 4, embedding_dim)
def forward(self, x):
x = self.conv1(x)
@@ -39,39 +53,52 @@ class Encoder(nn.Module):
x = self.conv3(x)
x = self.conv4(x)
x = self.max_pool3(x)
- x = x.view(-1, (64 * 8) * 2 * 4)
- embedding = self.bn_fc(self.fc(x))
+ x = x.view(-1, (self.feature_channels * 8) * 2 * 4)
+ embedding = self.fc(x)
- fa, fgs, fgd = embedding.split(
- (self.opt.fa_dim, self.opt.fg_dim / 2, self.opt.fg_dim / 2), dim=1
+ f_appearance, f_canonical, f_pose = embedding.split(
+ (self.f_a_dim, self.f_c_dim, self.f_p_dim), dim=1
)
- return fa, fgs, fgd
+ return f_appearance, f_canonical, f_pose
class Decoder(nn.Module):
- def __init__(self, out_channels: int, opt):
- super(Decoder, self).__init__()
- self.em_dim = opt.em_dim
- nf = 64
+ """Upscale embedding to original image"""
+
+ def __init__(
+ self,
+ out_channels: int,
+ feature_channels: int = 64,
+ input_dims: tuple[int, int, int] = (128, 128, 64)
+ ):
+ super().__init__()
+ self.feature_channels = feature_channels
+
+ embedding_dim = sum(input_dims)
+ # FC 320 -> (feature_map_size*8) * 4 * 2
+ self.fc = BasicLinear(embedding_dim, feature_channels * 8 * 2 * 4)
- # Cx[HxW]
- # FC 320 -> 512*4*2
- self.fc = nn.Linear(self.em_dim, nf * 8 * 2 * 4)
- self.bn_fc = nn.BatchNorm1d(nf * 8 * 2 * 4)
- # TransConv1 512x4x2 -> 256x8x4
- self.trans_conv1 = DCGANConvTranspose2d(nf * 8, nf * 4)
- # TransConv2 256x8x4 -> 128x16x8
- self.trans_conv2 = DCGANConvTranspose2d(nf * 4, nf * 2)
- # TransConv3 128x16x8 -> 64x32x16
- self.trans_conv3 = DCGANConvTranspose2d(nf * 2, nf)
- # TransConv4 3x32x16
- self.trans_conv4 = DCGANConvTranspose2d(nf, out_channels,
+ # TransConv1 (feature_map_size*8) x 4 x 2
+ # -> (feature_map_size*4) x 8 x 4
+ self.trans_conv1 = DCGANConvTranspose2d(feature_channels * 8,
+ feature_channels * 4)
+ # TransConv2 (feature_map_size*4) x 8 x 4
+ # -> (feature_map_size*2) x 16 x 8
+ self.trans_conv2 = DCGANConvTranspose2d(feature_channels * 4,
+ feature_channels * 2)
+ # TransConv3 (feature_map_size*2) x 16 x 8
+ # -> feature_map_size x 32 x 16
+ self.trans_conv3 = DCGANConvTranspose2d(feature_channels * 2,
+ feature_channels)
+ # TransConv4 feature_map_size x 32 x 16
+ # -> in_channels x 64 x 32
+ self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
- def forward(self, fa, fgs, fgd):
- x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim)
- x = self.bn_fc(self.fc(x))
- x = F.relu(x.view(-1, 64 * 8, 4, 2), True)
+ def forward(self, f_appearance, f_canonical, f_pose):
+ x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
+ x = self.fc(x)
+ x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True)
x = self.trans_conv1(x)
x = self.trans_conv2(x)
x = self.trans_conv3(x)
diff --git a/models/layers.py b/models/layers.py
index cba6e47..f824078 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -83,6 +83,22 @@ class DCGANConvTranspose2d(BasicConvTranspose2d):
return super().forward(x)
+class BasicLinear(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ ):
+ super().__init__()
+ self.fc = nn.Linear(in_features, out_features, bias=False)
+ self.bn = nn.BatchNorm1d(out_features)
+
+ def forward(self, x):
+ x = self.fc(x)
+ x = self.bn(x)
+ return x
+
+
class FocalConv2d(BasicConv2d):
def __init__(
self,