diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-23 18:59:08 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2020-12-23 18:59:08 +0800 |
commit | 96f345d25237c7e616ea5f524a2fc2d340ed8aff (patch) | |
tree | 9791c394d147f39d45ecab2a1ea5f83b396f3568 | |
parent | 74a3df70a47630b7e95abc09197d23de9b81d4de (diff) |
Split modules to different files
-rw-r--r-- | models/__init__.py | 0 | ||||
-rw-r--r-- | models/auto_encoder.py | 87 | ||||
-rw-r--r-- | models/fpfe.py | 37 | ||||
-rw-r--r-- | models/layers.py | 27 | ||||
-rw-r--r-- | models/tfa.py | 79 | ||||
-rw-r--r-- | modules/layers.py | 220 |
6 files changed, 230 insertions, 220 deletions
diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/models/__init__.py diff --git a/models/auto_encoder.py b/models/auto_encoder.py new file mode 100644 index 0000000..e35ed23 --- /dev/null +++ b/models/auto_encoder.py @@ -0,0 +1,87 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + + +class Encoder(nn.Module): + def __init__(self, in_channels: int, opt): + super(Encoder, self).__init__() + self.opt = opt + self.em_dim = opt.em_dim + nf = 64 + + # Cx[HxW] + # Conv1 3x64x32 -> 64x64x32 + self.conv1 = nn.Conv2d(in_channels, nf, kernel_size=3, padding=1) + self.batch_norm1 = nn.BatchNorm2d(nf) + # MaxPool1 64x64x32 -> 64x32x16 + self.max_pool1 = nn.AdaptiveMaxPool2d((32, 16)) + # Conv2 64x32x16 -> 256x32x16 + self.conv2 = nn.Conv2d(nf, nf * 4, kernel_size=3, padding=1) + self.batch_norm2 = nn.BatchNorm2d(nf * 4) + # MaxPool2 256x32x16 -> 256x16x8 + self.max_pool2 = nn.AdaptiveMaxPool2d((16, 8)) + # Conv3 256x16x8 -> 512x16x8 + self.conv3 = nn.Conv2d(nf * 4, nf * 8, kernel_size=3, padding=1) + self.batch_norm3 = nn.BatchNorm2d(nf * 8) + # Conv4 512x16x8 -> 512x16x8 (for large dataset) + self.conv4 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, padding=1) + self.batch_norm4 = nn.BatchNorm2d(nf * 8) + # MaxPool3 512x16x8 -> 512x4x2 + self.max_pool3 = nn.AdaptiveMaxPool2d((4, 2)) + # FC 512*4*2 -> 320 + self.fc = nn.Linear(nf * 8 * 2 * 4, self.em_dim) + self.batch_norm_fc = nn.BatchNorm1d(self.em_dim) + + def forward(self, x): + x = F.leaky_relu(self.batch_norm1(self.conv1(x)), 0.2) + x = self.max_pool1(x) + x = F.leaky_relu(self.batch_norm2(self.conv2(x)), 0.2) + x = self.max_pool2(x) + x = F.leaky_relu(self.batch_norm3(self.conv3(x)), 0.2) + x = F.leaky_relu(self.batch_norm4(self.conv4(x)), 0.2) + x = self.max_pool3(x) + x = x.view(-1, (64 * 8) * 2 * 4) + embedding = self.batch_norm_fc(self.fc(x)) + + fa, fgs, fgd = embedding.split( + (self.opt.fa_dim, self.opt.fg_dim / 2, self.opt.fg_dim / 2), dim=1 + ) + return fa, fgs, fgd + + +class Decoder(nn.Module): + def __init__(self, out_channels: int, opt): + super(Decoder, self).__init__() + self.em_dim = opt.em_dim + nf = 64 + + # Cx[HxW] + # FC 320 -> 512*4*2 + self.fc = nn.Linear(self.em_dim, nf * 8 * 2 * 4) + self.batch_norm_fc = nn.BatchNorm1d(nf * 8 * 2 * 4) + # TransConv1 512x4x2 -> 256x8x4 + self.trans_conv1 = nn.ConvTranspose2d(nf * 8, nf * 4, kernel_size=4, + stride=2, padding=1) + self.batch_norm1 = nn.BatchNorm2d(nf * 4) + # TransConv2 256x8x4 -> 128x16x8 + self.trans_conv2 = nn.ConvTranspose2d(nf * 4, nf * 2, kernel_size=4, + stride=2, padding=1) + self.batch_norm2 = nn.BatchNorm2d(nf * 2) + # TransConv3 128x16x8 -> 64x32x16 + self.trans_conv3 = nn.ConvTranspose2d(nf * 2, nf, kernel_size=4, + stride=2, padding=1) + self.batch_norm3 = nn.BatchNorm2d(nf) + # TransConv4 3x32x16 + self.trans_conv4 = nn.ConvTranspose2d(nf, out_channels, kernel_size=4, + stride=2, padding=1) + + def forward(self, fa, fgs, fgd): + x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim) + x = F.leaky_relu(self.batch_norm_fc(self.fc(x)), 0.2) + x = F.leaky_relu(self.batch_norm1(self.trans_conv1(x)), 0.2) + x = F.leaky_relu(self.batch_norm2(self.trans_conv2(x)), 0.2) + x = F.leaky_relu(self.batch_norm3(self.trans_conv3(x)), 0.2) + x = F.sigmoid(self.trans_conv4(x)) + + return x
\ No newline at end of file diff --git a/models/fpfe.py b/models/fpfe.py new file mode 100644 index 0000000..4a5cde8 --- /dev/null +++ b/models/fpfe.py @@ -0,0 +1,37 @@ +from torch import nn as nn +from torch.nn import functional as F + +from models.layers import FocalConv2d + + +class FrameLevelPartFeatureExtractor(nn.Module): + + def __init__(self, in_channels: int): + super(FrameLevelPartFeatureExtractor, self).__init__() + nf = 32 + + self.focal_conv1 = FocalConv2d(in_channels, nf, kernel_size=5, + padding=2, halving=1) + self.focal_conv2 = FocalConv2d(nf, nf, kernel_size=3, + padding=1, halving=1) + self.focal_conv3 = FocalConv2d(nf, nf * 2, kernel_size=3, + padding=1, halving=4) + self.focal_conv4 = FocalConv2d(nf * 2, nf * 2, kernel_size=3, + padding=1, halving=4) + self.focal_conv5 = FocalConv2d(nf * 2, nf * 4, kernel_size=3, + padding=1, halving=8) + self.focal_conv6 = FocalConv2d(nf * 4, nf * 4, kernel_size=3, + padding=1, halving=8) + self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) + + def forward(self, x): + x = F.leaky_relu(self.focal_conv1(x)) + x = F.leaky_relu(self.focal_conv2(x)) + x = self.max_pool(x) + x = F.leaky_relu(self.focal_conv3(x)) + x = F.leaky_relu(self.focal_conv4(x)) + x = self.max_pool(x) + x = F.leaky_relu(self.focal_conv5(x)) + x = F.leaky_relu(self.focal_conv6(x)) + + return x
\ No newline at end of file diff --git a/models/layers.py b/models/layers.py new file mode 100644 index 0000000..e737df2 --- /dev/null +++ b/models/layers.py @@ -0,0 +1,27 @@ +from typing import Union, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FocalConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + halving: int, + **kwargs + ): + super(FocalConv2d, self).__init__() + self.halving = halving + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, + bias=False, **kwargs) + + def forward(self, x): + h = x.size(2) + split_size = h // 2 ** self.halving + z = x.split(split_size, dim=2) + z = torch.cat([self.conv(_) for _ in z], dim=2) + return F.leaky_relu(z, inplace=True) diff --git a/models/tfa.py b/models/tfa.py new file mode 100644 index 0000000..2e4e656 --- /dev/null +++ b/models/tfa.py @@ -0,0 +1,79 @@ +import copy + +import torch +from torch import nn as nn + + +class TemporalFeatureAggregator(nn.Module): + def __init__( + self, + in_channels: int, + squeeze: int = 4, + num_part: int = 16 + ): + super(TemporalFeatureAggregator, self).__init__() + hidden_dim = in_channels // squeeze + self.num_part = num_part + + # MTB1 + conv3x1 = nn.Sequential( + nn.Conv1d(in_channels, hidden_dim, + kernel_size=3, padding=1, bias=False), + nn.LeakyReLU(inplace=True), + nn.Conv1d(hidden_dim, in_channels, + kernel_size=1, padding=0, bias=False) + ) + self.conv1d3x1 = self._parted(conv3x1) + self.avg_pool3x1 = nn.AvgPool1d(kernel_size=3, stride=1, padding=1) + self.max_pool3x1 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) + + # MTB2 + conv3x3 = nn.Sequential( + nn.Conv1d(in_channels, hidden_dim, + kernel_size=3, padding=1, bias=False), + nn.LeakyReLU(inplace=True), + nn.Conv1d(hidden_dim, in_channels, + kernel_size=3, padding=1, bias=False) + ) + self.conv1d3x3 = self._parted(conv3x3) + self.avg_pool3x3 = nn.AvgPool1d(kernel_size=5, stride=1, padding=2) + self.max_pool3x3 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2) + + def _parted(self, module: nn.Module): + """Duplicate module `part_num` times.""" + return nn.ModuleList([copy.deepcopy(module) + for _ in range(self.num_part)]) + + def forward(self, x): + """ + Input: x, [p, n, c, s] + """ + p, n, c, s = x.size() + feature = x.split(1, 0) + x = x.view(-1, c, s) + + # MTB1: ConvNet1d & Sigmoid + logits3x1 = torch.cat( + [conv(_.squeeze(0)).unsqueeze(0) + for conv, _ in zip(self.conv1d3x1, feature)], dim=0 + ) + scores3x1 = torch.sigmoid(logits3x1) + # MTB1: Template Function + feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x) + feature3x1 = feature3x1.view(p, n, c, s) + feature3x1 = feature3x1 * scores3x1 + + # MTB2: ConvNet1d & Sigmoid + logits3x3 = torch.cat( + [conv(_.squeeze(0)).unsqueeze(0) + for conv, _ in zip(self.conv1d3x3, feature)], dim=0 + ) + scores3x3 = torch.sigmoid(logits3x3) + # MTB2: Template Function + feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x) + feature3x3 = feature3x3.view(p, n, c, s) + feature3x3 = feature3x3 * scores3x3 + + # Temporal Pooling + ret = (feature3x1 + feature3x3).max(-1)[0] + return ret
\ No newline at end of file diff --git a/modules/layers.py b/modules/layers.py deleted file mode 100644 index a0116e2..0000000 --- a/modules/layers.py +++ /dev/null @@ -1,220 +0,0 @@ -import copy -from typing import Union, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Encoder(nn.Module): - def __init__(self, in_channels: int, opt): - super(Encoder, self).__init__() - self.opt = opt - self.em_dim = opt.em_dim - nf = 64 - - # Cx[HxW] - # Conv1 3x64x32 -> 64x64x32 - self.conv1 = nn.Conv2d(in_channels, nf, kernel_size=3, padding=1) - self.batch_norm1 = nn.BatchNorm2d(nf) - # MaxPool1 64x64x32 -> 64x32x16 - self.max_pool1 = nn.AdaptiveMaxPool2d((32, 16)) - # Conv2 64x32x16 -> 256x32x16 - self.conv2 = nn.Conv2d(nf, nf * 4, kernel_size=3, padding=1) - self.batch_norm2 = nn.BatchNorm2d(nf * 4) - # MaxPool2 256x32x16 -> 256x16x8 - self.max_pool2 = nn.AdaptiveMaxPool2d((16, 8)) - # Conv3 256x16x8 -> 512x16x8 - self.conv3 = nn.Conv2d(nf * 4, nf * 8, kernel_size=3, padding=1) - self.batch_norm3 = nn.BatchNorm2d(nf * 8) - # Conv4 512x16x8 -> 512x16x8 (for large dataset) - self.conv4 = nn.Conv2d(nf * 8, nf * 8, kernel_size=3, padding=1) - self.batch_norm4 = nn.BatchNorm2d(nf * 8) - # MaxPool3 512x16x8 -> 512x4x2 - self.max_pool3 = nn.AdaptiveMaxPool2d((4, 2)) - # FC 512*4*2 -> 320 - self.fc = nn.Linear(nf * 8 * 2 * 4, self.em_dim) - self.batch_norm_fc = nn.BatchNorm1d(self.em_dim) - - def forward(self, x): - x = F.leaky_relu(self.batch_norm1(self.conv1(x)), 0.2) - x = self.max_pool1(x) - x = F.leaky_relu(self.batch_norm2(self.conv2(x)), 0.2) - x = self.max_pool2(x) - x = F.leaky_relu(self.batch_norm3(self.conv3(x)), 0.2) - x = F.leaky_relu(self.batch_norm4(self.conv4(x)), 0.2) - x = self.max_pool3(x) - x = x.view(-1, (64 * 8) * 2 * 4) - embedding = self.batch_norm_fc(self.fc(x)) - - fa, fgs, fgd = embedding.split( - (self.opt.fa_dim, self.opt.fg_dim / 2, self.opt.fg_dim / 2), dim=1 - ) - return fa, fgs, fgd - - -class Decoder(nn.Module): - def __init__(self, out_channels: int, opt): - super(Decoder, self).__init__() - self.em_dim = opt.em_dim - nf = 64 - - # Cx[HxW] - # FC 320 -> 512*4*2 - self.fc = nn.Linear(self.em_dim, nf * 8 * 2 * 4) - self.batch_norm_fc = nn.BatchNorm1d(nf * 8 * 2 * 4) - # TransConv1 512x4x2 -> 256x8x4 - self.trans_conv1 = nn.ConvTranspose2d(nf * 8, nf * 4, kernel_size=4, - stride=2, padding=1) - self.batch_norm1 = nn.BatchNorm2d(nf * 4) - # TransConv2 256x8x4 -> 128x16x8 - self.trans_conv2 = nn.ConvTranspose2d(nf * 4, nf * 2, kernel_size=4, - stride=2, padding=1) - self.batch_norm2 = nn.BatchNorm2d(nf * 2) - # TransConv3 128x16x8 -> 64x32x16 - self.trans_conv3 = nn.ConvTranspose2d(nf * 2, nf, kernel_size=4, - stride=2, padding=1) - self.batch_norm3 = nn.BatchNorm2d(nf) - # TransConv4 3x32x16 - self.trans_conv4 = nn.ConvTranspose2d(nf, out_channels, kernel_size=4, - stride=2, padding=1) - - def forward(self, fa, fgs, fgd): - x = torch.cat((fa, fgs, fgd), dim=1).view(-1, self.em_dim) - x = F.leaky_relu(self.batch_norm_fc(self.fc(x)), 0.2) - x = F.leaky_relu(self.batch_norm1(self.trans_conv1(x)), 0.2) - x = F.leaky_relu(self.batch_norm2(self.trans_conv2(x)), 0.2) - x = F.leaky_relu(self.batch_norm3(self.trans_conv3(x)), 0.2) - x = F.sigmoid(self.trans_conv4(x)) - - return x - - -class FocalConv2d(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - halving: int, - **kwargs - ): - super(FocalConv2d, self).__init__() - self.halving = halving - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, - bias=False, **kwargs) - - def forward(self, x): - h = x.size(2) - split_size = h // 2 ** self.halving - z = x.split(split_size, dim=2) - z = torch.cat([self.conv(_) for _ in z], dim=2) - return F.leaky_relu(z, inplace=True) - - -class FrameLevelPartFeatureExtractor(nn.Module): - - def __init__(self, in_channels: int): - super(FrameLevelPartFeatureExtractor, self).__init__() - nf = 32 - - self.focal_conv1 = FocalConv2d(in_channels, nf, kernel_size=5, - padding=2, halving=1) - self.focal_conv2 = FocalConv2d(nf, nf, kernel_size=3, - padding=1, halving=1) - self.focal_conv3 = FocalConv2d(nf, nf * 2, kernel_size=3, - padding=1, halving=4) - self.focal_conv4 = FocalConv2d(nf * 2, nf * 2, kernel_size=3, - padding=1, halving=4) - self.focal_conv5 = FocalConv2d(nf * 2, nf * 4, kernel_size=3, - padding=1, halving=8) - self.focal_conv6 = FocalConv2d(nf * 4, nf * 4, kernel_size=3, - padding=1, halving=8) - self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) - - def forward(self, x): - x = F.leaky_relu(self.focal_conv1(x)) - x = F.leaky_relu(self.focal_conv2(x)) - x = self.max_pool(x) - x = F.leaky_relu(self.focal_conv3(x)) - x = F.leaky_relu(self.focal_conv4(x)) - x = self.max_pool(x) - x = F.leaky_relu(self.focal_conv5(x)) - x = F.leaky_relu(self.focal_conv6(x)) - - return x - - -class TemporalFeatureAggregator(nn.Module): - def __init__( - self, - in_channels: int, - squeeze: int = 4, - num_part: int = 16 - ): - super(TemporalFeatureAggregator, self).__init__() - hidden_dim = in_channels // squeeze - self.num_part = num_part - - # MTB1 - conv3x1 = nn.Sequential( - nn.Conv1d(in_channels, hidden_dim, - kernel_size=3, padding=1, bias=False), - nn.LeakyReLU(inplace=True), - nn.Conv1d(hidden_dim, in_channels, - kernel_size=1, padding=0, bias=False) - ) - self.conv1d3x1 = self._parted(conv3x1) - self.avg_pool3x1 = nn.AvgPool1d(kernel_size=3, stride=1, padding=1) - self.max_pool3x1 = nn.MaxPool1d(kernel_size=3, stride=1, padding=1) - - # MTB2 - conv3x3 = nn.Sequential( - nn.Conv1d(in_channels, hidden_dim, - kernel_size=3, padding=1, bias=False), - nn.LeakyReLU(inplace=True), - nn.Conv1d(hidden_dim, in_channels, - kernel_size=3, padding=1, bias=False) - ) - self.conv1d3x3 = self._parted(conv3x3) - self.avg_pool3x3 = nn.AvgPool1d(kernel_size=5, stride=1, padding=2) - self.max_pool3x3 = nn.MaxPool1d(kernel_size=5, stride=1, padding=2) - - def _parted(self, module: nn.Module): - """Duplicate module `part_num` times.""" - return nn.ModuleList([copy.deepcopy(module) - for _ in range(self.num_part)]) - - def forward(self, x): - """ - Input: x, [p, n, c, s] - """ - p, n, c, s = x.size() - feature = x.split(1, 0) - x = x.view(-1, c, s) - - # MTB1: ConvNet1d & Sigmoid - logits3x1 = torch.cat( - [conv(_.squeeze(0)).unsqueeze(0) - for conv, _ in zip(self.conv1d3x1, feature)], dim=0 - ) - scores3x1 = torch.sigmoid(logits3x1) - # MTB1: Template Function - feature3x1 = self.avg_pool3x1(x) + self.max_pool3x1(x) - feature3x1 = feature3x1.view(p, n, c, s) - feature3x1 = feature3x1 * scores3x1 - - # MTB2: ConvNet1d & Sigmoid - logits3x3 = torch.cat( - [conv(_.squeeze(0)).unsqueeze(0) - for conv, _ in zip(self.conv1d3x3, feature)], dim=0 - ) - scores3x3 = torch.sigmoid(logits3x3) - # MTB2: Template Function - feature3x3 = self.avg_pool3x3(x) + self.max_pool3x3(x) - feature3x3 = feature3x3.view(p, n, c, s) - feature3x3 = feature3x3 * scores3x3 - - # Temporal Pooling - ret = (feature3x1 + feature3x3).max(-1)[0] - return ret |