summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-08 18:11:25 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-08 18:25:42 +0800
commit99ddd7c142a4ec97cb8bd14b204651790b3cf4ee (patch)
treea4ccbd08a7155e90df63aba60eb93ab2b7969c9b /models/model.py
parent507e1d163aaa6ea4be23e7f08ff6ce0ef58c830b (diff)
Code refactoring, modifications and new features
1. Decode features outside of auto-encoder 2. Turn off HPM 1x1 conv by default 3. Change canonical feature map size from `feature_channels * 8 x 4 x 2` to `feature_channels * 2 x 16 x 8` 4. Use mean of canonical embeddings instead of mean of static features 5. Calculate static and dynamic loss separately 6. Calculate mean of parts in triplet loss instead of sum of parts 7. Add switch to log disentangled images 8. Change default configuration
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py34
1 files changed, 28 insertions, 6 deletions
diff --git a/models/model.py b/models/model.py
index ddb715d..0418070 100644
--- a/models/model.py
+++ b/models/model.py
@@ -69,6 +69,7 @@ class Model:
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
+ self.image_log_on = system_config.get('image_log_on', False)
self.CASIAB_GALLERY_SELECTOR = {
'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
@@ -146,7 +147,8 @@ class Model:
hpm_optim_hp = optim_hp.pop('hpm', {})
fc_optim_hp = optim_hp.pop('fc', {})
sched_hp = self.hp.get('scheduler', {})
- self.rgb_pn = RGBPartNet(self.in_channels, **model_hp)
+ self.rgb_pn = RGBPartNet(self.in_channels, **model_hp,
+ image_log_on=self.image_log_on)
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam([
@@ -168,9 +170,9 @@ class Model:
# Training start
start_time = datetime.now()
- running_loss = torch.zeros(4).to(self.device)
+ running_loss = torch.zeros(5, device=self.device)
print(f"{'Iter':^5} {'Loss':^6} {'Xrecon':^8} {'PoseSim':^8}",
- f"{'CanoCons':^8} {'BATrip':^8} LR(s)")
+ f"{'CanoCons':^8} {'BATripH':^8} {'BATripP':^8} LR(s)")
for (batch_c1, batch_c2) in dataloader:
if self.curr_iter == start_iter:
self.optimizer.add_param_group(
@@ -189,7 +191,7 @@ class Model:
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
y = batch_c1['label'].to(self.device)
- losses = self.rgb_pn(x_c1, x_c2, y)
+ losses, images = self.rgb_pn(x_c1, x_c2, y)
loss = losses.sum()
loss.backward()
self.optimizer.step()
@@ -200,13 +202,33 @@ class Model:
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
self.writer.add_scalars('Loss/details', dict(zip([
'Cross reconstruction loss', 'Pose similarity loss',
- 'Canonical consistency loss', 'Batch All triplet loss'
+ 'Canonical consistency loss', 'Batch All triplet loss (HPM)',
+ 'Batch All triplet loss (PartNet)'
], losses)), self.curr_iter)
+ if self.image_log_on:
+ (appearance_image, canonical_image, pose_image) = images
+ self.writer.add_images(
+ 'Canonical image', canonical_image, self.curr_iter
+ )
+ for i in range(self.pr * self.k):
+ self.writer.add_images(
+ f'Original image/batch {i}', x_c1[i], self.curr_iter
+ )
+ self.writer.add_images(
+ f'Appearance image/batch {i}',
+ appearance_image[:, i, :, :, :],
+ self.curr_iter
+ )
+ self.writer.add_images(
+ f'Pose image/batch {i}',
+ pose_image[:, i, :, :, :],
+ self.curr_iter
+ )
if self.curr_iter % 100 == 0:
lrs = self.scheduler.get_last_lr()
print(f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
- '{:f} {:f} {:f} {:f}'.format(*running_loss / 100),
+ '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
' '.join(('{:.3e}'.format(lr) for lr in lrs)))
running_loss.zero_()