diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/model.py | 38 | ||||
-rw-r--r-- | models/rgb_part_net.py | 2 |
2 files changed, 24 insertions, 16 deletions
diff --git a/models/model.py b/models/model.py index acccbff..1f8ae23 100644 --- a/models/model.py +++ b/models/model.py @@ -163,7 +163,7 @@ class Model: ) else: # Different margins self.triplet_loss = JointBatchTripletLoss( - self.rgb_pn.hpm_num_parts, + self.rgb_pn.module.hpm_num_parts, triplet_is_hard, triplet_is_mean, triplet_margins ) else: # Soft margins @@ -175,13 +175,15 @@ class Model: num_pos_pairs = (self.k*(self.k-1)//2) * self.pr # Try to accelerate computation using CUDA or others + self.rgb_pn = nn.DataParallel(self.rgb_pn) self.rgb_pn = self.rgb_pn.to(self.device) + self.triplet_loss = nn.DataParallel(self.triplet_loss) self.triplet_loss = self.triplet_loss.to(self.device) self.optimizer = optim.Adam([ - {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, - {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.fc_mat, **fc_optim_hp} + {'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp}, + {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp}, + {'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp}, + {'params': self.rgb_pn.module.fc_mat, **fc_optim_hp} ], **optim_hp) sched_final_gamma = sched_hp.get('final_gamma', 0.001) sched_start_step = sched_hp.get('start_step', 15_000) @@ -227,13 +229,14 @@ class Model: embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part - y = y.repeat(self.rgb_pn.num_total_parts, 1) + y = y.repeat(self.rgb_pn.module.num_total_parts, 1) + embedding = embedding.transpose(0, 1) trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y) losses = torch.cat(( - ae_losses, + ae_losses.view(-1, 3).mean(0), torch.stack(( - trip_loss[:self.rgb_pn.hpm_num_parts].mean(), - trip_loss[self.rgb_pn.hpm_num_parts:].mean() + trip_loss[:self.rgb_pn.module.hpm_num_parts].mean(), + trip_loss[self.rgb_pn.module.hpm_num_parts:].mean() )) )) loss = losses.sum() @@ -255,28 +258,32 @@ class Model: # None-zero losses in batch if num_non_zero is not None: self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(), - 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean() + 'HPM': num_non_zero[ + :self.rgb_pn.module.hpm_num_parts].mean(), + 'PartNet': num_non_zero[ + self.rgb_pn.module.hpm_num_parts:].mean() }, self.curr_iter) # Embedding distance - mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_dist = dist[:self.rgb_pn.module.hpm_num_parts].mean(0) self._add_ranked_scalars( 'Embedding/HPM distance', mean_hpm_dist, num_pos_pairs, num_pairs, self.curr_iter ) - mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_dist = dist[self.rgb_pn.module.hpm_num_parts:].mean(0) self._add_ranked_scalars( 'Embedding/ParNet distance', mean_pa_dist, num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm - mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0) + mean_hpm_embedding = embedding[ + :self.rgb_pn.module.hpm_num_parts].mean(0) mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/HPM norm', mean_hpm_norm, self.k, self.pr * self.k, self.curr_iter ) - mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0) + mean_pa_embedding = embedding[ + self.rgb_pn.module.hpm_num_parts:].mean(0) mean_pa_norm = mean_pa_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/PartNet norm', mean_pa_norm, @@ -393,6 +400,7 @@ class Model: model_hp.pop('triplet_margins', None) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp) # Try to accelerate computation using CUDA or others + self.rgb_pn = nn.DataParallel(self.rgb_pn) self.rgb_pn = self.rgb_pn.to(self.device) self.rgb_pn.eval() diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 8a0f3a7..cdf579b 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -67,7 +67,7 @@ class RGBPartNet(nn.Module): x = self.fc(x) if self.training: - return x, ae_losses, images + return x.transpose(0, 1), ae_losses, images else: return x.unsqueeze(1).view(-1) |