summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/model.py27
-rw-r--r--models/rgb_part_net.py2
2 files changed, 16 insertions, 13 deletions
diff --git a/models/model.py b/models/model.py
index 42064fe..7aa103e 100644
--- a/models/model.py
+++ b/models/model.py
@@ -187,11 +187,14 @@ class Model:
], **optim_hp)
sched_final_gamma = sched_hp.get('final_gamma', 0.001)
sched_start_step = sched_hp.get('start_step', 15_000)
+ all_step = self.total_iter - sched_start_step
def lr_lambda(epoch):
- passed_step = epoch - sched_start_step
- all_step = self.total_iter - sched_start_step
- return sched_final_gamma ** (passed_step / all_step)
+ if epoch > sched_start_step:
+ passed_step = epoch - sched_start_step
+ return sched_final_gamma ** (passed_step / all_step)
+ else:
+ return 1
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
lr_lambda, lr_lambda, lr_lambda, lr_lambda
])
@@ -227,11 +230,10 @@ class Model:
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
y = y.repeat(self.rgb_pn.module.num_total_parts, 1)
- trip_loss, dist, num_non_zero = self.triplet_loss(
- embedding.contiguous(), y
- )
+ embedding = embedding.transpose(0, 1)
+ trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
losses = torch.cat((
- ae_losses.mean(0),
+ ae_losses.view(-1, 3).mean(0),
torch.stack((
trip_loss[:self.rgb_pn.module.hpm_num_parts].mean(),
trip_loss[self.rgb_pn.module.hpm_num_parts:].mean()
@@ -287,13 +289,14 @@ class Model:
'Embedding/PartNet norm', mean_pa_norm,
self.k, self.pr * self.k, self.curr_iter
)
+ # Learning rate
+ lrs = self.scheduler.get_last_lr()
+ # Write learning rates
+ self.writer.add_scalar(
+ 'Learning rate', lrs[0], self.curr_iter
+ )
if self.curr_iter % 100 == 0:
- lrs = self.scheduler.get_last_lr()
- # Write learning rates
- self.writer.add_scalar(
- 'Learning rate', lrs[0], self.curr_iter
- )
# Write disentangled images
if self.image_log_on:
i_a, i_c, i_p = images
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 310ef25..1cda91c 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -69,7 +69,7 @@ class RGBPartNet(nn.Module):
x = self.fc(x)
if self.training:
- return x, ae_losses, images
+ return x.transpose(0, 1), ae_losses, images
else:
return x.unsqueeze(1).view(-1)