diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-20 14:48:16 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-20 14:48:16 +0800 |
commit | 9b1828be1db7fd1be8731a7cec66162de9145285 (patch) | |
tree | 9efb5a37856f34e333457e9d7ab2aaa8ba811cf6 /models | |
parent | e33c22e556ed64e1c1fdb011d78a124d1489ad15 (diff) | |
parent | c538919cb69e35a46811aef0b23baefe6a4c499c (diff) |
Merge branch 'python3.8' into data_parallel_py3.8
# Conflicts:
# models/model.py
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 2 | ||||
-rw-r--r-- | models/model.py | 24 | ||||
-rw-r--r-- | models/rgb_part_net.py | 20 |
3 files changed, 23 insertions, 23 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index e17caed..6f388c2 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -173,7 +173,7 @@ class AutoEncoder(nn.Module): return ( (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_), - (xrecon_loss, cano_cons_loss, pose_sim_loss * 10) + torch.stack((xrecon_loss, cano_cons_loss, pose_sim_loss * 10)) ) else: # evaluating return f_c_c1_t2_, f_p_c1_t2_ diff --git a/models/model.py b/models/model.py index 2ef3b80..e2de476 100644 --- a/models/model.py +++ b/models/model.py @@ -18,6 +18,7 @@ from utils.configuration import DataloaderConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler +from utils.triplet_loss import JointBatchAllTripletLoss class Model: @@ -67,6 +68,7 @@ class Model: self._dataset_sig: str = 'undefined' self.rgb_pn: Optional[RGBPartNet] = None + self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None @@ -140,7 +142,8 @@ class Model: dataset = self._parse_dataset_config(dataset_config) dataloader = self._parse_dataloader_config(dataset, dataloader_config) # Prepare for model, optimizer and scheduler - model_hp = self.hp.get('model', {}) + model_hp: Dict = self.hp.get('model', {}).copy() + triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2)) optim_hp: Dict = self.hp.get('optimizer', {}).copy() start_iter = optim_hp.pop('start_iter', 0) ae_optim_hp = optim_hp.pop('auto_encoder', {}) @@ -150,9 +153,14 @@ class Model: sched_hp = self.hp.get('scheduler', {}) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) + self.ba_triplet_loss = JointBatchAllTripletLoss( + self.rgb_pn.hpm_num_parts, triplet_margins + ) # Try to accelerate computation using CUDA or others self.rgb_pn = nn.DataParallel(self.rgb_pn) self.rgb_pn = self.rgb_pn.to(self.device) + self.ba_triplet_loss = nn.DataParallel(self.ba_triplet_loss) + self.ba_triplet_loss = self.ba_triplet_loss.to(self.device) self.optimizer = optim.Adam([ {'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp}, {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp}, @@ -194,12 +202,18 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) + feature, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part - y = y.unsqueeze(1).repeat(1, self.rgb_pn.module.num_total_parts) - losses, images = self.rgb_pn(x_c1, x_c2, y) - # Combine losses from different data splits - losses = losses.mean() + y = y.repeat(self.rgb_pn.num_total_parts, 1) + triplet_loss = self.ba_triplet_loss(feature, y) + losses = torch.cat(( + ae_losses.mean(0), + torch.stack(( + triplet_loss[:self.rgb_pn.hpm_num_parts].mean(), + triplet_loss[self.rgb_pn.hpm_num_parts:].mean() + )) + )) loss = losses.sum() loss.backward() self.optimizer.step() diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 2af990e..15b69f9 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -6,7 +6,6 @@ import torch.nn as nn from models.auto_encoder import AutoEncoder from models.hpm import HorizontalPyramidMatching from models.part_net import PartNet -from utils.triplet_loss import BatchAllTripletLoss class RGBPartNet(nn.Module): @@ -27,7 +26,6 @@ class RGBPartNet(nn.Module): tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, embedding_dims: int = 256, - triplet_margins: Tuple[float, float] = (0.2, 0.2), image_log_on: bool = False ): super().__init__() @@ -52,17 +50,13 @@ class RGBPartNet(nn.Module): out_channels, embedding_dims) self.fc_mat = nn.Parameter(empty_fc) - (hpm_margin, pn_margin) = triplet_margins - self.hpm_ba_trip = BatchAllTripletLoss(hpm_margin) - self.pn_ba_trip = BatchAllTripletLoss(pn_margin) - def fc(self, x): return x @ self.fc_mat - def forward(self, x_c1, x_c2=None, y=None): + def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement # n, t, c, h, w - ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2) + ((x_c, x_p), ae_losses, images) = self._disentangle(x_c1, x_c2) # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w @@ -79,15 +73,7 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - y = y.T - hpm_ba_trip = self.hpm_ba_trip( - x[:self.hpm_num_parts], y[:self.hpm_num_parts] - ) - pn_ba_trip = self.pn_ba_trip( - x[self.hpm_num_parts:], y[self.hpm_num_parts:] - ) - losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip)) - return losses, images + return x, ae_losses, images else: return x.unsqueeze(1).view(-1) |