From d51312415a32686793d3f0d14eda7fa7cc3990ea Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 15 Feb 2021 11:23:20 +0800 Subject: Revert "Memory usage improvement" This reverts commit be508061 --- models/model.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) (limited to 'models/model.py') diff --git a/models/model.py b/models/model.py index bd05115..f79b832 100644 --- a/models/model.py +++ b/models/model.py @@ -182,7 +182,7 @@ class Model: # Training start start_time = datetime.now() running_loss = torch.zeros(5, device=self.device) - print(f"{'Time':^8} {'Iter':^5} {'Loss':^5}", + print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}") for (batch_c1, batch_c2) in dataloader: @@ -190,21 +190,12 @@ class Model: # Zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize - # Feed data twice in order to reduce memory usage x_c1 = batch_c1['clip'].to(self.device) + x_c2 = batch_c2['clip'].to(self.device) y = batch_c1['label'].to(self.device) # Duplicate labels for each part y = y.unsqueeze(1).repeat(1, self.rgb_pn.num_total_parts) - # Feed condition 1 clips first - losses, images = self.rgb_pn(x_c1, y) - (xrecon_loss, hpm_ba_trip, pn_ba_trip) = losses - x_c2 = batch_c2['clip'].to(self.device) - # Then feed condition 2 clips - cano_cons_loss, pose_sim_loss = self.rgb_pn(x_c2, is_c1=False) - losses = torch.stack(( - xrecon_loss, cano_cons_loss, pose_sim_loss, - hpm_ba_trip, pn_ba_trip - )) + losses, images = self.rgb_pn(x_c1, x_c2, y) loss = losses.sum() loss.backward() self.optimizer.step() @@ -234,9 +225,7 @@ class Model: self.writer.add_images( 'Canonical image', i_c, self.curr_iter ) - for (i, (o, a, p)) in enumerate(zip( - batch_c1['clip'], i_a, i_p - )): + for (i, (o, a, p)) in enumerate(zip(x_c1, i_a, i_p)): self.writer.add_images( f'Original image/batch {i}', o, self.curr_iter ) @@ -250,7 +239,7 @@ class Model: remaining_minute, second = divmod(time_used.seconds, 60) hour, minute = divmod(remaining_minute, 60) print(f'{hour:02}:{minute:02}:{second:02}', - f'{self.curr_iter:5d} {running_loss.sum() / 100:5.3f}', + f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), '{:.3e} {:.3e}'.format(lrs[0], lrs[1])) running_loss.zero_() -- cgit v1.2.3 From 5657dd650a8fffab9c8e3096a65c3cd94a5c42f4 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Tue, 16 Feb 2021 11:19:55 +0800 Subject: Split transform and evaluate method --- models/model.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) (limited to 'models/model.py') diff --git a/models/model.py b/models/model.py index f79b832..0c3e5eb 100644 --- a/models/model.py +++ b/models/model.py @@ -269,6 +269,24 @@ class Model: ], dataloader_config: DataloaderConfiguration, ) -> dict[str, torch.Tensor]: + # Transform data to features + gallery_samples, probe_samples = self.transform( + iters, dataset_config, dataset_selectors, dataloader_config + ) + # Evaluate features + accuracy = self.evaluate(gallery_samples, probe_samples) + + return accuracy + + def transform( + self, + iters: tuple[int], + dataset_config: DatasetConfiguration, + dataset_selectors: dict[ + str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] + ], + dataloader_config: DataloaderConfiguration + ): self.is_train = False # Split gallery and probe dataset gallery_dataloader, probe_dataloaders = self._split_gallery_probe( @@ -278,15 +296,15 @@ class Model: checkpoints = self._load_pretrained( iters, dataset_config, dataset_selectors ) + # Init models model_hp = self.hp.get('model', {}) self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **model_hp) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) - self.rgb_pn.eval() - gallery_samples, probe_samples = [], {} + gallery_samples, probe_samples = [], {} # Gallery checkpoint = torch.load(list(checkpoints.values())[0]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) @@ -294,7 +312,6 @@ class Model: desc='Transforming gallery', unit='clips'): gallery_samples.append(self._get_eval_sample(sample)) gallery_samples = default_collate(gallery_samples) - # Probe for (condition, dataloader) in probe_dataloaders.items(): checkpoint = torch.load(checkpoints[condition]) @@ -306,7 +323,7 @@ class Model: probe_samples_c.append(self._get_eval_sample(sample)) probe_samples[condition] = default_collate(probe_samples_c) - return self._evaluate(gallery_samples, probe_samples) + return gallery_samples, probe_samples def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): label = sample.pop('label').item() @@ -318,7 +335,7 @@ class Model: **{'feature': feature} } - def _evaluate( + def evaluate( self, gallery_samples: dict[str, Union[list[str], torch.Tensor]], probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]], -- cgit v1.2.3