diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-21 23:47:52 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-21 23:47:52 +0800 |
commit | c0c8299354bd41bfd668ff0fb3edb5997f590c5d (patch) | |
tree | 25502d76a218a71ad5c2d51001876b08d734ee47 /models | |
parent | 42847b721a99350e1eed423dce99574c584d97ef (diff) | |
parent | d750dd9dafe3cda3b1331ad2bfecb53c8c2b1267 (diff) |
Merge branch 'python3.8' into python3.7
# Conflicts:
# utils/configuration.py
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 19 | ||||
-rw-r--r-- | models/model.py | 33 | ||||
-rw-r--r-- | models/rgb_part_net.py | 5 |
3 files changed, 36 insertions, 21 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 64c52e3..befd2d3 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -134,15 +134,16 @@ class AutoEncoder(nn.Module): # x_c1_t2 is the frame for later module (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) - # Decode canonical features for HPM - x_c_c1_t2 = self.decoder( - torch.zeros_like(f_a_c1_t2), f_c_c1_t2, torch.zeros_like(f_p_c1_t2), - no_trans_conv=True - ) - # Decode pose features for Part Net - x_p_c1_t2 = self.decoder( - torch.zeros_like(f_a_c1_t2), torch.zeros_like(f_c_c1_t2), f_p_c1_t2 - ) + with torch.no_grad(): + # Decode canonical features for HPM + x_c_c1_t2 = self.decoder( + torch.zeros_like(f_a_c1_t2), f_c_c1_t2, torch.zeros_like(f_p_c1_t2), + no_trans_conv=True + ) + # Decode pose features for Part Net + x_p_c1_t2 = self.decoder( + torch.zeros_like(f_a_c1_t2), torch.zeros_like(f_c_c1_t2), f_p_c1_t2 + ) if self.training: # t1 is random time step, c2 is another condition diff --git a/models/model.py b/models/model.py index aa45d66..5a8c0e8 100644 --- a/models/model.py +++ b/models/model.py @@ -130,12 +130,21 @@ class Model: dataloader = self._parse_dataloader_config(dataset, dataloader_config) # Prepare for model, optimizer and scheduler model_hp = self.hp.get('model', {}) - optim_hp = self.hp.get('optimizer', {}) + optim_hp: Dict = self.hp.get('optimizer', {}).copy() + ae_optim_hp = optim_hp.pop('auto_encoder', {}) + pn_optim_hp = optim_hp.pop('part_net', {}) + hpm_optim_hp = optim_hp.pop('hpm', {}) + fc_optim_hp = optim_hp.pop('fc', {}) sched_hp = self.hp.get('scheduler', {}) self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **model_hp) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) - self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp) + self.optimizer = optim.Adam([ + {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, + {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, + {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, + {'params': self.rgb_pn.fc_mat, **fc_optim_hp}, + ], **optim_hp) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, **sched_hp) self.writer = SummaryWriter(self._log_name) @@ -152,6 +161,9 @@ class Model: # Training start start_time = datetime.now() + running_loss = torch.zeros(4).to(self.device) + print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}", + f"{'CanoCons':^8} {'BATrip':^8} {'LR':^9}") for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -160,24 +172,27 @@ class Model: x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) y = batch_c1['label'].to(self.device) - loss, metrics = self.rgb_pn(x_c1, x_c2, y) + losses = self.rgb_pn(x_c1, x_c2, y) + loss = losses.sum() loss.backward() self.optimizer.step() # Step scheduler self.scheduler.step() + # Statistics and checkpoint + running_loss += losses.detach() # Write losses to TensorBoard - self.writer.add_scalar('Loss/all', loss.item(), self.curr_iter) + self.writer.add_scalar('Loss/all', loss, self.curr_iter) self.writer.add_scalars('Loss/details', dict(zip([ 'Cross reconstruction loss', 'Pose similarity loss', 'Canonical consistency loss', 'Batch All triplet loss' - ], metrics)), self.curr_iter) + ], losses)), self.curr_iter) if self.curr_iter % 100 == 0: - print('{0:5d} loss: {1:6.3f}'.format(self.curr_iter, loss), - '(xrecon = {:f}, pose_sim = {:f},' - ' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics), - 'lr:', self.scheduler.get_last_lr()[0]) + print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', + '{:f} {:f} {:f} {:f}'.format(*running_loss / 100), + f'{self.scheduler.get_last_lr()[0]:.3e}') + running_loss.zero_() if self.curr_iter % 1000 == 0: torch.save({ diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 95a3f2e..326ec81 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -81,9 +81,8 @@ class RGBPartNet(nn.Module): if self.training: batch_all_triplet_loss = self.ba_triplet_loss(x, y) - losses = (*losses, batch_all_triplet_loss) - loss = torch.sum(torch.stack(losses)) - return loss, [loss.item() for loss in losses] + losses = torch.stack((*losses, batch_all_triplet_loss)) + return losses else: return x.unsqueeze(1).view(-1) |