From 50eedae4f320c446544772fb2b0abcbce1be7590 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 20 Feb 2021 14:19:30 +0800 Subject: Separate triplet loss from model --- models/model.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) (limited to 'models/model.py') diff --git a/models/model.py b/models/model.py index 82d6461..5aa6436 100644 --- a/models/model.py +++ b/models/model.py @@ -18,6 +18,7 @@ from utils.configuration import DataloaderConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler +from utils.triplet_loss import JointBatchAllTripletLoss class Model: @@ -67,6 +68,7 @@ class Model: self._dataset_sig: str = 'undefined' self.rgb_pn: Optional[RGBPartNet] = None + self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None @@ -140,7 +142,8 @@ class Model: dataset = self._parse_dataset_config(dataset_config) dataloader = self._parse_dataloader_config(dataset, dataloader_config) # Prepare for model, optimizer and scheduler - model_hp = self.hp.get('model', {}) + model_hp: dict = self.hp.get('model', {}).copy() + triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2)) optim_hp: dict = self.hp.get('optimizer', {}).copy() start_iter = optim_hp.pop('start_iter', 0) ae_optim_hp = optim_hp.pop('auto_encoder', {}) @@ -150,6 +153,9 @@ class Model: sched_hp = self.hp.get('scheduler', {}) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) + self.ba_triplet_loss = JointBatchAllTripletLoss( + self.rgb_pn.hpm_num_parts, triplet_margins + ) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.optimizer = optim.Adam([ @@ -193,10 +199,18 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) + feature, ae_losses, images = self.rgb_pn(x_c1, x_c2) y = batch_c1['label'].to(self.device) # Duplicate labels for each part - y = y.unsqueeze(1).repeat(1, self.rgb_pn.num_total_parts) - losses, images = self.rgb_pn(x_c1, x_c2, y) + y = y.repeat(self.rgb_pn.num_total_parts, 1) + triplet_loss = self.ba_triplet_loss(feature, y) + losses = torch.cat(( + ae_losses, + torch.stack(( + triplet_loss[:self.rgb_pn.hpm_num_parts].mean(), + triplet_loss[self.rgb_pn.hpm_num_parts:].mean() + )) + )) loss = losses.sum() loss.backward() self.optimizer.step() -- cgit v1.2.3