summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py57
-rw-r--r--models/rgb_part_net.py128
2 files changed, 107 insertions, 78 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 234111a..ac3cfdf 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -122,35 +122,23 @@ class AutoEncoder(nn.Module):
self.encoder = Encoder(channels, feature_channels, embedding_dims)
self.decoder = Decoder(embedding_dims, feature_channels, channels)
- f_c_dim = embedding_dims[1]
- self.classifier = nn.Sequential(
- nn.LeakyReLU(0.2, inplace=True),
- BasicLinear(f_c_dim, num_class)
- )
-
- self.mse_loss = nn.MSELoss()
- self.xent_loss = nn.CrossEntropyLoss()
-
- def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y):
- # t1 is random time step
- (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
+ if self.training:
+ f_c_dim = embedding_dims[1]
+ self.classifier = nn.Sequential(
+ nn.LeakyReLU(0.2, inplace=True),
+ BasicLinear(f_c_dim, num_class)
+ )
+
+ def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y=None):
+ # x_c1_t2 is the frame for later module
(f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
- (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)
-
- x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)
- xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_)
-
- y_ = self.classifier(f_c_c1_t2.contiguous())
- cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2)
- + self.mse_loss(f_c_c1_t2, f_c_c2_t2)
- + self.xent_loss(y_, y))
f_a_size, f_c_size, f_p_size = (
f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size()
)
# Decode canonical features for HPM
x_c_c1_t2 = self.decoder(
- torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size),
+ torch.zeros(f_a_size), f_c_c1_t2, torch.zeros(f_p_size),
no_trans_conv=True
)
# Decode pose features for Part Net
@@ -158,8 +146,23 @@ class AutoEncoder(nn.Module):
torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2
)
- return (
- (x_c_c1_t2, x_p_c1_t2),
- (f_p_c1_t2, f_p_c2_t2),
- (xrecon_loss_t2, cano_cons_loss_t2)
- )
+ if self.training:
+ # t1 is random time step, c2 is another condition
+ (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1)
+ (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2)
+
+ x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2)
+ xrecon_loss_t2 = F.mse_loss(x_c1_t2, x_c1_t2_)
+
+ y_ = self.classifier(f_c_c1_t2.contiguous())
+ cano_cons_loss_t2 = (F.mse_loss(f_c_c1_t1, f_c_c1_t2)
+ + F.mse_loss(f_c_c1_t2, f_c_c2_t2)
+ + F.cross_entropy(y_, y))
+
+ return (
+ (x_c_c1_t2, x_p_c1_t2),
+ (f_p_c1_t2, f_p_c2_t2),
+ (xrecon_loss_t2, cano_cons_loss_t2)
+ )
+ else: # evaluating
+ return x_c_c1_t2, x_p_c1_t2
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 0ff8251..ba5a00e 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -2,6 +2,7 @@ import random
import torch
import torch.nn as nn
+import torch.nn.functional as F
from models import AutoEncoder, HorizontalPyramidMatching, PartNet
@@ -36,53 +37,16 @@ class RGBPartNet(nn.Module):
hpm_use_avg_pool, hpm_use_max_pool
)
- self.mse_loss = nn.MSELoss()
-
# TODO Weight inti here
- def pose_sim_loss(self, f_p_c1: torch.Tensor,
- f_p_c2: torch.Tensor) -> torch.Tensor:
- f_p_c1_mean = f_p_c1.mean(dim=0)
- f_p_c2_mean = f_p_c2.mean(dim=0)
- return self.mse_loss(f_p_c1_mean, f_p_c2_mean)
-
- def forward(self, x_c1, x_c2, y):
+ def forward(self, x_c1, x_c2, y=None):
# Step 0: Swap batch_size and time dimensions for next step
# n, t, c, h, w
x_c1, x_c2 = x_c1.transpose(0, 1), x_c2.transpose(0, 1)
# Step 1: Disentanglement
# t, n, c, h, w
- num_frames = len(x_c1)
- # Decoded canonical features and Pose images
- x_c_c1, x_p_c1 = [], []
- # Features required to calculate losses
- f_p_c1, f_p_c2 = [], []
- xrecon_loss, cano_cons_loss = torch.zeros(1), torch.zeros(1)
- for t2 in range(num_frames):
- t1 = random.randrange(num_frames)
- output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y)
- (x_c1_t2, f_p_t2, losses) = output
-
- # Decoded features or image
- (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
- # Canonical Features for HPM
- x_c_c1.append(x_c_c1_t2)
- # Pose image for Part Net
- x_p_c1.append(x_p_c1_t2)
-
- # Losses per time step
- # Used in pose similarity loss
- (f_p_c1_t2, f_p_c2_t2) = f_p_t2
- f_p_c1.append(f_p_c1_t2)
- f_p_c2.append(f_p_c2_t2)
- # Cross reconstruction loss and canonical loss
- (xrecon_loss_t2, cano_cons_loss_t2) = losses
- xrecon_loss += xrecon_loss_t2
- cano_cons_loss += cano_cons_loss_t2
-
- x_c_c1 = torch.stack(x_c_c1)
- x_p_c1 = torch.stack(x_p_c1)
+ ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2, y)
# Step 2.a: HPM & Static Gait Feature Aggregation
# t, n, c, h, w
@@ -97,15 +61,77 @@ class RGBPartNet(nn.Module):
# p, n, c
# Step 3: Cat feature map together and calculate losses
- x = torch.cat([x_c, x_p])
- # Losses
- f_p_c1 = torch.stack(f_p_c1)
- f_p_c2 = torch.stack(f_p_c2)
- pose_sim_loss = self.pose_sim_loss(f_p_c1, f_p_c2)
- cano_cons_loss /= num_frames
- # TODO Implement Batch All triplet loss function
- batch_all_triplet_loss = 0
- loss = (xrecon_loss + pose_sim_loss + cano_cons_loss
- + batch_all_triplet_loss)
-
- return x, loss
+ x = torch.cat((x_c, x_p))
+
+ if self.training:
+ # TODO Implement Batch All triplet loss function
+ batch_all_triplet_loss = torch.tensor(0.)
+ print(*losses, batch_all_triplet_loss)
+ loss = torch.sum(torch.stack((*losses, batch_all_triplet_loss)))
+ return loss
+ else:
+ return x
+
+ def _disentangle(self, x_c1, x_c2, y):
+ num_frames = len(x_c1)
+ # Decoded canonical features and Pose images
+ x_c_c1, x_p_c1 = [], []
+ if self.training:
+ # Features required to calculate losses
+ f_p_c1, f_p_c2 = [], []
+ xrecon_loss, cano_cons_loss = [], []
+ for t2 in range(num_frames):
+ t1 = random.randrange(num_frames)
+ output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y)
+ (x_c1_t2, f_p_t2, losses) = output
+
+ # Decoded features or image
+ (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
+ # Canonical Features for HPM
+ x_c_c1.append(x_c_c1_t2)
+ # Pose image for Part Net
+ x_p_c1.append(x_p_c1_t2)
+
+ # Losses per time step
+ # Used in pose similarity loss
+ (f_p_c1_t2, f_p_c2_t2) = f_p_t2
+ f_p_c1.append(f_p_c1_t2)
+ f_p_c2.append(f_p_c2_t2)
+ # Cross reconstruction loss and canonical loss
+ (xrecon_loss_t2, cano_cons_loss_t2) = losses
+ xrecon_loss.append(xrecon_loss_t2)
+ cano_cons_loss.append(cano_cons_loss_t2)
+
+ x_c_c1 = torch.stack(x_c_c1)
+ x_p_c1 = torch.stack(x_p_c1)
+
+ # Losses
+ xrecon_loss = torch.sum(torch.stack(xrecon_loss))
+ pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2)
+ cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
+
+ return ((x_c_c1, x_p_c1),
+ (xrecon_loss, pose_sim_loss, cano_cons_loss))
+
+ else: # evaluating
+ for t2 in range(num_frames):
+ t1 = random.randrange(num_frames)
+ x_c1_t2 = self.ae(x_c1[t1], x_c1[t2], x_c2[t2])
+ # Decoded features or image
+ (x_c_c1_t2, x_p_c1_t2) = x_c1_t2
+ # Canonical Features for HPM
+ x_c_c1.append(x_c_c1_t2)
+ # Pose image for Part Net
+ x_p_c1.append(x_p_c1_t2)
+
+ x_c_c1 = torch.stack(x_c_c1)
+ x_p_c1 = torch.stack(x_p_c1)
+
+ return (x_c_c1, x_p_c1), None
+
+ @staticmethod
+ def _pose_sim_loss(f_p_c1: list[torch.Tensor],
+ f_p_c2: list[torch.Tensor]) -> torch.Tensor:
+ f_p_c1_mean = torch.stack(f_p_c1).mean(dim=0)
+ f_p_c2_mean = torch.stack(f_p_c2).mean(dim=0)
+ return F.mse_loss(f_p_c1_mean, f_p_c2_mean)