diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 57 | ||||
-rw-r--r-- | models/rgb_part_net.py | 128 |
2 files changed, 107 insertions, 78 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 234111a..ac3cfdf 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -122,35 +122,23 @@ class AutoEncoder(nn.Module): self.encoder = Encoder(channels, feature_channels, embedding_dims) self.decoder = Decoder(embedding_dims, feature_channels, channels) - f_c_dim = embedding_dims[1] - self.classifier = nn.Sequential( - nn.LeakyReLU(0.2, inplace=True), - BasicLinear(f_c_dim, num_class) - ) - - self.mse_loss = nn.MSELoss() - self.xent_loss = nn.CrossEntropyLoss() - - def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y): - # t1 is random time step - (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) + if self.training: + f_c_dim = embedding_dims[1] + self.classifier = nn.Sequential( + nn.LeakyReLU(0.2, inplace=True), + BasicLinear(f_c_dim, num_class) + ) + + def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y=None): + # x_c1_t2 is the frame for later module (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) - (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2) - - x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2) - xrecon_loss_t2 = self.mse_loss(x_c1_t2, x_c1_t2_) - - y_ = self.classifier(f_c_c1_t2.contiguous()) - cano_cons_loss_t2 = (self.mse_loss(f_c_c1_t1, f_c_c1_t2) - + self.mse_loss(f_c_c1_t2, f_c_c2_t2) - + self.xent_loss(y_, y)) f_a_size, f_c_size, f_p_size = ( f_a_c1_t2.size(), f_c_c1_t2.size(), f_p_c1_t2.size() ) # Decode canonical features for HPM x_c_c1_t2 = self.decoder( - torch.zeros(f_a_size), f_c_c1_t1, torch.zeros(f_p_size), + torch.zeros(f_a_size), f_c_c1_t2, torch.zeros(f_p_size), no_trans_conv=True ) # Decode pose features for Part Net @@ -158,8 +146,23 @@ class AutoEncoder(nn.Module): torch.zeros(f_a_size), torch.zeros(f_c_size), f_p_c1_t2 ) - return ( - (x_c_c1_t2, x_p_c1_t2), - (f_p_c1_t2, f_p_c2_t2), - (xrecon_loss_t2, cano_cons_loss_t2) - ) + if self.training: + # t1 is random time step, c2 is another condition + (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) + (_, f_c_c2_t2, f_p_c2_t2) = self.encoder(x_c2_t2) + + x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2) + xrecon_loss_t2 = F.mse_loss(x_c1_t2, x_c1_t2_) + + y_ = self.classifier(f_c_c1_t2.contiguous()) + cano_cons_loss_t2 = (F.mse_loss(f_c_c1_t1, f_c_c1_t2) + + F.mse_loss(f_c_c1_t2, f_c_c2_t2) + + F.cross_entropy(y_, y)) + + return ( + (x_c_c1_t2, x_p_c1_t2), + (f_p_c1_t2, f_p_c2_t2), + (xrecon_loss_t2, cano_cons_loss_t2) + ) + else: # evaluating + return x_c_c1_t2, x_p_c1_t2 diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 0ff8251..ba5a00e 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -2,6 +2,7 @@ import random import torch import torch.nn as nn +import torch.nn.functional as F from models import AutoEncoder, HorizontalPyramidMatching, PartNet @@ -36,53 +37,16 @@ class RGBPartNet(nn.Module): hpm_use_avg_pool, hpm_use_max_pool ) - self.mse_loss = nn.MSELoss() - # TODO Weight inti here - def pose_sim_loss(self, f_p_c1: torch.Tensor, - f_p_c2: torch.Tensor) -> torch.Tensor: - f_p_c1_mean = f_p_c1.mean(dim=0) - f_p_c2_mean = f_p_c2.mean(dim=0) - return self.mse_loss(f_p_c1_mean, f_p_c2_mean) - - def forward(self, x_c1, x_c2, y): + def forward(self, x_c1, x_c2, y=None): # Step 0: Swap batch_size and time dimensions for next step # n, t, c, h, w x_c1, x_c2 = x_c1.transpose(0, 1), x_c2.transpose(0, 1) # Step 1: Disentanglement # t, n, c, h, w - num_frames = len(x_c1) - # Decoded canonical features and Pose images - x_c_c1, x_p_c1 = [], [] - # Features required to calculate losses - f_p_c1, f_p_c2 = [], [] - xrecon_loss, cano_cons_loss = torch.zeros(1), torch.zeros(1) - for t2 in range(num_frames): - t1 = random.randrange(num_frames) - output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y) - (x_c1_t2, f_p_t2, losses) = output - - # Decoded features or image - (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 - # Canonical Features for HPM - x_c_c1.append(x_c_c1_t2) - # Pose image for Part Net - x_p_c1.append(x_p_c1_t2) - - # Losses per time step - # Used in pose similarity loss - (f_p_c1_t2, f_p_c2_t2) = f_p_t2 - f_p_c1.append(f_p_c1_t2) - f_p_c2.append(f_p_c2_t2) - # Cross reconstruction loss and canonical loss - (xrecon_loss_t2, cano_cons_loss_t2) = losses - xrecon_loss += xrecon_loss_t2 - cano_cons_loss += cano_cons_loss_t2 - - x_c_c1 = torch.stack(x_c_c1) - x_p_c1 = torch.stack(x_p_c1) + ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2, y) # Step 2.a: HPM & Static Gait Feature Aggregation # t, n, c, h, w @@ -97,15 +61,77 @@ class RGBPartNet(nn.Module): # p, n, c # Step 3: Cat feature map together and calculate losses - x = torch.cat([x_c, x_p]) - # Losses - f_p_c1 = torch.stack(f_p_c1) - f_p_c2 = torch.stack(f_p_c2) - pose_sim_loss = self.pose_sim_loss(f_p_c1, f_p_c2) - cano_cons_loss /= num_frames - # TODO Implement Batch All triplet loss function - batch_all_triplet_loss = 0 - loss = (xrecon_loss + pose_sim_loss + cano_cons_loss - + batch_all_triplet_loss) - - return x, loss + x = torch.cat((x_c, x_p)) + + if self.training: + # TODO Implement Batch All triplet loss function + batch_all_triplet_loss = torch.tensor(0.) + print(*losses, batch_all_triplet_loss) + loss = torch.sum(torch.stack((*losses, batch_all_triplet_loss))) + return loss + else: + return x + + def _disentangle(self, x_c1, x_c2, y): + num_frames = len(x_c1) + # Decoded canonical features and Pose images + x_c_c1, x_p_c1 = [], [] + if self.training: + # Features required to calculate losses + f_p_c1, f_p_c2 = [], [] + xrecon_loss, cano_cons_loss = [], [] + for t2 in range(num_frames): + t1 = random.randrange(num_frames) + output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y) + (x_c1_t2, f_p_t2, losses) = output + + # Decoded features or image + (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 + # Canonical Features for HPM + x_c_c1.append(x_c_c1_t2) + # Pose image for Part Net + x_p_c1.append(x_p_c1_t2) + + # Losses per time step + # Used in pose similarity loss + (f_p_c1_t2, f_p_c2_t2) = f_p_t2 + f_p_c1.append(f_p_c1_t2) + f_p_c2.append(f_p_c2_t2) + # Cross reconstruction loss and canonical loss + (xrecon_loss_t2, cano_cons_loss_t2) = losses + xrecon_loss.append(xrecon_loss_t2) + cano_cons_loss.append(cano_cons_loss_t2) + + x_c_c1 = torch.stack(x_c_c1) + x_p_c1 = torch.stack(x_p_c1) + + # Losses + xrecon_loss = torch.sum(torch.stack(xrecon_loss)) + pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) + cano_cons_loss = torch.mean(torch.stack(cano_cons_loss)) + + return ((x_c_c1, x_p_c1), + (xrecon_loss, pose_sim_loss, cano_cons_loss)) + + else: # evaluating + for t2 in range(num_frames): + t1 = random.randrange(num_frames) + x_c1_t2 = self.ae(x_c1[t1], x_c1[t2], x_c2[t2]) + # Decoded features or image + (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 + # Canonical Features for HPM + x_c_c1.append(x_c_c1_t2) + # Pose image for Part Net + x_p_c1.append(x_p_c1_t2) + + x_c_c1 = torch.stack(x_c_c1) + x_p_c1 = torch.stack(x_p_c1) + + return (x_c_c1, x_p_c1), None + + @staticmethod + def _pose_sim_loss(f_p_c1: list[torch.Tensor], + f_p_c2: list[torch.Tensor]) -> torch.Tensor: + f_p_c1_mean = torch.stack(f_p_c1).mean(dim=0) + f_p_c2_mean = torch.stack(f_p_c2).mean(dim=0) + return F.mse_loss(f_p_c1_mean, f_p_c2_mean) |