summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.idea/csv-plugin.xml16
-rw-r--r--config.py10
-rw-r--r--models/model.py331
-rw-r--r--models/rgb_part_net.py2
-rw-r--r--utils/configuration.py1
-rw-r--r--utils/sampler.py35
-rw-r--r--utils/triplet_loss.py9
7 files changed, 266 insertions, 138 deletions
diff --git a/.idea/csv-plugin.xml b/.idea/csv-plugin.xml
new file mode 100644
index 0000000..5e5cec1
--- /dev/null
+++ b/.idea/csv-plugin.xml
@@ -0,0 +1,16 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+ <component name="CsvFileAttributes">
+ <option name="attributeMap">
+ <map>
+ <entry key="/models/model.py">
+ <value>
+ <Attribute>
+ <option name="separator" value="," />
+ </Attribute>
+ </value>
+ </entry>
+ </map>
+ </option>
+ </component>
+</project> \ No newline at end of file
diff --git a/config.py b/config.py
index fdd4328..97b2a13 100644
--- a/config.py
+++ b/config.py
@@ -9,7 +9,9 @@ config: Configuration = {
# Directory used in training or testing for temporary storage
'save_dir': 'runs',
# Recorde disentangled image or not
- 'image_log_on': False
+ 'image_log_on': False,
+ # The number of subjects for validating (Part of testing set)
+ 'val_size': 10,
},
# Dataset settings
'dataset': {
@@ -94,9 +96,9 @@ config: Configuration = {
'final_gamma': 0.01,
# Local parameters (override global ones)
- 'hpm': {
- 'final_gamma': 0.001
- }
+ # 'hpm': {
+ # 'final_gamma': 0.001
+ # }
}
},
# Model metadata
diff --git a/models/model.py b/models/model.py
index b623171..9cac5e5 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,6 +1,6 @@
+import copy
import os
import random
-from datetime import datetime
from typing import Union, Optional
import numpy as np
@@ -52,16 +52,18 @@ class Model:
self.meta = model_config
self.hp = hyperparameter_config
- self.curr_iter = self.meta.get('restore_iter', 0)
+ self.restore_iter = self.curr_iter = self.meta.get('restore_iter', 0)
self.total_iter = self.meta.get('total_iter', 80_000)
- self.curr_iters = self.meta.get('restore_iters', (0, 0, 0))
- self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000))
+ self.restore_iters = self.meta.get('restore_iters', (self.curr_iter,))
+ self.total_iters = self.meta.get('total_iters', (self.total_iter,))
self.is_train: bool = True
self.in_channels: int = 3
self.in_size: tuple[int, int] = (64, 48)
self.pr: Optional[int] = None
self.k: Optional[int] = None
+ self.num_pairs: Optional[int] = None
+ self.num_pos_pairs: Optional[int] = None
self._gallery_dataset_meta: Optional[dict[str, list]] = None
self._probe_datasets_meta: Optional[dict[str, dict[str, list]]] = None
@@ -77,6 +79,7 @@ class Model:
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
self.image_log_on = system_config.get('image_log_on', False)
+ self.val_size = system_config.get('val_size', 10)
self.CASIAB_GALLERY_SELECTOR = {
'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
@@ -90,7 +93,7 @@ class Model:
@property
def _model_sig(self) -> str:
return '_'.join(
- (self._model_name, str(self.curr_iter), str(self.total_iter))
+ (self._model_name, str(self.curr_iter + 1), str(self.total_iter))
)
@property
@@ -119,18 +122,18 @@ class Model:
],
dataloader_config: DataloaderConfiguration,
):
- for (curr_iter, total_iter, (condition, selector)) in zip(
- self.curr_iters, self.total_iters, dataset_selectors.items()
+ for (restore_iter, total_iter, (condition, selector)) in zip(
+ self.restore_iters, self.total_iters, dataset_selectors.items()
):
print(f'Training model {condition} ...')
# Skip finished model
- if curr_iter == total_iter:
+ if restore_iter == total_iter:
continue
# Check invalid restore iter
- elif curr_iter > total_iter:
+ elif restore_iter > total_iter:
raise ValueError("Restore iter '{}' should less than total "
- "iter '{}'".format(curr_iter, total_iter))
- self.curr_iter = curr_iter
+ "iter '{}'".format(restore_iter, total_iter))
+ self.restore_iter = self.curr_iter = restore_iter
self.total_iter = total_iter
self.fit(
dict(**dataset_config, **{'selector': selector}),
@@ -143,8 +146,24 @@ class Model:
dataloader_config: DataloaderConfiguration,
):
self.is_train = True
- dataset = self._parse_dataset_config(dataset_config)
- dataloader = self._parse_dataloader_config(dataset, dataloader_config)
+ # Validation dataset
+ # (the first `val_size` subjects from evaluation set)
+ val_dataset_config = copy.deepcopy(dataset_config)
+ train_size = dataset_config.get('train_size', 74)
+ val_dataset_config['train_size'] = train_size + self.val_size
+ val_dataset_config['selector']['classes'] = ClipClasses({
+ str(c).zfill(3)
+ for c in range(train_size + 1, train_size + self.val_size + 1)
+ })
+ val_dataset = self._parse_dataset_config(val_dataset_config)
+ val_dataloader = iter(self._parse_dataloader_config(
+ val_dataset, dataloader_config
+ ))
+ # Training dataset
+ train_dataset = self._parse_dataset_config(dataset_config)
+ train_dataloader = iter(self._parse_dataloader_config(
+ train_dataset, dataloader_config
+ ))
# Prepare for model, optimizer and scheduler
model_hp: dict = self.hp.get('model', {}).copy()
triplet_is_hard = model_hp.pop('triplet_is_hard', True)
@@ -178,8 +197,8 @@ class Model:
)
num_sampled_frames = dataset_config.get('num_sampled_frames', 30)
- num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
- num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
+ self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
+ self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
# Try to accelerate computation using CUDA or others
self.rgb_pn = nn.DataParallel(self.rgb_pn)
@@ -194,6 +213,7 @@ class Model:
{'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
], **optim_hp)
+ # Scheduler
start_step = sched_hp.get('start_step', 15_000)
final_gamma = sched_hp.get('final_gamma', 0.001)
ae_start_step = ae_sched_hp.get('start_step', start_step)
@@ -224,6 +244,8 @@ class Model:
if self.curr_iter == 0:
self.rgb_pn.apply(self.init_weights)
else: # Load saved state dicts
+ # Offset a iter to load last checkpoint
+ self.curr_iter -= 1
checkpoint = torch.load(self._checkpoint_name)
random.setstate(checkpoint['rand_states'][0])
torch.set_rng_state(checkpoint['rand_states'][1])
@@ -232,101 +254,38 @@ class Model:
self.scheduler.load_state_dict(checkpoint['sched_state_dict'])
# Training start
- start_time = datetime.now()
- running_loss = torch.zeros(5, device=self.device)
- print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
- f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}")
- for (batch_c1, batch_c2) in dataloader:
- self.curr_iter += 1
+ for self.curr_iter in tqdm(range(self.restore_iter, self.total_iter),
+ desc='Training'):
+ batch_c1, batch_c2 = next(train_dataloader)
# Zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embed_c, embed_p, images, feature_for_loss = self.rgb_pn(x_c1, x_c2)
- x_c1_pred = feature_for_loss[0]
- xrecon_loss = torch.stack([
- F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :])
- for i in range(num_sampled_frames)
- ]).sum()
- f_c_c1_t1, f_c_c1_t2, f_c_c2_t2 = feature_for_loss[1]
- cano_cons_loss = torch.stack([
- F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
- + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :])
- for i in range(num_sampled_frames)
- ]).mean()
- f_p_c1_t2, f_p_c2_t2 = feature_for_loss[2]
- pose_sim_loss = F.mse_loss(
- f_p_c1_t2.mean(1), f_p_c2_t2.mean(1)
- ) * 10
- y = batch_c1['label'].to(self.device)
- # Duplicate labels for each part
- y = y.repeat(self.rgb_pn.module.num_parts, 1)
- trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm(
- embed_c.transpose(0, 1), y[:self.rgb_pn.module.hpm.num_parts]
+ embed_c, embed_p, images, f_loss = self.rgb_pn(x_c1, x_c2)
+ ae_losses = self._disentangling_loss(
+ x_c1, f_loss, num_sampled_frames
)
- trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn(
- embed_p.transpose(0, 1), y[self.rgb_pn.module.hpm.num_parts:]
+ embed_c, embed_p = embed_c.transpose(0, 1), embed_p.transpose(0, 1)
+ y = batch_c1['label'].to(self.device)
+ losses, hpm_result, pn_result = self._classification_loss(
+ embed_c, embed_p, ae_losses, y
)
- losses = torch.stack((
- xrecon_loss, cano_cons_loss, pose_sim_loss,
- trip_loss_hpm.mean(), trip_loss_pn.mean()
- ))
loss = losses.sum()
loss.backward()
self.optimizer.step()
+ self.scheduler.step()
- # Statistics and checkpoint
- running_loss += losses.detach()
- # Write losses to TensorBoard
- self.writer.add_scalar('Loss/all', loss, self.curr_iter)
- self.writer.add_scalars('Loss/disentanglement', {
- 'Cross reconstruction loss': xrecon_loss,
- 'Canonical consistency loss': cano_cons_loss,
- 'Pose similarity loss': pose_sim_loss
- }, self.curr_iter)
- self.writer.add_scalars('Loss/triplet loss', {
- 'HPM': losses[3],
- 'PartNet': losses[4]
- }, self.curr_iter)
- # None-zero losses in batch
- if hpm_num_non_zero is not None and pn_num_non_zero is not None:
- self.writer.add_scalars('Loss/non-zero counts', {
- 'HPM': hpm_num_non_zero.mean(),
- 'PartNet': pn_num_non_zero.mean()
- }, self.curr_iter)
- # Embedding distance
- mean_hpm_dist = hpm_dist.mean(0)
- self._add_ranked_scalars(
- 'Embedding/HPM distance', mean_hpm_dist,
- num_pos_pairs, num_pairs, self.curr_iter
- )
- mean_pa_dist = pn_dist.mean(0)
- self._add_ranked_scalars(
- 'Embedding/ParNet distance', mean_pa_dist,
- num_pos_pairs, num_pairs, self.curr_iter
- )
- # Embedding norm
- mean_hpm_embedding = embed_c.mean(0)
- mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
- self._add_ranked_scalars(
- 'Embedding/HPM norm', mean_hpm_norm,
- self.k, self.pr * self.k, self.curr_iter
- )
- mean_pa_embedding = embed_p.mean(0)
- mean_pa_norm = mean_pa_embedding.norm(dim=-1)
- self._add_ranked_scalars(
- 'Embedding/PartNet norm', mean_pa_norm,
- self.k, self.pr * self.k, self.curr_iter
- )
# Learning rate
- lrs = self.scheduler.get_last_lr()
self.writer.add_scalars('Learning rate', dict(zip((
'Auto-encoder', 'HPM', 'PartNet'
- ), lrs)), self.curr_iter)
+ ), self.scheduler.get_last_lr())), self.curr_iter)
+ # Other stats
+ self._write_stat(
+ 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses
+ )
- if self.curr_iter % 100 == 0:
+ if self.curr_iter % 100 == 99:
# Write disentangled images
if self.image_log_on:
i_a, i_c, i_p = images
@@ -343,19 +302,40 @@ class Model:
self.writer.add_images(
f'Pose image/batch {i}', p, self.curr_iter
)
- time_used = datetime.now() - start_time
- remaining_minute, second = divmod(time_used.seconds, 60)
- hour, minute = divmod(remaining_minute, 60)
- print(f'{hour:02}:{minute:02}:{second:02}',
- f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
- '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- '{:.3e} {:.3e} {:.3e}'.format(*lrs))
- running_loss.zero_()
-
- # Step scheduler
- self.scheduler.step()
- if self.curr_iter % 1000 == 0:
+ # Validation
+ embed_c = self._flatten_embedding(embed_c)
+ embed_p = self._flatten_embedding(embed_p)
+ self._write_embedding('HPM Train', embed_c, x_c1, y)
+ self._write_embedding('PartNet Train', embed_p, x_c1, y)
+
+ # Calculate losses on testing batch
+ batch_c1, batch_c2 = next(val_dataloader)
+ x_c1 = batch_c1['clip'].to(self.device)
+ x_c2 = batch_c2['clip'].to(self.device)
+ with torch.no_grad():
+ embed_c, embed_p, _, f_loss = self.rgb_pn(x_c1, x_c2)
+ ae_losses = self._disentangling_loss(
+ x_c1, f_loss, num_sampled_frames
+ )
+ embed_c = embed_c.transpose(0, 1)
+ embed_p = embed_p.transpose(0, 1)
+ y = batch_c1['label'].to(self.device)
+ losses, hpm_result, pn_result = self._classification_loss(
+ embed_c, embed_p, ae_losses, y
+ )
+ loss = losses.sum()
+
+ self._write_stat(
+ 'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses
+ )
+ embed_c = self._flatten_embedding(embed_c)
+ embed_p = self._flatten_embedding(embed_p)
+ self._write_embedding('HPM Val', embed_c, x_c1, y)
+ self._write_embedding('PartNet Val', embed_p, x_c1, y)
+
+ # Checkpoint
+ if self.curr_iter % 1000 == 999:
torch.save({
'rand_states': (random.getstate(), torch.get_rng_state()),
'model_state_dict': self.rgb_pn.state_dict(),
@@ -363,9 +343,102 @@ class Model:
'sched_state_dict': self.scheduler.state_dict(),
}, self._checkpoint_name)
- if self.curr_iter == self.total_iter:
- self.writer.close()
- break
+ self.writer.close()
+
+ @staticmethod
+ def _disentangling_loss(x_c1, feature_for_loss, num_sampled_frames):
+ x_c1_pred = feature_for_loss[0]
+ xrecon_loss = torch.stack([
+ F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :])
+ for i in range(num_sampled_frames)
+ ]).sum()
+ f_c_c1_t1, f_c_c1_t2, f_c_c2_t2 = feature_for_loss[1]
+ cano_cons_loss = torch.stack([
+ F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
+ + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :])
+ for i in range(num_sampled_frames)
+ ]).mean()
+ f_p_c1_t2, f_p_c2_t2 = feature_for_loss[2]
+ pose_sim_loss = F.mse_loss(
+ f_p_c1_t2.mean(1), f_p_c2_t2.mean(1)
+ ) * 10
+ return xrecon_loss, cano_cons_loss, pose_sim_loss
+
+ def _classification_loss(self, embed_c, embed_p, ae_losses, y):
+ # Duplicate labels for each part
+ y_triplet = y.repeat(self.rgb_pn.module.num_parts, 1)
+ hpm_result = self.triplet_loss_hpm(
+ embed_c, y_triplet[:self.rgb_pn.module.hpm.num_parts]
+ )
+ pn_result = self.triplet_loss_pn(
+ embed_p, y_triplet[self.rgb_pn.module.hpm.num_parts:]
+ )
+ losses = torch.stack((
+ *ae_losses,
+ hpm_result.pop('loss').mean(),
+ pn_result.pop('loss').mean()
+ ))
+ return losses, hpm_result, pn_result
+
+ def _write_embedding(self, tag, embed, x, y):
+ frame = x[:, 0, :, :, :].cpu()
+ n, c, h, w = frame.size()
+ padding = torch.zeros(n, c, h, (h-w) // 2)
+ padded_frame = torch.cat((padding, frame, padding), dim=-1)
+ self.writer.add_embedding(
+ embed,
+ metadata=y.cpu().tolist(),
+ label_img=padded_frame,
+ global_step=self.curr_iter,
+ tag=tag
+ )
+
+ def _flatten_embedding(self, embed):
+ return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1)
+
+ def _write_stat(
+ self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses
+ ):
+ # Write losses to TensorBoard
+ self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter)
+ self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip((
+ 'Cross reconstruction loss', 'Canonical consistency loss',
+ 'Pose similarity loss'
+ ), losses[:3])), self.curr_iter)
+ self.writer.add_scalars(f'Loss/triplet loss {postfix}', {
+ 'HPM': losses[3],
+ 'PartNet': losses[4]
+ }, self.curr_iter)
+ # None-zero losses in batch
+ if hpm_result['counts'] is not None and pn_result['counts'] is not None:
+ self.writer.add_scalars(f'Loss/non-zero counts {postfix}', {
+ 'HPM': hpm_result['counts'].mean(),
+ 'PartNet': pn_result['counts'].mean()
+ }, self.curr_iter)
+ # Embedding distance
+ mean_hpm_dist = hpm_result['dist'].mean(0)
+ self._add_ranked_scalars(
+ f'Embedding/HPM distance {postfix}', mean_hpm_dist,
+ self.num_pos_pairs, self.num_pairs, self.curr_iter
+ )
+ mean_pn_dist = pn_result['dist'].mean(0)
+ self._add_ranked_scalars(
+ f'Embedding/ParNet distance {postfix}', mean_pn_dist,
+ self.num_pos_pairs, self.num_pairs, self.curr_iter
+ )
+ # Embedding norm
+ mean_hpm_embedding = embed_c.mean(0)
+ mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
+ self._add_ranked_scalars(
+ f'Embedding/HPM norm {postfix}', mean_hpm_norm,
+ self.k, self.pr * self.k, self.curr_iter
+ )
+ mean_pa_embedding = embed_p.mean(0)
+ mean_pa_norm = mean_pa_embedding.norm(dim=-1)
+ self._add_ranked_scalars(
+ f'Embedding/PartNet norm {postfix}', mean_pa_norm,
+ self.k, self.pr * self.k, self.curr_iter
+ )
def _add_ranked_scalars(
self,
@@ -410,12 +483,12 @@ class Model:
dataset_selectors: dict[
str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
],
- dataloader_config: DataloaderConfiguration
+ dataloader_config: DataloaderConfiguration,
+ is_train: bool = False
):
- self.is_train = False
# Split gallery and probe dataset
gallery_dataloader, probe_dataloaders = self._split_gallery_probe(
- dataset_config, dataloader_config
+ dataset_config, dataloader_config, is_train
)
# Get pretrained models at iter_
checkpoints = self._load_pretrained(
@@ -444,7 +517,6 @@ class Model:
unit='clips'):
gallery_samples_c.append(self._get_eval_sample(sample))
gallery_samples[condition] = default_collate(gallery_samples_c)
- gallery_samples['meta'] = self._gallery_dataset_meta
# Probe
probe_samples_c = []
for sample in tqdm(probe_dataloader,
@@ -454,18 +526,19 @@ class Model:
probe_samples_c = default_collate(probe_samples_c)
probe_samples_c['meta'] = self._probe_datasets_meta[condition]
probe_samples[condition] = probe_samples_c
+ gallery_samples['meta'] = self._gallery_dataset_meta
return gallery_samples, probe_samples
def _get_eval_sample(self, sample: dict[str, Union[list, torch.Tensor]]):
- label = sample.pop('label').item()
- clip = sample.pop('clip').to(self.device)
+ label, condition, view, clip = sample.values()
with torch.no_grad():
- feature = self.rgb_pn(clip)
+ feature_c, feature_p = self.rgb_pn(clip.to(self.device))
return {
- **{'label': label},
- **sample,
- **{'feature': feature}
+ 'label': label.item(),
+ 'condition': condition[0],
+ 'view': view[0],
+ 'feature': torch.cat((feature_c, feature_p)).view(-1)
}
@staticmethod
@@ -525,10 +598,11 @@ class Model:
]
) -> dict[str, str]:
checkpoints = {}
- for (iter_, (condition, selector)) in zip(
- iters, dataset_selectors.items()
+ for (iter_, total_iter, (condition, selector)) in zip(
+ iters, self.total_iters, dataset_selectors.items()
):
- self.curr_iter = iter_
+ self.curr_iter = iter_ - 1
+ self.total_iter = total_iter
self._dataset_sig = self._make_signature(
dict(**dataset_config, **selector),
popped_keys=['root_dir', 'cache_on']
@@ -540,26 +614,29 @@ class Model:
self,
dataset_config: DatasetConfiguration,
dataloader_config: DataloaderConfiguration,
+ is_train: bool = False
) -> tuple[DataLoader, dict[str, DataLoader]]:
dataset_name = dataset_config.get('name', 'CASIA-B')
if dataset_name == 'CASIA-B':
+ self.is_train = is_train
gallery_dataset = self._parse_dataset_config(
dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR)
)
- self._gallery_dataset_meta = gallery_dataset.metadata
- gallery_dataloader = self._parse_dataloader_config(
- gallery_dataset, dataloader_config
- )
probe_datasets = {
condition: self._parse_dataset_config(
dict(**dataset_config, **selector)
)
for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items()
}
+ self._gallery_dataset_meta = gallery_dataset.metadata
self._probe_datasets_meta = {
condition: dataset.metadata
for (condition, dataset) in probe_datasets.items()
}
+ self.is_train = False
+ gallery_dataloader = self._parse_dataloader_config(
+ gallery_dataset, dataloader_config
+ )
probe_dataloaders = {
condition: self._parse_dataloader_config(
dataset, dataloader_config
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index fdeed17..5d2c142 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -57,7 +57,7 @@ class RGBPartNet(nn.Module):
if self.training:
return x_c.transpose(0, 1), x_p.transpose(0, 1), images, f_loss
else:
- return torch.cat((x_c, x_p)).unsqueeze(1).view(-1)
+ return x_c, x_p
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()
diff --git a/utils/configuration.py b/utils/configuration.py
index f6ac182..5a5bc0c 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -8,6 +8,7 @@ class SystemConfiguration(TypedDict):
CUDA_VISIBLE_DEVICES: str
save_dir: str
image_log_on: bool
+ val_size: int
class DatasetConfiguration(TypedDict):
diff --git a/utils/sampler.py b/utils/sampler.py
index cdf1984..0c9872c 100644
--- a/utils/sampler.py
+++ b/utils/sampler.py
@@ -16,7 +16,18 @@ class TripletSampler(data.Sampler):
):
super().__init__(data_source)
self.metadata_labels = data_source.metadata['labels']
+ metadata_conditions = data_source.metadata['conditions']
+ self.subsets = {}
+ for condition in metadata_conditions:
+ pre, _ = condition.split('-')
+ if self.subsets.get(pre, None) is None:
+ self.subsets[pre] = []
+ self.subsets[pre].append(condition)
+ self.num_subsets = len(self.subsets)
+ self.num_seq = {pre: len(seq) for (pre, seq) in self.subsets.items()}
+ self.min_num_seq = min(self.num_seq.values())
self.labels = data_source.labels
+ self.conditions = data_source.conditions
self.length = len(self.labels)
self.indexes = np.arange(0, self.length)
(self.pr, self.k) = batch_size
@@ -27,15 +38,31 @@ class TripletSampler(data.Sampler):
# Sample pr subjects by sampling labels appeared in dataset
sampled_subjects = random.sample(self.metadata_labels, k=self.pr)
for label in sampled_subjects:
- clips_from_subject = self.indexes[self.labels == label].tolist()
+ mask = self.labels == label
+ # Fix unbalanced datasets
+ if self.num_subsets > 1:
+ condition_mask = np.zeros(self.conditions.shape, dtype=bool)
+ for num, conditions_ in zip(
+ self.num_seq.values(), self.subsets.values()
+ ):
+ if num > self.min_num_seq:
+ conditions = random.sample(
+ conditions_, self.min_num_seq
+ )
+ else:
+ conditions = conditions_
+ for condition in conditions:
+ condition_mask |= self.conditions == condition
+ mask &= condition_mask
+ clips = self.indexes[mask].tolist()
# Sample k clips from the subject without replacement if
# have enough clips, k more clips will sampled for
# disentanglement
k = self.k * 2
- if len(clips_from_subject) >= k:
- _sampled_indexes = random.sample(clips_from_subject, k=k)
+ if len(clips) >= k:
+ _sampled_indexes = random.sample(clips, k=k)
else:
- _sampled_indexes = random.choices(clips_from_subject, k=k)
+ _sampled_indexes = random.choices(clips, k=k)
sampled_indexes += _sampled_indexes
yield sampled_indexes
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 03fff21..5e3a97a 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -28,6 +28,7 @@ class BatchTripletLoss(nn.Module):
else: # is_all
positive_negative_dist = self._all_distance(dist, y, p, n)
+ non_zero_counts = None
if self.margin:
losses = F.relu(self.margin + positive_negative_dist).view(p, -1)
non_zero_counts = (losses != 0).sum(1).float()
@@ -35,14 +36,18 @@ class BatchTripletLoss(nn.Module):
loss_metric = self._none_zero_mean(losses, non_zero_counts)
else: # is_sum
loss_metric = losses.sum(1)
- return loss_metric, flat_dist, non_zero_counts
else: # Soft margin
losses = F.softplus(positive_negative_dist).view(p, -1)
if self.is_mean:
loss_metric = losses.mean(1)
else: # is_sum
loss_metric = losses.sum(1)
- return loss_metric, flat_dist, None
+
+ return {
+ 'loss': loss_metric,
+ 'dist': flat_dist,
+ 'counts': non_zero_counts
+ }
@staticmethod
def _batch_distance(x):