diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-12 11:29:02 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-12 11:29:02 +0800 |
commit | 966d4431c037b0c4641aa2a5fc22f05be064b331 (patch) | |
tree | 0239ba89d31857a7f086acf627fc1bbf167855a9 /models/model.py | |
parent | 7825f978f198e56958703f0d08f7ccbd8cef49ca (diff) | |
parent | 36cf502afe9b93efe31c244030270b0a62e644b8 (diff) |
Merge branch 'master' into python3.8
# Conflicts:
# models/model.py
Diffstat (limited to 'models/model.py')
-rw-r--r-- | models/model.py | 249 |
1 files changed, 229 insertions, 20 deletions
diff --git a/models/model.py b/models/model.py index 80d4499..1154d7f 100644 --- a/models/model.py +++ b/models/model.py @@ -1,19 +1,22 @@ import os +from datetime import datetime from typing import Union, Optional, Tuple, List, Dict, Set import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm from models.rgb_part_net import RGBPartNet from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration -from utils.dataset import CASIAB +from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler @@ -52,25 +55,52 @@ class Model: self.pr: Optional[int] = None self.k: Optional[int] = None + self._gallery_dataset_meta: Optional[Dict[str, List]] = None + self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None + self._model_sig: str = self._make_signature(self.meta, ['restore_iter']) self._hp_sig: str = self._make_signature(self.hp) self._dataset_sig: str = 'undefined' self._log_sig: str = '_'.join((self._model_sig, self._hp_sig)) - self.log_name: str = os.path.join(self.log_dir, self._log_sig) + self._log_name: str = os.path.join(self.log_dir, self._log_sig) self.rgb_pn: Optional[RGBPartNet] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None + self.CASIAB_GALLERY_SELECTOR = { + 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})} + } + self.CASIAB_PROBE_SELECTORS = { + 'nm': {'selector': {'conditions': ClipConditions({r'nm-0[5-6]'})}}, + 'bg': {'selector': {'conditions': ClipConditions({r'bg-0[1-2]'})}}, + 'cl': {'selector': {'conditions': ClipConditions({r'cl-0[1-2]'})}}, + } + @property - def signature(self) -> str: + def _signature(self) -> str: return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig, self._dataset_sig, str(self.pr), str(self.k))) @property - def checkpoint_name(self) -> str: - return os.path.join(self.checkpoint_dir, self.signature) + def _checkpoint_name(self) -> str: + return os.path.join(self.checkpoint_dir, self._signature) + + def fit_all( + self, + dataset_config: DatasetConfiguration, + dataset_selectors: Dict[ + str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]] + ], + dataloader_config: DataloaderConfiguration, + ): + for (condition, selector) in dataset_selectors.items(): + print(f'Training model {condition} ...') + self.fit( + dict(**dataset_config, **{'selector': selector}), + dataloader_config + ) def fit( self, @@ -86,24 +116,23 @@ class Model: self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp) self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9) - self.writer = SummaryWriter(self.log_name) - - if not self.disable_acc: - if torch.cuda.device_count() > 1: - self.rgb_pn = nn.DataParallel(self.rgb_pn) - self.rgb_pn = self.rgb_pn.to(self.device) + self.writer = SummaryWriter(self._log_name) + # Try to accelerate computation using CUDA or others + self._accelerate() self.rgb_pn.train() # Init weights at first iter if self.curr_iter == 0: self.rgb_pn.apply(self.init_weights) else: # Load saved state dicts - checkpoint = torch.load(self.checkpoint_name) + checkpoint = torch.load(self._checkpoint_name) iter_, loss = checkpoint['iter'], checkpoint['loss'] print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) + # Training start + start_time = datetime.now() for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -126,7 +155,7 @@ class Model: ], metrics)), self.curr_iter) if self.curr_iter % 100 == 0: - print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss), + print('{0:5d} loss: {1:6.3f}'.format(self.curr_iter, loss), '(xrecon = {:f}, pose_sim = {:f},' ' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics), 'lr:', self.scheduler.get_last_lr()[0]) @@ -137,12 +166,190 @@ class Model: 'model_state_dict': self.rgb_pn.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), 'loss': loss, - }, self.checkpoint_name) + }, self._checkpoint_name) + print(datetime.now() - start_time, 'used') + start_time = datetime.now() if self.curr_iter == self.total_iter: + self.curr_iter = 0 self.writer.close() break + def _accelerate(self): + if not self.disable_acc: + if torch.cuda.device_count() > 1: + self.rgb_pn = nn.DataParallel(self.rgb_pn) + self.rgb_pn = self.rgb_pn.to(self.device) + + def predict_all( + self, + iter_: int, + dataset_config: DatasetConfiguration, + dataset_selectors: Dict[ + str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]] + ], + dataloader_config: DataloaderConfiguration, + ) -> Dict[str, torch.Tensor]: + self.is_train = False + # Split gallery and probe dataset + gallery_dataloader, probe_dataloaders = self._split_gallery_probe( + dataset_config, dataloader_config + ) + # Get pretrained models at iter_ + checkpoints = self._load_pretrained( + iter_, dataset_config, dataset_selectors + ) + # Init models + hp = self.hp.copy() + hp.pop('lr'), hp.pop('betas') + self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **hp) + # Try to accelerate computation using CUDA or others + self._accelerate() + + self.rgb_pn.eval() + gallery_samples, probe_samples = [], {} + + # Gallery + checkpoint = torch.load(list(checkpoints.values())[0]) + self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + for sample in tqdm(gallery_dataloader, + desc='Transforming gallery', unit='clips'): + label = sample.pop('label').item() + clip = sample.pop('clip').to(self.device) + feature = self.rgb_pn(clip).detach() + gallery_samples.append({ + **{'label': label}, + **sample, + **{'feature': feature} + }) + gallery_samples = default_collate(gallery_samples) + + # Probe + for (condition, dataloader) in probe_dataloaders.items(): + checkpoint = torch.load(checkpoints[condition]) + self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + probe_samples[condition] = [] + for sample in tqdm(dataloader, + desc=f'Transforming probe {condition}', + unit='clips'): + label = sample.pop('label').item() + clip = sample.pop('clip').to(self.device) + feature = self.rgb_pn(clip).detach() + probe_samples[condition].append({ + **{'label': label}, + **sample, + **{'feature': feature} + }) + for (k, v) in probe_samples.items(): + probe_samples[k] = default_collate(v) + + return self._evaluate(gallery_samples, probe_samples) + + def _evaluate( + self, + gallery_samples: Dict[str, Union[List[str], torch.Tensor]], + probe_samples: Dict[str, Dict[str, Union[List[str], torch.Tensor]]], + num_ranks: int = 5 + ) -> Dict[str, torch.Tensor]: + probe_conditions = self._probe_datasets_meta.keys() + gallery_views_meta = self._gallery_dataset_meta['views'] + probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] + accuracy = { + condition: torch.empty( + len(gallery_views_meta), len(probe_views_meta), num_ranks + ) + for condition in self._probe_datasets_meta.keys() + } + + (labels_g, _, views_g, features_g) = gallery_samples.values() + views_g = np.asarray(views_g) + for (v_g_i, view_g) in enumerate(gallery_views_meta): + gallery_view_mask = (views_g == view_g) + f_g = features_g[gallery_view_mask] + y_g = labels_g[gallery_view_mask] + for condition in probe_conditions: + probe_samples_c = probe_samples[condition] + accuracy_c = accuracy[condition] + (labels_p, _, views_p, features_p) = probe_samples_c.values() + views_p = np.asarray(views_p) + for (v_p_i, view_p) in enumerate(probe_views_meta): + probe_view_mask = (views_p == view_p) + f_p = features_p[probe_view_mask] + y_p = labels_p[probe_view_mask] + # Euclidean distance + f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1) + f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0) + f_p_times_f_g_sum = f_p @ f_g.T + dist = torch.sqrt(F.relu( + f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum + )) + # Ranked accuracy + rank_mask = dist.argsort(1)[:, :num_ranks] + positive_mat = torch.eq(y_p.unsqueeze(1), + y_g[rank_mask]).cumsum(1).gt(0) + positive_counts = positive_mat.sum(0) + total_counts, _ = dist.size() + accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts + + return accuracy + + def _load_pretrained( + self, + iter_: int, + dataset_config: DatasetConfiguration, + dataset_selectors: Dict[ + str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]] + ] + ) -> Dict[str, str]: + checkpoints = {} + self.curr_iter = iter_ + for (k, v) in dataset_selectors.items(): + self._dataset_sig = self._make_signature( + dict(**dataset_config, **v), + popped_keys=['root_dir', 'cache_on'] + ) + checkpoints[k] = self._checkpoint_name + return checkpoints + + def _split_gallery_probe( + self, + dataset_config: DatasetConfiguration, + dataloader_config: DataloaderConfiguration, + ) -> Tuple[DataLoader, Dict[str: DataLoader]]: + dataset_name = dataset_config.get('name', 'CASIA-B') + if dataset_name == 'CASIA-B': + gallery_dataset = self._parse_dataset_config( + dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR) + ) + self._gallery_dataset_meta = gallery_dataset.metadata + gallery_dataloader = self._parse_dataloader_config( + gallery_dataset, dataloader_config + ) + probe_datasets = { + condition: self._parse_dataset_config( + dict(**dataset_config, **selector) + ) + for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items() + } + self._probe_datasets_meta = { + condition: dataset.metadata + for (condition, dataset) in probe_datasets.items() + } + probe_dataloaders = { + condtion: self._parse_dataloader_config( + dataset, dataloader_config + ) + for (condtion, dataset) in probe_datasets.items() + } + elif dataset_name == 'FVG': + # TODO + gallery_dataloader = None + probe_dataloaders = None + else: + raise ValueError('Invalid dataset: {0}'.format(dataset_name)) + + return gallery_dataloader, probe_dataloaders + @staticmethod def init_weights(m): if isinstance(m, nn.modules.conv._ConvNd): @@ -165,9 +372,9 @@ class Model: dataset_config, popped_keys=['root_dir', 'cache_on'] ) - self.log_name = '_'.join((self.log_name, self._dataset_sig)) + self._log_name = '_'.join((self._log_name, self._dataset_sig)) config: Dict = dataset_config.copy() - name = config.pop('name') + name = config.pop('name', 'CASIA-B') if name == 'CASIA-B': return CASIAB(**config, is_train=self.is_train) elif name == 'FVG': @@ -181,16 +388,16 @@ class Model: dataloader_config: DataloaderConfiguration ) -> DataLoader: config: Dict = dataloader_config.copy() + (self.pr, self.k) = config.pop('batch_size') if self.is_train: - (self.pr, self.k) = config.pop('batch_size') - self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k))) + self._log_name = '_'.join( + (self._log_name, str(self.pr), str(self.k))) triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, batch_sampler=triplet_sampler, collate_fn=self._batch_splitter, **config) else: # is_test - config.pop('batch_size') return DataLoader(dataset, **config) def _batch_splitter( @@ -225,8 +432,10 @@ class Model: for v in values: if isinstance(v, str): strings.append(v) - elif isinstance(v, (Tuple, List, Set)): + elif isinstance(v, (Tuple, List)): strings.append(self._gen_sig(v)) + elif isinstance(v, Set): + strings.append(self._gen_sig(sorted(list(v)))) elif isinstance(v, Dict): strings.append(self._gen_sig(list(v.values()))) else: |