From 543a35d814754c86ccafa1243bece387c1a780d6 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Tue, 2 Mar 2021 19:22:25 +0800 Subject: Fix bugs in new scheduler --- models/model.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'models/model.py') diff --git a/models/model.py b/models/model.py index 497a0ea..3242141 100644 --- a/models/model.py +++ b/models/model.py @@ -185,11 +185,14 @@ class Model: ], **optim_hp) sched_final_gamma = sched_hp.get('final_gamma', 0.001) sched_start_step = sched_hp.get('start_step', 15_000) + all_step = self.total_iter - sched_start_step def lr_lambda(epoch): - passed_step = epoch - sched_start_step - all_step = self.total_iter - sched_start_step - return sched_final_gamma ** (passed_step / all_step) + if epoch > sched_start_step: + passed_step = epoch - sched_start_step + return sched_final_gamma ** (passed_step / all_step) + else: + return 1 self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ lr_lambda, lr_lambda, lr_lambda, lr_lambda ]) -- cgit v1.2.3 From 02780c31385af7e1103448bd1994012ac95dd2bb Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Tue, 2 Mar 2021 19:26:11 +0800 Subject: Record learning rate every step --- models/model.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'models/model.py') diff --git a/models/model.py b/models/model.py index 3242141..acccbff 100644 --- a/models/model.py +++ b/models/model.py @@ -282,13 +282,14 @@ class Model: 'Embedding/PartNet norm', mean_pa_norm, self.k, self.pr * self.k, self.curr_iter ) + # Learning rate + lrs = self.scheduler.get_last_lr() + # Write learning rates + self.writer.add_scalar( + 'Learning rate', lrs[0], self.curr_iter + ) if self.curr_iter % 100 == 0: - lrs = self.scheduler.get_last_lr() - # Write learning rates - self.writer.add_scalar( - 'Learning rate', lrs[0], self.curr_iter - ) # Write disentangled images if self.image_log_on: i_a, i_c, i_p = images -- cgit v1.2.3 From 09c2af6f4881bb4a2ed5e3685b8ce65ad41695cb Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Tue, 2 Mar 2021 20:10:14 +0800 Subject: Fix DataParallel specific bugs --- models/model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'models/model.py') diff --git a/models/model.py b/models/model.py index 70596ad..a8e0316 100644 --- a/models/model.py +++ b/models/model.py @@ -227,11 +227,10 @@ class Model: y = batch_c1['label'].to(self.device) # Duplicate labels for each part y = y.repeat(self.rgb_pn.module.num_total_parts, 1) - trip_loss, dist, num_non_zero = self.triplet_loss( - embedding.contiguous(), y - ) + embedding = embedding.transpose(0, 1) + trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) losses = torch.cat(( - ae_losses.mean(0), + ae_losses.view(-1, 3).mean(0), torch.stack(( trip_loss[:self.rgb_pn.module.hpm_num_parts].mean(), trip_loss[self.rgb_pn.module.hpm_num_parts:].mean() -- cgit v1.2.3