diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 2 | ||||
-rw-r--r-- | models/model.py | 249 | ||||
-rw-r--r-- | models/rgb_part_net.py | 15 |
3 files changed, 238 insertions, 28 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 1e7c323..64c52e3 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -130,7 +130,7 @@ class AutoEncoder(nn.Module): BasicLinear(f_c_dim, num_class) ) - def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y=None): + def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None, y=None): # x_c1_t2 is the frame for later module (f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2) diff --git a/models/model.py b/models/model.py index 80d4499..1154d7f 100644 --- a/models/model.py +++ b/models/model.py @@ -1,19 +1,22 @@ import os +from datetime import datetime from typing import Union, Optional, Tuple, List, Dict, Set import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm from models.rgb_part_net import RGBPartNet from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration -from utils.dataset import CASIAB +from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler @@ -52,25 +55,52 @@ class Model: self.pr: Optional[int] = None self.k: Optional[int] = None + self._gallery_dataset_meta: Optional[Dict[str, List]] = None + self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None + self._model_sig: str = self._make_signature(self.meta, ['restore_iter']) self._hp_sig: str = self._make_signature(self.hp) self._dataset_sig: str = 'undefined' self._log_sig: str = '_'.join((self._model_sig, self._hp_sig)) - self.log_name: str = os.path.join(self.log_dir, self._log_sig) + self._log_name: str = os.path.join(self.log_dir, self._log_sig) self.rgb_pn: Optional[RGBPartNet] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None + self.CASIAB_GALLERY_SELECTOR = { + 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})} + } + self.CASIAB_PROBE_SELECTORS = { + 'nm': {'selector': {'conditions': ClipConditions({r'nm-0[5-6]'})}}, + 'bg': {'selector': {'conditions': ClipConditions({r'bg-0[1-2]'})}}, + 'cl': {'selector': {'conditions': ClipConditions({r'cl-0[1-2]'})}}, + } + @property - def signature(self) -> str: + def _signature(self) -> str: return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig, self._dataset_sig, str(self.pr), str(self.k))) @property - def checkpoint_name(self) -> str: - return os.path.join(self.checkpoint_dir, self.signature) + def _checkpoint_name(self) -> str: + return os.path.join(self.checkpoint_dir, self._signature) + + def fit_all( + self, + dataset_config: DatasetConfiguration, + dataset_selectors: Dict[ + str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]] + ], + dataloader_config: DataloaderConfiguration, + ): + for (condition, selector) in dataset_selectors.items(): + print(f'Training model {condition} ...') + self.fit( + dict(**dataset_config, **{'selector': selector}), + dataloader_config + ) def fit( self, @@ -86,24 +116,23 @@ class Model: self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp) self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9) - self.writer = SummaryWriter(self.log_name) - - if not self.disable_acc: - if torch.cuda.device_count() > 1: - self.rgb_pn = nn.DataParallel(self.rgb_pn) - self.rgb_pn = self.rgb_pn.to(self.device) + self.writer = SummaryWriter(self._log_name) + # Try to accelerate computation using CUDA or others + self._accelerate() self.rgb_pn.train() # Init weights at first iter if self.curr_iter == 0: self.rgb_pn.apply(self.init_weights) else: # Load saved state dicts - checkpoint = torch.load(self.checkpoint_name) + checkpoint = torch.load(self._checkpoint_name) iter_, loss = checkpoint['iter'], checkpoint['loss'] print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) + # Training start + start_time = datetime.now() for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -126,7 +155,7 @@ class Model: ], metrics)), self.curr_iter) if self.curr_iter % 100 == 0: - print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss), + print('{0:5d} loss: {1:6.3f}'.format(self.curr_iter, loss), '(xrecon = {:f}, pose_sim = {:f},' ' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics), 'lr:', self.scheduler.get_last_lr()[0]) @@ -137,12 +166,190 @@ class Model: 'model_state_dict': self.rgb_pn.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), 'loss': loss, - }, self.checkpoint_name) + }, self._checkpoint_name) + print(datetime.now() - start_time, 'used') + start_time = datetime.now() if self.curr_iter == self.total_iter: + self.curr_iter = 0 self.writer.close() break + def _accelerate(self): + if not self.disable_acc: + if torch.cuda.device_count() > 1: + self.rgb_pn = nn.DataParallel(self.rgb_pn) + self.rgb_pn = self.rgb_pn.to(self.device) + + def predict_all( + self, + iter_: int, + dataset_config: DatasetConfiguration, + dataset_selectors: Dict[ + str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]] + ], + dataloader_config: DataloaderConfiguration, + ) -> Dict[str, torch.Tensor]: + self.is_train = False + # Split gallery and probe dataset + gallery_dataloader, probe_dataloaders = self._split_gallery_probe( + dataset_config, dataloader_config + ) + # Get pretrained models at iter_ + checkpoints = self._load_pretrained( + iter_, dataset_config, dataset_selectors + ) + # Init models + hp = self.hp.copy() + hp.pop('lr'), hp.pop('betas') + self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **hp) + # Try to accelerate computation using CUDA or others + self._accelerate() + + self.rgb_pn.eval() + gallery_samples, probe_samples = [], {} + + # Gallery + checkpoint = torch.load(list(checkpoints.values())[0]) + self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + for sample in tqdm(gallery_dataloader, + desc='Transforming gallery', unit='clips'): + label = sample.pop('label').item() + clip = sample.pop('clip').to(self.device) + feature = self.rgb_pn(clip).detach() + gallery_samples.append({ + **{'label': label}, + **sample, + **{'feature': feature} + }) + gallery_samples = default_collate(gallery_samples) + + # Probe + for (condition, dataloader) in probe_dataloaders.items(): + checkpoint = torch.load(checkpoints[condition]) + self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + probe_samples[condition] = [] + for sample in tqdm(dataloader, + desc=f'Transforming probe {condition}', + unit='clips'): + label = sample.pop('label').item() + clip = sample.pop('clip').to(self.device) + feature = self.rgb_pn(clip).detach() + probe_samples[condition].append({ + **{'label': label}, + **sample, + **{'feature': feature} + }) + for (k, v) in probe_samples.items(): + probe_samples[k] = default_collate(v) + + return self._evaluate(gallery_samples, probe_samples) + + def _evaluate( + self, + gallery_samples: Dict[str, Union[List[str], torch.Tensor]], + probe_samples: Dict[str, Dict[str, Union[List[str], torch.Tensor]]], + num_ranks: int = 5 + ) -> Dict[str, torch.Tensor]: + probe_conditions = self._probe_datasets_meta.keys() + gallery_views_meta = self._gallery_dataset_meta['views'] + probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] + accuracy = { + condition: torch.empty( + len(gallery_views_meta), len(probe_views_meta), num_ranks + ) + for condition in self._probe_datasets_meta.keys() + } + + (labels_g, _, views_g, features_g) = gallery_samples.values() + views_g = np.asarray(views_g) + for (v_g_i, view_g) in enumerate(gallery_views_meta): + gallery_view_mask = (views_g == view_g) + f_g = features_g[gallery_view_mask] + y_g = labels_g[gallery_view_mask] + for condition in probe_conditions: + probe_samples_c = probe_samples[condition] + accuracy_c = accuracy[condition] + (labels_p, _, views_p, features_p) = probe_samples_c.values() + views_p = np.asarray(views_p) + for (v_p_i, view_p) in enumerate(probe_views_meta): + probe_view_mask = (views_p == view_p) + f_p = features_p[probe_view_mask] + y_p = labels_p[probe_view_mask] + # Euclidean distance + f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1) + f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0) + f_p_times_f_g_sum = f_p @ f_g.T + dist = torch.sqrt(F.relu( + f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum + )) + # Ranked accuracy + rank_mask = dist.argsort(1)[:, :num_ranks] + positive_mat = torch.eq(y_p.unsqueeze(1), + y_g[rank_mask]).cumsum(1).gt(0) + positive_counts = positive_mat.sum(0) + total_counts, _ = dist.size() + accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts + + return accuracy + + def _load_pretrained( + self, + iter_: int, + dataset_config: DatasetConfiguration, + dataset_selectors: Dict[ + str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]] + ] + ) -> Dict[str, str]: + checkpoints = {} + self.curr_iter = iter_ + for (k, v) in dataset_selectors.items(): + self._dataset_sig = self._make_signature( + dict(**dataset_config, **v), + popped_keys=['root_dir', 'cache_on'] + ) + checkpoints[k] = self._checkpoint_name + return checkpoints + + def _split_gallery_probe( + self, + dataset_config: DatasetConfiguration, + dataloader_config: DataloaderConfiguration, + ) -> Tuple[DataLoader, Dict[str: DataLoader]]: + dataset_name = dataset_config.get('name', 'CASIA-B') + if dataset_name == 'CASIA-B': + gallery_dataset = self._parse_dataset_config( + dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR) + ) + self._gallery_dataset_meta = gallery_dataset.metadata + gallery_dataloader = self._parse_dataloader_config( + gallery_dataset, dataloader_config + ) + probe_datasets = { + condition: self._parse_dataset_config( + dict(**dataset_config, **selector) + ) + for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items() + } + self._probe_datasets_meta = { + condition: dataset.metadata + for (condition, dataset) in probe_datasets.items() + } + probe_dataloaders = { + condtion: self._parse_dataloader_config( + dataset, dataloader_config + ) + for (condtion, dataset) in probe_datasets.items() + } + elif dataset_name == 'FVG': + # TODO + gallery_dataloader = None + probe_dataloaders = None + else: + raise ValueError('Invalid dataset: {0}'.format(dataset_name)) + + return gallery_dataloader, probe_dataloaders + @staticmethod def init_weights(m): if isinstance(m, nn.modules.conv._ConvNd): @@ -165,9 +372,9 @@ class Model: dataset_config, popped_keys=['root_dir', 'cache_on'] ) - self.log_name = '_'.join((self.log_name, self._dataset_sig)) + self._log_name = '_'.join((self._log_name, self._dataset_sig)) config: Dict = dataset_config.copy() - name = config.pop('name') + name = config.pop('name', 'CASIA-B') if name == 'CASIA-B': return CASIAB(**config, is_train=self.is_train) elif name == 'FVG': @@ -181,16 +388,16 @@ class Model: dataloader_config: DataloaderConfiguration ) -> DataLoader: config: Dict = dataloader_config.copy() + (self.pr, self.k) = config.pop('batch_size') if self.is_train: - (self.pr, self.k) = config.pop('batch_size') - self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k))) + self._log_name = '_'.join( + (self._log_name, str(self.pr), str(self.k))) triplet_sampler = TripletSampler(dataset, (self.pr, self.k)) return DataLoader(dataset, batch_sampler=triplet_sampler, collate_fn=self._batch_splitter, **config) else: # is_test - config.pop('batch_size') return DataLoader(dataset, **config) def _batch_splitter( @@ -225,8 +432,10 @@ class Model: for v in values: if isinstance(v, str): strings.append(v) - elif isinstance(v, (Tuple, List, Set)): + elif isinstance(v, (Tuple, List)): strings.append(self._gen_sig(v)) + elif isinstance(v, Set): + strings.append(self._gen_sig(sorted(list(v)))) elif isinstance(v, Dict): strings.append(self._gen_sig(list(v.values()))) else: diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 39cbed6..95a3f2e 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -52,10 +52,12 @@ class RGBPartNet(nn.Module): def fc(self, x): return x @ self.fc_mat - def forward(self, x_c1, x_c2, y=None): + def forward(self, x_c1, x_c2=None, y=None): # Step 0: Swap batch_size and time dimensions for next step # n, t, c, h, w - x_c1, x_c2 = x_c1.transpose(0, 1), x_c2.transpose(0, 1) + x_c1 = x_c1.transpose(0, 1) + if self.training: + x_c2 = x_c2.transpose(0, 1) # Step 1: Disentanglement # t, n, c, h, w @@ -83,9 +85,9 @@ class RGBPartNet(nn.Module): loss = torch.sum(torch.stack(losses)) return loss, [loss.item() for loss in losses] else: - return x + return x.unsqueeze(1).view(-1) - def _disentangle(self, x_c1, x_c2, y): + def _disentangle(self, x_c1, x_c2=None, y=None): num_frames = len(x_c1) # Decoded canonical features and Pose images x_c_c1, x_p_c1 = [], [] @@ -95,7 +97,7 @@ class RGBPartNet(nn.Module): xrecon_loss, cano_cons_loss = [], [] for t2 in range(num_frames): t1 = random.randrange(num_frames) - output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y) + output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y) (x_c1_t2, f_p_t2, losses) = output # Decoded features or image @@ -128,8 +130,7 @@ class RGBPartNet(nn.Module): else: # evaluating for t2 in range(num_frames): - t1 = random.randrange(num_frames) - x_c1_t2 = self.ae(x_c1[t1], x_c1[t2], x_c2[t2]) + x_c1_t2 = self.ae(x_c1[t2]) # Decoded features or image (x_c_c1_t2, x_p_c1_t2) = x_c1_t2 # Canonical Features for HPM |