From d30cf2cb280e83e4a4abe1e9c2abdbba17d903a3 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sun, 10 Jan 2021 19:54:42 +0800 Subject: Make predict function transform samples different conditions in a single shot --- models/model.py | 142 ++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 102 insertions(+), 40 deletions(-) (limited to 'models') diff --git a/models/model.py b/models/model.py index ba29ede..cabfa97 100644 --- a/models/model.py +++ b/models/model.py @@ -14,7 +14,7 @@ from models.rgb_part_net import RGBPartNet from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration -from utils.dataset import CASIAB +from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler @@ -57,21 +57,30 @@ class Model: self._hp_sig: str = self._make_signature(self.hp) self._dataset_sig: str = 'undefined' self._log_sig: str = '_'.join((self._model_sig, self._hp_sig)) - self.log_name: str = os.path.join(self.log_dir, self._log_sig) + self._log_name: str = os.path.join(self.log_dir, self._log_sig) self.rgb_pn: Optional[RGBPartNet] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None + self.CASIAB_GALLERY_SELECTOR = { + 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})} + } + self.CASIAB_PROBE_SELECTORS = { + 'nm': {'selector': {'conditions': ClipConditions({r'nm-0[5-6]'})}}, + 'bg': {'selector': {'conditions': ClipConditions({r'bg-0[1-2]'})}}, + 'cl': {'selector': {'conditions': ClipConditions({r'cl-0[1-2]'})}}, + } + @property - def signature(self) -> str: + def _signature(self) -> str: return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig, self._dataset_sig, str(self.pr), str(self.k))) @property - def checkpoint_name(self) -> str: - return os.path.join(self.checkpoint_dir, self.signature) + def _checkpoint_name(self) -> str: + return os.path.join(self.checkpoint_dir, self._signature) def fit( self, @@ -87,8 +96,8 @@ class Model: self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp) self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9) - self.writer = SummaryWriter(self.log_name) - + self.writer = SummaryWriter(self._log_name) + # Try to accelerate computation using CUDA or others self._accelerate() self.rgb_pn.train() @@ -96,7 +105,7 @@ class Model: if self.curr_iter == 0: self.rgb_pn.apply(self.init_weights) else: # Load saved state dicts - checkpoint = torch.load(self.checkpoint_name) + checkpoint = torch.load(self._checkpoint_name) iter_, loss = checkpoint['iter'], checkpoint['loss'] print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) @@ -135,7 +144,7 @@ class Model: 'model_state_dict': self.rgb_pn.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), 'loss': loss, - }, self.checkpoint_name) + }, self._checkpoint_name) if self.curr_iter == self.total_iter: self.writer.close() @@ -151,46 +160,98 @@ class Model: self, iter_: int, dataset_config: DatasetConfiguration, + dataset_selectors: dict[ + str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] + ], dataloader_config: DataloaderConfiguration, ): self.is_train = False - dataset = self._parse_dataset_config(dataset_config) - dataloader = self._parse_dataloader_config(dataset, dataloader_config) + # Split gallery and probe dataset + gallery_dataloader, probe_dataloaders = self._split_gallery_probe( + dataset_config, dataloader_config + ) + # Get pretrained models at iter_ + checkpoints = self._load_pretrained( + iter_, dataset_config, dataset_selectors + ) + # Init models hp = self.hp.copy() - _, _ = hp.pop('lr'), hp.pop('betas') + hp.pop('lr'), hp.pop('betas') + self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **hp) + # Try to accelerate computation using CUDA or others + self._accelerate() + + self.rgb_pn.eval() + gallery_samples, probe_samples = [], {} + + # Gallery + self.rgb_pn.load_state_dict(torch.load(list(checkpoints.values())[0])) + for sample in tqdm(gallery_dataloader, + desc='Transforming gallery', unit='clips'): + clip = sample.pop('clip').to(self.device) + feature = self.rgb_pn(clip).detach().cpu() + gallery_samples.append({**sample, **{'feature': feature}}) + gallery_samples = default_collate(gallery_samples) + + # Probe + for (name, dataloader) in probe_dataloaders.items(): + self.rgb_pn.load_state_dict(torch.load(checkpoints[name])) + probe_samples[name] = [] + for sample in tqdm(dataloader, + desc=f'Transforming probe {name}', unit='clips'): + clip = sample.pop('clip').to(self.device) + feature = self.rgb_pn(clip) + probe_samples[name].append({**sample, **{'feature': feature}}) + for (k, v) in probe_samples.items(): + probe_samples[k] = default_collate(v) + + # TODO Implement evaluation function here + + def _load_pretrained( + self, + iter_: int, + dataset_config: DatasetConfiguration, + dataset_selectors: dict[ + str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] + ] + ) -> dict[str, str]: + checkpoints = {} + self.curr_iter = iter_ + for (k, v) in dataset_selectors.items(): + self._dataset_sig = self._make_signature( + dict(**dataset_config, **v), + popped_keys=['root_dir', 'cache_on'] + ) + checkpoints[k] = self._checkpoint_name + return checkpoints + + def _split_gallery_probe( + self, + dataset_config: DatasetConfiguration, + dataloader_config: DataloaderConfiguration, + ) -> tuple[DataLoader, dict[str: DataLoader]]: dataset_name = dataset_config.get('name', 'CASIA-B') if dataset_name == 'CASIA-B': - self.rgb_pn = RGBPartNet(124 - self.train_size, - self.in_channels, - **hp) + gallery_dataset = self._parse_dataset_config( + dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR) + ) + gallery_dataloader = self._parse_dataloader_config( + gallery_dataset, dataloader_config + ) + probe_datasets = {k: self._parse_dataset_config( + dict(**dataset_config, **v) + ) for (k, v) in self.CASIAB_PROBE_SELECTORS.items()} + probe_dataloaders = {k: self._parse_dataloader_config( + v, dataloader_config + ) for (k, v) in probe_datasets.items()} elif dataset_name == 'FVG': # TODO - pass + gallery_dataloader = None + probe_dataloaders = None else: raise ValueError('Invalid dataset: {0}'.format(dataset_name)) - self._accelerate() - - self.rgb_pn.eval() - # Load checkpoint at iter_ - self.curr_iter = iter_ - checkpoint = torch.load(self.checkpoint_name) - self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) - - labels, conditions, views, features = [], [], [], [] - for sample in tqdm(dataloader, desc='Transforming', unit='clips'): - label, condition, view, clip = sample.values() - feature = self.rgb_pn(clip).detach().cpu().numpy() - labels.append(label) - conditions.append(condition) - views.append(view) - features.append(feature) - labels = np.asarray(labels) - conditions = np.asarray(conditions) - views = np.asarray(views) - features = np.asarray(features) - - # TODO Implement evaluation function here + return gallery_dataloader, probe_dataloaders @staticmethod def init_weights(m): @@ -214,7 +275,7 @@ class Model: dataset_config, popped_keys=['root_dir', 'cache_on'] ) - self.log_name = '_'.join((self.log_name, self._dataset_sig)) + self._log_name = '_'.join((self._log_name, self._dataset_sig)) config: dict = dataset_config.copy() name = config.pop('name', 'CASIA-B') if name == 'CASIA-B': @@ -232,7 +293,8 @@ class Model: config: dict = dataloader_config.copy() if self.is_train: (self.pr, self.k) = config.pop('batch_size') - self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k))) + self._log_name = '_'.join( + (self._log_name, str(self.pr), str(self.k))) triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, batch_sampler=triplet_sampler, -- cgit v1.2.3