diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-25 15:46:31 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-03-25 15:46:31 +0800 |
commit | 8dd0c42dc341445987c6a907df909155a8b9abd2 (patch) | |
tree | c8b2ceedbb2513662b4cf151121ddf3ca2deb187 | |
parent | 478ecf50f3dd9e2c2f8ca62e6ecd4a65b2cc7c3a (diff) | |
parent | 104c6fbf0686828ed299b2a8bda1806a9b45f440 (diff) |
Merge branch 'data_parallel' into data_parallel_py3.8data_parallel_py3.8
# Conflicts:
# models/model.py
-rw-r--r-- | .idea/csv-plugin.xml | 16 | ||||
-rw-r--r-- | config.py | 10 | ||||
-rw-r--r-- | models/model.py | 331 | ||||
-rw-r--r-- | models/rgb_part_net.py | 2 | ||||
-rw-r--r-- | utils/configuration.py | 1 | ||||
-rw-r--r-- | utils/sampler.py | 35 | ||||
-rw-r--r-- | utils/triplet_loss.py | 9 |
7 files changed, 266 insertions, 138 deletions
diff --git a/.idea/csv-plugin.xml b/.idea/csv-plugin.xml new file mode 100644 index 0000000..5e5cec1 --- /dev/null +++ b/.idea/csv-plugin.xml @@ -0,0 +1,16 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="CsvFileAttributes"> + <option name="attributeMap"> + <map> + <entry key="/models/model.py"> + <value> + <Attribute> + <option name="separator" value="," /> + </Attribute> + </value> + </entry> + </map> + </option> + </component> +</project>
\ No newline at end of file @@ -9,7 +9,9 @@ config: Configuration = { # Directory used in training or testing for temporary storage 'save_dir': 'runs', # Recorde disentangled image or not - 'image_log_on': False + 'image_log_on': False, + # The number of subjects for validating (Part of testing set) + 'val_size': 10, }, # Dataset settings 'dataset': { @@ -94,9 +96,9 @@ config: Configuration = { 'final_gamma': 0.01, # Local parameters (override global ones) - 'hpm': { - 'final_gamma': 0.001 - } + # 'hpm': { + # 'final_gamma': 0.001 + # } } }, # Model metadata diff --git a/models/model.py b/models/model.py index 07ef37e..0c2e2b9 100644 --- a/models/model.py +++ b/models/model.py @@ -1,6 +1,6 @@ +import copy import os import random -from datetime import datetime from typing import Union, Optional, Tuple, List, Dict, Set import numpy as np @@ -52,16 +52,18 @@ class Model: self.meta = model_config self.hp = hyperparameter_config - self.curr_iter = self.meta.get('restore_iter', 0) + self.restore_iter = self.curr_iter = self.meta.get('restore_iter', 0) self.total_iter = self.meta.get('total_iter', 80_000) - self.curr_iters = self.meta.get('restore_iters', (0, 0, 0)) - self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000)) + self.restore_iters = self.meta.get('restore_iters', (self.curr_iter,)) + self.total_iters = self.meta.get('total_iters', (self.total_iter,)) self.is_train: bool = True self.in_channels: int = 3 self.in_size: Tuple[int, int] = (64, 48) self.pr: Optional[int] = None self.k: Optional[int] = None + self.num_pairs: Optional[int] = None + self.num_pos_pairs: Optional[int] = None self._gallery_dataset_meta: Optional[Dict[str, List]] = None self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None @@ -77,6 +79,7 @@ class Model: self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None self.image_log_on = system_config.get('image_log_on', False) + self.val_size = system_config.get('val_size', 10) self.CASIAB_GALLERY_SELECTOR = { 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})} @@ -90,7 +93,7 @@ class Model: @property def _model_sig(self) -> str: return '_'.join( - (self._model_name, str(self.curr_iter), str(self.total_iter)) + (self._model_name, str(self.curr_iter + 1), str(self.total_iter)) ) @property @@ -119,18 +122,18 @@ class Model: ], dataloader_config: DataloaderConfiguration, ): - for (curr_iter, total_iter, (condition, selector)) in zip( - self.curr_iters, self.total_iters, dataset_selectors.items() + for (restore_iter, total_iter, (condition, selector)) in zip( + self.restore_iters, self.total_iters, dataset_selectors.items() ): print(f'Training model {condition} ...') # Skip finished model - if curr_iter == total_iter: + if restore_iter == total_iter: continue # Check invalid restore iter - elif curr_iter > total_iter: + elif restore_iter > total_iter: raise ValueError("Restore iter '{}' should less than total " - "iter '{}'".format(curr_iter, total_iter)) - self.curr_iter = curr_iter + "iter '{}'".format(restore_iter, total_iter)) + self.restore_iter = self.curr_iter = restore_iter self.total_iter = total_iter self.fit( dict(**dataset_config, **{'selector': selector}), @@ -143,8 +146,24 @@ class Model: dataloader_config: DataloaderConfiguration, ): self.is_train = True - dataset = self._parse_dataset_config(dataset_config) - dataloader = self._parse_dataloader_config(dataset, dataloader_config) + # Validation dataset + # (the first `val_size` subjects from evaluation set) + val_dataset_config = copy.deepcopy(dataset_config) + train_size = dataset_config.get('train_size', 74) + val_dataset_config['train_size'] = train_size + self.val_size + val_dataset_config['selector']['classes'] = ClipClasses({ + str(c).zfill(3) + for c in range(train_size + 1, train_size + self.val_size + 1) + }) + val_dataset = self._parse_dataset_config(val_dataset_config) + val_dataloader = iter(self._parse_dataloader_config( + val_dataset, dataloader_config + )) + # Training dataset + train_dataset = self._parse_dataset_config(dataset_config) + train_dataloader = iter(self._parse_dataloader_config( + train_dataset, dataloader_config + )) # Prepare for model, optimizer and scheduler model_hp: Dict = self.hp.get('model', {}).copy() triplet_is_hard = model_hp.pop('triplet_is_hard', True) @@ -178,8 +197,8 @@ class Model: ) num_sampled_frames = dataset_config.get('num_sampled_frames', 30) - num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 - num_pos_pairs = (self.k*(self.k-1)//2) * self.pr + self.num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2 + self.num_pos_pairs = (self.k*(self.k-1)//2) * self.pr # Try to accelerate computation using CUDA or others self.rgb_pn = nn.DataParallel(self.rgb_pn) @@ -194,6 +213,7 @@ class Model: {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp}, ], **optim_hp) + # Scheduler start_step = sched_hp.get('start_step', 15_000) final_gamma = sched_hp.get('final_gamma', 0.001) ae_start_step = ae_sched_hp.get('start_step', start_step) @@ -224,6 +244,8 @@ class Model: if self.curr_iter == 0: self.rgb_pn.apply(self.init_weights) else: # Load saved state dicts + # Offset a iter to load last checkpoint + self.curr_iter -= 1 checkpoint = torch.load(self._checkpoint_name) random.setstate(checkpoint['rand_states'][0]) torch.set_rng_state(checkpoint['rand_states'][1]) @@ -232,101 +254,38 @@ class Model: self.scheduler.load_state_dict(checkpoint['sched_state_dict']) # Training start - start_time = datetime.now() - running_loss = torch.zeros(5, device=self.device) - print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", - f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}") - for (batch_c1, batch_c2) in dataloader: - self.curr_iter += 1 + for self.curr_iter in tqdm(range(self.restore_iter, self.total_iter), + desc='Training'): + batch_c1, batch_c2 = next(train_dataloader) # Zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embed_c, embed_p, images, feature_for_loss = self.rgb_pn(x_c1, x_c2) - x_c1_pred = feature_for_loss[0] - xrecon_loss = torch.stack([ - F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :]) - for i in range(num_sampled_frames) - ]).sum() - f_c_c1_t1, f_c_c1_t2, f_c_c2_t2 = feature_for_loss[1] - cano_cons_loss = torch.stack([ - F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :]) - + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :]) - for i in range(num_sampled_frames) - ]).mean() - f_p_c1_t2, f_p_c2_t2 = feature_for_loss[2] - pose_sim_loss = F.mse_loss( - f_p_c1_t2.mean(1), f_p_c2_t2.mean(1) - ) * 10 - y = batch_c1['label'].to(self.device) - # Duplicate labels for each part - y = y.repeat(self.rgb_pn.module.num_parts, 1) - trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( - embed_c.transpose(0, 1), y[:self.rgb_pn.module.hpm.num_parts] + embed_c, embed_p, images, f_loss = self.rgb_pn(x_c1, x_c2) + ae_losses = self._disentangling_loss( + x_c1, f_loss, num_sampled_frames ) - trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( - embed_p.transpose(0, 1), y[self.rgb_pn.module.hpm.num_parts:] + embed_c, embed_p = embed_c.transpose(0, 1), embed_p.transpose(0, 1) + y = batch_c1['label'].to(self.device) + losses, hpm_result, pn_result = self._classification_loss( + embed_c, embed_p, ae_losses, y ) - losses = torch.stack(( - xrecon_loss, cano_cons_loss, pose_sim_loss, - trip_loss_hpm.mean(), trip_loss_pn.mean() - )) loss = losses.sum() loss.backward() self.optimizer.step() + self.scheduler.step() - # Statistics and checkpoint - running_loss += losses.detach() - # Write losses to TensorBoard - self.writer.add_scalar('Loss/all', loss, self.curr_iter) - self.writer.add_scalars('Loss/disentanglement', { - 'Cross reconstruction loss': xrecon_loss, - 'Canonical consistency loss': cano_cons_loss, - 'Pose similarity loss': pose_sim_loss - }, self.curr_iter) - self.writer.add_scalars('Loss/triplet loss', { - 'HPM': losses[3], - 'PartNet': losses[4] - }, self.curr_iter) - # None-zero losses in batch - if hpm_num_non_zero is not None and pn_num_non_zero is not None: - self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': hpm_num_non_zero.mean(), - 'PartNet': pn_num_non_zero.mean() - }, self.curr_iter) - # Embedding distance - mean_hpm_dist = hpm_dist.mean(0) - self._add_ranked_scalars( - 'Embedding/HPM distance', mean_hpm_dist, - num_pos_pairs, num_pairs, self.curr_iter - ) - mean_pa_dist = pn_dist.mean(0) - self._add_ranked_scalars( - 'Embedding/ParNet distance', mean_pa_dist, - num_pos_pairs, num_pairs, self.curr_iter - ) - # Embedding norm - mean_hpm_embedding = embed_c.mean(0) - mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) - self._add_ranked_scalars( - 'Embedding/HPM norm', mean_hpm_norm, - self.k, self.pr * self.k, self.curr_iter - ) - mean_pa_embedding = embed_p.mean(0) - mean_pa_norm = mean_pa_embedding.norm(dim=-1) - self._add_ranked_scalars( - 'Embedding/PartNet norm', mean_pa_norm, - self.k, self.pr * self.k, self.curr_iter - ) # Learning rate - lrs = self.scheduler.get_last_lr() self.writer.add_scalars('Learning rate', dict(zip(( 'Auto-encoder', 'HPM', 'PartNet' - ), lrs)), self.curr_iter) + ), self.scheduler.get_last_lr())), self.curr_iter) + # Other stats + self._write_stat( + 'Train', embed_c, embed_p, hpm_result, pn_result, loss, losses + ) - if self.curr_iter % 100 == 0: + if self.curr_iter % 100 == 99: # Write disentangled images if self.image_log_on: i_a, i_c, i_p = images @@ -343,19 +302,40 @@ class Model: self.writer.add_images( f'Pose image/batch {i}', p, self.curr_iter ) - time_used = datetime.now() - start_time - remaining_minute, second = divmod(time_used.seconds, 60) - hour, minute = divmod(remaining_minute, 60) - print(f'{hour:02}:{minute:02}:{second:02}', - f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', - '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), - '{:.3e} {:.3e} {:.3e}'.format(*lrs)) - running_loss.zero_() - - # Step scheduler - self.scheduler.step() - if self.curr_iter % 1000 == 0: + # Validation + embed_c = self._flatten_embedding(embed_c) + embed_p = self._flatten_embedding(embed_p) + self._write_embedding('HPM Train', embed_c, x_c1, y) + self._write_embedding('PartNet Train', embed_p, x_c1, y) + + # Calculate losses on testing batch + batch_c1, batch_c2 = next(val_dataloader) + x_c1 = batch_c1['clip'].to(self.device) + x_c2 = batch_c2['clip'].to(self.device) + with torch.no_grad(): + embed_c, embed_p, _, f_loss = self.rgb_pn(x_c1, x_c2) + ae_losses = self._disentangling_loss( + x_c1, f_loss, num_sampled_frames + ) + embed_c = embed_c.transpose(0, 1) + embed_p = embed_p.transpose(0, 1) + y = batch_c1['label'].to(self.device) + losses, hpm_result, pn_result = self._classification_loss( + embed_c, embed_p, ae_losses, y + ) + loss = losses.sum() + + self._write_stat( + 'Val', embed_c, embed_p, hpm_result, pn_result, loss, losses + ) + embed_c = self._flatten_embedding(embed_c) + embed_p = self._flatten_embedding(embed_p) + self._write_embedding('HPM Val', embed_c, x_c1, y) + self._write_embedding('PartNet Val', embed_p, x_c1, y) + + # Checkpoint + if self.curr_iter % 1000 == 999: torch.save({ 'rand_states': (random.getstate(), torch.get_rng_state()), 'model_state_dict': self.rgb_pn.state_dict(), @@ -363,9 +343,102 @@ class Model: 'sched_state_dict': self.scheduler.state_dict(), }, self._checkpoint_name) - if self.curr_iter == self.total_iter: - self.writer.close() - break + self.writer.close() + + @staticmethod + def _disentangling_loss(x_c1, feature_for_loss, num_sampled_frames): + x_c1_pred = feature_for_loss[0] + xrecon_loss = torch.stack([ + F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :]) + for i in range(num_sampled_frames) + ]).sum() + f_c_c1_t1, f_c_c1_t2, f_c_c2_t2 = feature_for_loss[1] + cano_cons_loss = torch.stack([ + F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :]) + + F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :]) + for i in range(num_sampled_frames) + ]).mean() + f_p_c1_t2, f_p_c2_t2 = feature_for_loss[2] + pose_sim_loss = F.mse_loss( + f_p_c1_t2.mean(1), f_p_c2_t2.mean(1) + ) * 10 + return xrecon_loss, cano_cons_loss, pose_sim_loss + + def _classification_loss(self, embed_c, embed_p, ae_losses, y): + # Duplicate labels for each part + y_triplet = y.repeat(self.rgb_pn.module.num_parts, 1) + hpm_result = self.triplet_loss_hpm( + embed_c, y_triplet[:self.rgb_pn.module.hpm.num_parts] + ) + pn_result = self.triplet_loss_pn( + embed_p, y_triplet[self.rgb_pn.module.hpm.num_parts:] + ) + losses = torch.stack(( + *ae_losses, + hpm_result.pop('loss').mean(), + pn_result.pop('loss').mean() + )) + return losses, hpm_result, pn_result + + def _write_embedding(self, tag, embed, x, y): + frame = x[:, 0, :, :, :].cpu() + n, c, h, w = frame.size() + padding = torch.zeros(n, c, h, (h-w) // 2) + padded_frame = torch.cat((padding, frame, padding), dim=-1) + self.writer.add_embedding( + embed, + metadata=y.cpu().tolist(), + label_img=padded_frame, + global_step=self.curr_iter, + tag=tag + ) + + def _flatten_embedding(self, embed): + return embed.detach().transpose(0, 1).reshape(self.k * self.pr, -1) + + def _write_stat( + self, postfix, embed_c, embed_p, hpm_result, pn_result, loss, losses + ): + # Write losses to TensorBoard + self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter) + self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip(( + 'Cross reconstruction loss', 'Canonical consistency loss', + 'Pose similarity loss' + ), losses[:3])), self.curr_iter) + self.writer.add_scalars(f'Loss/triplet loss {postfix}', { + 'HPM': losses[3], + 'PartNet': losses[4] + }, self.curr_iter) + # None-zero losses in batch + if hpm_result['counts'] is not None and pn_result['counts'] is not None: + self.writer.add_scalars(f'Loss/non-zero counts {postfix}', { + 'HPM': hpm_result['counts'].mean(), + 'PartNet': pn_result['counts'].mean() + }, self.curr_iter) + # Embedding distance + mean_hpm_dist = hpm_result['dist'].mean(0) + self._add_ranked_scalars( + f'Embedding/HPM distance {postfix}', mean_hpm_dist, + self.num_pos_pairs, self.num_pairs, self.curr_iter + ) + mean_pn_dist = pn_result['dist'].mean(0) + self._add_ranked_scalars( + f'Embedding/ParNet distance {postfix}', mean_pn_dist, + self.num_pos_pairs, self.num_pairs, self.curr_iter + ) + # Embedding norm + mean_hpm_embedding = embed_c.mean(0) + mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) + self._add_ranked_scalars( + f'Embedding/HPM norm {postfix}', mean_hpm_norm, + self.k, self.pr * self.k, self.curr_iter + ) + mean_pa_embedding = embed_p.mean(0) + mean_pa_norm = mean_pa_embedding.norm(dim=-1) + self._add_ranked_scalars( + f'Embedding/PartNet norm {postfix}', mean_pa_norm, + self.k, self.pr * self.k, self.curr_iter + ) def _add_ranked_scalars( self, @@ -410,12 +483,12 @@ class Model: dataset_selectors: Dict[ str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]] ], - dataloader_config: DataloaderConfiguration + dataloader_config: DataloaderConfiguration, + is_train: bool = False ): - self.is_train = False # Split gallery and probe dataset gallery_dataloader, probe_dataloaders = self._split_gallery_probe( - dataset_config, dataloader_config + dataset_config, dataloader_config, is_train ) # Get pretrained models at iter_ checkpoints = self._load_pretrained( @@ -444,7 +517,6 @@ class Model: unit='clips'): gallery_samples_c.append(self._get_eval_sample(sample)) gallery_samples[condition] = default_collate(gallery_samples_c) - gallery_samples['meta'] = self._gallery_dataset_meta # Probe probe_samples_c = [] for sample in tqdm(probe_dataloader, @@ -454,18 +526,19 @@ class Model: probe_samples_c = default_collate(probe_samples_c) probe_samples_c['meta'] = self._probe_datasets_meta[condition] probe_samples[condition] = probe_samples_c + gallery_samples['meta'] = self._gallery_dataset_meta return gallery_samples, probe_samples def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]): - label = sample.pop('label').item() - clip = sample.pop('clip').to(self.device) + label, condition, view, clip = sample.values() with torch.no_grad(): - feature = self.rgb_pn(clip) + feature_c, feature_p = self.rgb_pn(clip.to(self.device)) return { - **{'label': label}, - **sample, - **{'feature': feature} + 'label': label.item(), + 'condition': condition[0], + 'view': view[0], + 'feature': torch.cat((feature_c, feature_p)).view(-1) } @staticmethod @@ -525,10 +598,11 @@ class Model: ] ) -> Dict[str, str]: checkpoints = {} - for (iter_, (condition, selector)) in zip( - iters, dataset_selectors.items() + for (iter_, total_iter, (condition, selector)) in zip( + iters, self.total_iters, dataset_selectors.items() ): - self.curr_iter = iter_ + self.curr_iter = iter_ - 1 + self.total_iter = total_iter self._dataset_sig = self._make_signature( dict(**dataset_config, **selector), popped_keys=['root_dir', 'cache_on'] @@ -540,26 +614,29 @@ class Model: self, dataset_config: DatasetConfiguration, dataloader_config: DataloaderConfiguration, + is_train: bool = False ) -> Tuple[DataLoader, Dict[str, DataLoader]]: dataset_name = dataset_config.get('name', 'CASIA-B') if dataset_name == 'CASIA-B': + self.is_train = is_train gallery_dataset = self._parse_dataset_config( dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR) ) - self._gallery_dataset_meta = gallery_dataset.metadata - gallery_dataloader = self._parse_dataloader_config( - gallery_dataset, dataloader_config - ) probe_datasets = { condition: self._parse_dataset_config( dict(**dataset_config, **selector) ) for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items() } + self._gallery_dataset_meta = gallery_dataset.metadata self._probe_datasets_meta = { condition: dataset.metadata for (condition, dataset) in probe_datasets.items() } + self.is_train = False + gallery_dataloader = self._parse_dataloader_config( + gallery_dataset, dataloader_config + ) probe_dataloaders = { condition: self._parse_dataloader_config( dataset, dataloader_config diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 81f198e..e8e320d 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -59,7 +59,7 @@ class RGBPartNet(nn.Module): if self.training: return x_c.transpose(0, 1), x_p.transpose(0, 1), images, f_loss else: - return torch.cat((x_c, x_p)).unsqueeze(1).view(-1) + return x_c, x_p def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() diff --git a/utils/configuration.py b/utils/configuration.py index 8ee08f2..8dcae07 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -8,6 +8,7 @@ class SystemConfiguration(TypedDict): CUDA_VISIBLE_DEVICES: str save_dir: str image_log_on: bool + val_size: int class DatasetConfiguration(TypedDict): diff --git a/utils/sampler.py b/utils/sampler.py index 0977f94..581d7a2 100644 --- a/utils/sampler.py +++ b/utils/sampler.py @@ -15,7 +15,18 @@ class TripletSampler(data.Sampler): ): super().__init__(data_source) self.metadata_labels = data_source.metadata['labels'] + metadata_conditions = data_source.metadata['conditions'] + self.subsets = {} + for condition in metadata_conditions: + pre, _ = condition.split('-') + if self.subsets.get(pre, None) is None: + self.subsets[pre] = [] + self.subsets[pre].append(condition) + self.num_subsets = len(self.subsets) + self.num_seq = {pre: len(seq) for (pre, seq) in self.subsets.items()} + self.min_num_seq = min(self.num_seq.values()) self.labels = data_source.labels + self.conditions = data_source.conditions self.length = len(self.labels) self.indexes = np.arange(0, self.length) (self.pr, self.k) = batch_size @@ -26,15 +37,31 @@ class TripletSampler(data.Sampler): # Sample pr subjects by sampling labels appeared in dataset sampled_subjects = random.sample(self.metadata_labels, k=self.pr) for label in sampled_subjects: - clips_from_subject = self.indexes[self.labels == label].tolist() + mask = self.labels == label + # Fix unbalanced datasets + if self.num_subsets > 1: + condition_mask = np.zeros(self.conditions.shape, dtype=bool) + for num, conditions_ in zip( + self.num_seq.values(), self.subsets.values() + ): + if num > self.min_num_seq: + conditions = random.sample( + conditions_, self.min_num_seq + ) + else: + conditions = conditions_ + for condition in conditions: + condition_mask |= self.conditions == condition + mask &= condition_mask + clips = self.indexes[mask].tolist() # Sample k clips from the subject without replacement if # have enough clips, k more clips will sampled for # disentanglement k = self.k * 2 - if len(clips_from_subject) >= k: - _sampled_indexes = random.sample(clips_from_subject, k=k) + if len(clips) >= k: + _sampled_indexes = random.sample(clips, k=k) else: - _sampled_indexes = random.choices(clips_from_subject, k=k) + _sampled_indexes = random.choices(clips, k=k) sampled_indexes += _sampled_indexes yield sampled_indexes diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 03fff21..5e3a97a 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -28,6 +28,7 @@ class BatchTripletLoss(nn.Module): else: # is_all positive_negative_dist = self._all_distance(dist, y, p, n) + non_zero_counts = None if self.margin: losses = F.relu(self.margin + positive_negative_dist).view(p, -1) non_zero_counts = (losses != 0).sum(1).float() @@ -35,14 +36,18 @@ class BatchTripletLoss(nn.Module): loss_metric = self._none_zero_mean(losses, non_zero_counts) else: # is_sum loss_metric = losses.sum(1) - return loss_metric, flat_dist, non_zero_counts else: # Soft margin losses = F.softplus(positive_negative_dist).view(p, -1) if self.is_mean: loss_metric = losses.mean(1) else: # is_sum loss_metric = losses.sum(1) - return loss_metric, flat_dist, None + + return { + 'loss': loss_metric, + 'dist': flat_dist, + 'counts': non_zero_counts + } @staticmethod def _batch_distance(x): |