diff options
Diffstat (limited to 'models')
-rw-r--r-- | models/model.py | 163 |
1 files changed, 94 insertions, 69 deletions
diff --git a/models/model.py b/models/model.py index af37bb2..aad0a99 100644 --- a/models/model.py +++ b/models/model.py @@ -7,6 +7,9 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim +import torch_xla.core.xla_model as xm +import torch_xla.distributed.parallel_loader as pl +import torch_xla.distributed.xla_multiprocessing as xmp from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate from torch.utils.tensorboard import SummaryWriter @@ -24,16 +27,7 @@ class Model: model_config: Dict, hyperparameter_config: Dict ): - self.disable_acc = system_config['disable_acc'] - if self.disable_acc: - self.device = torch.device('cpu') - else: # Enable accelerator - if torch.cuda.is_available(): - self.device = torch.device('cuda') - else: - print('No accelerator available, fallback to CPU.') - self.device = torch.device('cpu') - + self.nprocs = system_config['nprocs'] self.save_dir = system_config['save_dir'] self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint') self.log_dir = os.path.join(self.save_dir, 'logs') @@ -61,11 +55,6 @@ class Model: self._log_sig: str = '_'.join((self._model_sig, self._hp_sig)) self._log_name: str = os.path.join(self.log_dir, self._log_sig) - self.rgb_pn: Optional[RGBPartNet] = None - self.optimizer: Optional[optim.Adam] = None - self.scheduler: Optional[optim.lr_scheduler.StepLR] = None - self.writer: Optional[SummaryWriter] = None - self.CASIAB_GALLERY_SELECTOR = { 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})} } @@ -104,81 +93,118 @@ class Model: dataset_config: Dict, dataloader_config: Dict, ): + # Only instantiate model weights once in memory. + model_hp = self.hp.get('model', {}) + rgb_pn = RGBPartNet(self.train_size, self.in_channels, **model_hp) + wrapped_rgb_pn = xmp.MpModelWrapper(rgb_pn) + + xmp.spawn( + self._fit_map_fn, + args=(wrapped_rgb_pn, dataset_config, dataloader_config), + nprocs=self.nprocs, + start_method='fork' + ) + + def _fit_map_fn( + self, + rank: int, + wrapped_rgb_pn: xmp.MpModelWrapper, + dataset_config: Dict, + dataloader_config: Dict, + ): self.is_train = True dataset = self._parse_dataset_config(dataset_config) dataloader = self._parse_dataloader_config(dataset, dataloader_config) - # Prepare for model, optimizer and scheduler - model_hp = self.hp.get('model', {}) + # Prepare for optimizer and scheduler optim_hp = self.hp.get('optimizer', {}) + # Scale learning rate to world size + if optim_hp['lr']: + optim_hp['lr'] *= xm.xrt_world_size() sched_hp = self.hp.get('scheduler', {}) - self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **model_hp) - self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp) - self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, **sched_hp) - self.writer = SummaryWriter(self._log_name) - # Try to accelerate computation using CUDA or others - self._accelerate() - - self.rgb_pn.train() + device = xm.xla_device() + rgb_pn = wrapped_rgb_pn.to(device) + optimizer = optim.Adam(rgb_pn.parameters(), **optim_hp) + scheduler = optim.lr_scheduler.StepLR(optimizer, **sched_hp) + writer = SummaryWriter(self._log_name) + + para_loader = pl.ParallelLoader(dataloader, [device]) + self._train_loop( + rank, + para_loader.per_device_loader(device), + rgb_pn, optimizer, scheduler, writer + ) + + def _train_loop( + self, + rank: int, + dataloader: pl.PerDeviceLoader, + rgb_pn: RGBPartNet, + optimizer: optim.Adam, + scheduler: optim.lr_scheduler.StepLR, + writer: SummaryWriter + ): + rgb_pn.train() # Init weights at first iter if self.curr_iter == 0: - self.rgb_pn.apply(self.init_weights) + rgb_pn.apply(self.init_weights) else: # Load saved state dicts checkpoint = torch.load(self._checkpoint_name) iter_, loss = checkpoint['iter'], checkpoint['loss'] print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) - self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) - self.optimizer.load_state_dict(checkpoint['optim_state_dict']) - + rgb_pn.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optim_state_dict']) # Training start start_time = datetime.now() - for (batch_c1, batch_c2) in dataloader: - self.curr_iter += 1 + for (iter_i, (batch_c1, batch_c2)) in enumerate(dataloader): # Zero the parameter gradients - self.optimizer.zero_grad() + optimizer.zero_grad() # forward + backward + optimize - x_c1 = batch_c1['clip'].to(self.device) - x_c2 = batch_c2['clip'].to(self.device) - y = batch_c1['label'].to(self.device) - loss, metrics = self.rgb_pn(x_c1, x_c2, y) + x_c1 = batch_c1['clip'] + x_c2 = batch_c2['clip'] + y = batch_c1['label'] + loss, metrics = rgb_pn(x_c1, x_c2, y) loss.backward() - self.optimizer.step() + xm.optimizer_step(optimizer) # Step scheduler - self.scheduler.step() + scheduler.step() # Write losses to TensorBoard - self.writer.add_scalar('Loss/all', loss.item(), self.curr_iter) - self.writer.add_scalars('Loss/details', dict(zip([ - 'Cross reconstruction loss', 'Pose similarity loss', - 'Canonical consistency loss', 'Batch All triplet loss' - ], metrics)), self.curr_iter) - - if self.curr_iter % 100 == 0: - print('{0:5d} loss: {1:6.3f}'.format(self.curr_iter, loss), + writer.add_scalar( + f'[xla:{rank}]Loss/all', loss.item(), iter_i + 1 + ) + writer.add_scalars( + f'[xla:{rank}]Loss/details', dict(zip([ + 'Cross reconstruction loss', 'Pose similarity loss', + 'Canonical consistency loss', 'Batch All triplet loss' + ], metrics)), + iter_i + 1 + ) + + if iter_i % 100 == 99: + print('[xla:{0}]({1:5d})'.format(rank, iter_i + 1), + 'loss: {:6.3f}'.format(loss), '(xrecon = {:f}, pose_sim = {:f},' ' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics), - 'lr:', self.scheduler.get_last_lr()[0]) + 'lr:', scheduler.get_last_lr()[0]) + xm.master_print(iter_i + 1, 'iteration finished') - if self.curr_iter % 1000 == 0: + if iter_i % 1000 == 999 and xm.is_master_ordinal(): + self.curr_iter = iter_i + 1 torch.save({ 'iter': self.curr_iter, - 'model_state_dict': self.rgb_pn.state_dict(), - 'optim_state_dict': self.optimizer.state_dict(), + 'model_state_dict': rgb_pn.state_dict(), + 'optim_state_dict': optimizer.state_dict(), 'loss': loss, }, self._checkpoint_name) print(datetime.now() - start_time, 'used') start_time = datetime.now() - if self.curr_iter == self.total_iter: - self.curr_iter = 0 - self.writer.close() + if iter_i == self.total_iter - 1: + if xm.is_master_ordinal(): + self.curr_iter = 0 + writer.close() break - def _accelerate(self): - if not self.disable_acc: - if torch.cuda.device_count() > 1: - self.rgb_pn = nn.DataParallel(self.rgb_pn) - self.rgb_pn = self.rgb_pn.to(self.device) - def predict_all( self, iter_: int, @@ -189,6 +215,7 @@ class Model: dataloader_config: Dict, ) -> Dict[str, torch.Tensor]: self.is_train = False + device = xm.xla_device() # Split gallery and probe dataset gallery_dataloader, probe_dataloaders = self._split_gallery_probe( dataset_config, dataloader_config @@ -199,21 +226,19 @@ class Model: ) # Init models model_hp = self.hp.get('model', {}) - self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **model_hp) - # Try to accelerate computation using CUDA or others - self._accelerate() + rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **model_hp) - self.rgb_pn.eval() + rgb_pn.eval() gallery_samples, probe_samples = [], {} # Gallery checkpoint = torch.load(list(checkpoints.values())[0]) - self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + rgb_pn.load_state_dict(checkpoint['model_state_dict']) for sample in tqdm(gallery_dataloader, desc='Transforming gallery', unit='clips'): label = sample.pop('label').item() - clip = sample.pop('clip').to(self.device) - feature = self.rgb_pn(clip).detach() + clip = sample.pop('clip').to(device) + feature = rgb_pn(clip).detach() gallery_samples.append({ **{'label': label}, **sample, @@ -224,14 +249,14 @@ class Model: # Probe for (condition, dataloader) in probe_dataloaders.items(): checkpoint = torch.load(checkpoints[condition]) - self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + rgb_pn.load_state_dict(checkpoint['model_state_dict']) probe_samples[condition] = [] for sample in tqdm(dataloader, desc=f'Transforming probe {condition}', unit='clips'): label = sample.pop('label').item() - clip = sample.pop('clip').to(self.device) - feature = self.rgb_pn(clip).detach() + clip = sample.pop('clip').to(device) + feature = rgb_pn(clip).detach() probe_samples[condition].append({ **{'label': label}, **sample, |