diff options
| author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-02 20:22:38 +0800 | 
|---|---|---|
| committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-02 20:22:38 +0800 | 
| commit | b274351a8528bc63e52afd6a5d0c34811cea84b1 (patch) | |
| tree | b41df93b460f2166f930fad5ef34ca92d3dc775f | |
| parent | c3143f388730d2869067f6f259775289c742bb48 (diff) | |
| parent | 7fac206f92602462ad8eecde524b0324f7991bde (diff) | |
Merge branch 'data_parallel' into data_parallel_py3.8
| -rw-r--r-- | models/model.py | 27 | ||||
| -rw-r--r-- | models/rgb_part_net.py | 2 | 
2 files changed, 16 insertions, 13 deletions
| diff --git a/models/model.py b/models/model.py index 42064fe..7aa103e 100644 --- a/models/model.py +++ b/models/model.py @@ -187,11 +187,14 @@ class Model:          ], **optim_hp)          sched_final_gamma = sched_hp.get('final_gamma', 0.001)          sched_start_step = sched_hp.get('start_step', 15_000) +        all_step = self.total_iter - sched_start_step          def lr_lambda(epoch): -            passed_step = epoch - sched_start_step -            all_step = self.total_iter - sched_start_step -            return sched_final_gamma ** (passed_step / all_step) +            if epoch > sched_start_step: +                passed_step = epoch - sched_start_step +                return sched_final_gamma ** (passed_step / all_step) +            else: +                return 1          self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[              lr_lambda, lr_lambda, lr_lambda, lr_lambda          ]) @@ -227,11 +230,10 @@ class Model:              y = batch_c1['label'].to(self.device)              # Duplicate labels for each part              y = y.repeat(self.rgb_pn.module.num_total_parts, 1) -            trip_loss, dist, num_non_zero = self.triplet_loss( -                embedding.contiguous(), y -            ) +            embedding = embedding.transpose(0, 1) +            trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y)              losses = torch.cat(( -                ae_losses.mean(0), +                ae_losses.view(-1, 3).mean(0),                  torch.stack((                      trip_loss[:self.rgb_pn.module.hpm_num_parts].mean(),                      trip_loss[self.rgb_pn.module.hpm_num_parts:].mean() @@ -287,13 +289,14 @@ class Model:                  'Embedding/PartNet norm', mean_pa_norm,                  self.k, self.pr * self.k, self.curr_iter              ) +            # Learning rate +            lrs = self.scheduler.get_last_lr() +            # Write learning rates +            self.writer.add_scalar( +                'Learning rate', lrs[0], self.curr_iter +            )              if self.curr_iter % 100 == 0: -                lrs = self.scheduler.get_last_lr() -                # Write learning rates -                self.writer.add_scalar( -                    'Learning rate', lrs[0], self.curr_iter -                )                  # Write disentangled images                  if self.image_log_on:                      i_a, i_c, i_p = images diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 310ef25..1cda91c 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -69,7 +69,7 @@ class RGBPartNet(nn.Module):          x = self.fc(x)          if self.training: -            return x, ae_losses, images +            return x.transpose(0, 1), ae_losses, images          else:              return x.unsqueeze(1).view(-1) | 
