summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-25 15:46:31 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-25 15:46:31 +0800
commit8dd0c42dc341445987c6a907df909155a8b9abd2 (patch)
treec8b2ceedbb2513662b4cf151121ddf3ca2deb187 /models
parent478ecf50f3dd9e2c2f8ca62e6ecd4a65b2cc7c3a (diff)
parent104c6fbf0686828ed299b2a8bda1806a9b45f440 (diff)
Merge branch 'data_parallel' into data_parallel_py3.8data_parallel_py3.8
# Conflicts: # models/model.py
Diffstat (limited to 'models')
-rw-r--r--models/model.py331
-rw-r--r--models/rgb_part_net.py2
2 files changed, 205 insertions, 128 deletions
diff --git a/models/model.py b/models/model.py
index 07ef37e..0c2e2b9 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,6 +1,6 @@
+import copy
import os
import random
-from datetime import datetime
from typing import Union, Optional, Tuple, List, Dict, Set
import numpy as np
@@ -52,16 +52,18 @@ class Model:
self.meta = model_config
self.hp = hyperparameter_config
- self.curr_iter = self.meta.get('restore_iter', 0)
+ self.restore_iter = self.curr_iter = self.meta.get('restore_iter', 0)
self.total_iter = self.meta.get('total_iter', 80_000)
- self.curr_iters = self.meta.get('restore_iters', (0, 0, 0))
- self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000))
+ self.restore_iters = self.meta.get('restore_iters', (self.curr_iter,))
+ self.total_iters = self.meta.get('total_iters', (self.total_iter,))
self.is_train: bool = True
self.in_channels: int = 3
self.in_size: Tuple[int, int] = (64, 48)
self.pr: Optional[int] = None
self.k: Optional[int] = None
+ self.num_pairs: Optional[int] = None
+ self.num_pos_pairs: Optional[int] = None
self._gallery_dataset_meta: Optional[Dict[str, List]] = None
self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None
@@ -77,6 +79,7 @@ class Model:
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
self.image_log_on = system_config.get('image_log_on', False)
+ self.val_size = system_config.get('val_size', 10)
self.CASIAB_GALLERY_SELECTOR = {
'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
@@ -90,7 +93,7 @@ class Model:
@property
def _model_sig(self) -> str:
return '_'.join(
- (self._model_name, str(self.curr_iter), str(self.total_iter))
+ (self._model_name, str(self.curr_iter + 1), str(self.total_iter))
)
@property
@@ -119,18 +122,18 @@ class Model:
],
dataloader_config: DataloaderConfiguration,
):
- for (curr_iter, total_iter, (condition, selector)) in zip(
- self.curr_iters, self.total_iters, dataset_selectors.items()
+ for (restore_iter, total_iter, (condition, selector)) in zip(
+ self.restore_iters, self.total_iters, dataset_selectors.items()
):
print(f'Training model {condition} ...')
# Skip finished model
- if curr_iter == total_iter:
+ if restore_iter == total_iter:
continue
# Check invalid restore iter
- elif curr_iter > total_iter:
+ elif restore_iter > total_iter:
raise ValueError("Restore iter '{}' should less than total "
- "iter '{}'".format(curr_iter, total_iter))
- self.curr_iter = curr_iter
+ "iter '{}'".format(restore_iter, total_iter))
+ self.restore_iter = self.curr_iter = restore_iter
self.total_iter = total_iter
self.fit(
dict(**dataset_config, **{'selector': selector}),
@@ -143,8 +146,24 @@ class Model:
dataloader_config: DataloaderConfiguration,
):
self.is_train = True
- dataset = self._parse_dataset_config(dataset_config)
- dataloader = self._parse_dataloader_config(dataset, dataloader_config)
+ # Validation dataset
+ # (the first `val_size` subjects from evaluation set)
+ val_dataset_config = copy.deepcopy(dataset_config)
+ train_size = dataset_config.get('train_size', 74)
+ val_dataset_config['train_size'] = train_size + self.val_size
+ val_dataset_config['selector']['classes'] = ClipClasses({
+ str(c).zfill(3)
+ for c in range(train_size + 1, train_size + self.val_size + 1)
+ })
+ val_dataset = self._parse_dataset_config(val_dataset_config)
+ val_dataloader = iter(self._parse_dataloader_config(
+ val_dataset, dataloader_config
+ ))
+ # Training dataset
+ train_dataset = self._parse_dataset_config(dataset_config)
+ train_dataloader = iter(self._parse_dataloader_config(
+ train_dataset, dataloader_config
+ ))
# Prepare for model, optimizer and scheduler
model_hp: Dict = self.hp.get('model', {}).copy()
triplet_is_hard = model_hp.pop('triplet_is_hard', True)
@@ -178,8 +197,8 @@ class Model:
)
num_sampled_frames = dataset_config.get('num_sampled_frames', 30)
- num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
- num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
+ self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
+ self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
# Try to accelerate computation using CUDA or others
self.rgb_pn = nn.DataParallel(self.rgb_pn)
@@ -194,6 +213,7 @@ class Model:
{'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
], **optim_hp)
+ # Scheduler
start_step = sched_hp.get('start_step', 15_000)
final_gamma = sched_hp.get('final_gamma', 0.001)
ae_start_step = ae_sched_hp.get('start_step', start_step)
@@ -224,6 +244,8 @@ class Model:
if self.curr_iter == 0:
self.rgb_pn.apply(self.init_weights)
else: # Load saved state dicts
+ # Offset a iter to load last checkpoint
+ self.curr_iter -= 1
checkpoint = torch.load(self._checkpoint_name)
random.setstate(checkpoint['rand_states'][0])
torch.set_rng_state(checkpoint['rand_states'][1])
@@ -232,101 +254,38 @@ class Model:
self.scheduler.load_state_dict(checkpoint['sched_state_dict'])
# Training start
- start_time = datetime.now()
- running_loss = torch.zeros(5, device=self.device)
- print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
- f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}")
- for (batch_c1, batch_c2) in dataloader:
- self.curr_iter += 1
+ for self.curr_iter in tqdm(range(self.restore_iter, self.total_iter),
+ desc='Training'):
+ batch_c1, batch_c2 = next(train_dataloader)
# Zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embed_c, embed_p, images, feature_for_loss = self.rgb_pn(x_c1, x_c2)
- x_c1_pred = feature_for_loss[0]
- xrecon_loss = torch.stack([
- F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :])
- for i in range(num_sampled_frames)
- ]).sum()
- f_c_c1_t1, f_c_c1_t2, f_c_c2_t2 = feature_for_loss[1]
- cano_cons_loss = torch.stack([
- F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
- + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :])
- for i in range(num_sampled_frames)
- ]).mean()
- f_p_c1_t2, f_p_c2_t2 = feature_for_loss[2]
- pose_sim_loss = F.mse_loss(
- f_p_c1_t2.mean(1), f_p_c2_t2.mean(1)
- ) * 10
- y = batch_c1['label'].to(self.device)
- # Duplicate labels for each part
- y = y.repeat(self.rgb_pn.module.num_parts, 1)
- trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm(
- embed_c.transpose(0, 1), y[:self.rgb_pn.module.hpm.num_parts]
+ embed_c, embed_p, images, f_loss = self.rgb_pn(x_c1, x_c2)
+ ae_losses = self._disentangling_loss(
+ x_c1, f_loss, num_sampled_frames
)
- trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn(
- embed_p.transpose(0, 1), y[self.rgb_pn.module.hpm.num_parts:]
+ embed_c, embed_p = embed_c.transpose(0, 1), embed_p.transpose(0, 1)
+ y = batch_c1['label'].to(self.device)
+ losses, hpm_result, pn_result = self._classification_loss(
+ embed_c, embed_p, ae_losses, y
)
- losses = torch.stack((
- xrecon_loss, cano_cons_loss, pose_sim_loss,
- trip_loss_hpm.mean(), trip_loss_pn.mean()
- ))
loss = losses.sum()
loss.backward()
self.optimizer.step()
+ self.scheduler.step()
- # Statistics and checkpoint
- running_loss += losses.detach()
- # Write losses to TensorBoard
- self.writer.add_scalar('Loss/all', loss, self.curr_iter)
- self.writer.add_scalars('Loss/disentanglement', {
- 'Cross reconstruction loss': xrecon_loss,
- 'Canonical consistency loss': cano_cons_loss,
- 'Pose similarity loss': pose_sim_loss
- }, self.curr_iter)
- self.writer.add_scalars('Loss/triplet loss', {
- 'HPM': losses[3],
- 'PartNet': losses[4]
- }, self.curr_iter)
- # None-zero losses in batch
- if hpm_num_non_zero is not None and pn_num_non_zero is not None:
- self.writer.add_scalars('Loss/non-zero counts', {
- 'HPM': hpm_num_non_zero.mean(),
- 'PartNet': pn_num_non_zero.mean()
- }, self.curr_iter)
- # Embedding distance
- mean_hpm_dist = hpm_dist.mean(0)
- self._add_ranked_scalars(
- 'Embedding/HPM distance', mean_hpm_dist,
- num_pos_pairs, num_pairs, self.curr_iter
- )
- mean_pa_dist = pn_dist.mean(0)
- self._add_ranked_scalars(
- 'Embedding/ParNet distance', mean_pa_dist,
- num_pos_pairs, num_pairs, self.curr_iter
- )
- # Embedding norm
- mean_hpm_embedding = embed_c.mean(0)
- mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
- self._add_ranked_scalars(
- 'Embedding/HPM norm', mean_hpm_norm,
- self.k, self.pr * self.k, self.curr_iter
- )
- mean_pa_embedding = embed_p.mean(0)
- mean_pa_norm = mean_pa_embedding.norm(dim=-1)
- self._add_ranked_scalars(
- 'Embedding/PartNet norm', mean_pa_norm,
- self.k, self.pr * self.k, self.curr_iter
- )
# Learning rate
- lrs = self.scheduler.get_last_lr()
self.writer.add_scalars('Learning rate', dict(zip((
'Auto-encoder', 'HPM', 'PartNet'
- ), lrs)), self.curr_iter)
+ ), self.scheduler.get_last_lr())), self.curr_iter)
+ # Other stats
+ self._write_stat(
+ 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses
+ )
- if self.curr_iter % 100 == 0:
+ if self.curr_iter % 100 == 99:
# Write disentangled images
if self.image_log_on:
i_a, i_c, i_p = images
@@ -343,19 +302,40 @@ class Model:
self.writer.add_images(
f'Pose image/batch {i}', p, self.curr_iter
)
- time_used = datetime.now() - start_time
- remaining_minute, second = divmod(time_used.seconds, 60)
- hour, minute = divmod(remaining_minute, 60)
- print(f'{hour:02}:{minute:02}:{second:02}',
- f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
- '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- '{:.3e} {:.3e} {:.3e}'.format(*lrs))
- running_loss.zero_()
-
- # Step scheduler
- self.scheduler.step()
- if self.curr_iter % 1000 == 0:
+ # Validation
+ embed_c = self._flatten_embedding(embed_c)
+ embed_p = self._flatten_embedding(embed_p)
+ self._write_embedding('HPM Train', embed_c, x_c1, y)
+ self._write_embedding('PartNet Train', embed_p, x_c1, y)
+
+ # Calculate losses on testing batch
+ batch_c1, batch_c2 = next(val_dataloader)
+ x_c1 = batch_c1['clip'].to(self.device)
+ x_c2 = batch_c2['clip'].to(self.device)
+ with torch.no_grad():
+ embed_c, embed_p, _, f_loss = self.rgb_pn(x_c1, x_c2)
+ ae_losses = self._disentangling_loss(
+ x_c1, f_loss, num_sampled_frames
+ )
+ embed_c = embed_c.transpose(0, 1)
+ embed_p = embed_p.transpose(0, 1)
+ y = batch_c1['label'].to(self.device)
+ losses, hpm_result, pn_result = self._classification_loss(
+ embed_c, embed_p, ae_losses, y
+ )
+ loss = losses.sum()
+
+ self._write_stat(
+ 'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses
+ )
+ embed_c = self._flatten_embedding(embed_c)
+ embed_p = self._flatten_embedding(embed_p)
+ self._write_embedding('HPM Val', embed_c, x_c1, y)
+ self._write_embedding('PartNet Val', embed_p, x_c1, y)
+
+ # Checkpoint
+ if self.curr_iter % 1000 == 999:
torch.save({
'rand_states': (random.getstate(), torch.get_rng_state()),
'model_state_dict': self.rgb_pn.state_dict(),
@@ -363,9 +343,102 @@ class Model:
'sched_state_dict': self.scheduler.state_dict(),
}, self._checkpoint_name)
- if self.curr_iter == self.total_iter:
- self.writer.close()
- break
+ self.writer.close()
+
+ @staticmethod
+ def _disentangling_loss(x_c1, feature_for_loss, num_sampled_frames):
+ x_c1_pred = feature_for_loss[0]
+ xrecon_loss = torch.stack([
+ F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :])
+ for i in range(num_sampled_frames)
+ ]).sum()
+ f_c_c1_t1, f_c_c1_t2, f_c_c2_t2 = feature_for_loss[1]
+ cano_cons_loss = torch.stack([
+ F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
+ + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :])
+ for i in range(num_sampled_frames)
+ ]).mean()
+ f_p_c1_t2, f_p_c2_t2 = feature_for_loss[2]
+ pose_sim_loss = F.mse_loss(
+ f_p_c1_t2.mean(1), f_p_c2_t2.mean(1)
+ ) * 10
+ return xrecon_loss, cano_cons_loss, pose_sim_loss
+
+ def _classification_loss(self, embed_c, embed_p, ae_losses, y):
+ # Duplicate labels for each part
+ y_triplet = y.repeat(self.rgb_pn.module.num_parts, 1)
+ hpm_result = self.triplet_loss_hpm(
+ embed_c, y_triplet[:self.rgb_pn.module.hpm.num_parts]
+ )
+ pn_result = self.triplet_loss_pn(
+ embed_p, y_triplet[self.rgb_pn.module.hpm.num_parts:]
+ )
+ losses = torch.stack((
+ *ae_losses,
+ hpm_result.pop('loss').mean(),
+ pn_result.pop('loss').mean()
+ ))
+ return losses, hpm_result, pn_result
+
+ def _write_embedding(self, tag, embed, x, y):
+ frame = x[:, 0, :, :, :].cpu()
+ n, c, h, w = frame.size()
+ padding = torch.zeros(n, c, h, (h-w) // 2)
+ padded_frame = torch.cat((padding, frame, padding), dim=-1)
+ self.writer.add_embedding(
+ embed,
+ metadata=y.cpu().tolist(),
+ label_img=padded_frame,
+ global_step=self.curr_iter,
+ tag=tag
+ )
+
+ def _flatten_embedding(self, embed):
+ return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1)
+
+ def _write_stat(
+ self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses
+ ):
+ # Write losses to TensorBoard
+ self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter)
+ self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip((
+ 'Cross reconstruction loss', 'Canonical consistency loss',
+ 'Pose similarity loss'
+ ), losses[:3])), self.curr_iter)
+ self.writer.add_scalars(f'Loss/triplet loss {postfix}', {
+ 'HPM': losses[3],
+ 'PartNet': losses[4]
+ }, self.curr_iter)
+ # None-zero losses in batch
+ if hpm_result['counts'] is not None and pn_result['counts'] is not None:
+ self.writer.add_scalars(f'Loss/non-zero counts {postfix}', {
+ 'HPM': hpm_result['counts'].mean(),
+ 'PartNet': pn_result['counts'].mean()
+ }, self.curr_iter)
+ # Embedding distance
+ mean_hpm_dist = hpm_result['dist'].mean(0)
+ self._add_ranked_scalars(
+ f'Embedding/HPM distance {postfix}', mean_hpm_dist,
+ self.num_pos_pairs, self.num_pairs, self.curr_iter
+ )
+ mean_pn_dist = pn_result['dist'].mean(0)
+ self._add_ranked_scalars(
+ f'Embedding/ParNet distance {postfix}', mean_pn_dist,
+ self.num_pos_pairs, self.num_pairs, self.curr_iter
+ )
+ # Embedding norm
+ mean_hpm_embedding = embed_c.mean(0)
+ mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
+ self._add_ranked_scalars(
+ f'Embedding/HPM norm {postfix}', mean_hpm_norm,
+ self.k, self.pr * self.k, self.curr_iter
+ )
+ mean_pa_embedding = embed_p.mean(0)
+ mean_pa_norm = mean_pa_embedding.norm(dim=-1)
+ self._add_ranked_scalars(
+ f'Embedding/PartNet norm {postfix}', mean_pa_norm,
+ self.k, self.pr * self.k, self.curr_iter
+ )
def _add_ranked_scalars(
self,
@@ -410,12 +483,12 @@ class Model:
dataset_selectors: Dict[
str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
],
- dataloader_config: DataloaderConfiguration
+ dataloader_config: DataloaderConfiguration,
+ is_train: bool = False
):
- self.is_train = False
# Split gallery and probe dataset
gallery_dataloader, probe_dataloaders = self._split_gallery_probe(
- dataset_config, dataloader_config
+ dataset_config, dataloader_config, is_train
)
# Get pretrained models at iter_
checkpoints = self._load_pretrained(
@@ -444,7 +517,6 @@ class Model:
unit='clips'):
gallery_samples_c.append(self._get_eval_sample(sample))
gallery_samples[condition] = default_collate(gallery_samples_c)
- gallery_samples['meta'] = self._gallery_dataset_meta
# Probe
probe_samples_c = []
for sample in tqdm(probe_dataloader,
@@ -454,18 +526,19 @@ class Model:
probe_samples_c = default_collate(probe_samples_c)
probe_samples_c['meta'] = self._probe_datasets_meta[condition]
probe_samples[condition] = probe_samples_c
+ gallery_samples['meta'] = self._gallery_dataset_meta
return gallery_samples, probe_samples
def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]):
- label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
+ label, condition, view, clip = sample.values()
with torch.no_grad():
- feature = self.rgb_pn(clip)
+ feature_c, feature_p = self.rgb_pn(clip.to(self.device))
return {
- **{'label': label},
- **sample,
- **{'feature': feature}
+ 'label': label.item(),
+ 'condition': condition[0],
+ 'view': view[0],
+ 'feature': torch.cat((feature_c, feature_p)).view(-1)
}
@staticmethod
@@ -525,10 +598,11 @@ class Model:
]
) -> Dict[str, str]:
checkpoints = {}
- for (iter_, (condition, selector)) in zip(
- iters, dataset_selectors.items()
+ for (iter_, total_iter, (condition, selector)) in zip(
+ iters, self.total_iters, dataset_selectors.items()
):
- self.curr_iter = iter_
+ self.curr_iter = iter_ - 1
+ self.total_iter = total_iter
self._dataset_sig = self._make_signature(
dict(**dataset_config, **selector),
popped_keys=['root_dir', 'cache_on']
@@ -540,26 +614,29 @@ class Model:
self,
dataset_config: DatasetConfiguration,
dataloader_config: DataloaderConfiguration,
+ is_train: bool = False
) -> Tuple[DataLoader, Dict[str, DataLoader]]:
dataset_name = dataset_config.get('name', 'CASIA-B')
if dataset_name == 'CASIA-B':
+ self.is_train = is_train
gallery_dataset = self._parse_dataset_config(
dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR)
)
- self._gallery_dataset_meta = gallery_dataset.metadata
- gallery_dataloader = self._parse_dataloader_config(
- gallery_dataset, dataloader_config
- )
probe_datasets = {
condition: self._parse_dataset_config(
dict(**dataset_config, **selector)
)
for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items()
}
+ self._gallery_dataset_meta = gallery_dataset.metadata
self._probe_datasets_meta = {
condition: dataset.metadata
for (condition, dataset) in probe_datasets.items()
}
+ self.is_train = False
+ gallery_dataloader = self._parse_dataloader_config(
+ gallery_dataset, dataloader_config
+ )
probe_dataloaders = {
condition: self._parse_dataloader_config(
dataset, dataloader_config
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 81f198e..e8e320d 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -59,7 +59,7 @@ class RGBPartNet(nn.Module):
if self.training:
return x_c.transpose(0, 1), x_p.transpose(0, 1), images, f_loss
else:
- return torch.cat((x_c, x_p)).unsqueeze(1).view(-1)
+ return x_c, x_p
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()