summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-18 18:37:13 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-18 18:37:13 +0800
commit5d4a7460597f2829fad04ab9a8157a0999904d44 (patch)
treeca02bc8c46241f39ce806c45d6e47a5672414c84 /models
parent5097f32e39ca66723808c21288d0c8698e3e313c (diff)
parent84a3d5991f2f7272d1be54ad6cfe6ce695f915a0 (diff)
Merge branch 'master' into data_parallel
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py81
-rw-r--r--models/layers.py8
-rw-r--r--models/model.py14
-rw-r--r--models/rgb_part_net.py17
4 files changed, 70 insertions, 50 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index a9312dd..2d715db 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -11,39 +11,46 @@ class Encoder(nn.Module):
def __init__(
self,
in_channels: int = 3,
+ frame_size: tuple[int, int] = (64, 48),
feature_channels: int = 64,
output_dims: tuple[int, int, int] = (128, 128, 64)
):
super().__init__()
self.feature_channels = feature_channels
+ h_0, w_0 = frame_size
+ h_1, w_1 = h_0 // 2, w_0 // 2
+ h_2, w_2 = h_1 // 2, w_1 // 2
+ self.feature_size = self.h_3, self.w_3 = h_2 // 4, w_2 // 4
# Appearance features, canonical features, pose features
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = output_dims
- # Conv1 in_channels x 64 x 32
- # -> feature_map_size x 64 x 32
+ # Conv1 in_channels x H x W
+ # -> feature_map_size x H x W
self.conv1 = VGGConv2d(in_channels, feature_channels)
- # MaxPool1 feature_map_size x 64 x 32
- # -> feature_map_size x 32 x 16
- self.max_pool1 = nn.AdaptiveMaxPool2d((32, 16))
- # Conv2 feature_map_size x 32 x 16
- # -> (feature_map_size*4) x 32 x 16
+ # MaxPool1 feature_map_size x H x W
+ # -> feature_map_size x H//2 x W//2
+ self.max_pool1 = nn.AdaptiveMaxPool2d((h_1, w_1))
+ # Conv2 feature_map_size x H//2 x W//2
+ # -> feature_map_size*4 x H//2 x W//2
self.conv2 = VGGConv2d(feature_channels, feature_channels * 4)
- # MaxPool2 (feature_map_size*4) x 32 x 16
- # -> (feature_map_size*4) x 16 x 8
- self.max_pool2 = nn.AdaptiveMaxPool2d((16, 8))
- # Conv3 (feature_map_size*4) x 16 x 8
- # -> (feature_map_size*8) x 16 x 8
+ # MaxPool2 feature_map_size*4 x H//2 x W//2
+ # -> feature_map_size*4 x H//4 x W//4
+ self.max_pool2 = nn.AdaptiveMaxPool2d((h_2, w_2))
+ # Conv3 feature_map_size*4 x H//4 x W//4
+ # -> feature_map_size*8 x H//4 x W//4
self.conv3 = VGGConv2d(feature_channels * 4, feature_channels * 8)
- # Conv4 (feature_map_size*8) x 16 x 8
- # -> (feature_map_size*8) x 16 x 8 (for large dataset)
+ # Conv4 feature_map_size*8 x H//4 x W//4
+ # -> feature_map_size*8 x H//4 x W//4 (for large dataset)
self.conv4 = VGGConv2d(feature_channels * 8, feature_channels * 8)
- # MaxPool3 (feature_map_size*8) x 16 x 8
- # -> (feature_map_size*8) x 4 x 2
- self.max_pool3 = nn.AdaptiveMaxPool2d((4, 2))
+ # MaxPool3 feature_map_size*8 x H//4 x W//4
+ # -> feature_map_size*8 x H//16 x W//16
+ self.max_pool3 = nn.AdaptiveMaxPool2d(self.feature_size)
embedding_dim = sum(output_dims)
- # FC (feature_map_size*8) * 4 * 2 -> 320
- self.fc = BasicLinear(feature_channels * 8 * 2 * 4, embedding_dim)
+ # FC feature_map_size*8 * H//16 * W//16 -> embedding_dim
+ self.fc = BasicLinear(
+ (feature_channels * 8) * self.h_3 * self.w_3, embedding_dim
+ )
def forward(self, x):
x = self.conv1(x)
@@ -53,7 +60,7 @@ class Encoder(nn.Module):
x = self.conv3(x)
x = self.conv4(x)
x = self.max_pool3(x)
- x = x.view(-1, (self.feature_channels * 8) * 2 * 4)
+ x = x.view(-1, (self.feature_channels * 8) * self.h_3 * self.w_3)
embedding = self.fc(x)
f_appearance, f_canonical, f_pose = embedding.split(
@@ -69,36 +76,41 @@ class Decoder(nn.Module):
self,
input_dims: tuple[int, int, int] = (128, 128, 64),
feature_channels: int = 64,
+ feature_size: tuple[int, int] = (4, 3),
out_channels: int = 3,
):
super().__init__()
self.feature_channels = feature_channels
+ self.h_0, self.w_0 = feature_size
embedding_dim = sum(input_dims)
- # FC 320 -> (feature_map_size*8) * 4 * 2
- self.fc = BasicLinear(embedding_dim, feature_channels * 8 * 2 * 4)
+ # FC 320 -> feature_map_size*8 * H * W
+ self.fc = BasicLinear(
+ embedding_dim, (feature_channels * 8) * self.h_0 * self.w_0
+ )
- # TransConv1 (feature_map_size*8) x 4 x 2
- # -> (feature_map_size*4) x 8 x 4
+ # TransConv1 feature_map_size*8 x H x W
+ # -> feature_map_size*4 x H*2 x W*2
self.trans_conv1 = DCGANConvTranspose2d(feature_channels * 8,
feature_channels * 4)
- # TransConv2 (feature_map_size*4) x 8 x 4
- # -> (feature_map_size*2) x 16 x 8
+ # TransConv2 feature_map_size*4 x H*2 x W*2
+ # -> feature_map_size*2 x H*4 x W*4
self.trans_conv2 = DCGANConvTranspose2d(feature_channels * 4,
feature_channels * 2)
- # TransConv3 (feature_map_size*2) x 16 x 8
- # -> feature_map_size x 32 x 16
+ # TransConv3 feature_map_size*2 x H*4 x W*4
+ # -> feature_map_size x H*8 x W*8
self.trans_conv3 = DCGANConvTranspose2d(feature_channels * 2,
feature_channels)
- # TransConv4 feature_map_size x 32 x 16
- # -> in_channels x 64 x 32
+ # TransConv4 feature_map_size x H*8 x W*8
+ # -> in_channels x H*16 x W*16
self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
def forward(self, f_appearance, f_canonical, f_pose, cano_only=False):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.fc(x)
- x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True)
+ x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0)
+ x = F.relu(x, inplace=True)
x = self.trans_conv1(x)
x = self.trans_conv2(x)
if cano_only:
@@ -113,12 +125,15 @@ class AutoEncoder(nn.Module):
def __init__(
self,
channels: int = 3,
+ frame_size: tuple[int, int] = (64, 48),
feature_channels: int = 64,
embedding_dims: tuple[int, int, int] = (128, 128, 64)
):
super().__init__()
- self.encoder = Encoder(channels, feature_channels, embedding_dims)
- self.decoder = Decoder(embedding_dims, feature_channels, channels)
+ self.encoder = Encoder(channels, frame_size,
+ feature_channels, embedding_dims)
+ self.decoder = Decoder(embedding_dims, feature_channels,
+ self.encoder.feature_size, channels)
def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()
diff --git a/models/layers.py b/models/layers.py
index 7b6ba5c..ef53a95 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -162,7 +162,7 @@ class BasicConv1d(nn.Module):
return self.conv(x)
-class HorizontalPyramidPooling(BasicConv2d):
+class HorizontalPyramidPooling(nn.Module):
def __init__(
self,
in_channels: int,
@@ -172,8 +172,10 @@ class HorizontalPyramidPooling(BasicConv2d):
use_max_pool: bool = False,
**kwargs
):
- super().__init__(in_channels, out_channels, kernel_size=1, **kwargs)
+ super().__init__()
self.use_1x1conv = use_1x1conv
+ if use_1x1conv:
+ self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.'
@@ -188,5 +190,5 @@ class HorizontalPyramidPooling(BasicConv2d):
elif not self.use_avg_pool and self.use_max_pool:
x = self.max_pool(x)
if self.use_1x1conv:
- x = super().forward(x)
+ x = self.conv(x)
return x
diff --git a/models/model.py b/models/model.py
index 874aa66..c7ed4b2 100644
--- a/models/model.py
+++ b/models/model.py
@@ -55,6 +55,7 @@ class Model:
self.is_train: bool = True
self.in_channels: int = 3
+ self.in_size: tuple[int, int] = (64, 48)
self.pr: Optional[int] = None
self.k: Optional[int] = None
@@ -147,7 +148,7 @@ class Model:
hpm_optim_hp = optim_hp.pop('hpm', {})
fc_optim_hp = optim_hp.pop('fc', {})
sched_hp = self.hp.get('scheduler', {})
- self.rgb_pn = RGBPartNet(self.in_channels, **model_hp,
+ self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
# Try to accelerate computation using CUDA or others
self.rgb_pn = nn.DataParallel(self.rgb_pn)
@@ -230,16 +231,16 @@ class Model:
if self.image_log_on:
i_a, i_c, i_p = images
self.writer.add_images(
+ 'Appearance image', i_a, self.curr_iter
+ )
+ self.writer.add_images(
'Canonical image', i_c, self.curr_iter
)
- for (i, (o, a, p)) in enumerate(zip(x_c1, i_a, i_p)):
+ for i, (o, p) in enumerate(zip(x_c1, i_p)):
self.writer.add_images(
f'Original image/batch {i}', o, self.curr_iter
)
self.writer.add_images(
- f'Appearance image/batch {i}', a, self.curr_iter
- )
- self.writer.add_images(
f'Pose image/batch {i}', p, self.curr_iter
)
time_used = datetime.now() - start_time
@@ -306,7 +307,7 @@ class Model:
# Init models
model_hp = self.hp.get('model', {})
- self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **model_hp)
+ self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
# Try to accelerate computation using CUDA or others
self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
@@ -467,6 +468,7 @@ class Model:
dataset_config: DatasetConfiguration
) -> Union[CASIAB]:
self.in_channels = dataset_config.get('num_input_channels', 3)
+ self.in_size = dataset_config.get('frame_size', (64, 48))
self._dataset_sig = self._make_signature(
dataset_config,
popped_keys=['root_dir', 'cache_on']
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 66609fd..bd3a31c 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -11,6 +11,7 @@ class RGBPartNet(nn.Module):
def __init__(
self,
ae_in_channels: int = 3,
+ ae_in_size: tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64),
hpm_use_1x1conv: bool = False,
@@ -33,7 +34,7 @@ class RGBPartNet(nn.Module):
self.image_log_on = image_log_on
self.ae = AutoEncoder(
- ae_in_channels, ae_feature_channels, f_a_c_p_dims
+ ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
self.pn = PartNet(
ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes,
@@ -101,7 +102,7 @@ class RGBPartNet(nn.Module):
i_a, i_c, i_p = None, None, None
if self.image_log_on:
- i_a = self._decode_appr_feature(f_a_, n, t, c, h, w, device)
+ i_a = self._decode_appr_feature(f_a_, n, t, device)
# Continue decoding canonical features
i_c = self.ae.decoder.trans_conv3(x_c)
i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
@@ -115,14 +116,14 @@ class RGBPartNet(nn.Module):
x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
return (x_c, x_p), None, None
- def _decode_appr_feature(self, f_a_, n, t, c, h, w, device):
+ def _decode_appr_feature(self, f_a_, n, t, device):
# Decode appearance features
- x_a_ = self.ae.decoder(
- f_a_,
- torch.zeros((n * t, self.f_c_dim), device=device),
- torch.zeros((n * t, self.f_p_dim), device=device)
+ f_a = f_a_.view(n, t, -1)
+ x_a = self.ae.decoder(
+ f_a.mean(1),
+ torch.zeros((n, self.f_c_dim), device=device),
+ torch.zeros((n, self.f_p_dim), device=device)
)
- x_a = x_a_.view(n, t, c, h, w)
return x_a
def _decode_cano_feature(self, f_c_, n, t, device):