diff options
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 114 |
1 files changed, 13 insertions, 101 deletions
diff --git a/models/model.py b/models/model.py index 82d6461..3f24936 100644 --- a/models/model.py +++ b/models/model.py @@ -5,7 +5,6 @@ from typing import Union, Optional import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate @@ -142,29 +141,16 @@ class Model: # Prepare for model, optimizer and scheduler model_hp = self.hp.get('model', {}) optim_hp: dict = self.hp.get('optimizer', {}).copy() - start_iter = optim_hp.pop('start_iter', 0) - ae_optim_hp = optim_hp.pop('auto_encoder', {}) - pn_optim_hp = optim_hp.pop('part_net', {}) - hpm_optim_hp = optim_hp.pop('hpm', {}) - fc_optim_hp = optim_hp.pop('fc', {}) sched_hp = self.hp.get('scheduler', {}) self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) - self.optimizer = optim.Adam([ - {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp}, - {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.fc_mat, **fc_optim_hp} - ], **optim_hp) + self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp) sched_gamma = sched_hp.get('gamma', 0.9) sched_step_size = sched_hp.get('step_size', 500) self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ lambda epoch: sched_gamma ** (epoch // sched_step_size), - lambda epoch: 0 if epoch < start_iter else 1, - lambda epoch: 0 if epoch < start_iter else 1, - lambda epoch: 0 if epoch < start_iter else 1, ]) self.writer = SummaryWriter(self._log_name) @@ -182,10 +168,10 @@ class Model: # Training start start_time = datetime.now() - running_loss = torch.zeros(5, device=self.device) + running_loss = torch.zeros(3, device=self.device) print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}") + f"{'LR':^9}") for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -193,10 +179,7 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - y = batch_c1['label'].to(self.device) - # Duplicate labels for each part - y = y.unsqueeze(1).repeat(1, self.rgb_pn.num_total_parts) - losses, images = self.rgb_pn(x_c1, x_c2, y) + losses, images = self.rgb_pn(x_c1, x_c2) loss = losses.sum() loss.backward() self.optimizer.step() @@ -206,19 +189,16 @@ class Model: # Write losses to TensorBoard self.writer.add_scalar('Loss/all', loss, self.curr_iter) self.writer.add_scalars('Loss/details', dict(zip([ - 'Cross reconstruction loss', 'Canonical consistency loss', - 'Pose similarity loss', 'Batch All triplet loss (HPM)', - 'Batch All triplet loss (PartNet)' + 'Cross reconstruction loss', + 'Canonical consistency loss', + 'Pose similarity loss' ], losses)), self.curr_iter) if self.curr_iter % 100 == 0: - lrs = self.scheduler.get_last_lr() + lr = self.scheduler.get_last_lr()[0] # Write learning rates self.writer.add_scalar( - 'Learning rate/Auto-encoder', lrs[0], self.curr_iter - ) - self.writer.add_scalar( - 'Learning rate/Others', lrs[1], self.curr_iter + 'Learning rate/Auto-encoder', lr, self.curr_iter ) # Write disentangled images if self.image_log_on: @@ -241,8 +221,8 @@ class Model: hour, minute = divmod(remaining_minute, 60) print(f'{hour:02}:{minute:02}:{second:02}', f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', - '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), - '{:.3e} {:.3e}'.format(lrs[0], lrs[1])) + '{:f} {:f} {:f}'.format(*running_loss / 100), + f'{lr:.3e}') running_loss.zero_() # Step scheduler @@ -261,24 +241,6 @@ class Model: self.writer.close() break - def predict_all( - self, - iters: tuple[int], - dataset_config: DatasetConfiguration, - dataset_selectors: dict[ - str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] - ], - dataloader_config: DataloaderConfiguration, - ) -> dict[str, torch.Tensor]: - # Transform data to features - gallery_samples, probe_samples = self.transform( - iters, dataset_config, dataset_selectors, dataloader_config - ) - # Evaluate features - accuracy = self.evaluate(gallery_samples, probe_samples) - - return accuracy - def transform( self, iters: tuple[int], @@ -329,61 +291,13 @@ class Model: def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): label = sample.pop('label').item() clip = sample.pop('clip').to(self.device) - feature = self.rgb_pn(clip).detach() + x_c, x_p = self.rgb_pn(clip).detach() return { **{'label': label}, **sample, - **{'feature': feature} - } - - def evaluate( - self, - gallery_samples: dict[str, Union[list[str], torch.Tensor]], - probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]], - num_ranks: int = 5 - ) -> dict[str, torch.Tensor]: - probe_conditions = self._probe_datasets_meta.keys() - gallery_views_meta = self._gallery_dataset_meta['views'] - probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] - accuracy = { - condition: torch.empty( - len(gallery_views_meta), len(probe_views_meta), num_ranks - ) - for condition in self._probe_datasets_meta.keys() + **{'cano_feature': x_c, 'pose_feature': x_p} } - (labels_g, _, views_g, features_g) = gallery_samples.values() - views_g = np.asarray(views_g) - for (v_g_i, view_g) in enumerate(gallery_views_meta): - gallery_view_mask = (views_g == view_g) - f_g = features_g[gallery_view_mask] - y_g = labels_g[gallery_view_mask] - for condition in probe_conditions: - probe_samples_c = probe_samples[condition] - accuracy_c = accuracy[condition] - (labels_p, _, views_p, features_p) = probe_samples_c.values() - views_p = np.asarray(views_p) - for (v_p_i, view_p) in enumerate(probe_views_meta): - probe_view_mask = (views_p == view_p) - f_p = features_p[probe_view_mask] - y_p = labels_p[probe_view_mask] - # Euclidean distance - f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1) - f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0) - f_p_times_f_g_sum = f_p @ f_g.T - dist = torch.sqrt(F.relu( - f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum - )) - # Ranked accuracy - rank_mask = dist.argsort(1)[:, :num_ranks] - positive_mat = torch.eq(y_p.unsqueeze(1), - y_g[rank_mask]).cumsum(1).gt(0) - positive_counts = positive_mat.sum(0) - total_counts, _ = dist.size() - accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts - - return accuracy - def _load_pretrained( self, iters: tuple[int], @@ -452,8 +366,6 @@ class Model: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) - elif isinstance(m, RGBPartNet): - nn.init.xavier_uniform_(m.fc_mat) def _parse_dataset_config( self, |