diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 36 | ||||
-rw-r--r-- | models/fpfe.py | 14 | ||||
-rw-r--r-- | models/layers.py | 49 |
3 files changed, 52 insertions, 47 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 1be878f..6483bd9 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -15,36 +15,32 @@ class Encoder(nn.Module): # Cx[HxW] # Conv1 3x64x32 -> 64x64x32 self.conv1 = VGGConv2d(in_channels, nf) - self.batch_norm1 = nn.BatchNorm2d(nf) # MaxPool1 64x64x32 -> 64x32x16 self.max_pool1 = nn.AdaptiveMaxPool2d((32, 16)) # Conv2 64x32x16 -> 256x32x16 self.conv2 = VGGConv2d(nf, nf * 4) - self.batch_norm2 = nn.BatchNorm2d(nf * 4) # MaxPool2 256x32x16 -> 256x16x8 self.max_pool2 = nn.AdaptiveMaxPool2d((16, 8)) # Conv3 256x16x8 -> 512x16x8 self.conv3 = VGGConv2d(nf * 4, nf * 8) - self.batch_norm3 = nn.BatchNorm2d(nf * 8) # Conv4 512x16x8 -> 512x16x8 (for large dataset) self.conv4 = VGGConv2d(nf * 8, nf * 8) - self.batch_norm4 = nn.BatchNorm2d(nf * 8) # MaxPool3 512x16x8 -> 512x4x2 self.max_pool3 = nn.AdaptiveMaxPool2d((4, 2)) # FC 512*4*2 -> 320 - self.fc = nn.Linear(nf * 8 * 2 * 4, self.em_dim) - self.batch_norm_fc = nn.BatchNorm1d(self.em_dim) + self.fc = nn.Linear(nf * 8 * 2 * 4, self.em_dim, bias=False) + self.bn_fc = nn.BatchNorm1d(self.em_dim) def forward(self, x): - x = F.leaky_relu(self.batch_norm1(self.conv1(x)), 0.2, inplace=True) + x = self.conv1(x) x = self.max_pool1(x) - x = F.leaky_relu(self.batch_norm2(self.conv2(x)), 0.2, inplace=True) + x = self.conv2(x) x = self.max_pool2(x) - x = F.leaky_relu(self.batch_norm3(self.conv3(x)), 0.2, inplace=True) - x = F.leaky_relu(self.batch_norm4(self.conv4(x)), 0.2, inplace=True) + x = self.conv3(x) + x = self.conv4(x) x = self.max_pool3(x) x = x.view(-1, (64 * 8) * 2 * 4) - embedding = self.batch_norm_fc(self.fc(x)) + embedding = self.bn_fc(self.fc(x)) fa, fgs, fgd = embedding.split( (self.opt.fa_dim, self.opt.fg_dim / 2, self.opt.fg_dim / 2), dim=1 @@ -61,26 +57,24 @@ class Decoder(nn.Module): # Cx[HxW] # FC 320 -> 512*4*2 self.fc = nn.Linear(self.em_dim, nf * 8 * 2 * 4) - self.batch_norm_fc = nn.BatchNorm1d(nf * 8 * 2 * 4) + self.bn_fc = nn.BatchNorm1d(nf * 8 * 2 * 4) # TransConv1 512x4x2 -> 256x8x4 self.trans_conv1 = DCGANConvTranspose2d(nf * 8, nf * 4) - self.batch_norm1 = nn.BatchNorm2d(nf * 4) # TransConv2 256x8x4 -> 128x16x8 self.trans_conv2 = DCGANConvTranspose2d(nf * 4, nf * 2) - self.batch_norm2 = nn.BatchNorm2d(nf * 2) # TransConv3 128x16x8 -> 64x32x16 self.trans_conv3 = DCGANConvTranspose2d(nf * 2, nf) - self.batch_norm3 = nn.BatchNorm2d(nf) # TransConv4 3x32x16 - self.trans_conv4 = DCGANConvTranspose2d(nf, out_channels) + self.trans_conv4 = DCGANConvTranspose2d(nf, out_channels, + is_last_layer=True) def forward(self, fa, fgs, fgd): x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim) - x = F.relu(self.batch_norm_fc(self.fc(x)), True) - x = x.view(-1, 64 * 8, 4, 2) - x = F.relu(self.batch_norm1(self.trans_conv1(x)), True) - x = F.relu(self.batch_norm2(self.trans_conv2(x)), True) - x = F.relu(self.batch_norm3(self.trans_conv3(x)), True) + x = self.bn_fc(self.fc(x)) + x = F.relu(x.view(-1, 64 * 8, 4, 2), True) + x = self.trans_conv1(x) + x = self.trans_conv2(x) + x = self.trans_conv3(x) x = F.sigmoid(self.trans_conv4(x)) return x diff --git a/models/fpfe.py b/models/fpfe.py index 7d04655..29f7a5f 100644 --- a/models/fpfe.py +++ b/models/fpfe.py @@ -25,13 +25,15 @@ class FrameLevelPartFeatureExtractor(nn.Module): self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): - x = F.leaky_relu(self.focal_conv1(x), inplace=True) - x = F.leaky_relu(self.focal_conv2(x), inplace=True) + x = self.focal_conv1(x) + x = self.focal_conv2(x) x = self.max_pool(x) - x = F.leaky_relu(self.focal_conv3(x), inplace=True) - x = F.leaky_relu(self.focal_conv4(x), inplace=True) + + x = self.focal_conv3(x) + x = self.focal_conv4(x) x = self.max_pool(x) - x = F.leaky_relu(self.focal_conv5(x), inplace=True) - x = F.leaky_relu(self.focal_conv6(x), inplace=True) + + x = self.focal_conv5(x) + x = self.focal_conv6(x) return x diff --git a/models/layers.py b/models/layers.py index f3ccbeb..2dc87a1 100644 --- a/models/layers.py +++ b/models/layers.py @@ -1,6 +1,7 @@ from typing import Union, Tuple import torch +import torch.nn.functional as F import torch.nn as nn @@ -12,15 +13,18 @@ class BasicConv2d(nn.Module): kernel_size: Union[int, Tuple[int, int]], **kwargs ): - super(BasicConv2d, self).__init__() + super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): - return self.conv(x) + x = self.conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) -class VGGConv2d(nn.Module): +class VGGConv2d(BasicConv2d): def __init__( self, in_channels: int, @@ -29,12 +33,13 @@ class VGGConv2d(nn.Module): padding: int = 1, **kwargs ): - super(VGGConv2d, self).__init__() - self.conv = BasicConv2d(in_channels, out_channels, kernel_size, - padding=padding, **kwargs) + super().__init__(in_channels, out_channels, kernel_size, + padding=padding, **kwargs) def forward(self, x): - return self.conv(x) + x = self.conv(x) + x = self.bn(x) + return F.leaky_relu(x, 0.2, inplace=True) class BasicConvTranspose2d(nn.Module): @@ -45,15 +50,18 @@ class BasicConvTranspose2d(nn.Module): kernel_size: Union[int, Tuple[int, int]], **kwargs ): - super(BasicConvTranspose2d, self).__init__() + super().__init__() self.trans_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, bias=False, **kwargs) + self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): - return self.trans_conv(x) + x = self.trans_conv(x) + x = self.bn(x) + return F.relu(x, inplace=True) -class DCGANConvTranspose2d(nn.Module): +class DCGANConvTranspose2d(BasicConvTranspose2d): def __init__( self, in_channels: int, @@ -61,18 +69,21 @@ class DCGANConvTranspose2d(nn.Module): kernel_size: Union[int, Tuple[int, int]] = 4, stride: int = 2, padding: int = 1, + is_last_layer: bool = False, **kwargs ): - super(DCGANConvTranspose2d).__init__() - self.trans_conv = BasicConvTranspose2d(in_channels, out_channels, - kernel_size, stride=stride, - padding=padding, **kwargs) + super().__init__(in_channels, out_channels, kernel_size, + stride=stride, padding=padding, **kwargs) + self.is_last_layer = is_last_layer def forward(self, x): - return self.trans_conv(x) + if self.is_last_layer: + return self.trans_conv(x) + else: + return super().forward(x) -class FocalConv2d(nn.Module): +class FocalConv2d(BasicConv2d): def __init__( self, in_channels: int, @@ -81,17 +92,15 @@ class FocalConv2d(nn.Module): halving: int, **kwargs ): - super(FocalConv2d, self).__init__() + super().__init__(in_channels, out_channels, kernel_size, **kwargs) self.halving = halving - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, - bias=False, **kwargs) def forward(self, x): h = x.size(2) split_size = h // 2 ** self.halving z = x.split(split_size, dim=2) z = torch.cat([self.conv(_) for _ in z], dim=2) - return z + return F.leaky_relu(z, inplace=True) class BasicConv1d(nn.Module): |