diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-11 23:59:30 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-11 23:59:30 +0800 |
commit | 72a53806746bc7ffa2f3939721e34b5cfdb7330a (patch) | |
tree | 36c549aa32ed9e160381e47de6dbec045f6085cc /models/model.py | |
parent | 7188d71b2b6faf3da527c8d0ade9a32ec4893dc5 (diff) |
Add evaluation script, code review and fix some bugs
1. Add new `train_all` method for one shot calling
2. Print time used in 1k iterations
3. Correct label dimension in predict function
4. Transpose distance matrix for convenient indexing
5. Sort dictionary before generate signature
6. Extract visible CUDA setting function
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 67 |
1 files changed, 51 insertions, 16 deletions
diff --git a/models/model.py b/models/model.py index 456c2f1..b343f86 100644 --- a/models/model.py +++ b/models/model.py @@ -1,4 +1,5 @@ import os +from datetime import datetime from typing import Union, Optional import numpy as np @@ -86,6 +87,21 @@ class Model: def _checkpoint_name(self) -> str: return os.path.join(self.checkpoint_dir, self._signature) + def fit_all( + self, + dataset_config: DatasetConfiguration, + dataset_selectors: dict[ + str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] + ], + dataloader_config: DataloaderConfiguration, + ): + for (condition, selector) in dataset_selectors.items(): + print(f'Training model {condition} ...') + self.fit( + dict(**dataset_config, **{'selector': selector}), + dataloader_config + ) + def fit( self, dataset_config: DatasetConfiguration, @@ -115,6 +131,8 @@ class Model: self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) + # Training start + start_time = datetime.now() for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -137,7 +155,7 @@ class Model: ], metrics)), self.curr_iter) if self.curr_iter % 100 == 0: - print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss), + print('{0:5d} loss: {1:6.3f}'.format(self.curr_iter, loss), '(xrecon = {:f}, pose_sim = {:f},' ' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics), 'lr:', self.scheduler.get_last_lr()[0]) @@ -149,8 +167,11 @@ class Model: 'optim_state_dict': self.optimizer.state_dict(), 'loss': loss, }, self._checkpoint_name) + print(datetime.now() - start_time, 'used') + start_time = datetime.now() if self.curr_iter == self.total_iter: + self.curr_iter = 0 self.writer.close() break @@ -160,7 +181,7 @@ class Model: self.rgb_pn = nn.DataParallel(self.rgb_pn) self.rgb_pn = self.rgb_pn.to(self.device) - def predict( + def predict_all( self, iter_: int, dataset_config: DatasetConfiguration, @@ -189,23 +210,36 @@ class Model: gallery_samples, probe_samples = [], {} # Gallery - self.rgb_pn.load_state_dict(torch.load(list(checkpoints.values())[0])) + checkpoint = torch.load(list(checkpoints.values())[0]) + self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) for sample in tqdm(gallery_dataloader, desc='Transforming gallery', unit='clips'): + label = sample.pop('label').item() clip = sample.pop('clip').to(self.device) feature = self.rgb_pn(clip).detach() - gallery_samples.append({**sample, **{'feature': feature}}) + gallery_samples.append({ + **{'label': label}, + **sample, + **{'feature': feature} + }) gallery_samples = default_collate(gallery_samples) # Probe - for (name, dataloader) in probe_dataloaders.items(): - self.rgb_pn.load_state_dict(torch.load(checkpoints[name])) - probe_samples[name] = [] + for (condition, dataloader) in probe_dataloaders.items(): + checkpoint = torch.load(checkpoints[condition]) + self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + probe_samples[condition] = [] for sample in tqdm(dataloader, - desc=f'Transforming probe {name}', unit='clips'): + desc=f'Transforming probe {condition}', + unit='clips'): + label = sample.pop('label').item() clip = sample.pop('clip').to(self.device) feature = self.rgb_pn(clip).detach() - probe_samples[name].append({**sample, **{'feature': feature}}) + probe_samples[condition].append({ + **{'label': label}, + **sample, + **{'feature': feature} + }) for (k, v) in probe_samples.items(): probe_samples[k] = default_collate(v) @@ -243,11 +277,11 @@ class Model: f_p = features_p[probe_view_mask] y_p = labels_p[probe_view_mask] # Euclidean distance - f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(1) - f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(0) - f_g_times_f_p_sum = f_g @ f_p.T + f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1) + f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0) + f_p_times_f_g_sum = f_p @ f_g.T dist = torch.sqrt(F.relu( - f_g_squared_sum - 2*f_g_times_f_p_sum + f_p_squared_sum + f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum )) # Ranked accuracy rank_mask = dist.argsort(1)[:, :num_ranks] @@ -354,8 +388,8 @@ class Model: dataloader_config: DataloaderConfiguration ) -> DataLoader: config: dict = dataloader_config.copy() + (self.pr, self.k) = config.pop('batch_size') if self.is_train: - (self.pr, self.k) = config.pop('batch_size') self._log_name = '_'.join( (self._log_name, str(self.pr), str(self.k))) triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) @@ -364,7 +398,6 @@ class Model: collate_fn=self._batch_splitter, **config) else: # is_test - config.pop('batch_size') return DataLoader(dataset, **config) def _batch_splitter( @@ -399,8 +432,10 @@ class Model: for v in values: if isinstance(v, str): strings.append(v) - elif isinstance(v, (tuple, list, set)): + elif isinstance(v, (tuple, list)): strings.append(self._gen_sig(v)) + elif isinstance(v, set): + strings.append(self._gen_sig(sorted(list(v)))) elif isinstance(v, dict): strings.append(self._gen_sig(list(v.values()))) else: |