summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.py24
-rw-r--r--models/auto_encoder.py4
-rw-r--r--models/model.py3
-rw-r--r--models/part_net.py18
-rw-r--r--models/rgb_part_net.py42
5 files changed, 37 insertions, 54 deletions
diff --git a/config.py b/config.py
index 424bf5b..03f2f0d 100644
--- a/config.py
+++ b/config.py
@@ -37,7 +37,7 @@ config: Configuration = {
# Batch size (pr, k)
# `pr` denotes number of persons
# `k` denotes number of sequences per person
- 'batch_size': (4, 8),
+ 'batch_size': (4, 6),
# Number of workers of Dataloader
'num_workers': 4,
# Faster data transfer from RAM to GPU if enabled
@@ -49,22 +49,14 @@ config: Configuration = {
# Auto-encoder feature channels coefficient
'ae_feature_channels': 64,
# Appearance, canonical and pose feature dimensions
- 'f_a_c_p_dims': (128, 128, 64),
+ 'f_a_c_p_dims': (192, 192, 96),
# Use 1x1 convolution in dimensionality reduction
'hpm_use_1x1conv': False,
# HPM pyramid scales, of which sum is number of parts
'hpm_scales': (1, 2, 4),
# Global pooling method
'hpm_use_avg_pool': True,
- 'hpm_use_max_pool': False,
- # FConv feature channels coefficient
- 'fpfe_feature_channels': 32,
- # FConv blocks kernel sizes
- 'fpfe_kernel_sizes': ((5, 3), (3, 3), (3, 3)),
- # FConv blocks paddings
- 'fpfe_paddings': ((2, 1), (1, 1), (1, 1)),
- # FConv blocks halving
- 'fpfe_halving': (0, 2, 3),
+ 'hpm_use_max_pool': True,
# Attention squeeze ratio
'tfa_squeeze_ratio': 4,
# Number of parts after Part Net
@@ -72,7 +64,7 @@ config: Configuration = {
# Embedding dimension for each part
'embedding_dims': 256,
# Triplet loss margins for HPM and PartNet
- 'triplet_margins': (0.2, 0.2),
+ 'triplet_margins': (1.5, 1.5),
},
'optimizer': {
# Global parameters
@@ -91,15 +83,15 @@ config: Configuration = {
# 'amsgrad': False,
# Local parameters (override global ones)
- 'auto_encoder': {
- 'weight_decay': 0.001
- },
+ # 'auto_encoder': {
+ # 'weight_decay': 0.001
+ # },
},
'scheduler': {
# Period of learning rate decay
'step_size': 500,
# Multiplicative factor of decay
- 'gamma': 0.9,
+ 'gamma': 1,
}
},
# Model metadata
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 1ef7494..e6a3e60 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -106,14 +106,14 @@ class Decoder(nn.Module):
self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
- def forward(self, f_appearance, f_canonical, f_pose, cano_only=False):
+ def forward(self, f_appearance, f_canonical, f_pose, is_feature_map=False):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.fc(x)
x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0)
x = F.relu(x, inplace=True)
x = self.trans_conv1(x)
x = self.trans_conv2(x)
- if cano_only:
+ if is_feature_map:
return x
x = self.trans_conv3(x)
x = torch.sigmoid(self.trans_conv4(x))
diff --git a/models/model.py b/models/model.py
index a69a9b0..a086e7b 100644
--- a/models/model.py
+++ b/models/model.py
@@ -316,7 +316,8 @@ class Model:
)
# Init models
- model_hp = self.hp.get('model', {})
+ model_hp: dict = self.hp.get('model', {}).copy()
+ model_hp.pop('triplet_margins', None)
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
# Try to accelerate computation using CUDA or others
self.rgb_pn = nn.DataParallel(self.rgb_pn)
diff --git a/models/part_net.py b/models/part_net.py
index 62a2bac..29cf9cd 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -110,32 +110,22 @@ class TemporalFeatureAggregator(nn.Module):
class PartNet(nn.Module):
def __init__(
self,
- in_channels: int = 3,
- feature_channels: int = 32,
- kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)),
- paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)),
- halving: tuple[int, ...] = (0, 2, 3),
+ in_channels: int = 128,
squeeze_ratio: int = 4,
num_part: int = 16
):
super().__init__()
self.num_part = num_part
- self.fpfe = FrameLevelPartFeatureExtractor(
- in_channels, feature_channels, kernel_sizes, paddings, halving
- )
-
- num_fconv_blocks = len(self.fpfe.fconv_blocks)
- self.tfa_in_channels = feature_channels * 2 ** (num_fconv_blocks - 1)
self.tfa = TemporalFeatureAggregator(
- self.tfa_in_channels, squeeze_ratio, self.num_part
+ in_channels, squeeze_ratio, self.num_part
)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
def forward(self, x):
- n, t, _, _, _ = x.size()
- x = self.fpfe(x)
+ n, t, c, h, w = x.size()
+ x = x.view(n * t, c, h, w)
# n * t x c x h x w
# Horizontal Pooling
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 408bca0..4367c62 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -17,16 +17,13 @@ class RGBPartNet(nn.Module):
hpm_scales: tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
- fpfe_feature_channels: int = 32,
- fpfe_kernel_sizes: tuple[tuple, ...] = ((5, 3), (3, 3), (3, 3)),
- fpfe_paddings: tuple[tuple, ...] = ((2, 1), (1, 1), (1, 1)),
- fpfe_halving: tuple[int, ...] = (0, 2, 3),
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
embedding_dims: int = 256,
image_log_on: bool = False
):
super().__init__()
+ self.h, self.w = ae_in_size
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
self.hpm_num_parts = sum(hpm_scales)
self.image_log_on = image_log_on
@@ -34,18 +31,17 @@ class RGBPartNet(nn.Module):
self.ae = AutoEncoder(
ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
+ self.pn_in_channels = ae_feature_channels * 2
self.pn = PartNet(
- ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes,
- fpfe_paddings, fpfe_halving, tfa_squeeze_ratio, tfa_num_parts
+ self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts
)
- out_channels = self.pn.tfa_in_channels
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 2, out_channels, hpm_use_1x1conv,
+ ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv,
hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
)
self.num_total_parts = self.hpm_num_parts + tfa_num_parts
empty_fc = torch.empty(self.num_total_parts,
- out_channels, embedding_dims)
+ self.pn_in_channels, embedding_dims)
self.fc_mat = nn.Parameter(empty_fc)
def fc(self, x):
@@ -78,28 +74,32 @@ class RGBPartNet(nn.Module):
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()
device = x_c1_t2.device
- x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
if self.training:
+ x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
# Decode features
- with torch.no_grad():
- x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ x_c = self._decode_cano_feature(f_c_, n, t, device)
+ x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4)
- i_a, i_c, i_p = None, None, None
- if self.image_log_on:
+ i_a, i_c, i_p = None, None, None
+ if self.image_log_on:
+ with torch.no_grad():
i_a = self._decode_appr_feature(f_a_, n, t, device)
# Continue decoding canonical features
i_c = self.ae.decoder.trans_conv3(x_c)
i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
- i_p = x_p
+ i_p_ = self.ae.decoder.trans_conv3(x_p_)
+ i_p_ = torch.sigmoid(self.ae.decoder.trans_conv4(i_p_))
+ i_p = i_p_.view(n, t, c, h, w)
return (x_c, x_p), losses, (i_a, i_c, i_p)
else: # evaluating
f_c_, f_p_ = self.ae(x_c1_t2)
x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4)
return (x_c, x_p), None, None
def _decode_appr_feature(self, f_a_, n, t, device):
@@ -119,7 +119,7 @@ class RGBPartNet(nn.Module):
torch.zeros((n, self.f_a_dim), device=device),
f_c.mean(1),
torch.zeros((n, self.f_p_dim), device=device),
- cano_only=True
+ is_feature_map=True
)
return x_c
@@ -128,7 +128,7 @@ class RGBPartNet(nn.Module):
x_p_ = self.ae.decoder(
torch.zeros((n * t, self.f_a_dim), device=device),
torch.zeros((n * t, self.f_c_dim), device=device),
- f_p_
+ f_p_,
+ is_feature_map=True
)
- x_p = x_p_.view(n, t, c, h, w)
- return x_p
+ return x_p_