summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py114
1 files changed, 13 insertions, 101 deletions
diff --git a/models/model.py b/models/model.py
index 09ddaf1..ae21a1b 100644
--- a/models/model.py
+++ b/models/model.py
@@ -5,7 +5,6 @@ from typing import Union, Optional, Tuple, List, Dict, Set
import numpy as np
import torch
import torch.nn as nn
-import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
@@ -139,29 +138,16 @@ class Model:
# Prepare for model, optimizer and scheduler
model_hp = self.hp.get('model', {})
optim_hp: Dict = self.hp.get('optimizer', {}).copy()
- start_iter = optim_hp.pop('start_iter', 0)
- ae_optim_hp = optim_hp.pop('auto_encoder', {})
- pn_optim_hp = optim_hp.pop('part_net', {})
- hpm_optim_hp = optim_hp.pop('hpm', {})
- fc_optim_hp = optim_hp.pop('fc', {})
sched_hp = self.hp.get('scheduler', {})
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
- self.optimizer = optim.Adam([
- {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
- {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
- {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
- {'params': self.rgb_pn.fc_mat, **fc_optim_hp}
- ], **optim_hp)
+ self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp)
sched_gamma = sched_hp.get('gamma', 0.9)
sched_step_size = sched_hp.get('step_size', 500)
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
lambda epoch: sched_gamma ** (epoch // sched_step_size),
- lambda epoch: 0 if epoch < start_iter else 1,
- lambda epoch: 0 if epoch < start_iter else 1,
- lambda epoch: 0 if epoch < start_iter else 1,
])
self.writer = SummaryWriter(self._log_name)
@@ -179,10 +165,10 @@ class Model:
# Training start
start_time = datetime.now()
- running_loss = torch.zeros(5, device=self.device)
+ running_loss = torch.zeros(3, device=self.device)
print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
+ f"{'LR':^9}")
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -190,10 +176,7 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- y = batch_c1['label'].to(self.device)
- # Duplicate labels for each part
- y = y.unsqueeze(1).repeat(1, self.rgb_pn.num_total_parts)
- losses, images = self.rgb_pn(x_c1, x_c2, y)
+ losses, images = self.rgb_pn(x_c1, x_c2)
loss = losses.sum()
loss.backward()
self.optimizer.step()
@@ -203,19 +186,16 @@ class Model:
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
self.writer.add_scalars('Loss/details', dict(zip([
- 'Cross reconstruction loss', 'Canonical consistency loss',
- 'Pose similarity loss', 'Batch All triplet loss (HPM)',
- 'Batch All triplet loss (PartNet)'
+ 'Cross reconstruction loss',
+ 'Canonical consistency loss',
+ 'Pose similarity loss'
], losses)), self.curr_iter)
if self.curr_iter % 100 == 0:
- lrs = self.scheduler.get_last_lr()
+ lr = self.scheduler.get_last_lr()[0]
# Write learning rates
self.writer.add_scalar(
- 'Learning rate/Auto-encoder', lrs[0], self.curr_iter
- )
- self.writer.add_scalar(
- 'Learning rate/Others', lrs[1], self.curr_iter
+ 'Learning rate/Auto-encoder', lr, self.curr_iter
)
# Write disentangled images
if self.image_log_on:
@@ -238,8 +218,8 @@ class Model:
hour, minute = divmod(remaining_minute, 60)
print(f'{hour:02}:{minute:02}:{second:02}',
f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
- '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- '{:.3e} {:.3e}'.format(lrs[0], lrs[1]))
+ '{:f} {:f} {:f}'.format(*running_loss / 100),
+ f'{lr:.3e}')
running_loss.zero_()
# Step scheduler
@@ -258,24 +238,6 @@ class Model:
self.writer.close()
break
- def predict_all(
- self,
- iters: Tuple[int],
- dataset_config: Dict,
- dataset_selectors: Dict[
- str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
- ],
- dataloader_config: Dict,
- ) -> Dict[str, torch.Tensor]:
- # Transform data to features
- gallery_samples, probe_samples = self.transform(
- iters, dataset_config, dataset_selectors, dataloader_config
- )
- # Evaluate features
- accuracy = self.evaluate(gallery_samples, probe_samples)
-
- return accuracy
-
def transform(
self,
iters: Tuple[int],
@@ -326,61 +288,13 @@ class Model:
def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]):
label = sample.pop('label').item()
clip = sample.pop('clip').to(self.device)
- feature = self.rgb_pn(clip).detach()
+ x_c, x_p = self.rgb_pn(clip).detach()
return {
**{'label': label},
**sample,
- **{'feature': feature}
- }
-
- def evaluate(
- self,
- gallery_samples: Dict[str, Union[List[str], torch.Tensor]],
- probe_samples: Dict[str, Dict[str, Union[List[str], torch.Tensor]]],
- num_ranks: int = 5
- ) -> Dict[str, torch.Tensor]:
- probe_conditions = self._probe_datasets_meta.keys()
- gallery_views_meta = self._gallery_dataset_meta['views']
- probe_views_meta = list(self._probe_datasets_meta.values())[0]['views']
- accuracy = {
- condition: torch.empty(
- len(gallery_views_meta), len(probe_views_meta), num_ranks
- )
- for condition in self._probe_datasets_meta.keys()
+ **{'cano_feature': x_c, 'pose_feature': x_p}
}
- (labels_g, _, views_g, features_g) = gallery_samples.values()
- views_g = np.asarray(views_g)
- for (v_g_i, view_g) in enumerate(gallery_views_meta):
- gallery_view_mask = (views_g == view_g)
- f_g = features_g[gallery_view_mask]
- y_g = labels_g[gallery_view_mask]
- for condition in probe_conditions:
- probe_samples_c = probe_samples[condition]
- accuracy_c = accuracy[condition]
- (labels_p, _, views_p, features_p) = probe_samples_c.values()
- views_p = np.asarray(views_p)
- for (v_p_i, view_p) in enumerate(probe_views_meta):
- probe_view_mask = (views_p == view_p)
- f_p = features_p[probe_view_mask]
- y_p = labels_p[probe_view_mask]
- # Euclidean distance
- f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1)
- f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0)
- f_p_times_f_g_sum = f_p @ f_g.T
- dist = torch.sqrt(F.relu(
- f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum
- ))
- # Ranked accuracy
- rank_mask = dist.argsort(1)[:, :num_ranks]
- positive_mat = torch.eq(y_p.unsqueeze(1),
- y_g[rank_mask]).cumsum(1).gt(0)
- positive_counts = positive_mat.sum(0)
- total_counts, _ = dist.size()
- accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts
-
- return accuracy
-
def _load_pretrained(
self,
iters: Tuple[int],
@@ -449,8 +363,6 @@ class Model:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
- elif isinstance(m, RGBPartNet):
- nn.init.xavier_uniform_(m.fc_mat)
def _parse_dataset_config(
self,