summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-04-03 21:30:35 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-04-03 23:06:07 +0800
commitf6f133fa7b926ce0c7d28bbf0ba4de41b3708d4a (patch)
tree4bf9b80c1c7a96f081a4e3b3b751145054fccc39 /models/model.py
parentd12dd6b04a4e7c2b1ee43ab6f36f25d0c35ca364 (diff)
parentb9f35fbe7d78b3c478086ea26c2a76f72ce35687 (diff)
Merge branch 'master' into disentangling_only
# Conflicts: # config.py # models/hpm.py # models/layers.py # models/model.py # models/part_net.py # models/rgb_part_net.py # test/part_net.py # utils/configuration.py # utils/triplet_loss.py
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py233
1 files changed, 141 insertions, 92 deletions
diff --git a/models/model.py b/models/model.py
index 3f24936..25c8a4f 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,5 +1,6 @@
+import copy
import os
-from datetime import datetime
+import random
from typing import Union, Optional
import numpy as np
@@ -47,10 +48,10 @@ class Model:
self.meta = model_config
self.hp = hyperparameter_config
- self.curr_iter = self.meta.get('restore_iter', 0)
+ self.restore_iter = self.curr_iter = self.meta.get('restore_iter', 0)
self.total_iter = self.meta.get('total_iter', 80_000)
- self.curr_iters = self.meta.get('restore_iters', (0, 0, 0))
- self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000))
+ self.restore_iters = self.meta.get('restore_iters', (self.curr_iter,))
+ self.total_iters = self.meta.get('total_iters', (self.total_iter,))
self.is_train: bool = True
self.in_channels: int = 3
@@ -70,6 +71,7 @@ class Model:
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
self.image_log_on = system_config.get('image_log_on', False)
+ self.val_size = system_config.get('val_size', 10)
self.CASIAB_GALLERY_SELECTOR = {
'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
@@ -83,7 +85,7 @@ class Model:
@property
def _model_sig(self) -> str:
return '_'.join(
- (self._model_name, str(self.curr_iter), str(self.total_iter))
+ (self._model_name, str(self.curr_iter + 1), str(self.total_iter))
)
@property
@@ -112,18 +114,18 @@ class Model:
],
dataloader_config: DataloaderConfiguration,
):
- for (curr_iter, total_iter, (condition, selector)) in zip(
- self.curr_iters, self.total_iters, dataset_selectors.items()
+ for (restore_iter, total_iter, (condition, selector)) in zip(
+ self.restore_iters, self.total_iters, dataset_selectors.items()
):
print(f'Training model {condition} ...')
# Skip finished model
- if curr_iter == total_iter:
+ if restore_iter == total_iter:
continue
# Check invalid restore iter
- elif curr_iter > total_iter:
+ elif restore_iter > total_iter:
raise ValueError("Restore iter '{}' should less than total "
- "iter '{}'".format(curr_iter, total_iter))
- self.curr_iter = curr_iter
+ "iter '{}'".format(restore_iter, total_iter))
+ self.restore_iter = self.curr_iter = restore_iter
self.total_iter = total_iter
self.fit(
dict(**dataset_config, **{'selector': selector}),
@@ -136,70 +138,85 @@ class Model:
dataloader_config: DataloaderConfiguration,
):
self.is_train = True
- dataset = self._parse_dataset_config(dataset_config)
- dataloader = self._parse_dataloader_config(dataset, dataloader_config)
+ # Validation dataset
+ # (the first `val_size` subjects from evaluation set)
+ val_dataset_config = copy.deepcopy(dataset_config)
+ train_size = dataset_config.get('train_size', 74)
+ val_dataset_config['train_size'] = train_size + self.val_size
+ val_dataset_config['selector']['classes'] = ClipClasses({
+ str(c).zfill(3)
+ for c in range(train_size + 1, train_size + self.val_size + 1)
+ })
+ val_dataset = self._parse_dataset_config(val_dataset_config)
+ val_dataloader = iter(self._parse_dataloader_config(
+ val_dataset, dataloader_config
+ ))
+ # Training dataset
+ train_dataset = self._parse_dataset_config(dataset_config)
+ train_dataloader = iter(self._parse_dataloader_config(
+ train_dataset, dataloader_config
+ ))
# Prepare for model, optimizer and scheduler
- model_hp = self.hp.get('model', {})
+ model_hp: dict = self.hp.get('model', {}).copy()
optim_hp: dict = self.hp.get('optimizer', {}).copy()
sched_hp = self.hp.get('scheduler', {})
+
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
+
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp)
- sched_gamma = sched_hp.get('gamma', 0.9)
- sched_step_size = sched_hp.get('step_size', 500)
- self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
- lambda epoch: sched_gamma ** (epoch // sched_step_size),
- ])
+ start_step = sched_hp.get('start_step', 15_000)
+ final_gamma = sched_hp.get('final_gamma', 0.001)
+ all_step = self.total_iter - start_step
+ self.scheduler = optim.lr_scheduler.LambdaLR(
+ self.optimizer,
+ lambda t: final_gamma ** ((t - start_step) / all_step)
+ if t > start_step else 1,
+ )
self.writer = SummaryWriter(self._log_name)
+ # Set seeds for reproducibility
+ random.seed(0)
+ torch.manual_seed(0)
self.rgb_pn.train()
# Init weights at first iter
if self.curr_iter == 0:
self.rgb_pn.apply(self.init_weights)
else: # Load saved state dicts
+ # Offset a iter to load last checkpoint
+ self.curr_iter -= 1
checkpoint = torch.load(self._checkpoint_name)
- iter_, loss = checkpoint['iter'], checkpoint['loss']
- print('{0:5d} loss: {1:.3f}'.format(iter_, loss))
+ random.setstate(checkpoint['rand_states'][0])
+ torch.set_rng_state(checkpoint['rand_states'][1])
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
self.scheduler.load_state_dict(checkpoint['sched_state_dict'])
# Training start
- start_time = datetime.now()
- running_loss = torch.zeros(3, device=self.device)
- print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
- f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'LR':^9}")
- for (batch_c1, batch_c2) in dataloader:
- self.curr_iter += 1
+ for self.curr_iter in tqdm(range(self.restore_iter, self.total_iter),
+ desc='Training'):
+ batch_c1, batch_c2 = next(train_dataloader)
# Zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- losses, images = self.rgb_pn(x_c1, x_c2)
+ losses, features, images = self.rgb_pn(x_c1, x_c2)
loss = losses.sum()
loss.backward()
self.optimizer.step()
+ self.scheduler.step()
- # Statistics and checkpoint
- running_loss += losses.detach()
- # Write losses to TensorBoard
- self.writer.add_scalar('Loss/all', loss, self.curr_iter)
- self.writer.add_scalars('Loss/details', dict(zip([
- 'Cross reconstruction loss',
- 'Canonical consistency loss',
- 'Pose similarity loss'
- ], losses)), self.curr_iter)
-
- if self.curr_iter % 100 == 0:
- lr = self.scheduler.get_last_lr()[0]
- # Write learning rates
- self.writer.add_scalar(
- 'Learning rate/Auto-encoder', lr, self.curr_iter
- )
+ # Learning rate
+ self.writer.add_scalar(
+ 'Learning rate', self.scheduler.get_last_lr()[0], self.curr_iter
+ )
+ # Other stats
+ self._write_stat('Train', loss, losses)
+
+ if self.curr_iter % 100 == 99:
# Write disentangled images
if self.image_log_on:
i_a, i_c, i_p = images
@@ -216,30 +233,54 @@ class Model:
self.writer.add_images(
f'Pose image/batch {i}', p, self.curr_iter
)
- time_used = datetime.now() - start_time
- remaining_minute, second = divmod(time_used.seconds, 60)
- hour, minute = divmod(remaining_minute, 60)
- print(f'{hour:02}:{minute:02}:{second:02}',
- f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
- '{:f} {:f} {:f}'.format(*running_loss / 100),
- f'{lr:.3e}')
- running_loss.zero_()
-
- # Step scheduler
- self.scheduler.step()
-
- if self.curr_iter % 1000 == 0:
+ f_a, f_c, f_p = features
+ for i, (f_a_i, f_c_i, f_p_i) in enumerate(
+ zip(f_a, f_c, f_p)
+ ):
+ self.writer.add_images(
+ f'Appearance features/Layer {i}',
+ f_a_i[:, :3, :, :], self.curr_iter
+ )
+ self.writer.add_images(
+ f'Canonical features/Layer {i}',
+ f_c_i[:, :3, :, :], self.curr_iter
+ )
+ for j, p in enumerate(f_p_i):
+ self.writer.add_images(
+ f'Pose features/Layer {i}/batch{j}',
+ p[:, :3, :, :], self.curr_iter
+ )
+
+ # Calculate losses on testing batch
+ batch_c1, batch_c2 = next(val_dataloader)
+ x_c1 = batch_c1['clip'].to(self.device)
+ x_c2 = batch_c2['clip'].to(self.device)
+ with torch.no_grad():
+ losses, _, _ = self.rgb_pn(x_c1, x_c2)
+ loss = losses.sum()
+
+ self._write_stat('Val', loss, losses)
+
+ # Checkpoint
+ if self.curr_iter % 1000 == 999:
torch.save({
- 'iter': self.curr_iter,
+ 'rand_states': (random.getstate(), torch.get_rng_state()),
'model_state_dict': self.rgb_pn.state_dict(),
'optim_state_dict': self.optimizer.state_dict(),
'sched_state_dict': self.scheduler.state_dict(),
- 'loss': loss,
}, self._checkpoint_name)
- if self.curr_iter == self.total_iter:
- self.writer.close()
- break
+ self.writer.close()
+
+ def _write_stat(
+ self, postfix, loss, losses
+ ):
+ # Write losses to TensorBoard
+ self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter)
+ self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip((
+ 'Cross reconstruction loss', 'Canonical consistency loss',
+ 'Pose similarity loss'
+ ), losses)), self.curr_iter)
def transform(
self,
@@ -248,12 +289,12 @@ class Model:
dataset_selectors: dict[
str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
],
- dataloader_config: DataloaderConfiguration
+ dataloader_config: DataloaderConfiguration,
+ is_train: bool = False
):
- self.is_train = False
# Split gallery and probe dataset
gallery_dataloader, probe_dataloaders = self._split_gallery_probe(
- dataset_config, dataloader_config
+ dataset_config, dataloader_config, is_train
)
# Get pretrained models at iter_
checkpoints = self._load_pretrained(
@@ -261,41 +302,45 @@ class Model:
)
# Init models
- model_hp = self.hp.get('model', {})
+ model_hp: dict = self.hp.get('model', {}).copy()
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
self.rgb_pn.eval()
- gallery_samples, probe_samples = [], {}
- # Gallery
- checkpoint = torch.load(list(checkpoints.values())[0])
- self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
- for sample in tqdm(gallery_dataloader,
- desc='Transforming gallery', unit='clips'):
- gallery_samples.append(self._get_eval_sample(sample))
- gallery_samples = default_collate(gallery_samples)
- # Probe
- for (condition, dataloader) in probe_dataloaders.items():
+ gallery_samples, probe_samples = {}, {}
+ for (condition, probe_dataloader) in probe_dataloaders.items():
checkpoint = torch.load(checkpoints[condition])
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
+ # Gallery
+ gallery_samples_c = []
+ for sample in tqdm(gallery_dataloader,
+ desc=f'Transforming gallery {condition}',
+ unit='clips'):
+ gallery_samples_c.append(self._get_eval_sample(sample))
+ gallery_samples[condition] = default_collate(gallery_samples_c)
+ # Probe
probe_samples_c = []
- for sample in tqdm(dataloader,
+ for sample in tqdm(probe_dataloader,
desc=f'Transforming probe {condition}',
unit='clips'):
probe_samples_c.append(self._get_eval_sample(sample))
- probe_samples[condition] = default_collate(probe_samples_c)
+ probe_samples_c = default_collate(probe_samples_c)
+ probe_samples_c['meta'] = self._probe_datasets_meta[condition]
+ probe_samples[condition] = probe_samples_c
+ gallery_samples['meta'] = self._gallery_dataset_meta
return gallery_samples, probe_samples
def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]):
- label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
- x_c, x_p = self.rgb_pn(clip).detach()
+ label, condition, view, clip = sample.values()
+ with torch.no_grad():
+ feature_c, feature_p = self.rgb_pn(clip.to(self.device))
return {
- **{'label': label},
- **sample,
- **{'cano_feature': x_c, 'pose_feature': x_p}
+ 'label': label.item(),
+ 'condition': condition[0],
+ 'view': view[0],
+ 'feature': torch.cat((feature_c, feature_p)).view(-1)
}
def _load_pretrained(
@@ -307,10 +352,11 @@ class Model:
]
) -> dict[str, str]:
checkpoints = {}
- for (iter_, (condition, selector)) in zip(
- iters, dataset_selectors.items()
+ for (iter_, total_iter, (condition, selector)) in zip(
+ iters, self.total_iters, dataset_selectors.items()
):
- self.curr_iter = iter_
+ self.curr_iter = iter_ - 1
+ self.total_iter = total_iter
self._dataset_sig = self._make_signature(
dict(**dataset_config, **selector),
popped_keys=['root_dir', 'cache_on']
@@ -322,26 +368,29 @@ class Model:
self,
dataset_config: DatasetConfiguration,
dataloader_config: DataloaderConfiguration,
+ is_train: bool = False
) -> tuple[DataLoader, dict[str, DataLoader]]:
dataset_name = dataset_config.get('name', 'CASIA-B')
if dataset_name == 'CASIA-B':
+ self.is_train = is_train
gallery_dataset = self._parse_dataset_config(
dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR)
)
- self._gallery_dataset_meta = gallery_dataset.metadata
- gallery_dataloader = self._parse_dataloader_config(
- gallery_dataset, dataloader_config
- )
probe_datasets = {
condition: self._parse_dataset_config(
dict(**dataset_config, **selector)
)
for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items()
}
+ self._gallery_dataset_meta = gallery_dataset.metadata
self._probe_datasets_meta = {
condition: dataset.metadata
for (condition, dataset) in probe_datasets.items()
}
+ self.is_train = False
+ gallery_dataloader = self._parse_dataloader_config(
+ gallery_dataset, dataloader_config
+ )
probe_dataloaders = {
condition: self._parse_dataloader_config(
dataset, dataloader_config