summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/auto_encoder.py36
-rw-r--r--models/fpfe.py14
-rw-r--r--models/layers.py49
3 files changed, 52 insertions, 47 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 1be878f..6483bd9 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -15,36 +15,32 @@ class Encoder(nn.Module):
# Cx[HxW]
# Conv1 3x64x32 -> 64x64x32
self.conv1 = VGGConv2d(in_channels, nf)
- self.batch_norm1 = nn.BatchNorm2d(nf)
# MaxPool1 64x64x32 -> 64x32x16
self.max_pool1 = nn.AdaptiveMaxPool2d((32, 16))
# Conv2 64x32x16 -> 256x32x16
self.conv2 = VGGConv2d(nf, nf * 4)
- self.batch_norm2 = nn.BatchNorm2d(nf * 4)
# MaxPool2 256x32x16 -> 256x16x8
self.max_pool2 = nn.AdaptiveMaxPool2d((16, 8))
# Conv3 256x16x8 -> 512x16x8
self.conv3 = VGGConv2d(nf * 4, nf * 8)
- self.batch_norm3 = nn.BatchNorm2d(nf * 8)
# Conv4 512x16x8 -> 512x16x8 (for large dataset)
self.conv4 = VGGConv2d(nf * 8, nf * 8)
- self.batch_norm4 = nn.BatchNorm2d(nf * 8)
# MaxPool3 512x16x8 -> 512x4x2
self.max_pool3 = nn.AdaptiveMaxPool2d((4, 2))
# FC 512*4*2 -> 320
- self.fc = nn.Linear(nf * 8 * 2 * 4, self.em_dim)
- self.batch_norm_fc = nn.BatchNorm1d(self.em_dim)
+ self.fc = nn.Linear(nf * 8 * 2 * 4, self.em_dim, bias=False)
+ self.bn_fc = nn.BatchNorm1d(self.em_dim)
def forward(self, x):
- x = F.leaky_relu(self.batch_norm1(self.conv1(x)), 0.2, inplace=True)
+ x = self.conv1(x)
x = self.max_pool1(x)
- x = F.leaky_relu(self.batch_norm2(self.conv2(x)), 0.2, inplace=True)
+ x = self.conv2(x)
x = self.max_pool2(x)
- x = F.leaky_relu(self.batch_norm3(self.conv3(x)), 0.2, inplace=True)
- x = F.leaky_relu(self.batch_norm4(self.conv4(x)), 0.2, inplace=True)
+ x = self.conv3(x)
+ x = self.conv4(x)
x = self.max_pool3(x)
x = x.view(-1, (64 * 8) * 2 * 4)
- embedding = self.batch_norm_fc(self.fc(x))
+ embedding = self.bn_fc(self.fc(x))
fa, fgs, fgd = embedding.split(
(self.opt.fa_dim, self.opt.fg_dim / 2, self.opt.fg_dim / 2), dim=1
@@ -61,26 +57,24 @@ class Decoder(nn.Module):
# Cx[HxW]
# FC 320 -> 512*4*2
self.fc = nn.Linear(self.em_dim, nf * 8 * 2 * 4)
- self.batch_norm_fc = nn.BatchNorm1d(nf * 8 * 2 * 4)
+ self.bn_fc = nn.BatchNorm1d(nf * 8 * 2 * 4)
# TransConv1 512x4x2 -> 256x8x4
self.trans_conv1 = DCGANConvTranspose2d(nf * 8, nf * 4)
- self.batch_norm1 = nn.BatchNorm2d(nf * 4)
# TransConv2 256x8x4 -> 128x16x8
self.trans_conv2 = DCGANConvTranspose2d(nf * 4, nf * 2)
- self.batch_norm2 = nn.BatchNorm2d(nf * 2)
# TransConv3 128x16x8 -> 64x32x16
self.trans_conv3 = DCGANConvTranspose2d(nf * 2, nf)
- self.batch_norm3 = nn.BatchNorm2d(nf)
# TransConv4 3x32x16
- self.trans_conv4 = DCGANConvTranspose2d(nf, out_channels)
+ self.trans_conv4 = DCGANConvTranspose2d(nf, out_channels,
+ is_last_layer=True)
def forward(self, fa, fgs, fgd):
x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim)
- x = F.relu(self.batch_norm_fc(self.fc(x)), True)
- x = x.view(-1, 64 * 8, 4, 2)
- x = F.relu(self.batch_norm1(self.trans_conv1(x)), True)
- x = F.relu(self.batch_norm2(self.trans_conv2(x)), True)
- x = F.relu(self.batch_norm3(self.trans_conv3(x)), True)
+ x = self.bn_fc(self.fc(x))
+ x = F.relu(x.view(-1, 64 * 8, 4, 2), True)
+ x = self.trans_conv1(x)
+ x = self.trans_conv2(x)
+ x = self.trans_conv3(x)
x = F.sigmoid(self.trans_conv4(x))
return x
diff --git a/models/fpfe.py b/models/fpfe.py
index 7d04655..29f7a5f 100644
--- a/models/fpfe.py
+++ b/models/fpfe.py
@@ -25,13 +25,15 @@ class FrameLevelPartFeatureExtractor(nn.Module):
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
- x = F.leaky_relu(self.focal_conv1(x), inplace=True)
- x = F.leaky_relu(self.focal_conv2(x), inplace=True)
+ x = self.focal_conv1(x)
+ x = self.focal_conv2(x)
x = self.max_pool(x)
- x = F.leaky_relu(self.focal_conv3(x), inplace=True)
- x = F.leaky_relu(self.focal_conv4(x), inplace=True)
+
+ x = self.focal_conv3(x)
+ x = self.focal_conv4(x)
x = self.max_pool(x)
- x = F.leaky_relu(self.focal_conv5(x), inplace=True)
- x = F.leaky_relu(self.focal_conv6(x), inplace=True)
+
+ x = self.focal_conv5(x)
+ x = self.focal_conv6(x)
return x
diff --git a/models/layers.py b/models/layers.py
index f3ccbeb..2dc87a1 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -1,6 +1,7 @@
from typing import Union, Tuple
import torch
+import torch.nn.functional as F
import torch.nn as nn
@@ -12,15 +13,18 @@ class BasicConv2d(nn.Module):
kernel_size: Union[int, Tuple[int, int]],
**kwargs
):
- super(BasicConv2d, self).__init__()
+ super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
bias=False, **kwargs)
+ self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
- return self.conv(x)
+ x = self.conv(x)
+ x = self.bn(x)
+ return F.relu(x, inplace=True)
-class VGGConv2d(nn.Module):
+class VGGConv2d(BasicConv2d):
def __init__(
self,
in_channels: int,
@@ -29,12 +33,13 @@ class VGGConv2d(nn.Module):
padding: int = 1,
**kwargs
):
- super(VGGConv2d, self).__init__()
- self.conv = BasicConv2d(in_channels, out_channels, kernel_size,
- padding=padding, **kwargs)
+ super().__init__(in_channels, out_channels, kernel_size,
+ padding=padding, **kwargs)
def forward(self, x):
- return self.conv(x)
+ x = self.conv(x)
+ x = self.bn(x)
+ return F.leaky_relu(x, 0.2, inplace=True)
class BasicConvTranspose2d(nn.Module):
@@ -45,15 +50,18 @@ class BasicConvTranspose2d(nn.Module):
kernel_size: Union[int, Tuple[int, int]],
**kwargs
):
- super(BasicConvTranspose2d, self).__init__()
+ super().__init__()
self.trans_conv = nn.ConvTranspose2d(in_channels, out_channels,
kernel_size, bias=False, **kwargs)
+ self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
- return self.trans_conv(x)
+ x = self.trans_conv(x)
+ x = self.bn(x)
+ return F.relu(x, inplace=True)
-class DCGANConvTranspose2d(nn.Module):
+class DCGANConvTranspose2d(BasicConvTranspose2d):
def __init__(
self,
in_channels: int,
@@ -61,18 +69,21 @@ class DCGANConvTranspose2d(nn.Module):
kernel_size: Union[int, Tuple[int, int]] = 4,
stride: int = 2,
padding: int = 1,
+ is_last_layer: bool = False,
**kwargs
):
- super(DCGANConvTranspose2d).__init__()
- self.trans_conv = BasicConvTranspose2d(in_channels, out_channels,
- kernel_size, stride=stride,
- padding=padding, **kwargs)
+ super().__init__(in_channels, out_channels, kernel_size,
+ stride=stride, padding=padding, **kwargs)
+ self.is_last_layer = is_last_layer
def forward(self, x):
- return self.trans_conv(x)
+ if self.is_last_layer:
+ return self.trans_conv(x)
+ else:
+ return super().forward(x)
-class FocalConv2d(nn.Module):
+class FocalConv2d(BasicConv2d):
def __init__(
self,
in_channels: int,
@@ -81,17 +92,15 @@ class FocalConv2d(nn.Module):
halving: int,
**kwargs
):
- super(FocalConv2d, self).__init__()
+ super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.halving = halving
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
- bias=False, **kwargs)
def forward(self, x):
h = x.size(2)
split_size = h // 2 ** self.halving
z = x.split(split_size, dim=2)
z = torch.cat([self.conv(_) for _ in z], dim=2)
- return z
+ return F.leaky_relu(z, inplace=True)
class BasicConv1d(nn.Module):