summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py143
1 files changed, 79 insertions, 64 deletions
diff --git a/models/model.py b/models/model.py
index 4335bc9..91d6651 100644
--- a/models/model.py
+++ b/models/model.py
@@ -13,13 +13,15 @@ from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
+from models.hpm import HorizontalPyramidMatching
+from models.part_net import PartNet
from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
-from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss
+from utils.triplet_loss import BatchTripletLoss
class Model:
@@ -69,7 +71,8 @@ class Model:
self._dataset_sig: str = 'undefined'
self.rgb_pn: Optional[RGBPartNet] = None
- self.triplet_loss: Optional[JointBatchTripletLoss] = None
+ self.triplet_loss_hpm: Optional[BatchTripletLoss] = None
+ self.triplet_loss_pn: Optional[BatchTripletLoss] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
@@ -149,26 +152,28 @@ class Model:
triplet_margins = model_hp.pop('triplet_margins', None)
optim_hp: Dict = self.hp.get('optimizer', {}).copy()
ae_optim_hp = optim_hp.pop('auto_encoder', {})
- pn_optim_hp = optim_hp.pop('part_net', {})
hpm_optim_hp = optim_hp.pop('hpm', {})
- fc_optim_hp = optim_hp.pop('fc', {})
+ pn_optim_hp = optim_hp.pop('part_net', {})
sched_hp = self.hp.get('scheduler', {})
+ ae_sched_hp = sched_hp.get('auto_encoder', {})
+ hpm_sched_hp = sched_hp.get('hpm', {})
+ pn_sched_hp = sched_hp.get('part_net', {})
+
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
# Hard margins
if triplet_margins:
- # Same margins
- if triplet_margins[0] == triplet_margins[1]:
- self.triplet_loss = BatchTripletLoss(
- triplet_is_hard, triplet_margins[0]
- )
- else: # Different margins
- self.triplet_loss = JointBatchTripletLoss(
- self.rgb_pn.module.hpm_num_parts,
- triplet_is_hard, triplet_is_mean, triplet_margins
- )
+ self.triplet_loss_hpm = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, triplet_margins[0]
+ )
+ self.triplet_loss_pn = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, triplet_margins[1]
+ )
else: # Soft margins
- self.triplet_loss = BatchTripletLoss(
+ self.triplet_loss_hpm = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, None
+ )
+ self.triplet_loss_pn = BatchTripletLoss(
triplet_is_hard, triplet_is_mean, None
)
@@ -179,26 +184,34 @@ class Model:
# Try to accelerate computation using CUDA or others
self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
- self.triplet_loss = nn.DataParallel(self.triplet_loss)
- self.triplet_loss = self.triplet_loss.to(self.device)
+ self.triplet_loss_hpm = nn.DataParallel(self.triplet_loss_hpm)
+ self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device)
+ self.triplet_loss_pn = nn.DataParallel(self.triplet_loss_pn)
+ self.triplet_loss_pn = self.triplet_loss_pn.to(self.device)
self.optimizer = optim.Adam([
{'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp},
- {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
{'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp},
- {'params': self.rgb_pn.module.fc_mat, **fc_optim_hp}
+ {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
], **optim_hp)
- sched_final_gamma = sched_hp.get('final_gamma', 0.001)
- sched_start_step = sched_hp.get('start_step', 15_000)
- all_step = self.total_iter - sched_start_step
-
- def lr_lambda(epoch):
- if epoch > sched_start_step:
- passed_step = epoch - sched_start_step
- return sched_final_gamma ** (passed_step / all_step)
- else:
- return 1
+
+ start_step = sched_hp.get('start_step', 15_000)
+ final_gamma = sched_hp.get('final_gamma', 0.001)
+ ae_start_step = ae_sched_hp.get('start_step', start_step)
+ ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma)
+ ae_all_step = self.total_iter - ae_start_step
+ hpm_start_step = hpm_sched_hp.get('start_step', start_step)
+ hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma)
+ hpm_all_step = self.total_iter - hpm_start_step
+ pn_start_step = pn_sched_hp.get('start_step', start_step)
+ pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma)
+ pn_all_step = self.total_iter - pn_start_step
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
- lr_lambda, lr_lambda, lr_lambda, lr_lambda
+ lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step)
+ if t > ae_start_step else 1,
+ lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step)
+ if t > hpm_start_step else 1,
+ lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step)
+ if t > pn_start_step else 1,
])
self.writer = SummaryWriter(self._log_name)
@@ -223,7 +236,7 @@ class Model:
running_loss = torch.zeros(5, device=self.device)
print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}")
+ f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}")
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -231,7 +244,7 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embedding, images, feature_for_loss = self.rgb_pn(x_c1, x_c2)
+ embed_c, embed_p, images, feature_for_loss = self.rgb_pn(x_c1, x_c2)
x_c1_pred = feature_for_loss[0]
xrecon_loss = torch.stack([
F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :])
@@ -249,13 +262,16 @@ class Model:
) * 10
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
- y = y.repeat(self.rgb_pn.module.num_total_parts, 1)
- embedding = embedding.transpose(0, 1)
- triplet_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
- hpm_loss = triplet_loss[:self.rgb_pn.module.hpm_num_parts].mean()
- pn_loss = triplet_loss[self.rgb_pn.module.hpm_num_parts:].mean()
+ y = y.repeat(self.rgb_pn.module.num_parts, 1)
+ trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm(
+ embed_c.transpose(0, 1), y[:self.rgb_pn.module.hpm.num_parts]
+ )
+ trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn(
+ embed_p.transpose(0, 1), y[self.rgb_pn.module.hpm.num_parts:]
+ )
losses = torch.stack((
- xrecon_loss, cano_cons_loss, pose_sim_loss, hpm_loss, pn_loss
+ xrecon_loss, cano_cons_loss, pose_sim_loss,
+ trip_loss_hpm.mean(), trip_loss_pn.mean()
))
loss = losses.sum()
loss.backward()
@@ -271,37 +287,34 @@ class Model:
'Pose similarity loss': pose_sim_loss
}, self.curr_iter)
self.writer.add_scalars('Loss/triplet loss', {
- 'HPM': hpm_loss, 'PartNet': pn_loss
+ 'HPM': losses[3],
+ 'PartNet': losses[4]
}, self.curr_iter)
# None-zero losses in batch
- if num_non_zero is not None:
+ if hpm_num_non_zero is not None and hpm_num_non_zero is not None:
self.writer.add_scalars('Loss/non-zero counts', {
- 'HPM': num_non_zero[
- :self.rgb_pn.module.hpm_num_parts].mean(),
- 'PartNet': num_non_zero[
- self.rgb_pn.module.hpm_num_parts:].mean()
+ 'HPM': hpm_num_non_zero.mean(),
+ 'PartNet': pn_num_non_zero.mean()
}, self.curr_iter)
# Embedding distance
- mean_hpm_dist = dist[:self.rgb_pn.module.hpm_num_parts].mean(0)
+ mean_hpm_dist = hpm_dist.mean(0)
self._add_ranked_scalars(
'Embedding/HPM distance', mean_hpm_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
- mean_pa_dist = dist[self.rgb_pn.module.hpm_num_parts:].mean(0)
+ mean_pa_dist = pn_dist.mean(0)
self._add_ranked_scalars(
'Embedding/ParNet distance', mean_pa_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
# Embedding norm
- mean_hpm_embedding = embedding[
- :self.rgb_pn.module.hpm_num_parts].mean(0)
+ mean_hpm_embedding = embed_c.mean(0)
mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/HPM norm', mean_hpm_norm,
self.k, self.pr * self.k, self.curr_iter
)
- mean_pa_embedding = embedding[
- self.rgb_pn.module.hpm_num_parts:].mean(0)
+ mean_pa_embedding = embed_p.mean(0)
mean_pa_norm = mean_pa_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/PartNet norm', mean_pa_norm,
@@ -309,10 +322,9 @@ class Model:
)
# Learning rate
lrs = self.scheduler.get_last_lr()
- # Write learning rates
- self.writer.add_scalar(
- 'Learning rate', lrs[0], self.curr_iter
- )
+ self.writer.add_scalars('Learning rate', dict(zip((
+ 'Auto-encoder', 'HPM', 'PartNet'
+ ), lrs)), self.curr_iter)
if self.curr_iter % 100 == 0:
# Write disentangled images
@@ -337,7 +349,7 @@ class Model:
print(f'{hour:02}:{minute:02}:{second:02}',
f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- f'{lrs[0]:.3e}')
+ '{:.3e} {:.3e} {:.3e}'.format(*lrs))
running_loss.zero_()
# Step scheduler
@@ -432,13 +444,16 @@ class Model:
unit='clips'):
gallery_samples_c.append(self._get_eval_sample(sample))
gallery_samples[condition] = default_collate(gallery_samples_c)
+ gallery_samples['meta'] = self._gallery_dataset_meta
# Probe
probe_samples_c = []
for sample in tqdm(probe_dataloader,
desc=f'Transforming probe {condition}',
unit='clips'):
probe_samples_c.append(self._get_eval_sample(sample))
- probe_samples[condition] = default_collate(probe_samples_c)
+ probe_samples_c = default_collate(probe_samples_c)
+ probe_samples_c['meta'] = self._probe_datasets_meta[condition]
+ probe_samples[condition] = probe_samples_c
return gallery_samples, probe_samples
@@ -453,15 +468,16 @@ class Model:
**{'feature': feature}
}
+ @staticmethod
def evaluate(
self,
- gallery_samples: Dict[str, Union[List[str], torch.Tensor]],
- probe_samples: Dict[str, Dict[str, Union[List[str], torch.Tensor]]],
+ gallery_samples: Dict[str, Dict[str, Union[List, torch.Tensor]]],
+ probe_samples: Dict[str, Dict[str, Union[List, torch.Tensor]]],
num_ranks: int = 5
) -> Dict[str, torch.Tensor]:
- conditions = gallery_samples.keys()
- gallery_views_meta = self._gallery_dataset_meta['views']
- probe_views_meta = list(self._probe_datasets_meta.values())[0]['views']
+ conditions = list(probe_samples.keys())
+ gallery_views_meta = gallery_samples['meta']['views']
+ probe_views_meta = probe_samples[conditions[0]]['meta']['views']
accuracy = {
condition: torch.empty(
len(gallery_views_meta), len(probe_views_meta), num_ranks
@@ -474,7 +490,7 @@ class Model:
(labels_g, _, views_g, features_g) = gallery_samples_c.values()
views_g = np.asarray(views_g)
probe_samples_c = probe_samples[condition]
- (labels_p, _, views_p, features_p) = probe_samples_c.values()
+ (labels_p, _, views_p, features_p, _) = probe_samples_c.values()
views_p = np.asarray(views_p)
accuracy_c = accuracy[condition]
for (v_g_i, view_g) in enumerate(gallery_views_meta):
@@ -499,7 +515,6 @@ class Model:
positive_counts = positive_mat.sum(0)
total_counts, _ = dist.size()
accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts
-
return accuracy
def _load_pretrained(
@@ -570,7 +585,7 @@ class Model:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
- elif isinstance(m, RGBPartNet):
+ elif isinstance(m, (HorizontalPyramidMatching, PartNet)):
nn.init.xavier_uniform_(m.fc_mat)
def _parse_dataset_config(