diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-03 21:30:35 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-04-03 23:06:07 +0800 |
commit | f6f133fa7b926ce0c7d28bbf0ba4de41b3708d4a (patch) | |
tree | 4bf9b80c1c7a96f081a4e3b3b751145054fccc39 /models | |
parent | d12dd6b04a4e7c2b1ee43ab6f36f25d0c35ca364 (diff) | |
parent | b9f35fbe7d78b3c478086ea26c2a76f72ce35687 (diff) |
Merge branch 'master' into disentangling_only
# Conflicts:
# config.py
# models/hpm.py
# models/layers.py
# models/model.py
# models/part_net.py
# models/rgb_part_net.py
# test/part_net.py
# utils/configuration.py
# utils/triplet_loss.py
Diffstat (limited to 'models')
-rw-r--r-- | models/auto_encoder.py | 4 | ||||
-rw-r--r-- | models/layers.py | 4 | ||||
-rw-r--r-- | models/model.py | 233 | ||||
-rw-r--r-- | models/rgb_part_net.py | 98 |
4 files changed, 192 insertions, 147 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py index 2d715db..91071dd 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -106,15 +106,13 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose, cano_only=False): + def forward(self, f_appearance, f_canonical, f_pose): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0) x = F.relu(x, inplace=True) x = self.trans_conv1(x) x = self.trans_conv2(x) - if cano_only: - return x x = self.trans_conv3(x) x = torch.sigmoid(self.trans_conv4(x)) diff --git a/models/layers.py b/models/layers.py index 1b4640f..5306769 100644 --- a/models/layers.py +++ b/models/layers.py @@ -79,7 +79,9 @@ class DCGANConvTranspose2d(BasicConvTranspose2d): if self.is_last_layer: return self.trans_conv(x) else: - return super().forward(x) + x = self.trans_conv(x) + x = self.bn(x) + return F.leaky_relu(x, 0.2, inplace=True) class BasicLinear(nn.Module): diff --git a/models/model.py b/models/model.py index 3f24936..25c8a4f 100644 --- a/models/model.py +++ b/models/model.py @@ -1,5 +1,6 @@ +import copy import os -from datetime import datetime +import random from typing import Union, Optional import numpy as np @@ -47,10 +48,10 @@ class Model: self.meta = model_config self.hp = hyperparameter_config - self.curr_iter = self.meta.get('restore_iter', 0) + self.restore_iter = self.curr_iter = self.meta.get('restore_iter', 0) self.total_iter = self.meta.get('total_iter', 80_000) - self.curr_iters = self.meta.get('restore_iters', (0, 0, 0)) - self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000)) + self.restore_iters = self.meta.get('restore_iters', (self.curr_iter,)) + self.total_iters = self.meta.get('total_iters', (self.total_iter,)) self.is_train: bool = True self.in_channels: int = 3 @@ -70,6 +71,7 @@ class Model: self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None self.image_log_on = system_config.get('image_log_on', False) + self.val_size = system_config.get('val_size', 10) self.CASIAB_GALLERY_SELECTOR = { 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})} @@ -83,7 +85,7 @@ class Model: @property def _model_sig(self) -> str: return '_'.join( - (self._model_name, str(self.curr_iter), str(self.total_iter)) + (self._model_name, str(self.curr_iter + 1), str(self.total_iter)) ) @property @@ -112,18 +114,18 @@ class Model: ], dataloader_config: DataloaderConfiguration, ): - for (curr_iter, total_iter, (condition, selector)) in zip( - self.curr_iters, self.total_iters, dataset_selectors.items() + for (restore_iter, total_iter, (condition, selector)) in zip( + self.restore_iters, self.total_iters, dataset_selectors.items() ): print(f'Training model {condition} ...') # Skip finished model - if curr_iter == total_iter: + if restore_iter == total_iter: continue # Check invalid restore iter - elif curr_iter > total_iter: + elif restore_iter > total_iter: raise ValueError("Restore iter '{}' should less than total " - "iter '{}'".format(curr_iter, total_iter)) - self.curr_iter = curr_iter + "iter '{}'".format(restore_iter, total_iter)) + self.restore_iter = self.curr_iter = restore_iter self.total_iter = total_iter self.fit( dict(**dataset_config, **{'selector': selector}), @@ -136,70 +138,85 @@ class Model: dataloader_config: DataloaderConfiguration, ): self.is_train = True - dataset = self._parse_dataset_config(dataset_config) - dataloader = self._parse_dataloader_config(dataset, dataloader_config) + # Validation dataset + # (the first `val_size` subjects from evaluation set) + val_dataset_config = copy.deepcopy(dataset_config) + train_size = dataset_config.get('train_size', 74) + val_dataset_config['train_size'] = train_size + self.val_size + val_dataset_config['selector']['classes'] = ClipClasses({ + str(c).zfill(3) + for c in range(train_size + 1, train_size + self.val_size + 1) + }) + val_dataset = self._parse_dataset_config(val_dataset_config) + val_dataloader = iter(self._parse_dataloader_config( + val_dataset, dataloader_config + )) + # Training dataset + train_dataset = self._parse_dataset_config(dataset_config) + train_dataloader = iter(self._parse_dataloader_config( + train_dataset, dataloader_config + )) # Prepare for model, optimizer and scheduler - model_hp = self.hp.get('model', {}) + model_hp: dict = self.hp.get('model', {}).copy() optim_hp: dict = self.hp.get('optimizer', {}).copy() sched_hp = self.hp.get('scheduler', {}) + self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) + # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp) - sched_gamma = sched_hp.get('gamma', 0.9) - sched_step_size = sched_hp.get('step_size', 500) - self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ - lambda epoch: sched_gamma ** (epoch // sched_step_size), - ]) + start_step = sched_hp.get('start_step', 15_000) + final_gamma = sched_hp.get('final_gamma', 0.001) + all_step = self.total_iter - start_step + self.scheduler = optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda t: final_gamma ** ((t - start_step) / all_step) + if t > start_step else 1, + ) self.writer = SummaryWriter(self._log_name) + # Set seeds for reproducibility + random.seed(0) + torch.manual_seed(0) self.rgb_pn.train() # Init weights at first iter if self.curr_iter == 0: self.rgb_pn.apply(self.init_weights) else: # Load saved state dicts + # Offset a iter to load last checkpoint + self.curr_iter -= 1 checkpoint = torch.load(self._checkpoint_name) - iter_, loss = checkpoint['iter'], checkpoint['loss'] - print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) + random.setstate(checkpoint['rand_states'][0]) + torch.set_rng_state(checkpoint['rand_states'][1]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) self.scheduler.load_state_dict(checkpoint['sched_state_dict']) # Training start - start_time = datetime.now() - running_loss = torch.zeros(3, device=self.device) - print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", - f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'LR':^9}") - for (batch_c1, batch_c2) in dataloader: - self.curr_iter += 1 + for self.curr_iter in tqdm(range(self.restore_iter, self.total_iter), + desc='Training'): + batch_c1, batch_c2 = next(train_dataloader) # Zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - losses, images = self.rgb_pn(x_c1, x_c2) + losses, features, images = self.rgb_pn(x_c1, x_c2) loss = losses.sum() loss.backward() self.optimizer.step() + self.scheduler.step() - # Statistics and checkpoint - running_loss += losses.detach() - # Write losses to TensorBoard - self.writer.add_scalar('Loss/all', loss, self.curr_iter) - self.writer.add_scalars('Loss/details', dict(zip([ - 'Cross reconstruction loss', - 'Canonical consistency loss', - 'Pose similarity loss' - ], losses)), self.curr_iter) - - if self.curr_iter % 100 == 0: - lr = self.scheduler.get_last_lr()[0] - # Write learning rates - self.writer.add_scalar( - 'Learning rate/Auto-encoder', lr, self.curr_iter - ) + # Learning rate + self.writer.add_scalar( + 'Learning rate', self.scheduler.get_last_lr()[0], self.curr_iter + ) + # Other stats + self._write_stat('Train', loss, losses) + + if self.curr_iter % 100 == 99: # Write disentangled images if self.image_log_on: i_a, i_c, i_p = images @@ -216,30 +233,54 @@ class Model: self.writer.add_images( f'Pose image/batch {i}', p, self.curr_iter ) - time_used = datetime.now() - start_time - remaining_minute, second = divmod(time_used.seconds, 60) - hour, minute = divmod(remaining_minute, 60) - print(f'{hour:02}:{minute:02}:{second:02}', - f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', - '{:f} {:f} {:f}'.format(*running_loss / 100), - f'{lr:.3e}') - running_loss.zero_() - - # Step scheduler - self.scheduler.step() - - if self.curr_iter % 1000 == 0: + f_a, f_c, f_p = features + for i, (f_a_i, f_c_i, f_p_i) in enumerate( + zip(f_a, f_c, f_p) + ): + self.writer.add_images( + f'Appearance features/Layer {i}', + f_a_i[:, :3, :, :], self.curr_iter + ) + self.writer.add_images( + f'Canonical features/Layer {i}', + f_c_i[:, :3, :, :], self.curr_iter + ) + for j, p in enumerate(f_p_i): + self.writer.add_images( + f'Pose features/Layer {i}/batch{j}', + p[:, :3, :, :], self.curr_iter + ) + + # Calculate losses on testing batch + batch_c1, batch_c2 = next(val_dataloader) + x_c1 = batch_c1['clip'].to(self.device) + x_c2 = batch_c2['clip'].to(self.device) + with torch.no_grad(): + losses, _, _ = self.rgb_pn(x_c1, x_c2) + loss = losses.sum() + + self._write_stat('Val', loss, losses) + + # Checkpoint + if self.curr_iter % 1000 == 999: torch.save({ - 'iter': self.curr_iter, + 'rand_states': (random.getstate(), torch.get_rng_state()), 'model_state_dict': self.rgb_pn.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), 'sched_state_dict': self.scheduler.state_dict(), - 'loss': loss, }, self._checkpoint_name) - if self.curr_iter == self.total_iter: - self.writer.close() - break + self.writer.close() + + def _write_stat( + self, postfix, loss, losses + ): + # Write losses to TensorBoard + self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter) + self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip(( + 'Cross reconstruction loss', 'Canonical consistency loss', + 'Pose similarity loss' + ), losses)), self.curr_iter) def transform( self, @@ -248,12 +289,12 @@ class Model: dataset_selectors: dict[ str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]] ], - dataloader_config: DataloaderConfiguration + dataloader_config: DataloaderConfiguration, + is_train: bool = False ): - self.is_train = False # Split gallery and probe dataset gallery_dataloader, probe_dataloaders = self._split_gallery_probe( - dataset_config, dataloader_config + dataset_config, dataloader_config, is_train ) # Get pretrained models at iter_ checkpoints = self._load_pretrained( @@ -261,41 +302,45 @@ class Model: ) # Init models - model_hp = self.hp.get('model', {}) + model_hp: dict = self.hp.get('model', {}).copy() self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.rgb_pn.eval() - gallery_samples, probe_samples = [], {} - # Gallery - checkpoint = torch.load(list(checkpoints.values())[0]) - self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) - for sample in tqdm(gallery_dataloader, - desc='Transforming gallery', unit='clips'): - gallery_samples.append(self._get_eval_sample(sample)) - gallery_samples = default_collate(gallery_samples) - # Probe - for (condition, dataloader) in probe_dataloaders.items(): + gallery_samples, probe_samples = {}, {} + for (condition, probe_dataloader) in probe_dataloaders.items(): checkpoint = torch.load(checkpoints[condition]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + # Gallery + gallery_samples_c = [] + for sample in tqdm(gallery_dataloader, + desc=f'Transforming gallery {condition}', + unit='clips'): + gallery_samples_c.append(self._get_eval_sample(sample)) + gallery_samples[condition] = default_collate(gallery_samples_c) + # Probe probe_samples_c = [] - for sample in tqdm(dataloader, + for sample in tqdm(probe_dataloader, desc=f'Transforming probe {condition}', unit='clips'): probe_samples_c.append(self._get_eval_sample(sample)) - probe_samples[condition] = default_collate(probe_samples_c) + probe_samples_c = default_collate(probe_samples_c) + probe_samples_c['meta'] = self._probe_datasets_meta[condition] + probe_samples[condition] = probe_samples_c + gallery_samples['meta'] = self._gallery_dataset_meta return gallery_samples, probe_samples def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]): - label = sample.pop('label').item() - clip = sample.pop('clip').to(self.device) - x_c, x_p = self.rgb_pn(clip).detach() + label, condition, view, clip = sample.values() + with torch.no_grad(): + feature_c, feature_p = self.rgb_pn(clip.to(self.device)) return { - **{'label': label}, - **sample, - **{'cano_feature': x_c, 'pose_feature': x_p} + 'label': label.item(), + 'condition': condition[0], + 'view': view[0], + 'feature': torch.cat((feature_c, feature_p)).view(-1) } def _load_pretrained( @@ -307,10 +352,11 @@ class Model: ] ) -> dict[str, str]: checkpoints = {} - for (iter_, (condition, selector)) in zip( - iters, dataset_selectors.items() + for (iter_, total_iter, (condition, selector)) in zip( + iters, self.total_iters, dataset_selectors.items() ): - self.curr_iter = iter_ + self.curr_iter = iter_ - 1 + self.total_iter = total_iter self._dataset_sig = self._make_signature( dict(**dataset_config, **selector), popped_keys=['root_dir', 'cache_on'] @@ -322,26 +368,29 @@ class Model: self, dataset_config: DatasetConfiguration, dataloader_config: DataloaderConfiguration, + is_train: bool = False ) -> tuple[DataLoader, dict[str, DataLoader]]: dataset_name = dataset_config.get('name', 'CASIA-B') if dataset_name == 'CASIA-B': + self.is_train = is_train gallery_dataset = self._parse_dataset_config( dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR) ) - self._gallery_dataset_meta = gallery_dataset.metadata - gallery_dataloader = self._parse_dataloader_config( - gallery_dataset, dataloader_config - ) probe_datasets = { condition: self._parse_dataset_config( dict(**dataset_config, **selector) ) for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items() } + self._gallery_dataset_meta = gallery_dataset.metadata self._probe_datasets_meta = { condition: dataset.metadata for (condition, dataset) in probe_datasets.items() } + self.is_train = False + gallery_dataloader = self._parse_dataloader_config( + gallery_dataset, dataloader_config + ) probe_dataloaders = { condition: self._parse_dataloader_config( dataset, dataloader_config diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index f18d675..d3f8ade 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F from models.auto_encoder import AutoEncoder @@ -14,6 +15,7 @@ class RGBPartNet(nn.Module): image_log_on: bool = False ): super().__init__() + self.h, self.w = ae_in_size (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims self.image_log_on = image_log_on @@ -22,70 +24,64 @@ class RGBPartNet(nn.Module): ) def forward(self, x_c1, x_c2=None): - # Step 1: Disentanglement - # n, t, c, h, w - ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2) + losses, features, images = self._disentangle(x_c1, x_c2) if self.training: losses = torch.stack(losses) - return losses, images + return losses, features, images else: - return x_c, x_p + return features def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() - device = x_c1_t2.device - x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] if self.training: + x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) - # Decode features - with torch.no_grad(): - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + f_a = f_a_.view(n, t, -1) + f_c = f_c_.view(n, t, -1) + f_p = f_p_.view(n, t, -1) - i_a, i_c, i_p = None, None, None - if self.image_log_on: - i_a = self._decode_appr_feature(f_a_, n, t, device) - # Continue decoding canonical features - i_c = self.ae.decoder.trans_conv3(x_c) - i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c)) - i_p = x_p + i_a, i_c, i_p = None, None, None + if self.image_log_on: + with torch.no_grad(): + x_a, i_a = self._separate_decode( + f_a.mean(1), + torch.zeros_like(f_c[:, 0, :]), + torch.zeros_like(f_p[:, 0, :]) + ) + x_c, i_c = self._separate_decode( + torch.zeros_like(f_a[:, 0, :]), + f_c.mean(1), + torch.zeros_like(f_p[:, 0, :]), + ) + x_p_, i_p_ = self._separate_decode( + torch.zeros_like(f_a_), + torch.zeros_like(f_c_), + f_p_ + ) + x_p = tuple(_x_p.view(n, t, *_x_p.size()[1:]) for _x_p in x_p_) + i_p = i_p_.view(n, t, c, h, w) - return (x_c, x_p), losses, (i_a, i_c, i_p) + return losses, (x_a, x_c, x_p), (i_a, i_c, i_p) else: # evaluating f_c_, f_p_ = self.ae(x_c1_t2) - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device) - return (x_c, x_p), None, None + f_c = f_c_.view(n, t, -1) + f_p = f_p_.view(n, t, -1) + return (f_c, f_p), None, None - def _decode_appr_feature(self, f_a_, n, t, device): - # Decode appearance features - f_a = f_a_.view(n, t, -1) - x_a = self.ae.decoder( - f_a.mean(1), - torch.zeros((n, self.f_c_dim), device=device), - torch.zeros((n, self.f_p_dim), device=device) + def _separate_decode(self, f_a, f_c, f_p): + x_1 = torch.cat((f_a, f_c, f_p), dim=1) + x_1 = self.ae.decoder.fc(x_1).view( + -1, + self.ae.decoder.feature_channels * 8, + self.ae.decoder.h_0, + self.ae.decoder.w_0 ) - return x_a - - def _decode_cano_feature(self, f_c_, n, t, device): - # Decode average canonical features to higher dimension - f_c = f_c_.view(n, t, -1) - x_c = self.ae.decoder( - torch.zeros((n, self.f_a_dim), device=device), - f_c.mean(1), - torch.zeros((n, self.f_p_dim), device=device), - cano_only=True - ) - return x_c - - def _decode_pose_feature(self, f_p_, n, t, c, h, w, device): - # Decode pose features to images - x_p_ = self.ae.decoder( - torch.zeros((n * t, self.f_a_dim), device=device), - torch.zeros((n * t, self.f_c_dim), device=device), - f_p_ - ) - x_p = x_p_.view(n, t, c, h, w) - return x_p + x_1 = F.relu(x_1, inplace=True) + x_2 = self.ae.decoder.trans_conv1(x_1) + x_3 = self.ae.decoder.trans_conv2(x_2) + x_4 = self.ae.decoder.trans_conv3(x_3) + image = torch.sigmoid(self.ae.decoder.trans_conv4(x_4)) + x = (x_1, x_2, x_3, x_4) + return x, image |