summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py129
1 files changed, 100 insertions, 29 deletions
diff --git a/models/model.py b/models/model.py
index 2c72270..46d7c4c 100644
--- a/models/model.py
+++ b/models/model.py
@@ -18,7 +18,7 @@ from utils.configuration import DataloaderConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
-from utils.triplet_loss import JointBatchAllTripletLoss
+from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss
class Model:
@@ -68,7 +68,7 @@ class Model:
self._dataset_sig: str = 'undefined'
self.rgb_pn: Optional[RGBPartNet] = None
- self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None
+ self.triplet_loss: Optional[JointBatchTripletLoss] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
@@ -143,9 +143,10 @@ class Model:
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
# Prepare for model, optimizer and scheduler
model_hp: Dict = self.hp.get('model', {}).copy()
- triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2))
+ triplet_is_hard = model_hp.pop('triplet_is_hard', True)
+ triplet_is_mean = model_hp.pop('triplet_is_mean', True)
+ triplet_margins = model_hp.pop('triplet_margins', None)
optim_hp: Dict = self.hp.get('optimizer', {}).copy()
- start_iter = optim_hp.pop('start_iter', 0)
ae_optim_hp = optim_hp.pop('auto_encoder', {})
pn_optim_hp = optim_hp.pop('part_net', {})
hpm_optim_hp = optim_hp.pop('hpm', {})
@@ -153,28 +154,48 @@ class Model:
sched_hp = self.hp.get('scheduler', {})
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
- self.ba_triplet_loss = JointBatchAllTripletLoss(
- self.rgb_pn.hpm_num_parts, triplet_margins
- )
+ # Hard margins
+ if triplet_margins:
+ # Same margins
+ if triplet_margins[0] == triplet_margins[1]:
+ self.triplet_loss = BatchTripletLoss(
+ triplet_is_hard, triplet_margins[0]
+ )
+ else: # Different margins
+ self.triplet_loss = JointBatchTripletLoss(
+ self.rgb_pn.hpm_num_parts,
+ triplet_is_hard, triplet_is_mean, triplet_margins
+ )
+ else: # Soft margins
+ self.triplet_loss = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, None
+ )
+
+ num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
+ num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
+
# Try to accelerate computation using CUDA or others
self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
- self.ba_triplet_loss = nn.DataParallel(self.ba_triplet_loss)
- self.ba_triplet_loss = self.ba_triplet_loss.to(self.device)
+ self.triplet_loss = nn.DataParallel(self.triplet_loss)
+ self.triplet_loss = self.triplet_loss.to(self.device)
self.optimizer = optim.Adam([
{'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp},
{'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
{'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp},
{'params': self.rgb_pn.module.fc_mat, **fc_optim_hp}
], **optim_hp)
- sched_gamma = sched_hp.get('gamma', 0.9)
- sched_step_size = sched_hp.get('step_size', 500)
+ sched_final_gamma = sched_hp.get('final_gamma', 0.001)
+ sched_start_step = sched_hp.get('start_step', 15_000)
+
+ def lr_lambda(epoch):
+ passed_step = epoch - sched_start_step
+ all_step = self.total_iter - sched_start_step
+ return sched_final_gamma ** (passed_step / all_step)
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
- lambda epoch: sched_gamma ** (epoch // sched_step_size),
- lambda epoch: 0 if epoch < start_iter else 1,
- lambda epoch: 0 if epoch < start_iter else 1,
- lambda epoch: 0 if epoch < start_iter else 1,
+ lr_lambda, lr_lambda, lr_lambda, lr_lambda
])
+
self.writer = SummaryWriter(self._log_name)
self.rgb_pn.train()
@@ -194,7 +215,7 @@ class Model:
running_loss = torch.zeros(5, device=self.device)
print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
+ f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}")
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -202,16 +223,16 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- feature, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+ embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2)
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
y = y.repeat(self.rgb_pn.num_total_parts, 1)
- triplet_loss = self.ba_triplet_loss(feature, y)
+ trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
losses = torch.cat((
ae_losses.mean(0),
torch.stack((
- triplet_loss[:self.rgb_pn.hpm_num_parts].mean(),
- triplet_loss[self.rgb_pn.hpm_num_parts:].mean()
+ trip_loss[:self.rgb_pn.hpm_num_parts].mean(),
+ trip_loss[self.rgb_pn.hpm_num_parts:].mean()
))
))
loss = losses.sum()
@@ -222,20 +243,50 @@ class Model:
running_loss += losses.detach()
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
- self.writer.add_scalars('Loss/details', dict(zip([
+ self.writer.add_scalars('Loss/disentanglement', dict(zip((
'Cross reconstruction loss', 'Canonical consistency loss',
- 'Pose similarity loss', 'Batch All triplet loss (HPM)',
- 'Batch All triplet loss (PartNet)'
- ], losses)), self.curr_iter)
+ 'Pose similarity loss'
+ ), ae_losses)), self.curr_iter)
+ self.writer.add_scalars('Loss/triplet loss', {
+ 'HPM': losses[3],
+ 'PartNet': losses[4]
+ }, self.curr_iter)
+ # None-zero losses in batch
+ if num_non_zero is not None:
+ self.writer.add_scalars('Loss/non-zero counts', {
+ 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(),
+ 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean()
+ }, self.curr_iter)
+ # Embedding distance
+ mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0)
+ self._add_ranked_scalars(
+ 'Embedding/HPM distance', mean_hpm_dist,
+ num_pos_pairs, num_pairs, self.curr_iter
+ )
+ mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0)
+ self._add_ranked_scalars(
+ 'Embedding/ParNet distance', mean_pa_dist,
+ num_pos_pairs, num_pairs, self.curr_iter
+ )
+ # Embedding norm
+ mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0)
+ mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
+ self._add_ranked_scalars(
+ 'Embedding/HPM norm', mean_hpm_norm,
+ self.k, self.pr * self.k, self.curr_iter
+ )
+ mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0)
+ mean_pa_norm = mean_pa_embedding.norm(dim=-1)
+ self._add_ranked_scalars(
+ 'Embedding/PartNet norm', mean_pa_norm,
+ self.k, self.pr * self.k, self.curr_iter
+ )
if self.curr_iter % 100 == 0:
lrs = self.scheduler.get_last_lr()
# Write learning rates
self.writer.add_scalar(
- 'Learning rate/Auto-encoder', lrs[0], self.curr_iter
- )
- self.writer.add_scalar(
- 'Learning rate/Others', lrs[1], self.curr_iter
+ 'Learning rate', lrs[0], self.curr_iter
)
# Write disentangled images
if self.image_log_on:
@@ -259,7 +310,7 @@ class Model:
print(f'{hour:02}:{minute:02}:{second:02}',
f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- '{:.3e} {:.3e}'.format(lrs[0], lrs[1]))
+ f'{lrs[0]:.3e}')
running_loss.zero_()
# Step scheduler
@@ -278,6 +329,24 @@ class Model:
self.writer.close()
break
+ def _add_ranked_scalars(
+ self,
+ main_tag: str,
+ metric: torch.Tensor,
+ num_pos: int,
+ num_all: int,
+ global_step: int
+ ):
+ rank = metric.argsort()
+ pos_ile = 100 - (num_pos - 1) * 100 // num_all
+ self.writer.add_scalars(main_tag, {
+ '0%-ile': metric[rank[-1]],
+ f'{100 - pos_ile}%-ile': metric[rank[-num_pos]],
+ '50%-ile': metric[rank[num_all // 2 - 1]],
+ f'{pos_ile}%-ile': metric[rank[num_pos - 1]],
+ '100%-ile': metric[rank[0]]
+ }, global_step)
+
def predict_all(
self,
iters: Tuple[int],
@@ -317,6 +386,8 @@ class Model:
# Init models
model_hp: Dict = self.hp.get('model', {}).copy()
+ model_hp.pop('triplet_is_hard', True)
+ model_hp.pop('triplet_is_mean', True)
model_hp.pop('triplet_margins', None)
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
# Try to accelerate computation using CUDA or others