summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-04-03 23:07:23 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-04-03 23:07:23 +0800
commit258efcafe4d34ed5ffeebcaab9389f75a17e4717 (patch)
tree0f4ffe75990b63b8e17956eeec269e3589852769
parent4049566103a00aa6d5a0b1f73569bdc5435714ca (diff)
parentf6f133fa7b926ce0c7d28bbf0ba4de41b3708d4a (diff)
Merge branch 'disentangling_only' into disentangling_only_py3.8
# Conflicts: # models/model.py
-rw-r--r--config.py23
-rw-r--r--models/auto_encoder.py4
-rw-r--r--models/layers.py4
-rw-r--r--models/model.py233
-rw-r--r--models/rgb_part_net.py98
-rw-r--r--requirements.txt8
-rw-r--r--utils/configuration.py6
-rw-r--r--utils/dataset.py6
-rw-r--r--utils/sampler.py35
9 files changed, 248 insertions, 169 deletions
diff --git a/config.py b/config.py
index afd40d5..dc4e0ba 100644
--- a/config.py
+++ b/config.py
@@ -9,7 +9,9 @@ config: Configuration = {
# Directory used in training or testing for temporary storage
'save_dir': 'runs/dis_only',
# Recorde disentangled image or not
- 'image_log_on': True
+ 'image_log_on': True,
+ # The number of subjects for validating (Part of testing set)
+ 'val_size': 10,
},
# Dataset settings
'dataset': {
@@ -37,7 +39,7 @@ config: Configuration = {
# Batch size (pr, k)
# `pr` denotes number of persons
# `k` denotes number of sequences per person
- 'batch_size': (2, 2),
+ 'batch_size': (4, 6),
# Number of workers of Dataloader
'num_workers': 4,
# Faster data transfer from RAM to GPU if enabled
@@ -61,15 +63,20 @@ config: Configuration = {
# Term added to the denominator
# 'eps': 1e-8,
# Weight decay (L2 penalty)
- # 'weight_decay': 0,
+ 'weight_decay': 0.001,
# Use AMSGrad or not
# 'amsgrad': False,
},
'scheduler': {
- # Period of learning rate decay
- 'step_size': 500,
- # Multiplicative factor of decay
- 'gamma': 0.9,
+ # Step start to decay
+ 'start_step': 500,
+ # Multiplicative factor of decay in the end
+ 'final_gamma': 0.01,
+
+ # Local parameters (override global ones)
+ # 'hpm': {
+ # 'final_gamma': 0.001
+ # }
}
},
# Model metadata
@@ -83,6 +90,6 @@ config: Configuration = {
# Restoration iteration (multiple models, e.g. nm, bg and cl)
'restore_iters': (0, 0, 0),
# Total iteration for training (multiple models)
- 'total_iters': (80_000, 80_000, 80_000),
+ 'total_iters': (30_000, 40_000, 60_000),
},
}
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index e17caed..b1d51ef 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -108,15 +108,13 @@ class Decoder(nn.Module):
self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels,
is_last_layer=True)
- def forward(self, f_appearance, f_canonical, f_pose, cano_only=False):
+ def forward(self, f_appearance, f_canonical, f_pose):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.fc(x)
x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0)
x = F.relu(x, inplace=True)
x = self.trans_conv1(x)
x = self.trans_conv2(x)
- if cano_only:
- return x
x = self.trans_conv3(x)
x = torch.sigmoid(self.trans_conv4(x))
diff --git a/models/layers.py b/models/layers.py
index 8228f49..1da79ef 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -79,7 +79,9 @@ class DCGANConvTranspose2d(BasicConvTranspose2d):
if self.is_last_layer:
return self.trans_conv(x)
else:
- return super().forward(x)
+ x = self.trans_conv(x)
+ x = self.bn(x)
+ return F.leaky_relu(x, 0.2, inplace=True)
class BasicLinear(nn.Module):
diff --git a/models/model.py b/models/model.py
index c8f0450..667a0a7 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,5 +1,6 @@
+import copy
import os
-from datetime import datetime
+import random
from typing import Union, Optional, Tuple, List, Dict, Set
import numpy as np
@@ -47,10 +48,10 @@ class Model:
self.meta = model_config
self.hp = hyperparameter_config
- self.curr_iter = self.meta.get('restore_iter', 0)
+ self.restore_iter = self.curr_iter = self.meta.get('restore_iter', 0)
self.total_iter = self.meta.get('total_iter', 80_000)
- self.curr_iters = self.meta.get('restore_iters', (0, 0, 0))
- self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000))
+ self.restore_iters = self.meta.get('restore_iters', (self.curr_iter,))
+ self.total_iters = self.meta.get('total_iters', (self.total_iter,))
self.is_train: bool = True
self.in_channels: int = 3
@@ -70,6 +71,7 @@ class Model:
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
self.image_log_on = system_config.get('image_log_on', False)
+ self.val_size = system_config.get('val_size', 10)
self.CASIAB_GALLERY_SELECTOR = {
'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
@@ -83,7 +85,7 @@ class Model:
@property
def _model_sig(self) -> str:
return '_'.join(
- (self._model_name, str(self.curr_iter), str(self.total_iter))
+ (self._model_name, str(self.curr_iter + 1), str(self.total_iter))
)
@property
@@ -112,18 +114,18 @@ class Model:
],
dataloader_config: DataloaderConfiguration,
):
- for (curr_iter, total_iter, (condition, selector)) in zip(
- self.curr_iters, self.total_iters, dataset_selectors.items()
+ for (restore_iter, total_iter, (condition, selector)) in zip(
+ self.restore_iters, self.total_iters, dataset_selectors.items()
):
print(f'Training model {condition} ...')
# Skip finished model
- if curr_iter == total_iter:
+ if restore_iter == total_iter:
continue
# Check invalid restore iter
- elif curr_iter > total_iter:
+ elif restore_iter > total_iter:
raise ValueError("Restore iter '{}' should less than total "
- "iter '{}'".format(curr_iter, total_iter))
- self.curr_iter = curr_iter
+ "iter '{}'".format(restore_iter, total_iter))
+ self.restore_iter = self.curr_iter = restore_iter
self.total_iter = total_iter
self.fit(
dict(**dataset_config, **{'selector': selector}),
@@ -136,70 +138,85 @@ class Model:
dataloader_config: DataloaderConfiguration,
):
self.is_train = True
- dataset = self._parse_dataset_config(dataset_config)
- dataloader = self._parse_dataloader_config(dataset, dataloader_config)
+ # Validation dataset
+ # (the first `val_size` subjects from evaluation set)
+ val_dataset_config = copy.deepcopy(dataset_config)
+ train_size = dataset_config.get('train_size', 74)
+ val_dataset_config['train_size'] = train_size + self.val_size
+ val_dataset_config['selector']['classes'] = ClipClasses({
+ str(c).zfill(3)
+ for c in range(train_size + 1, train_size + self.val_size + 1)
+ })
+ val_dataset = self._parse_dataset_config(val_dataset_config)
+ val_dataloader = iter(self._parse_dataloader_config(
+ val_dataset, dataloader_config
+ ))
+ # Training dataset
+ train_dataset = self._parse_dataset_config(dataset_config)
+ train_dataloader = iter(self._parse_dataloader_config(
+ train_dataset, dataloader_config
+ ))
# Prepare for model, optimizer and scheduler
- model_hp = self.hp.get('model', {})
+ model_hp: dict = self.hp.get('model', {}).copy()
optim_hp: Dict = self.hp.get('optimizer', {}).copy()
sched_hp = self.hp.get('scheduler', {})
+
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
+
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp)
- sched_gamma = sched_hp.get('gamma', 0.9)
- sched_step_size = sched_hp.get('step_size', 500)
- self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
- lambda epoch: sched_gamma ** (epoch // sched_step_size),
- ])
+ start_step = sched_hp.get('start_step', 15_000)
+ final_gamma = sched_hp.get('final_gamma', 0.001)
+ all_step = self.total_iter - start_step
+ self.scheduler = optim.lr_scheduler.LambdaLR(
+ self.optimizer,
+ lambda t: final_gamma ** ((t - start_step) / all_step)
+ if t > start_step else 1,
+ )
self.writer = SummaryWriter(self._log_name)
+ # Set seeds for reproducibility
+ random.seed(0)
+ torch.manual_seed(0)
self.rgb_pn.train()
# Init weights at first iter
if self.curr_iter == 0:
self.rgb_pn.apply(self.init_weights)
else: # Load saved state dicts
+ # Offset a iter to load last checkpoint
+ self.curr_iter -= 1
checkpoint = torch.load(self._checkpoint_name)
- iter_, loss = checkpoint['iter'], checkpoint['loss']
- print('{0:5d} loss: {1:.3f}'.format(iter_, loss))
+ random.setstate(checkpoint['rand_states'][0])
+ torch.set_rng_state(checkpoint['rand_states'][1])
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
self.scheduler.load_state_dict(checkpoint['sched_state_dict'])
# Training start
- start_time = datetime.now()
- running_loss = torch.zeros(3, device=self.device)
- print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
- f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'LR':^9}")
- for (batch_c1, batch_c2) in dataloader:
- self.curr_iter += 1
+ for self.curr_iter in tqdm(range(self.restore_iter, self.total_iter),
+ desc='Training'):
+ batch_c1, batch_c2 = next(train_dataloader)
# Zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- losses, images = self.rgb_pn(x_c1, x_c2)
+ losses, features, images = self.rgb_pn(x_c1, x_c2)
loss = losses.sum()
loss.backward()
self.optimizer.step()
+ self.scheduler.step()
- # Statistics and checkpoint
- running_loss += losses.detach()
- # Write losses to TensorBoard
- self.writer.add_scalar('Loss/all', loss, self.curr_iter)
- self.writer.add_scalars('Loss/details', dict(zip([
- 'Cross reconstruction loss',
- 'Canonical consistency loss',
- 'Pose similarity loss'
- ], losses)), self.curr_iter)
-
- if self.curr_iter % 100 == 0:
- lr = self.scheduler.get_last_lr()[0]
- # Write learning rates
- self.writer.add_scalar(
- 'Learning rate/Auto-encoder', lr, self.curr_iter
- )
+ # Learning rate
+ self.writer.add_scalar(
+ 'Learning rate', self.scheduler.get_last_lr()[0], self.curr_iter
+ )
+ # Other stats
+ self._write_stat('Train', loss, losses)
+
+ if self.curr_iter % 100 == 99:
# Write disentangled images
if self.image_log_on:
i_a, i_c, i_p = images
@@ -216,30 +233,54 @@ class Model:
self.writer.add_images(
f'Pose image/batch {i}', p, self.curr_iter
)
- time_used = datetime.now() - start_time
- remaining_minute, second = divmod(time_used.seconds, 60)
- hour, minute = divmod(remaining_minute, 60)
- print(f'{hour:02}:{minute:02}:{second:02}',
- f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
- '{:f} {:f} {:f}'.format(*running_loss / 100),
- f'{lr:.3e}')
- running_loss.zero_()
-
- # Step scheduler
- self.scheduler.step()
-
- if self.curr_iter % 1000 == 0:
+ f_a, f_c, f_p = features
+ for i, (f_a_i, f_c_i, f_p_i) in enumerate(
+ zip(f_a, f_c, f_p)
+ ):
+ self.writer.add_images(
+ f'Appearance features/Layer {i}',
+ f_a_i[:, :3, :, :], self.curr_iter
+ )
+ self.writer.add_images(
+ f'Canonical features/Layer {i}',
+ f_c_i[:, :3, :, :], self.curr_iter
+ )
+ for j, p in enumerate(f_p_i):
+ self.writer.add_images(
+ f'Pose features/Layer {i}/batch{j}',
+ p[:, :3, :, :], self.curr_iter
+ )
+
+ # Calculate losses on testing batch
+ batch_c1, batch_c2 = next(val_dataloader)
+ x_c1 = batch_c1['clip'].to(self.device)
+ x_c2 = batch_c2['clip'].to(self.device)
+ with torch.no_grad():
+ losses, _, _ = self.rgb_pn(x_c1, x_c2)
+ loss = losses.sum()
+
+ self._write_stat('Val', loss, losses)
+
+ # Checkpoint
+ if self.curr_iter % 1000 == 999:
torch.save({
- 'iter': self.curr_iter,
+ 'rand_states': (random.getstate(), torch.get_rng_state()),
'model_state_dict': self.rgb_pn.state_dict(),
'optim_state_dict': self.optimizer.state_dict(),
'sched_state_dict': self.scheduler.state_dict(),
- 'loss': loss,
}, self._checkpoint_name)
- if self.curr_iter == self.total_iter:
- self.writer.close()
- break
+ self.writer.close()
+
+ def _write_stat(
+ self, postfix, loss, losses
+ ):
+ # Write losses to TensorBoard
+ self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter)
+ self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip((
+ 'Cross reconstruction loss', 'Canonical consistency loss',
+ 'Pose similarity loss'
+ ), losses)), self.curr_iter)
def transform(
self,
@@ -248,12 +289,12 @@ class Model:
dataset_selectors: Dict[
str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
],
- dataloader_config: DataloaderConfiguration
+ dataloader_config: DataloaderConfiguration,
+ is_train: bool = False
):
- self.is_train = False
# Split gallery and probe dataset
gallery_dataloader, probe_dataloaders = self._split_gallery_probe(
- dataset_config, dataloader_config
+ dataset_config, dataloader_config, is_train
)
# Get pretrained models at iter_
checkpoints = self._load_pretrained(
@@ -261,41 +302,45 @@ class Model:
)
# Init models
- model_hp = self.hp.get('model', {})
+ model_hp: dict = self.hp.get('model', {}).copy()
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
self.rgb_pn.eval()
- gallery_samples, probe_samples = [], {}
- # Gallery
- checkpoint = torch.load(list(checkpoints.values())[0])
- self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
- for sample in tqdm(gallery_dataloader,
- desc='Transforming gallery', unit='clips'):
- gallery_samples.append(self._get_eval_sample(sample))
- gallery_samples = default_collate(gallery_samples)
- # Probe
- for (condition, dataloader) in probe_dataloaders.items():
+ gallery_samples, probe_samples = {}, {}
+ for (condition, probe_dataloader) in probe_dataloaders.items():
checkpoint = torch.load(checkpoints[condition])
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
+ # Gallery
+ gallery_samples_c = []
+ for sample in tqdm(gallery_dataloader,
+ desc=f'Transforming gallery {condition}',
+ unit='clips'):
+ gallery_samples_c.append(self._get_eval_sample(sample))
+ gallery_samples[condition] = default_collate(gallery_samples_c)
+ # Probe
probe_samples_c = []
- for sample in tqdm(dataloader,
+ for sample in tqdm(probe_dataloader,
desc=f'Transforming probe {condition}',
unit='clips'):
probe_samples_c.append(self._get_eval_sample(sample))
- probe_samples[condition] = default_collate(probe_samples_c)
+ probe_samples_c = default_collate(probe_samples_c)
+ probe_samples_c['meta'] = self._probe_datasets_meta[condition]
+ probe_samples[condition] = probe_samples_c
+ gallery_samples['meta'] = self._gallery_dataset_meta
return gallery_samples, probe_samples
def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]):
- label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
- x_c, x_p = self.rgb_pn(clip).detach()
+ label, condition, view, clip = sample.values()
+ with torch.no_grad():
+ feature_c, feature_p = self.rgb_pn(clip.to(self.device))
return {
- **{'label': label},
- **sample,
- **{'cano_feature': x_c, 'pose_feature': x_p}
+ 'label': label.item(),
+ 'condition': condition[0],
+ 'view': view[0],
+ 'feature': torch.cat((feature_c, feature_p)).view(-1)
}
def _load_pretrained(
@@ -307,10 +352,11 @@ class Model:
]
) -> Dict[str, str]:
checkpoints = {}
- for (iter_, (condition, selector)) in zip(
- iters, dataset_selectors.items()
+ for (iter_, total_iter, (condition, selector)) in zip(
+ iters, self.total_iters, dataset_selectors.items()
):
- self.curr_iter = iter_
+ self.curr_iter = iter_ - 1
+ self.total_iter = total_iter
self._dataset_sig = self._make_signature(
dict(**dataset_config, **selector),
popped_keys=['root_dir', 'cache_on']
@@ -322,26 +368,29 @@ class Model:
self,
dataset_config: DatasetConfiguration,
dataloader_config: DataloaderConfiguration,
+ is_train: bool = False
) -> Tuple[DataLoader, Dict[str, DataLoader]]:
dataset_name = dataset_config.get('name', 'CASIA-B')
if dataset_name == 'CASIA-B':
+ self.is_train = is_train
gallery_dataset = self._parse_dataset_config(
dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR)
)
- self._gallery_dataset_meta = gallery_dataset.metadata
- gallery_dataloader = self._parse_dataloader_config(
- gallery_dataset, dataloader_config
- )
probe_datasets = {
condition: self._parse_dataset_config(
dict(**dataset_config, **selector)
)
for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items()
}
+ self._gallery_dataset_meta = gallery_dataset.metadata
self._probe_datasets_meta = {
condition: dataset.metadata
for (condition, dataset) in probe_datasets.items()
}
+ self.is_train = False
+ gallery_dataloader = self._parse_dataloader_config(
+ gallery_dataset, dataloader_config
+ )
probe_dataloaders = {
condition: self._parse_dataloader_config(
dataset, dataloader_config
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 797e02b..1c7a1a2 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -2,6 +2,7 @@ from typing import Tuple
import torch
import torch.nn as nn
+import torch.nn.functional as F
from models.auto_encoder import AutoEncoder
@@ -16,6 +17,7 @@ class RGBPartNet(nn.Module):
image_log_on: bool = False
):
super().__init__()
+ self.h, self.w = ae_in_size
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
self.image_log_on = image_log_on
@@ -24,70 +26,64 @@ class RGBPartNet(nn.Module):
)
def forward(self, x_c1, x_c2=None):
- # Step 1: Disentanglement
- # n, t, c, h, w
- ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2)
+ losses, features, images = self._disentangle(x_c1, x_c2)
if self.training:
losses = torch.stack(losses)
- return losses, images
+ return losses, features, images
else:
- return x_c, x_p
+ return features
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()
- device = x_c1_t2.device
- x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
if self.training:
+ x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
- # Decode features
- with torch.no_grad():
- x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ f_a = f_a_.view(n, t, -1)
+ f_c = f_c_.view(n, t, -1)
+ f_p = f_p_.view(n, t, -1)
- i_a, i_c, i_p = None, None, None
- if self.image_log_on:
- i_a = self._decode_appr_feature(f_a_, n, t, device)
- # Continue decoding canonical features
- i_c = self.ae.decoder.trans_conv3(x_c)
- i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
- i_p = x_p
+ i_a, i_c, i_p = None, None, None
+ if self.image_log_on:
+ with torch.no_grad():
+ x_a, i_a = self._separate_decode(
+ f_a.mean(1),
+ torch.zeros_like(f_c[:, 0, :]),
+ torch.zeros_like(f_p[:, 0, :])
+ )
+ x_c, i_c = self._separate_decode(
+ torch.zeros_like(f_a[:, 0, :]),
+ f_c.mean(1),
+ torch.zeros_like(f_p[:, 0, :]),
+ )
+ x_p_, i_p_ = self._separate_decode(
+ torch.zeros_like(f_a_),
+ torch.zeros_like(f_c_),
+ f_p_
+ )
+ x_p = tuple(_x_p.view(n, t, *_x_p.size()[1:]) for _x_p in x_p_)
+ i_p = i_p_.view(n, t, c, h, w)
- return (x_c, x_p), losses, (i_a, i_c, i_p)
+ return losses, (x_a, x_c, x_p), (i_a, i_c, i_p)
else: # evaluating
f_c_, f_p_ = self.ae(x_c1_t2)
- x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
- return (x_c, x_p), None, None
+ f_c = f_c_.view(n, t, -1)
+ f_p = f_p_.view(n, t, -1)
+ return (f_c, f_p), None, None
- def _decode_appr_feature(self, f_a_, n, t, device):
- # Decode appearance features
- f_a = f_a_.view(n, t, -1)
- x_a = self.ae.decoder(
- f_a.mean(1),
- torch.zeros((n, self.f_c_dim), device=device),
- torch.zeros((n, self.f_p_dim), device=device)
+ def _separate_decode(self, f_a, f_c, f_p):
+ x_1 = torch.cat((f_a, f_c, f_p), dim=1)
+ x_1 = self.ae.decoder.fc(x_1).view(
+ -1,
+ self.ae.decoder.feature_channels * 8,
+ self.ae.decoder.h_0,
+ self.ae.decoder.w_0
)
- return x_a
-
- def _decode_cano_feature(self, f_c_, n, t, device):
- # Decode average canonical features to higher dimension
- f_c = f_c_.view(n, t, -1)
- x_c = self.ae.decoder(
- torch.zeros((n, self.f_a_dim), device=device),
- f_c.mean(1),
- torch.zeros((n, self.f_p_dim), device=device),
- cano_only=True
- )
- return x_c
-
- def _decode_pose_feature(self, f_p_, n, t, c, h, w, device):
- # Decode pose features to images
- x_p_ = self.ae.decoder(
- torch.zeros((n * t, self.f_a_dim), device=device),
- torch.zeros((n * t, self.f_c_dim), device=device),
- f_p_
- )
- x_p = x_p_.view(n, t, c, h, w)
- return x_p
+ x_1 = F.relu(x_1, inplace=True)
+ x_2 = self.ae.decoder.trans_conv1(x_1)
+ x_3 = self.ae.decoder.trans_conv2(x_2)
+ x_4 = self.ae.decoder.trans_conv3(x_3)
+ image = torch.sigmoid(self.ae.decoder.trans_conv4(x_4))
+ x = (x_1, x_2, x_3, x_4)
+ return x, image
diff --git a/requirements.txt b/requirements.txt
index 4d30e17..8b0a41b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
-torch~=1.7.1
-torchvision~=0.8.0a0+ecf4e9c
-numpy~=1.19.4
-tqdm~=4.57.0
+torch~=1.8.0
+torchvision~=0.9.0a0+3f090d0
+numpy~=1.20.1
+tqdm~=4.59.0
Pillow~=8.1.0
scikit-learn~=0.24.0 \ No newline at end of file
diff --git a/utils/configuration.py b/utils/configuration.py
index 340815b..46149b3 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -8,6 +8,7 @@ class SystemConfiguration(TypedDict):
CUDA_VISIBLE_DEVICES: str
save_dir: str
image_log_on: bool
+ val_size: int
class DatasetConfiguration(TypedDict):
@@ -35,7 +36,6 @@ class ModelHPConfiguration(TypedDict):
class OptimizerHPConfiguration(TypedDict):
- start_iter: int
lr: int
betas: Tuple[float, float]
eps: float
@@ -44,8 +44,8 @@ class OptimizerHPConfiguration(TypedDict):
class SchedulerHPConfiguration(TypedDict):
- step_size: int
- gamma: float
+ start_step: int
+ final_gamma: float
class HyperparameterConfiguration(TypedDict):
diff --git a/utils/dataset.py b/utils/dataset.py
index 72cf050..41e2f1e 100644
--- a/utils/dataset.py
+++ b/utils/dataset.py
@@ -111,9 +111,9 @@ class CASIAB(data.Dataset):
# in Bag #2 condition from 90 degree angle
classes, conditions, views = [], [], []
if selector:
- selected_classes = selector.pop('classes', None)
- selected_conditions = selector.pop('conditions', None)
- selected_views = selector.pop('views', None)
+ selected_classes = selector.get('classes', None)
+ selected_conditions = selector.get('conditions', None)
+ selected_views = selector.get('views', None)
class_regex = r'\d{3}'
condition_regex = r'(nm|bg|cl)-0[0-6]'
diff --git a/utils/sampler.py b/utils/sampler.py
index 0977f94..581d7a2 100644
--- a/utils/sampler.py
+++ b/utils/sampler.py
@@ -15,7 +15,18 @@ class TripletSampler(data.Sampler):
):
super().__init__(data_source)
self.metadata_labels = data_source.metadata['labels']
+ metadata_conditions = data_source.metadata['conditions']
+ self.subsets = {}
+ for condition in metadata_conditions:
+ pre, _ = condition.split('-')
+ if self.subsets.get(pre, None) is None:
+ self.subsets[pre] = []
+ self.subsets[pre].append(condition)
+ self.num_subsets = len(self.subsets)
+ self.num_seq = {pre: len(seq) for (pre, seq) in self.subsets.items()}
+ self.min_num_seq = min(self.num_seq.values())
self.labels = data_source.labels
+ self.conditions = data_source.conditions
self.length = len(self.labels)
self.indexes = np.arange(0, self.length)
(self.pr, self.k) = batch_size
@@ -26,15 +37,31 @@ class TripletSampler(data.Sampler):
# Sample pr subjects by sampling labels appeared in dataset
sampled_subjects = random.sample(self.metadata_labels, k=self.pr)
for label in sampled_subjects:
- clips_from_subject = self.indexes[self.labels == label].tolist()
+ mask = self.labels == label
+ # Fix unbalanced datasets
+ if self.num_subsets > 1:
+ condition_mask = np.zeros(self.conditions.shape, dtype=bool)
+ for num, conditions_ in zip(
+ self.num_seq.values(), self.subsets.values()
+ ):
+ if num > self.min_num_seq:
+ conditions = random.sample(
+ conditions_, self.min_num_seq
+ )
+ else:
+ conditions = conditions_
+ for condition in conditions:
+ condition_mask |= self.conditions == condition
+ mask &= condition_mask
+ clips = self.indexes[mask].tolist()
# Sample k clips from the subject without replacement if
# have enough clips, k more clips will sampled for
# disentanglement
k = self.k * 2
- if len(clips_from_subject) >= k:
- _sampled_indexes = random.sample(clips_from_subject, k=k)
+ if len(clips) >= k:
+ _sampled_indexes = random.sample(clips, k=k)
else:
- _sampled_indexes = random.choices(clips_from_subject, k=k)
+ _sampled_indexes = random.choices(clips, k=k)
sampled_indexes += _sampled_indexes
yield sampled_indexes