summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py249
1 files changed, 154 insertions, 95 deletions
diff --git a/models/model.py b/models/model.py
index 573b1e6..80cec06 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,6 +1,6 @@
+import copy
import os
import random
-from datetime import datetime
from typing import Union, Optional, Tuple, List, Dict, Set
import numpy as np
@@ -52,9 +52,9 @@ class Model:
self.meta = model_config
self.hp = hyperparameter_config
- self.curr_iter = self.meta.get('restore_iter', 0)
+ self.restore_iter = self.curr_iter = self.meta.get('restore_iter', 0)
self.total_iter = self.meta.get('total_iter', 80_000)
- self.curr_iters = self.meta.get('restore_iters', (self.curr_iter,))
+ self.restore_iters = self.meta.get('restore_iters', (self.curr_iter,))
self.total_iters = self.meta.get('total_iters', (self.total_iter,))
self.is_train: bool = True
@@ -62,6 +62,8 @@ class Model:
self.in_size: Tuple[int, int] = (64, 48)
self.pr: Optional[int] = None
self.k: Optional[int] = None
+ self.num_pairs: Optional[int] = None
+ self.num_pos_pairs: Optional[int] = None
self._gallery_dataset_meta: Optional[Dict[str, List]] = None
self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None
@@ -90,7 +92,7 @@ class Model:
@property
def _model_sig(self) -> str:
return '_'.join(
- (self._model_name, str(self.curr_iter), str(self.total_iter))
+ (self._model_name, str(self.curr_iter + 1), str(self.total_iter))
)
@property
@@ -119,18 +121,18 @@ class Model:
],
dataloader_config: DataloaderConfiguration,
):
- for (curr_iter, total_iter, (condition, selector)) in zip(
- self.curr_iters, self.total_iters, dataset_selectors.items()
+ for (restore_iter, total_iter, (condition, selector)) in zip(
+ self.restore_iters, self.total_iters, dataset_selectors.items()
):
print(f'Training model {condition} ...')
# Skip finished model
- if curr_iter == total_iter:
+ if restore_iter == total_iter:
continue
# Check invalid restore iter
- elif curr_iter > total_iter:
+ elif restore_iter > total_iter:
raise ValueError("Restore iter '{}' should less than total "
- "iter '{}'".format(curr_iter, total_iter))
- self.curr_iter = curr_iter
+ "iter '{}'".format(restore_iter, total_iter))
+ self.restore_iter = self.curr_iter = restore_iter
self.total_iter = total_iter
self.fit(
dict(**dataset_config, **{'selector': selector}),
@@ -143,8 +145,24 @@ class Model:
dataloader_config: DataloaderConfiguration,
):
self.is_train = True
- dataset = self._parse_dataset_config(dataset_config)
- dataloader = self._parse_dataloader_config(dataset, dataloader_config)
+ # Validation dataset
+ # (the first `val_size` subjects from evaluation set)
+ val_size = dataset_config.pop('val_size', 10)
+ val_dataset_config = copy.deepcopy(dataset_config)
+ train_size = dataset_config.get('train_size', 74)
+ val_dataset_config['train_size'] = train_size + val_size
+ val_dataset_config['selector']['classes'] = ClipClasses({
+ str(c).zfill(3) for c in range(train_size, train_size + val_size)
+ })
+ val_dataset = self._parse_dataset_config(val_dataset_config)
+ val_dataloader = iter(self._parse_dataloader_config(
+ val_dataset, dataloader_config
+ ))
+ # Training dataset
+ train_dataset = self._parse_dataset_config(dataset_config)
+ train_dataloader = iter(self._parse_dataloader_config(
+ train_dataset, dataloader_config
+ ))
# Prepare for model, optimizer and scheduler
model_hp: Dict = self.hp.get('model', {}).copy()
triplet_is_hard = model_hp.pop('triplet_is_hard', True)
@@ -177,8 +195,8 @@ class Model:
triplet_is_hard, triplet_is_mean, None
)
- num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
- num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
+ self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
+ self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
@@ -191,6 +209,7 @@ class Model:
{'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
], **optim_hp)
+ # Scheduler
start_step = sched_hp.get('start_step', 15_000)
final_gamma = sched_hp.get('final_gamma', 0.001)
ae_start_step = ae_sched_hp.get('start_step', start_step)
@@ -221,6 +240,8 @@ class Model:
if self.curr_iter == 0:
self.rgb_pn.apply(self.init_weights)
else: # Load saved state dicts
+ # Offset a iter to load last checkpoint
+ self.curr_iter -= 1
checkpoint = torch.load(self._checkpoint_name)
random.setstate(checkpoint['rand_states'][0])
torch.set_rng_state(checkpoint['rand_states'][1])
@@ -229,13 +250,9 @@ class Model:
self.scheduler.load_state_dict(checkpoint['sched_state_dict'])
# Training start
- start_time = datetime.now()
- running_loss = torch.zeros(5, device=self.device)
- print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
- f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}")
- for (batch_c1, batch_c2) in dataloader:
- self.curr_iter += 1
+ for self.curr_iter in tqdm(range(self.restore_iter, self.total_iter),
+ desc='Training'):
+ batch_c1, batch_c2 = next(train_dataloader)
# Zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
@@ -243,72 +260,24 @@ class Model:
x_c2 = batch_c2['clip'].to(self.device)
embed_c, embed_p, ae_losses, images = self.rgb_pn(x_c1, x_c2)
y = batch_c1['label'].to(self.device)
- # Duplicate labels for each part
- y = y.repeat(self.rgb_pn.num_parts, 1)
- trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm(
- embed_c, y[:self.rgb_pn.hpm.num_parts]
+ losses, hpm_result, pn_result = self._classification_loss(
+ embed_c, embed_p, ae_losses, y
)
- trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn(
- embed_p, y[self.rgb_pn.hpm.num_parts:]
- )
- losses = torch.stack((
- *ae_losses,
- trip_loss_hpm.mean(),
- trip_loss_pn.mean()
- ))
loss = losses.sum()
loss.backward()
self.optimizer.step()
+ self.scheduler.step()
- # Statistics and checkpoint
- running_loss += losses.detach()
- # Write losses to TensorBoard
- self.writer.add_scalar('Loss/all', loss, self.curr_iter)
- self.writer.add_scalars('Loss/disentanglement', dict(zip((
- 'Cross reconstruction loss', 'Canonical consistency loss',
- 'Pose similarity loss'
- ), ae_losses)), self.curr_iter)
- self.writer.add_scalars('Loss/triplet loss', {
- 'HPM': losses[3],
- 'PartNet': losses[4]
- }, self.curr_iter)
- # None-zero losses in batch
- if hpm_num_non_zero is not None and pn_num_non_zero is not None:
- self.writer.add_scalars('Loss/non-zero counts', {
- 'HPM': hpm_num_non_zero.mean(),
- 'PartNet': pn_num_non_zero.mean()
- }, self.curr_iter)
- # Embedding distance
- mean_hpm_dist = hpm_dist.mean(0)
- self._add_ranked_scalars(
- 'Embedding/HPM distance', mean_hpm_dist,
- num_pos_pairs, num_pairs, self.curr_iter
- )
- mean_pa_dist = pn_dist.mean(0)
- self._add_ranked_scalars(
- 'Embedding/ParNet distance', mean_pa_dist,
- num_pos_pairs, num_pairs, self.curr_iter
- )
- # Embedding norm
- mean_hpm_embedding = embed_c.mean(0)
- mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
- self._add_ranked_scalars(
- 'Embedding/HPM norm', mean_hpm_norm,
- self.k, self.pr * self.k, self.curr_iter
- )
- mean_pa_embedding = embed_p.mean(0)
- mean_pa_norm = mean_pa_embedding.norm(dim=-1)
- self._add_ranked_scalars(
- 'Embedding/PartNet norm', mean_pa_norm,
- self.k, self.pr * self.k, self.curr_iter
- )
# Learning rate
- lrs = self.scheduler.get_last_lr()
self.writer.add_scalars('Learning rate', dict(zip((
'Auto-encoder', 'HPM', 'PartNet'
- ), lrs)), self.curr_iter)
+ ), self.scheduler.get_last_lr())), self.curr_iter)
+ # Other stats
+ self._write_stat(
+ 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses
+ )
- if self.curr_iter % 100 == 0:
+ if self.curr_iter % 100 == 99:
# Write disentangled images
if self.image_log_on:
i_a, i_c, i_p = images
@@ -325,19 +294,35 @@ class Model:
self.writer.add_images(
f'Pose image/batch {i}', p, self.curr_iter
)
- time_used = datetime.now() - start_time
- remaining_minute, second = divmod(time_used.seconds, 60)
- hour, minute = divmod(remaining_minute, 60)
- print(f'{hour:02}:{minute:02}:{second:02}',
- f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
- '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- '{:.3e} {:.3e} {:.3e}'.format(*lrs))
- running_loss.zero_()
-
- # Step scheduler
- self.scheduler.step()
- if self.curr_iter % 1000 == 0:
+ # Validation
+ embed_c = self._flatten_embedding(embed_c)
+ embed_p = self._flatten_embedding(embed_p)
+ self._write_embedding('HPM Train', embed_c, x_c1, y)
+ self._write_embedding('PartNet Train', embed_p, x_c1, y)
+
+ # Calculate losses on testing batch
+ batch_c1, batch_c2 = next(val_dataloader)
+ x_c1 = batch_c1['clip'].to(self.device)
+ x_c2 = batch_c2['clip'].to(self.device)
+ with torch.no_grad():
+ embed_c, embed_p, ae_losses, _ = self.rgb_pn(x_c1, x_c2)
+ y = batch_c1['label'].to(self.device)
+ losses, hpm_result, pn_result = self._classification_loss(
+ embed_c, embed_p, ae_losses, y
+ )
+ loss = losses.sum()
+
+ self._write_stat(
+ 'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses
+ )
+ embed_c = self._flatten_embedding(embed_c)
+ embed_p = self._flatten_embedding(embed_p)
+ self._write_embedding('HPM Val', embed_c, x_c1, y)
+ self._write_embedding('PartNet Val', embed_p, x_c1, y)
+
+ # Checkpoint
+ if self.curr_iter % 1000 == 999:
torch.save({
'rand_states': (random.getstate(), torch.get_rng_state()),
'model_state_dict': self.rgb_pn.state_dict(),
@@ -345,9 +330,83 @@ class Model:
'sched_state_dict': self.scheduler.state_dict(),
}, self._checkpoint_name)
- if self.curr_iter == self.total_iter:
- self.writer.close()
- break
+ self.writer.close()
+
+ def _classification_loss(self, embed_c, embed_p, ae_losses, y):
+ # Duplicate labels for each part
+ y_triplet = y.repeat(self.rgb_pn.num_parts, 1)
+ hpm_result = self.triplet_loss_hpm(
+ embed_c, y_triplet[:self.rgb_pn.hpm.num_parts]
+ )
+ pn_result = self.triplet_loss_pn(
+ embed_p, y_triplet[self.rgb_pn.hpm.num_parts:]
+ )
+ losses = torch.stack((
+ *ae_losses,
+ hpm_result.pop('loss').mean(),
+ pn_result.pop('loss').mean()
+ ))
+ return losses, hpm_result, pn_result
+
+ def _write_embedding(self, tag, embed, x, y):
+ frame = x[:, 0, :, :, :].cpu()
+ n, c, h, w = frame.size()
+ padding = torch.zeros(n, c, h, (h-w) // 2)
+ padded_frame = torch.cat((padding, frame, padding), dim=-1)
+ self.writer.add_embedding(
+ embed,
+ metadata=y.cpu().tolist(),
+ label_img=padded_frame,
+ global_step=self.curr_iter,
+ tag=tag
+ )
+
+ def _flatten_embedding(self, embed):
+ return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1)
+
+ def _write_stat(
+ self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses
+ ):
+ # Write losses to TensorBoard
+ self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter)
+ self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip((
+ 'Cross reconstruction loss', 'Canonical consistency loss',
+ 'Pose similarity loss'
+ ), losses[:3])), self.curr_iter)
+ self.writer.add_scalars(f'Loss/triplet loss {postfix}', {
+ 'HPM': losses[3],
+ 'PartNet': losses[4]
+ }, self.curr_iter)
+ # None-zero losses in batch
+ if hpm_result['counts'] is not None and pn_result['counts'] is not None:
+ self.writer.add_scalars(f'Loss/non-zero counts {postfix}', {
+ 'HPM': hpm_result['counts'].mean(),
+ 'PartNet': pn_result['counts'].mean()
+ }, self.curr_iter)
+ # Embedding distance
+ mean_hpm_dist = hpm_result['dist'].mean(0)
+ self._add_ranked_scalars(
+ f'Embedding/HPM distance {postfix}', mean_hpm_dist,
+ self.num_pos_pairs, self.num_pairs, self.curr_iter
+ )
+ mean_pn_dist = pn_result['dist'].mean(0)
+ self._add_ranked_scalars(
+ f'Embedding/ParNet distance {postfix}', mean_pn_dist,
+ self.num_pos_pairs, self.num_pairs, self.curr_iter
+ )
+ # Embedding norm
+ mean_hpm_embedding = embed_c.mean(0)
+ mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
+ self._add_ranked_scalars(
+ f'Embedding/HPM norm {postfix}', mean_hpm_norm,
+ self.k, self.pr * self.k, self.curr_iter
+ )
+ mean_pa_embedding = embed_p.mean(0)
+ mean_pa_norm = mean_pa_embedding.norm(dim=-1)
+ self._add_ranked_scalars(
+ f'Embedding/PartNet norm {postfix}', mean_pa_norm,
+ self.k, self.pr * self.k, self.curr_iter
+ )
def _add_ranked_scalars(
self,
@@ -441,12 +500,12 @@ class Model:
def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]):
label, condition, view, clip = sample.values()
with torch.no_grad():
- feature = self.rgb_pn(clip.to(self.device))
+ feature_c, feature_p = self.rgb_pn(clip.to(self.device))
return {
'label': label.item(),
'condition': condition[0],
'view': view[0],
- 'feature': feature
+ 'feature': torch.cat((feature_c, feature_p)).view(-1)
}
@staticmethod