summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/model.py163
1 files changed, 94 insertions, 69 deletions
diff --git a/models/model.py b/models/model.py
index af37bb2..aad0a99 100644
--- a/models/model.py
+++ b/models/model.py
@@ -7,6 +7,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
+import torch_xla.core.xla_model as xm
+import torch_xla.distributed.parallel_loader as pl
+import torch_xla.distributed.xla_multiprocessing as xmp
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
@@ -24,16 +27,7 @@ class Model:
model_config: Dict,
hyperparameter_config: Dict
):
- self.disable_acc = system_config['disable_acc']
- if self.disable_acc:
- self.device = torch.device('cpu')
- else: # Enable accelerator
- if torch.cuda.is_available():
- self.device = torch.device('cuda')
- else:
- print('No accelerator available, fallback to CPU.')
- self.device = torch.device('cpu')
-
+ self.nprocs = system_config['nprocs']
self.save_dir = system_config['save_dir']
self.checkpoint_dir = os.path.join(self.save_dir, 'checkpoint')
self.log_dir = os.path.join(self.save_dir, 'logs')
@@ -61,11 +55,6 @@ class Model:
self._log_sig: str = '_'.join((self._model_sig, self._hp_sig))
self._log_name: str = os.path.join(self.log_dir, self._log_sig)
- self.rgb_pn: Optional[RGBPartNet] = None
- self.optimizer: Optional[optim.Adam] = None
- self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
- self.writer: Optional[SummaryWriter] = None
-
self.CASIAB_GALLERY_SELECTOR = {
'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
}
@@ -104,81 +93,118 @@ class Model:
dataset_config: Dict,
dataloader_config: Dict,
):
+ # Only instantiate model weights once in memory.
+ model_hp = self.hp.get('model', {})
+ rgb_pn = RGBPartNet(self.train_size, self.in_channels, **model_hp)
+ wrapped_rgb_pn = xmp.MpModelWrapper(rgb_pn)
+
+ xmp.spawn(
+ self._fit_map_fn,
+ args=(wrapped_rgb_pn, dataset_config, dataloader_config),
+ nprocs=self.nprocs,
+ start_method='fork'
+ )
+
+ def _fit_map_fn(
+ self,
+ rank: int,
+ wrapped_rgb_pn: xmp.MpModelWrapper,
+ dataset_config: Dict,
+ dataloader_config: Dict,
+ ):
self.is_train = True
dataset = self._parse_dataset_config(dataset_config)
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
- # Prepare for model, optimizer and scheduler
- model_hp = self.hp.get('model', {})
+ # Prepare for optimizer and scheduler
optim_hp = self.hp.get('optimizer', {})
+ # Scale learning rate to world size
+ if optim_hp['lr']:
+ optim_hp['lr'] *= xm.xrt_world_size()
sched_hp = self.hp.get('scheduler', {})
- self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **model_hp)
- self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp)
- self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, **sched_hp)
- self.writer = SummaryWriter(self._log_name)
- # Try to accelerate computation using CUDA or others
- self._accelerate()
-
- self.rgb_pn.train()
+ device = xm.xla_device()
+ rgb_pn = wrapped_rgb_pn.to(device)
+ optimizer = optim.Adam(rgb_pn.parameters(), **optim_hp)
+ scheduler = optim.lr_scheduler.StepLR(optimizer, **sched_hp)
+ writer = SummaryWriter(self._log_name)
+
+ para_loader = pl.ParallelLoader(dataloader, [device])
+ self._train_loop(
+ rank,
+ para_loader.per_device_loader(device),
+ rgb_pn, optimizer, scheduler, writer
+ )
+
+ def _train_loop(
+ self,
+ rank: int,
+ dataloader: pl.PerDeviceLoader,
+ rgb_pn: RGBPartNet,
+ optimizer: optim.Adam,
+ scheduler: optim.lr_scheduler.StepLR,
+ writer: SummaryWriter
+ ):
+ rgb_pn.train()
# Init weights at first iter
if self.curr_iter == 0:
- self.rgb_pn.apply(self.init_weights)
+ rgb_pn.apply(self.init_weights)
else: # Load saved state dicts
checkpoint = torch.load(self._checkpoint_name)
iter_, loss = checkpoint['iter'], checkpoint['loss']
print('{0:5d} loss: {1:.3f}'.format(iter_, loss))
- self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
- self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
-
+ rgb_pn.load_state_dict(checkpoint['model_state_dict'])
+ optimizer.load_state_dict(checkpoint['optim_state_dict'])
# Training start
start_time = datetime.now()
- for (batch_c1, batch_c2) in dataloader:
- self.curr_iter += 1
+ for (iter_i, (batch_c1, batch_c2)) in enumerate(dataloader):
# Zero the parameter gradients
- self.optimizer.zero_grad()
+ optimizer.zero_grad()
# forward + backward + optimize
- x_c1 = batch_c1['clip'].to(self.device)
- x_c2 = batch_c2['clip'].to(self.device)
- y = batch_c1['label'].to(self.device)
- loss, metrics = self.rgb_pn(x_c1, x_c2, y)
+ x_c1 = batch_c1['clip']
+ x_c2 = batch_c2['clip']
+ y = batch_c1['label']
+ loss, metrics = rgb_pn(x_c1, x_c2, y)
loss.backward()
- self.optimizer.step()
+ xm.optimizer_step(optimizer)
# Step scheduler
- self.scheduler.step()
+ scheduler.step()
# Write losses to TensorBoard
- self.writer.add_scalar('Loss/all', loss.item(), self.curr_iter)
- self.writer.add_scalars('Loss/details', dict(zip([
- 'Cross reconstruction loss', 'Pose similarity loss',
- 'Canonical consistency loss', 'Batch All triplet loss'
- ], metrics)), self.curr_iter)
-
- if self.curr_iter % 100 == 0:
- print('{0:5d} loss: {1:6.3f}'.format(self.curr_iter, loss),
+ writer.add_scalar(
+ f'[xla:{rank}]Loss/all', loss.item(), iter_i + 1
+ )
+ writer.add_scalars(
+ f'[xla:{rank}]Loss/details', dict(zip([
+ 'Cross reconstruction loss', 'Pose similarity loss',
+ 'Canonical consistency loss', 'Batch All triplet loss'
+ ], metrics)),
+ iter_i + 1
+ )
+
+ if iter_i % 100 == 99:
+ print('[xla:{0}]({1:5d})'.format(rank, iter_i + 1),
+ 'loss: {:6.3f}'.format(loss),
'(xrecon = {:f}, pose_sim = {:f},'
' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics),
- 'lr:', self.scheduler.get_last_lr()[0])
+ 'lr:', scheduler.get_last_lr()[0])
+ xm.master_print(iter_i + 1, 'iteration finished')
- if self.curr_iter % 1000 == 0:
+ if iter_i % 1000 == 999 and xm.is_master_ordinal():
+ self.curr_iter = iter_i + 1
torch.save({
'iter': self.curr_iter,
- 'model_state_dict': self.rgb_pn.state_dict(),
- 'optim_state_dict': self.optimizer.state_dict(),
+ 'model_state_dict': rgb_pn.state_dict(),
+ 'optim_state_dict': optimizer.state_dict(),
'loss': loss,
}, self._checkpoint_name)
print(datetime.now() - start_time, 'used')
start_time = datetime.now()
- if self.curr_iter == self.total_iter:
- self.curr_iter = 0
- self.writer.close()
+ if iter_i == self.total_iter - 1:
+ if xm.is_master_ordinal():
+ self.curr_iter = 0
+ writer.close()
break
- def _accelerate(self):
- if not self.disable_acc:
- if torch.cuda.device_count() > 1:
- self.rgb_pn = nn.DataParallel(self.rgb_pn)
- self.rgb_pn = self.rgb_pn.to(self.device)
-
def predict_all(
self,
iter_: int,
@@ -189,6 +215,7 @@ class Model:
dataloader_config: Dict,
) -> Dict[str, torch.Tensor]:
self.is_train = False
+ device = xm.xla_device()
# Split gallery and probe dataset
gallery_dataloader, probe_dataloaders = self._split_gallery_probe(
dataset_config, dataloader_config
@@ -199,21 +226,19 @@ class Model:
)
# Init models
model_hp = self.hp.get('model', {})
- self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **model_hp)
- # Try to accelerate computation using CUDA or others
- self._accelerate()
+ rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **model_hp)
- self.rgb_pn.eval()
+ rgb_pn.eval()
gallery_samples, probe_samples = [], {}
# Gallery
checkpoint = torch.load(list(checkpoints.values())[0])
- self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
+ rgb_pn.load_state_dict(checkpoint['model_state_dict'])
for sample in tqdm(gallery_dataloader,
desc='Transforming gallery', unit='clips'):
label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
- feature = self.rgb_pn(clip).detach()
+ clip = sample.pop('clip').to(device)
+ feature = rgb_pn(clip).detach()
gallery_samples.append({
**{'label': label},
**sample,
@@ -224,14 +249,14 @@ class Model:
# Probe
for (condition, dataloader) in probe_dataloaders.items():
checkpoint = torch.load(checkpoints[condition])
- self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
+ rgb_pn.load_state_dict(checkpoint['model_state_dict'])
probe_samples[condition] = []
for sample in tqdm(dataloader,
desc=f'Transforming probe {condition}',
unit='clips'):
label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
- feature = self.rgb_pn(clip).detach()
+ clip = sample.pop('clip').to(device)
+ feature = rgb_pn(clip).detach()
probe_samples[condition].append({
**{'label': label},
**sample,