summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py98
1 files changed, 47 insertions, 51 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 797e02b..1c7a1a2 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -2,6 +2,7 @@ from typing import Tuple
import torch
import torch.nn as nn
+import torch.nn.functional as F
from models.auto_encoder import AutoEncoder
@@ -16,6 +17,7 @@ class RGBPartNet(nn.Module):
image_log_on: bool = False
):
super().__init__()
+ self.h, self.w = ae_in_size
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
self.image_log_on = image_log_on
@@ -24,70 +26,64 @@ class RGBPartNet(nn.Module):
)
def forward(self, x_c1, x_c2=None):
- # Step 1: Disentanglement
- # n, t, c, h, w
- ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2)
+ losses, features, images = self._disentangle(x_c1, x_c2)
if self.training:
losses = torch.stack(losses)
- return losses, images
+ return losses, features, images
else:
- return x_c, x_p
+ return features
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()
- device = x_c1_t2.device
- x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
if self.training:
+ x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
- # Decode features
- with torch.no_grad():
- x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ f_a = f_a_.view(n, t, -1)
+ f_c = f_c_.view(n, t, -1)
+ f_p = f_p_.view(n, t, -1)
- i_a, i_c, i_p = None, None, None
- if self.image_log_on:
- i_a = self._decode_appr_feature(f_a_, n, t, device)
- # Continue decoding canonical features
- i_c = self.ae.decoder.trans_conv3(x_c)
- i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
- i_p = x_p
+ i_a, i_c, i_p = None, None, None
+ if self.image_log_on:
+ with torch.no_grad():
+ x_a, i_a = self._separate_decode(
+ f_a.mean(1),
+ torch.zeros_like(f_c[:, 0, :]),
+ torch.zeros_like(f_p[:, 0, :])
+ )
+ x_c, i_c = self._separate_decode(
+ torch.zeros_like(f_a[:, 0, :]),
+ f_c.mean(1),
+ torch.zeros_like(f_p[:, 0, :]),
+ )
+ x_p_, i_p_ = self._separate_decode(
+ torch.zeros_like(f_a_),
+ torch.zeros_like(f_c_),
+ f_p_
+ )
+ x_p = tuple(_x_p.view(n, t, *_x_p.size()[1:]) for _x_p in x_p_)
+ i_p = i_p_.view(n, t, c, h, w)
- return (x_c, x_p), losses, (i_a, i_c, i_p)
+ return losses, (x_a, x_c, x_p), (i_a, i_c, i_p)
else: # evaluating
f_c_, f_p_ = self.ae(x_c1_t2)
- x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
- return (x_c, x_p), None, None
+ f_c = f_c_.view(n, t, -1)
+ f_p = f_p_.view(n, t, -1)
+ return (f_c, f_p), None, None
- def _decode_appr_feature(self, f_a_, n, t, device):
- # Decode appearance features
- f_a = f_a_.view(n, t, -1)
- x_a = self.ae.decoder(
- f_a.mean(1),
- torch.zeros((n, self.f_c_dim), device=device),
- torch.zeros((n, self.f_p_dim), device=device)
+ def _separate_decode(self, f_a, f_c, f_p):
+ x_1 = torch.cat((f_a, f_c, f_p), dim=1)
+ x_1 = self.ae.decoder.fc(x_1).view(
+ -1,
+ self.ae.decoder.feature_channels * 8,
+ self.ae.decoder.h_0,
+ self.ae.decoder.w_0
)
- return x_a
-
- def _decode_cano_feature(self, f_c_, n, t, device):
- # Decode average canonical features to higher dimension
- f_c = f_c_.view(n, t, -1)
- x_c = self.ae.decoder(
- torch.zeros((n, self.f_a_dim), device=device),
- f_c.mean(1),
- torch.zeros((n, self.f_p_dim), device=device),
- cano_only=True
- )
- return x_c
-
- def _decode_pose_feature(self, f_p_, n, t, c, h, w, device):
- # Decode pose features to images
- x_p_ = self.ae.decoder(
- torch.zeros((n * t, self.f_a_dim), device=device),
- torch.zeros((n * t, self.f_c_dim), device=device),
- f_p_
- )
- x_p = x_p_.view(n, t, c, h, w)
- return x_p
+ x_1 = F.relu(x_1, inplace=True)
+ x_2 = self.ae.decoder.trans_conv1(x_1)
+ x_3 = self.ae.decoder.trans_conv2(x_2)
+ x_4 = self.ae.decoder.trans_conv3(x_3)
+ image = torch.sigmoid(self.ae.decoder.trans_conv4(x_4))
+ x = (x_1, x_2, x_3, x_4)
+ return x, image