From 8f5bef7f3d10ba0994ce51d9f84100c26218d6ee Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 23 Jan 2021 13:44:12 +0800 Subject: Transform all frames together in evaluation --- models/rgb_part_net.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index e707c26..2cc0958 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -86,15 +86,15 @@ class RGBPartNet(nn.Module): return x.unsqueeze(1).view(-1) def _disentangle(self, x_c1, x_c2=None, y=None): - num_frames = len(x_c1) - # Decoded canonical features and Pose images - x_c_c1, x_p_c1 = [], [] + t, n, c, h, w = x_c1.size() if self.training: + # Decoded canonical features and Pose images + x_c_c1, x_p_c1 = [], [] # Features required to calculate losses f_p_c1, f_p_c2 = [], [] xrecon_loss, cano_cons_loss = [], [] - for t2 in range(num_frames): - t1 = random.randrange(num_frames) + for t2 in range(t): + t1 = random.randrange(t) output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y) (x_c1_t2, f_p_t2, losses) = output @@ -127,17 +127,11 @@ class RGBPartNet(nn.Module): (xrecon_loss, pose_sim_loss, cano_cons_loss)) else: # evaluating - for t2 in range(num_frames): - x_c1_t2 = self.ae(x_c1[t2]) - # Decoded features or image - (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 - # Canonical Features for HPM - x_c_c1.append(x_c_c1_t2) - # Pose image for Part Net - x_p_c1.append(x_p_c1_t2) - - x_c_c1 = torch.stack(x_c_c1) - x_p_c1 = torch.stack(x_p_c1) + x_c1 = x_c1.view(-1, c, h, w) + x_c_c1, x_p_c1 = self.ae(x_c1) + _, c_c, h_c, w_c = x_c_c1.size() + x_c_c1 = x_c_c1.view(t, n, c_c, h_c, w_c) + x_p_c1 = x_p_c1.view(t, n, c, h, w) return (x_c_c1, x_p_c1), None -- cgit v1.2.3 From 59ccf61fed4d95b7fe91bb9552f0deb2f2c75b76 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 23 Jan 2021 16:04:35 +0800 Subject: Add late start support for non-disentangling parts --- config.py | 2 ++ models/model.py | 24 +++++++++++++++++------- utils/configuration.py | 1 + 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/config.py b/config.py index 991a3a6..641e8fb 100644 --- a/config.py +++ b/config.py @@ -70,6 +70,8 @@ config: Configuration = { }, 'optimizer': { # Global parameters + # Iteration start to optimize non-disentangling parts + # 'start_iter': 10, # Initial learning rate of Adam Optimizer 'lr': 1e-4, # Coefficients used for computing running averages of diff --git a/models/model.py b/models/model.py index 6b799ad..cccb6c4 100644 --- a/models/model.py +++ b/models/model.py @@ -141,6 +141,7 @@ class Model: # Prepare for model, optimizer and scheduler model_hp = self.hp.get('model', {}) optim_hp: dict = self.hp.get('optimizer', {}).copy() + start_iter = optim_hp.pop('start_iter', 0) ae_optim_hp = optim_hp.pop('auto_encoder', {}) pn_optim_hp = optim_hp.pop('part_net', {}) hpm_optim_hp = optim_hp.pop('hpm', {}) @@ -151,9 +152,6 @@ class Model: self.rgb_pn = self.rgb_pn.to(self.device) self.optimizer = optim.Adam([ {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, - {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.fc_mat, **fc_optim_hp}, ], **optim_hp) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, **sched_hp) self.writer = SummaryWriter(self._log_name) @@ -173,8 +171,18 @@ class Model: start_time = datetime.now() running_loss = torch.zeros(4).to(self.device) print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}", - f"{'CanoCons':^8} {'BATrip':^8} {'LR':^9}") + f"{'CanoCons':^8} {'BATrip':^8} LR(s)") for (batch_c1, batch_c2) in dataloader: + if self.curr_iter == start_iter: + self.optimizer.add_param_group( + {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp} + ) + self.optimizer.add_param_group( + {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp} + ) + self.optimizer.add_param_group( + {'params': self.rgb_pn.fc_mat, **fc_optim_hp} + ) self.curr_iter += 1 # Zero the parameter gradients self.optimizer.zero_grad() @@ -186,8 +194,6 @@ class Model: loss = losses.sum() loss.backward() self.optimizer.step() - # Step scheduler - self.scheduler.step() # Statistics and checkpoint running_loss += losses.detach() @@ -199,11 +205,15 @@ class Model: ], losses)), self.curr_iter) if self.curr_iter % 100 == 0: + lrs = self.scheduler.get_last_lr() print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', '{:f} {:f} {:f} {:f}'.format(*running_loss / 100), - f'{self.scheduler.get_last_lr()[0]:.3e}') + ' '.join(('{:.3e}'.format(lr) for lr in lrs))) running_loss.zero_() + # Step scheduler + self.scheduler.step() + if self.curr_iter % 1000 == 0: torch.save({ 'iter': self.curr_iter, diff --git a/utils/configuration.py b/utils/configuration.py index 8b265e8..c4c4b4d 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -52,6 +52,7 @@ class SubOptimizerHPConfiguration(TypedDict): class OptimizerHPConfiguration(TypedDict): + start_iter: int lr: int betas: tuple[float, float] eps: float -- cgit v1.2.3 From 507e1d163aaa6ea4be23e7f08ff6ce0ef58c830b Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 23 Jan 2021 22:19:51 +0800 Subject: Remove the third term in canonical consistency loss --- models/auto_encoder.py | 22 ++++++++-------------- models/model.py | 4 +--- models/rgb_part_net.py | 9 ++++----- 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 36be868..35cb629 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -113,7 +113,6 @@ class Decoder(nn.Module): class AutoEncoder(nn.Module): def __init__( self, - num_class: int = 74, channels: int = 3, feature_channels: int = 64, embedding_dims: tuple[int, int, int] = (128, 128, 64) @@ -122,25 +121,23 @@ class AutoEncoder(nn.Module): self.encoder = Encoder(channels, feature_channels, embedding_dims) self.decoder = Decoder(embedding_dims, feature_channels, channels) - f_c_dim = embedding_dims[1] - self.classifier = nn.Sequential( - nn.LeakyReLU(0.2, inplace=True), - BasicLinear(f_c_dim, num_class) - ) - - def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None, y=None): + def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None): # x_c1_t2 is the frame for later module (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) with torch.no_grad(): # Decode canonical features for HPM x_c_c1_t2 = self.decoder( - torch.zeros_like(f_a_c1_t2), f_c_c1_t2, torch.zeros_like(f_p_c1_t2), + torch.zeros_like(f_a_c1_t2), + f_c_c1_t2, + torch.zeros_like(f_p_c1_t2), no_trans_conv=True ) # Decode pose features for Part Net x_p_c1_t2 = self.decoder( - torch.zeros_like(f_a_c1_t2), torch.zeros_like(f_c_c1_t2), f_p_c1_t2 + torch.zeros_like(f_a_c1_t2), + torch.zeros_like(f_c_c1_t2), + f_p_c1_t2 ) if self.training: @@ -150,11 +147,8 @@ class AutoEncoder(nn.Module): x_c1_t2_ = self.decoder(f_a_c1_t1, f_c_c1_t1, f_p_c1_t2) xrecon_loss_t2 = F.mse_loss(x_c1_t2, x_c1_t2_) - - y_ = self.classifier(f_c_c1_t2.contiguous()) cano_cons_loss_t2 = (F.mse_loss(f_c_c1_t1, f_c_c1_t2) - + F.mse_loss(f_c_c1_t2, f_c_c2_t2) - + F.cross_entropy(y_, y)) + + F.mse_loss(f_c_c1_t2, f_c_c2_t2)) return ( (x_c_c1_t2, x_p_c1_t2), diff --git a/models/model.py b/models/model.py index cccb6c4..ddb715d 100644 --- a/models/model.py +++ b/models/model.py @@ -54,7 +54,6 @@ class Model: self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000)) self.is_train: bool = True - self.train_size: int = 74 self.in_channels: int = 3 self.pr: Optional[int] = None self.k: Optional[int] = None @@ -147,7 +146,7 @@ class Model: hpm_optim_hp = optim_hp.pop('hpm', {}) fc_optim_hp = optim_hp.pop('fc', {}) sched_hp = self.hp.get('scheduler', {}) - self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **model_hp) + self.rgb_pn = RGBPartNet(self.in_channels, **model_hp) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.optimizer = optim.Adam([ @@ -409,7 +408,6 @@ class Model: self, dataset_config: DatasetConfiguration ) -> Union[CASIAB]: - self.train_size = dataset_config.get('train_size', 74) self.in_channels = dataset_config.get('num_input_channels', 3) self._dataset_sig = self._make_signature( dataset_config, diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 2cc0958..755d5dc 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,7 +13,6 @@ from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): def __init__( self, - num_class: int = 74, ae_in_channels: int = 3, ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), @@ -31,7 +30,7 @@ class RGBPartNet(nn.Module): ): super().__init__() self.ae = AutoEncoder( - num_class, ae_in_channels, ae_feature_channels, f_a_c_p_dims + ae_in_channels, ae_feature_channels, f_a_c_p_dims ) self.pn = PartNet( ae_in_channels, fpfe_feature_channels, fpfe_kernel_sizes, @@ -60,7 +59,7 @@ class RGBPartNet(nn.Module): # Step 1: Disentanglement # t, n, c, h, w - ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2, y) + ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2) # Step 2.a: HPM & Static Gait Feature Aggregation # t, n, c, h, w @@ -85,7 +84,7 @@ class RGBPartNet(nn.Module): else: return x.unsqueeze(1).view(-1) - def _disentangle(self, x_c1, x_c2=None, y=None): + def _disentangle(self, x_c1, x_c2=None): t, n, c, h, w = x_c1.size() if self.training: # Decoded canonical features and Pose images @@ -95,7 +94,7 @@ class RGBPartNet(nn.Module): xrecon_loss, cano_cons_loss = [], [] for t2 in range(t): t1 = random.randrange(t) - output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y) + output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2]) (x_c1_t2, f_p_t2, losses) = output # Decoded features or image -- cgit v1.2.3 From 99ddd7c142a4ec97cb8bd14b204651790b3cf4ee Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 8 Feb 2021 18:11:25 +0800 Subject: Code refactoring, modifications and new features 1. Decode features outside of auto-encoder 2. Turn off HPM 1x1 conv by default 3. Change canonical feature map size from `feature_channels * 8 x 4 x 2` to `feature_channels * 2 x 16 x 8` 4. Use mean of canonical embeddings instead of mean of static features 5. Calculate static and dynamic loss separately 6. Calculate mean of parts in triplet loss instead of sum of parts 7. Add switch to log disentangled images 8. Change default configuration --- config.py | 12 +++-- models/auto_encoder.py | 26 ++------- models/hpm.py | 20 +++---- models/layers.py | 10 ++-- models/model.py | 34 +++++++++--- models/rgb_part_net.py | 141 +++++++++++++++++++++++++++++++++++-------------- utils/configuration.py | 4 +- utils/triplet_loss.py | 2 +- 8 files changed, 160 insertions(+), 89 deletions(-) diff --git a/config.py b/config.py index 641e8fb..04a22b9 100644 --- a/config.py +++ b/config.py @@ -8,6 +8,8 @@ config: Configuration = { 'CUDA_VISIBLE_DEVICES': '0', # Directory used in training or testing for temporary storage 'save_dir': 'runs', + # Recorde disentangled image or not + 'image_log_on': False }, # Dataset settings 'dataset': { @@ -46,11 +48,13 @@ config: Configuration = { 'ae_feature_channels': 64, # Appearance, canonical and pose feature dimensions 'f_a_c_p_dims': (128, 128, 64), + # Use 1x1 convolution in dimensionality reduction + 'hpm_use_1x1conv': False, # HPM pyramid scales, of which sum is number of parts 'hpm_scales': (1, 2, 4), # Global pooling method 'hpm_use_avg_pool': True, - 'hpm_use_max_pool': True, + 'hpm_use_max_pool': False, # FConv feature channels coefficient 'fpfe_feature_channels': 32, # FConv blocks kernel sizes @@ -65,13 +69,13 @@ config: Configuration = { 'tfa_num_parts': 16, # Embedding dimension for each part 'embedding_dims': 256, - # Triplet loss margin - 'triplet_margin': 0.2, + # Triplet loss margins for HPM and PartNet + 'triplet_margins': (0.2, 0.2), }, 'optimizer': { # Global parameters # Iteration start to optimize non-disentangling parts - # 'start_iter': 10, + # 'start_iter': 0, # Initial learning rate of Adam Optimizer 'lr': 1e-4, # Coefficients used for computing running averages of diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 35cb629..f04ffdb 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -95,15 +95,14 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose, no_trans_conv=False): + def forward(self, f_appearance, f_canonical, f_pose, cano_only=False): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = F.relu(x.view(-1, self.feature_channels * 8, 4, 2), inplace=True) - # Decode canonical features without transpose convolutions - if no_trans_conv: - return x x = self.trans_conv1(x) x = self.trans_conv2(x) + if cano_only: + return x x = self.trans_conv3(x) x = torch.sigmoid(self.trans_conv4(x)) @@ -125,21 +124,6 @@ class AutoEncoder(nn.Module): # x_c1_t2 is the frame for later module (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) - with torch.no_grad(): - # Decode canonical features for HPM - x_c_c1_t2 = self.decoder( - torch.zeros_like(f_a_c1_t2), - f_c_c1_t2, - torch.zeros_like(f_p_c1_t2), - no_trans_conv=True - ) - # Decode pose features for Part Net - x_p_c1_t2 = self.decoder( - torch.zeros_like(f_a_c1_t2), - torch.zeros_like(f_c_c1_t2), - f_p_c1_t2 - ) - if self.training: # t1 is random time step, c2 is another condition (f_a_c1_t1, f_c_c1_t1, _) = self.encoder(x_c1_t1) @@ -151,9 +135,9 @@ class AutoEncoder(nn.Module): + F.mse_loss(f_c_c1_t2, f_c_c2_t2)) return ( - (x_c_c1_t2, x_p_c1_t2), + (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2), (f_p_c1_t2, f_p_c2_t2), (xrecon_loss_t2, cano_cons_loss_t2) ) else: # evaluating - return x_c_c1_t2, x_p_c1_t2 + return f_c_c1_t2, f_p_c1_t2 diff --git a/models/hpm.py b/models/hpm.py index 66503e3..9879cfb 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -9,14 +9,16 @@ class HorizontalPyramidMatching(nn.Module): self, in_channels: int, out_channels: int = 128, + use_1x1conv: bool = False, scales: tuple[int, ...] = (1, 2, 4), use_avg_pool: bool = True, - use_max_pool: bool = True, + use_max_pool: bool = False, **kwargs ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels + self.use_1x1conv = use_1x1conv self.scales = scales self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool @@ -29,6 +31,7 @@ class HorizontalPyramidMatching(nn.Module): pyramid = nn.ModuleList([ HorizontalPyramidPooling(self.in_channels, self.out_channels, + use_1x1conv=self.use_1x1conv, use_avg_pool=self.use_avg_pool, use_max_pool=self.use_max_pool, **kwargs) @@ -37,23 +40,16 @@ class HorizontalPyramidMatching(nn.Module): return pyramid def forward(self, x): - # Flatten canonical features in all batches - t, n, c, h, w = x.size() - x = x.view(t * n, c, h, w) - + n, c, h, w = x.size() feature = [] - for pyramid_index, pyramid in enumerate(self.pyramids): - h_per_hpp = h // self.scales[pyramid_index] + for scale, pyramid in zip(self.scales, self.pyramids): + h_per_hpp = h // scale for hpp_index, hpp in enumerate(pyramid): h_filter = torch.arange(hpp_index * h_per_hpp, (hpp_index + 1) * h_per_hpp) x_slice = x[:, :, h_filter, :] x_slice = hpp(x_slice) - x_slice = x_slice.view(t * n, -1) + x_slice = x_slice.view(n, -1) feature.append(x_slice) x = torch.stack(feature) - - # Unfold frames to original batch - p, _, c = x.size() - x = x.view(p, t, n, c) return x diff --git a/models/layers.py b/models/layers.py index a9f04b3..7b6ba5c 100644 --- a/models/layers.py +++ b/models/layers.py @@ -167,12 +167,13 @@ class HorizontalPyramidPooling(BasicConv2d): self, in_channels: int, out_channels: int, - kernel_size: Union[int, tuple[int, int]] = 1, + use_1x1conv: bool = False, use_avg_pool: bool = True, - use_max_pool: bool = True, + use_max_pool: bool = False, **kwargs ): - super().__init__(in_channels, out_channels, kernel_size, **kwargs) + super().__init__(in_channels, out_channels, kernel_size=1, **kwargs) + self.use_1x1conv = use_1x1conv self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.' @@ -186,5 +187,6 @@ class HorizontalPyramidPooling(BasicConv2d): x = self.avg_pool(x) elif not self.use_avg_pool and self.use_max_pool: x = self.max_pool(x) - x = super().forward(x) + if self.use_1x1conv: + x = super().forward(x) return x diff --git a/models/model.py b/models/model.py index ddb715d..0418070 100644 --- a/models/model.py +++ b/models/model.py @@ -69,6 +69,7 @@ class Model: self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None + self.image_log_on = system_config.get('image_log_on', False) self.CASIAB_GALLERY_SELECTOR = { 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})} @@ -146,7 +147,8 @@ class Model: hpm_optim_hp = optim_hp.pop('hpm', {}) fc_optim_hp = optim_hp.pop('fc', {}) sched_hp = self.hp.get('scheduler', {}) - self.rgb_pn = RGBPartNet(self.in_channels, **model_hp) + self.rgb_pn = RGBPartNet(self.in_channels, **model_hp, + image_log_on=self.image_log_on) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.optimizer = optim.Adam([ @@ -168,9 +170,9 @@ class Model: # Training start start_time = datetime.now() - running_loss = torch.zeros(4).to(self.device) + running_loss = torch.zeros(5, device=self.device) print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}", - f"{'CanoCons':^8} {'BATrip':^8} LR(s)") + f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} LR(s)") for (batch_c1, batch_c2) in dataloader: if self.curr_iter == start_iter: self.optimizer.add_param_group( @@ -189,7 +191,7 @@ class Model: x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) y = batch_c1['label'].to(self.device) - losses = self.rgb_pn(x_c1, x_c2, y) + losses, images = self.rgb_pn(x_c1, x_c2, y) loss = losses.sum() loss.backward() self.optimizer.step() @@ -200,13 +202,33 @@ class Model: self.writer.add_scalar('Loss/all', loss, self.curr_iter) self.writer.add_scalars('Loss/details', dict(zip([ 'Cross reconstruction loss', 'Pose similarity loss', - 'Canonical consistency loss', 'Batch All triplet loss' + 'Canonical consistency loss', 'Batch All triplet loss (HPM)', + 'Batch All triplet loss (PartNet)' ], losses)), self.curr_iter) + if self.image_log_on: + (appearance_image, canonical_image, pose_image) = images + self.writer.add_images( + 'Canonical image', canonical_image, self.curr_iter + ) + for i in range(self.pr * self.k): + self.writer.add_images( + f'Original image/batch {i}', x_c1[i], self.curr_iter + ) + self.writer.add_images( + f'Appearance image/batch {i}', + appearance_image[:, i, :, :, :], + self.curr_iter + ) + self.writer.add_images( + f'Pose image/batch {i}', + pose_image[:, i, :, :, :], + self.curr_iter + ) if self.curr_iter % 100 == 0: lrs = self.scheduler.get_last_lr() print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', - '{:f} {:f} {:f} {:f}'.format(*running_loss / 100), + '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), ' '.join(('{:.3e}'.format(lr) for lr in lrs))) running_loss.zero_() diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 755d5dc..0e7d8b3 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -16,6 +16,7 @@ class RGBPartNet(nn.Module): ae_in_channels: int = 3, ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), + hpm_use_1x1conv: bool = False, hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, @@ -26,9 +27,14 @@ class RGBPartNet(nn.Module): tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, embedding_dims: int = 256, - triplet_margin: float = 0.2 + triplet_margins: tuple[float, float] = (0.2, 0.2), + image_log_on: bool = False ): super().__init__() + (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims + self.hpm_num_parts = sum(hpm_scales) + self.image_log_on = image_log_on + self.ae = AutoEncoder( ae_in_channels, ae_feature_channels, f_a_c_p_dims ) @@ -38,14 +44,16 @@ class RGBPartNet(nn.Module): ) out_channels = self.pn.tfa_in_channels self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 8, out_channels, hpm_scales, - hpm_use_avg_pool, hpm_use_max_pool + ae_feature_channels * 2, out_channels, hpm_use_1x1conv, + hpm_scales, hpm_use_avg_pool, hpm_use_max_pool ) - total_parts = sum(hpm_scales) + tfa_num_parts - empty_fc = torch.empty(total_parts, out_channels, embedding_dims) + empty_fc = torch.empty(self.hpm_num_parts + tfa_num_parts, + out_channels, embedding_dims) self.fc_mat = nn.Parameter(empty_fc) - self.ba_triplet_loss = BatchAllTripletLoss(triplet_margin) + (hpm_margin, pn_margin) = triplet_margins + self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) + self.pn_ba_trip = BatchAllTripletLoss(pn_margin) def fc(self, x): return x @ self.fc_mat @@ -59,13 +67,11 @@ class RGBPartNet(nn.Module): # Step 1: Disentanglement # t, n, c, h, w - ((x_c_c1, x_p_c1), losses) = self._disentangle(x_c1, x_c2) + ((x_c_c1, x_p_c1), images, losses) = self._disentangle(x_c1, x_c2) - # Step 2.a: HPM & Static Gait Feature Aggregation - # t, n, c, h, w + # Step 2.a: Static Gait Feature Aggregation & HPM + # n, c, h, w x_c = self.hpm(x_c_c1) - # p, t, n, c - x_c = x_c.mean(dim=1) # p, n, c # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) @@ -78,44 +84,83 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - batch_all_triplet_loss = self.ba_triplet_loss(x, y) - losses = torch.stack((*losses, batch_all_triplet_loss)) - return losses + hpm_ba_trip = self.hpm_ba_trip(x[:self.hpm_num_parts], y) + pn_ba_trip = self.pn_ba_trip(x[self.hpm_num_parts:], y) + losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) + return losses, images else: return x.unsqueeze(1).view(-1) def _disentangle(self, x_c1, x_c2=None): t, n, c, h, w = x_c1.size() + device = x_c1.device if self.training: - # Decoded canonical features and Pose images - x_c_c1, x_p_c1 = [], [] + # Encoded appearance, canonical and pose features + f_a_c1, f_c_c1, f_p_c1 = [], [], [] # Features required to calculate losses - f_p_c1, f_p_c2 = [], [] + f_p_c2 = [] xrecon_loss, cano_cons_loss = [], [] for t2 in range(t): t1 = random.randrange(t) output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2]) - (x_c1_t2, f_p_t2, losses) = output + (f_c1_t2, f_p_t2, losses) = output - # Decoded features or image - (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 - # Canonical Features for HPM - x_c_c1.append(x_c_c1_t2) - # Pose image for Part Net - x_p_c1.append(x_p_c1_t2) + (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = f_c1_t2 + if self.image_log_on: + f_a_c1.append(f_a_c1_t2) + # Save canonical features and pose features + f_c_c1.append(f_c_c1_t2) + f_p_c1.append(f_p_c1_t2) # Losses per time step # Used in pose similarity loss - (f_p_c1_t2, f_p_c2_t2) = f_p_t2 - f_p_c1.append(f_p_c1_t2) + (_, f_p_c2_t2) = f_p_t2 f_p_c2.append(f_p_c2_t2) + # Cross reconstruction loss and canonical loss (xrecon_loss_t2, cano_cons_loss_t2) = losses xrecon_loss.append(xrecon_loss_t2) cano_cons_loss.append(cano_cons_loss_t2) - - x_c_c1 = torch.stack(x_c_c1) - x_p_c1 = torch.stack(x_p_c1) + if self.image_log_on: + f_a_c1 = torch.stack(f_a_c1) + f_c_c1_mean = torch.stack(f_c_c1).mean(0) + f_p_c1 = torch.stack(f_p_c1) + f_p_c2 = torch.stack(f_p_c2) + + # Decode features + appearance_image, canonical_image, pose_image = None, None, None + with torch.no_grad(): + # Decode average canonical features to higher dimension + x_c_c1 = self.ae.decoder( + torch.zeros((n, self.f_a_dim), device=device), + f_c_c1_mean, + torch.zeros((n, self.f_p_dim), device=device), + cano_only=True + ) + # Decode pose features to images + f_p_c1_ = f_p_c1.view(t * n, -1) + x_p_c1_ = self.ae.decoder( + torch.zeros((t * n, self.f_a_dim), device=device), + torch.zeros((t * n, self.f_c_dim), device=device), + f_p_c1_ + ) + x_p_c1 = x_p_c1_.view(t, n, c, h, w) + + if self.image_log_on: + # Decode appearance features + f_a_c1_ = f_a_c1.view(t * n, -1) + appearance_image_ = self.ae.decoder( + f_a_c1_, + torch.zeros((t * n, self.f_c_dim), device=device), + torch.zeros((t * n, self.f_p_dim), device=device) + ) + appearance_image = appearance_image_.view(t, n, c, h, w) + # Continue decoding canonical features + canonical_image = self.ae.decoder.trans_conv3(x_c_c1) + canonical_image = torch.sigmoid( + self.ae.decoder.trans_conv4(canonical_image) + ) + pose_image = x_p_c1 # Losses xrecon_loss = torch.sum(torch.stack(xrecon_loss)) @@ -123,20 +168,36 @@ class RGBPartNet(nn.Module): cano_cons_loss = torch.mean(torch.stack(cano_cons_loss)) return ((x_c_c1, x_p_c1), + (appearance_image, canonical_image, pose_image), (xrecon_loss, pose_sim_loss, cano_cons_loss)) else: # evaluating - x_c1 = x_c1.view(-1, c, h, w) - x_c_c1, x_p_c1 = self.ae(x_c1) - _, c_c, h_c, w_c = x_c_c1.size() - x_c_c1 = x_c_c1.view(t, n, c_c, h_c, w_c) - x_p_c1 = x_p_c1.view(t, n, c, h, w) - - return (x_c_c1, x_p_c1), None + x_c1_ = x_c1.view(t * n, c, h, w) + (f_c_c1_, f_p_c1_) = self.ae(x_c1_) + + # Canonical features + f_c_c1 = f_c_c1_.view(t, n, -1) + f_c_c1_mean = f_c_c1.mean(0) + x_c_c1 = self.ae.decoder( + torch.zeros((n, self.f_a_dim)), + f_c_c1_mean, + torch.zeros((n, self.f_p_dim)), + cano_only=True + ) + + # Pose features + x_p_c1_ = self.ae.decoder( + torch.zeros((t * n, self.f_a_dim)), + torch.zeros((t * n, self.f_c_dim)), + f_p_c1_ + ) + x_p_c1 = x_p_c1_.view(t, n, c, h, w) + + return (x_c_c1, x_p_c1), None, None @staticmethod - def _pose_sim_loss(f_p_c1: list[torch.Tensor], - f_p_c2: list[torch.Tensor]) -> torch.Tensor: - f_p_c1_mean = torch.stack(f_p_c1).mean(dim=0) - f_p_c2_mean = torch.stack(f_p_c2).mean(dim=0) + def _pose_sim_loss(f_p_c1: torch.Tensor, + f_p_c2: torch.Tensor) -> torch.Tensor: + f_p_c1_mean = f_p_c1.mean(dim=0) + f_p_c2_mean = f_p_c2.mean(dim=0) return F.mse_loss(f_p_c1_mean, f_p_c2_mean) diff --git a/utils/configuration.py b/utils/configuration.py index c4c4b4d..4ab1520 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -7,6 +7,7 @@ class SystemConfiguration(TypedDict): disable_acc: bool CUDA_VISIBLE_DEVICES: str save_dir: str + image_log_on: bool class DatasetConfiguration(TypedDict): @@ -31,6 +32,7 @@ class ModelHPConfiguration(TypedDict): ae_feature_channels: int f_a_c_p_dims: tuple[int, int, int] hpm_scales: tuple[int, ...] + hpm_use_1x1conv: bool hpm_use_avg_pool: bool hpm_use_max_pool: bool fpfe_feature_channels: int @@ -40,7 +42,7 @@ class ModelHPConfiguration(TypedDict): tfa_squeeze_ratio: int tfa_num_parts: int embedding_dims: int - triplet_margin: float + triplet_margins: tuple[float, float] class SubOptimizerHPConfiguration(TypedDict): diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 8c143d6..d573ef4 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -34,5 +34,5 @@ class BatchAllTripletLoss(nn.Module): parted_loss_mean = all_loss.sum(1) / non_zero_counts parted_loss_mean[non_zero_counts == 0] = 0 - loss = parted_loss_mean.sum() + loss = parted_loss_mean.mean() return loss -- cgit v1.2.3