diff options
-rw-r--r-- | config.py | 23 | ||||
-rw-r--r-- | models/auto_encoder.py | 4 | ||||
-rw-r--r-- | models/layers.py | 4 | ||||
-rw-r--r-- | models/model.py | 233 | ||||
-rw-r--r-- | models/rgb_part_net.py | 98 | ||||
-rw-r--r-- | requirements.txt | 8 | ||||
-rw-r--r-- | utils/configuration.py | 6 | ||||
-rw-r--r-- | utils/dataset.py | 6 | ||||
-rw-r--r-- | utils/sampler.py | 35 |
9 files changed, 248 insertions, 169 deletions
@@ -9,7 +9,9 @@ config: Configuration = { # Directory used in training or testing for temporary storage 'save_dir': 'runs/dis_only', # Recorde disentangled image or not - 'image_log_on': True + 'image_log_on': True, + # The number of subjects for validating (Part of testing set) + 'val_size': 10, }, # Dataset settings 'dataset': { @@ -37,7 +39,7 @@ config: Configuration = { # Batch size (pr, k) # `pr` denotes number of persons # `k` denotes number of sequences per person - 'batch_size': (2, 2), + 'batch_size': (4, 6), # Number of workers of Dataloader 'num_workers': 4, # Faster data transfer from RAM to GPU if enabled @@ -61,15 +63,20 @@ config: Configuration = { # Term added to the denominator # 'eps': 1e-8, # Weight decay (L2 penalty) - # 'weight_decay': 0, + 'weight_decay': 0.001, # Use AMSGrad or not # 'amsgrad': False, }, 'scheduler': { - # Period of learning rate decay - 'step_size': 500, - # Multiplicative factor of decay - 'gamma': 0.9, + # Step start to decay + 'start_step': 500, + # Multiplicative factor of decay in the end + 'final_gamma': 0.01, + + # Local parameters (override global ones) + # 'hpm': { + # 'final_gamma': 0.001 + # } } }, # Model metadata @@ -83,6 +90,6 @@ config: Configuration = { # Restoration iteration (multiple models, e.g. nm, bg and cl) 'restore_iters': (0, 0, 0), # Total iteration for training (multiple models) - 'total_iters': (80_000, 80_000, 80_000), + 'total_iters': (30_000, 40_000, 60_000), }, } diff --git a/models/auto_encoder.py b/models/auto_encoder.py index e17caed..b1d51ef 100644 --- a/models/auto_encoder.py +++ b/models/auto_encoder.py @@ -108,15 +108,13 @@ class Decoder(nn.Module): self.trans_conv4 = DCGANConvTranspose2d(feature_channels, out_channels, is_last_layer=True) - def forward(self, f_appearance, f_canonical, f_pose, cano_only=False): + def forward(self, f_appearance, f_canonical, f_pose): x = torch.cat((f_appearance, f_canonical, f_pose), dim=1) x = self.fc(x) x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0) x = F.relu(x, inplace=True) x = self.trans_conv1(x) x = self.trans_conv2(x) - if cano_only: - return x x = self.trans_conv3(x) x = torch.sigmoid(self.trans_conv4(x)) diff --git a/models/layers.py b/models/layers.py index 8228f49..1da79ef 100644 --- a/models/layers.py +++ b/models/layers.py @@ -79,7 +79,9 @@ class DCGANConvTranspose2d(BasicConvTranspose2d): if self.is_last_layer: return self.trans_conv(x) else: - return super().forward(x) + x = self.trans_conv(x) + x = self.bn(x) + return F.leaky_relu(x, 0.2, inplace=True) class BasicLinear(nn.Module): diff --git a/models/model.py b/models/model.py index c8f0450..667a0a7 100644 --- a/models/model.py +++ b/models/model.py @@ -1,5 +1,6 @@ +import copy import os -from datetime import datetime +import random from typing import Union, Optional, Tuple, List, Dict, Set import numpy as np @@ -47,10 +48,10 @@ class Model: self.meta = model_config self.hp = hyperparameter_config - self.curr_iter = self.meta.get('restore_iter', 0) + self.restore_iter = self.curr_iter = self.meta.get('restore_iter', 0) self.total_iter = self.meta.get('total_iter', 80_000) - self.curr_iters = self.meta.get('restore_iters', (0, 0, 0)) - self.total_iters = self.meta.get('total_iters', (80000, 80000, 80000)) + self.restore_iters = self.meta.get('restore_iters', (self.curr_iter,)) + self.total_iters = self.meta.get('total_iters', (self.total_iter,)) self.is_train: bool = True self.in_channels: int = 3 @@ -70,6 +71,7 @@ class Model: self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None self.image_log_on = system_config.get('image_log_on', False) + self.val_size = system_config.get('val_size', 10) self.CASIAB_GALLERY_SELECTOR = { 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})} @@ -83,7 +85,7 @@ class Model: @property def _model_sig(self) -> str: return '_'.join( - (self._model_name, str(self.curr_iter), str(self.total_iter)) + (self._model_name, str(self.curr_iter + 1), str(self.total_iter)) ) @property @@ -112,18 +114,18 @@ class Model: ], dataloader_config: DataloaderConfiguration, ): - for (curr_iter, total_iter, (condition, selector)) in zip( - self.curr_iters, self.total_iters, dataset_selectors.items() + for (restore_iter, total_iter, (condition, selector)) in zip( + self.restore_iters, self.total_iters, dataset_selectors.items() ): print(f'Training model {condition} ...') # Skip finished model - if curr_iter == total_iter: + if restore_iter == total_iter: continue # Check invalid restore iter - elif curr_iter > total_iter: + elif restore_iter > total_iter: raise ValueError("Restore iter '{}' should less than total " - "iter '{}'".format(curr_iter, total_iter)) - self.curr_iter = curr_iter + "iter '{}'".format(restore_iter, total_iter)) + self.restore_iter = self.curr_iter = restore_iter self.total_iter = total_iter self.fit( dict(**dataset_config, **{'selector': selector}), @@ -136,70 +138,85 @@ class Model: dataloader_config: DataloaderConfiguration, ): self.is_train = True - dataset = self._parse_dataset_config(dataset_config) - dataloader = self._parse_dataloader_config(dataset, dataloader_config) + # Validation dataset + # (the first `val_size` subjects from evaluation set) + val_dataset_config = copy.deepcopy(dataset_config) + train_size = dataset_config.get('train_size', 74) + val_dataset_config['train_size'] = train_size + self.val_size + val_dataset_config['selector']['classes'] = ClipClasses({ + str(c).zfill(3) + for c in range(train_size + 1, train_size + self.val_size + 1) + }) + val_dataset = self._parse_dataset_config(val_dataset_config) + val_dataloader = iter(self._parse_dataloader_config( + val_dataset, dataloader_config + )) + # Training dataset + train_dataset = self._parse_dataset_config(dataset_config) + train_dataloader = iter(self._parse_dataloader_config( + train_dataset, dataloader_config + )) # Prepare for model, optimizer and scheduler - model_hp = self.hp.get('model', {}) + model_hp: dict = self.hp.get('model', {}).copy() optim_hp: Dict = self.hp.get('optimizer', {}).copy() sched_hp = self.hp.get('scheduler', {}) + self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) + # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp) - sched_gamma = sched_hp.get('gamma', 0.9) - sched_step_size = sched_hp.get('step_size', 500) - self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ - lambda epoch: sched_gamma ** (epoch // sched_step_size), - ]) + start_step = sched_hp.get('start_step', 15_000) + final_gamma = sched_hp.get('final_gamma', 0.001) + all_step = self.total_iter - start_step + self.scheduler = optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda t: final_gamma ** ((t - start_step) / all_step) + if t > start_step else 1, + ) self.writer = SummaryWriter(self._log_name) + # Set seeds for reproducibility + random.seed(0) + torch.manual_seed(0) self.rgb_pn.train() # Init weights at first iter if self.curr_iter == 0: self.rgb_pn.apply(self.init_weights) else: # Load saved state dicts + # Offset a iter to load last checkpoint + self.curr_iter -= 1 checkpoint = torch.load(self._checkpoint_name) - iter_, loss = checkpoint['iter'], checkpoint['loss'] - print('{0:5d} loss: {1:.3f}'.format(iter_, loss)) + random.setstate(checkpoint['rand_states'][0]) + torch.set_rng_state(checkpoint['rand_states'][1]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optim_state_dict']) self.scheduler.load_state_dict(checkpoint['sched_state_dict']) # Training start - start_time = datetime.now() - running_loss = torch.zeros(3, device=self.device) - print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", - f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'LR':^9}") - for (batch_c1, batch_c2) in dataloader: - self.curr_iter += 1 + for self.curr_iter in tqdm(range(self.restore_iter, self.total_iter), + desc='Training'): + batch_c1, batch_c2 = next(train_dataloader) # Zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - losses, images = self.rgb_pn(x_c1, x_c2) + losses, features, images = self.rgb_pn(x_c1, x_c2) loss = losses.sum() loss.backward() self.optimizer.step() + self.scheduler.step() - # Statistics and checkpoint - running_loss += losses.detach() - # Write losses to TensorBoard - self.writer.add_scalar('Loss/all', loss, self.curr_iter) - self.writer.add_scalars('Loss/details', dict(zip([ - 'Cross reconstruction loss', - 'Canonical consistency loss', - 'Pose similarity loss' - ], losses)), self.curr_iter) - - if self.curr_iter % 100 == 0: - lr = self.scheduler.get_last_lr()[0] - # Write learning rates - self.writer.add_scalar( - 'Learning rate/Auto-encoder', lr, self.curr_iter - ) + # Learning rate + self.writer.add_scalar( + 'Learning rate', self.scheduler.get_last_lr()[0], self.curr_iter + ) + # Other stats + self._write_stat('Train', loss, losses) + + if self.curr_iter % 100 == 99: # Write disentangled images if self.image_log_on: i_a, i_c, i_p = images @@ -216,30 +233,54 @@ class Model: self.writer.add_images( f'Pose image/batch {i}', p, self.curr_iter ) - time_used = datetime.now() - start_time - remaining_minute, second = divmod(time_used.seconds, 60) - hour, minute = divmod(remaining_minute, 60) - print(f'{hour:02}:{minute:02}:{second:02}', - f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', - '{:f} {:f} {:f}'.format(*running_loss / 100), - f'{lr:.3e}') - running_loss.zero_() - - # Step scheduler - self.scheduler.step() - - if self.curr_iter % 1000 == 0: + f_a, f_c, f_p = features + for i, (f_a_i, f_c_i, f_p_i) in enumerate( + zip(f_a, f_c, f_p) + ): + self.writer.add_images( + f'Appearance features/Layer {i}', + f_a_i[:, :3, :, :], self.curr_iter + ) + self.writer.add_images( + f'Canonical features/Layer {i}', + f_c_i[:, :3, :, :], self.curr_iter + ) + for j, p in enumerate(f_p_i): + self.writer.add_images( + f'Pose features/Layer {i}/batch{j}', + p[:, :3, :, :], self.curr_iter + ) + + # Calculate losses on testing batch + batch_c1, batch_c2 = next(val_dataloader) + x_c1 = batch_c1['clip'].to(self.device) + x_c2 = batch_c2['clip'].to(self.device) + with torch.no_grad(): + losses, _, _ = self.rgb_pn(x_c1, x_c2) + loss = losses.sum() + + self._write_stat('Val', loss, losses) + + # Checkpoint + if self.curr_iter % 1000 == 999: torch.save({ - 'iter': self.curr_iter, + 'rand_states': (random.getstate(), torch.get_rng_state()), 'model_state_dict': self.rgb_pn.state_dict(), 'optim_state_dict': self.optimizer.state_dict(), 'sched_state_dict': self.scheduler.state_dict(), - 'loss': loss, }, self._checkpoint_name) - if self.curr_iter == self.total_iter: - self.writer.close() - break + self.writer.close() + + def _write_stat( + self, postfix, loss, losses + ): + # Write losses to TensorBoard + self.writer.add_scalar(f'Loss/all {postfix}', loss, self.curr_iter) + self.writer.add_scalars(f'Loss/disentanglement {postfix}', dict(zip(( + 'Cross reconstruction loss', 'Canonical consistency loss', + 'Pose similarity loss' + ), losses)), self.curr_iter) def transform( self, @@ -248,12 +289,12 @@ class Model: dataset_selectors: Dict[ str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]] ], - dataloader_config: DataloaderConfiguration + dataloader_config: DataloaderConfiguration, + is_train: bool = False ): - self.is_train = False # Split gallery and probe dataset gallery_dataloader, probe_dataloaders = self._split_gallery_probe( - dataset_config, dataloader_config + dataset_config, dataloader_config, is_train ) # Get pretrained models at iter_ checkpoints = self._load_pretrained( @@ -261,41 +302,45 @@ class Model: ) # Init models - model_hp = self.hp.get('model', {}) + model_hp: dict = self.hp.get('model', {}).copy() self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp) # Try to accelerate computation using CUDA or others self.rgb_pn = self.rgb_pn.to(self.device) self.rgb_pn.eval() - gallery_samples, probe_samples = [], {} - # Gallery - checkpoint = torch.load(list(checkpoints.values())[0]) - self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) - for sample in tqdm(gallery_dataloader, - desc='Transforming gallery', unit='clips'): - gallery_samples.append(self._get_eval_sample(sample)) - gallery_samples = default_collate(gallery_samples) - # Probe - for (condition, dataloader) in probe_dataloaders.items(): + gallery_samples, probe_samples = {}, {} + for (condition, probe_dataloader) in probe_dataloaders.items(): checkpoint = torch.load(checkpoints[condition]) self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + # Gallery + gallery_samples_c = [] + for sample in tqdm(gallery_dataloader, + desc=f'Transforming gallery {condition}', + unit='clips'): + gallery_samples_c.append(self._get_eval_sample(sample)) + gallery_samples[condition] = default_collate(gallery_samples_c) + # Probe probe_samples_c = [] - for sample in tqdm(dataloader, + for sample in tqdm(probe_dataloader, desc=f'Transforming probe {condition}', unit='clips'): probe_samples_c.append(self._get_eval_sample(sample)) - probe_samples[condition] = default_collate(probe_samples_c) + probe_samples_c = default_collate(probe_samples_c) + probe_samples_c['meta'] = self._probe_datasets_meta[condition] + probe_samples[condition] = probe_samples_c + gallery_samples['meta'] = self._gallery_dataset_meta return gallery_samples, probe_samples def _get_eval_sample(self, sample: Dict[str, Union[List, torch.Tensor]]): - label = sample.pop('label').item() - clip = sample.pop('clip').to(self.device) - x_c, x_p = self.rgb_pn(clip).detach() + label, condition, view, clip = sample.values() + with torch.no_grad(): + feature_c, feature_p = self.rgb_pn(clip.to(self.device)) return { - **{'label': label}, - **sample, - **{'cano_feature': x_c, 'pose_feature': x_p} + 'label': label.item(), + 'condition': condition[0], + 'view': view[0], + 'feature': torch.cat((feature_c, feature_p)).view(-1) } def _load_pretrained( @@ -307,10 +352,11 @@ class Model: ] ) -> Dict[str, str]: checkpoints = {} - for (iter_, (condition, selector)) in zip( - iters, dataset_selectors.items() + for (iter_, total_iter, (condition, selector)) in zip( + iters, self.total_iters, dataset_selectors.items() ): - self.curr_iter = iter_ + self.curr_iter = iter_ - 1 + self.total_iter = total_iter self._dataset_sig = self._make_signature( dict(**dataset_config, **selector), popped_keys=['root_dir', 'cache_on'] @@ -322,26 +368,29 @@ class Model: self, dataset_config: DatasetConfiguration, dataloader_config: DataloaderConfiguration, + is_train: bool = False ) -> Tuple[DataLoader, Dict[str, DataLoader]]: dataset_name = dataset_config.get('name', 'CASIA-B') if dataset_name == 'CASIA-B': + self.is_train = is_train gallery_dataset = self._parse_dataset_config( dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR) ) - self._gallery_dataset_meta = gallery_dataset.metadata - gallery_dataloader = self._parse_dataloader_config( - gallery_dataset, dataloader_config - ) probe_datasets = { condition: self._parse_dataset_config( dict(**dataset_config, **selector) ) for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items() } + self._gallery_dataset_meta = gallery_dataset.metadata self._probe_datasets_meta = { condition: dataset.metadata for (condition, dataset) in probe_datasets.items() } + self.is_train = False + gallery_dataloader = self._parse_dataloader_config( + gallery_dataset, dataloader_config + ) probe_dataloaders = { condition: self._parse_dataloader_config( dataset, dataloader_config diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 797e02b..1c7a1a2 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -2,6 +2,7 @@ from typing import Tuple import torch import torch.nn as nn +import torch.nn.functional as F from models.auto_encoder import AutoEncoder @@ -16,6 +17,7 @@ class RGBPartNet(nn.Module): image_log_on: bool = False ): super().__init__() + self.h, self.w = ae_in_size (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims self.image_log_on = image_log_on @@ -24,70 +26,64 @@ class RGBPartNet(nn.Module): ) def forward(self, x_c1, x_c2=None): - # Step 1: Disentanglement - # n, t, c, h, w - ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2) + losses, features, images = self._disentangle(x_c1, x_c2) if self.training: losses = torch.stack(losses) - return losses, images + return losses, features, images else: - return x_c, x_p + return features def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() - device = x_c1_t2.device - x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] if self.training: + x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) - # Decode features - with torch.no_grad(): - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device) + f_a = f_a_.view(n, t, -1) + f_c = f_c_.view(n, t, -1) + f_p = f_p_.view(n, t, -1) - i_a, i_c, i_p = None, None, None - if self.image_log_on: - i_a = self._decode_appr_feature(f_a_, n, t, device) - # Continue decoding canonical features - i_c = self.ae.decoder.trans_conv3(x_c) - i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c)) - i_p = x_p + i_a, i_c, i_p = None, None, None + if self.image_log_on: + with torch.no_grad(): + x_a, i_a = self._separate_decode( + f_a.mean(1), + torch.zeros_like(f_c[:, 0, :]), + torch.zeros_like(f_p[:, 0, :]) + ) + x_c, i_c = self._separate_decode( + torch.zeros_like(f_a[:, 0, :]), + f_c.mean(1), + torch.zeros_like(f_p[:, 0, :]), + ) + x_p_, i_p_ = self._separate_decode( + torch.zeros_like(f_a_), + torch.zeros_like(f_c_), + f_p_ + ) + x_p = tuple(_x_p.view(n, t, *_x_p.size()[1:]) for _x_p in x_p_) + i_p = i_p_.view(n, t, c, h, w) - return (x_c, x_p), losses, (i_a, i_c, i_p) + return losses, (x_a, x_c, x_p), (i_a, i_c, i_p) else: # evaluating f_c_, f_p_ = self.ae(x_c1_t2) - x_c = self._decode_cano_feature(f_c_, n, t, device) - x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device) - return (x_c, x_p), None, None + f_c = f_c_.view(n, t, -1) + f_p = f_p_.view(n, t, -1) + return (f_c, f_p), None, None - def _decode_appr_feature(self, f_a_, n, t, device): - # Decode appearance features - f_a = f_a_.view(n, t, -1) - x_a = self.ae.decoder( - f_a.mean(1), - torch.zeros((n, self.f_c_dim), device=device), - torch.zeros((n, self.f_p_dim), device=device) + def _separate_decode(self, f_a, f_c, f_p): + x_1 = torch.cat((f_a, f_c, f_p), dim=1) + x_1 = self.ae.decoder.fc(x_1).view( + -1, + self.ae.decoder.feature_channels * 8, + self.ae.decoder.h_0, + self.ae.decoder.w_0 ) - return x_a - - def _decode_cano_feature(self, f_c_, n, t, device): - # Decode average canonical features to higher dimension - f_c = f_c_.view(n, t, -1) - x_c = self.ae.decoder( - torch.zeros((n, self.f_a_dim), device=device), - f_c.mean(1), - torch.zeros((n, self.f_p_dim), device=device), - cano_only=True - ) - return x_c - - def _decode_pose_feature(self, f_p_, n, t, c, h, w, device): - # Decode pose features to images - x_p_ = self.ae.decoder( - torch.zeros((n * t, self.f_a_dim), device=device), - torch.zeros((n * t, self.f_c_dim), device=device), - f_p_ - ) - x_p = x_p_.view(n, t, c, h, w) - return x_p + x_1 = F.relu(x_1, inplace=True) + x_2 = self.ae.decoder.trans_conv1(x_1) + x_3 = self.ae.decoder.trans_conv2(x_2) + x_4 = self.ae.decoder.trans_conv3(x_3) + image = torch.sigmoid(self.ae.decoder.trans_conv4(x_4)) + x = (x_1, x_2, x_3, x_4) + return x, image diff --git a/requirements.txt b/requirements.txt index 4d30e17..8b0a41b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -torch~=1.7.1 -torchvision~=0.8.0a0+ecf4e9c -numpy~=1.19.4 -tqdm~=4.57.0 +torch~=1.8.0 +torchvision~=0.9.0a0+3f090d0 +numpy~=1.20.1 +tqdm~=4.59.0 Pillow~=8.1.0 scikit-learn~=0.24.0
\ No newline at end of file diff --git a/utils/configuration.py b/utils/configuration.py index 340815b..46149b3 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -8,6 +8,7 @@ class SystemConfiguration(TypedDict): CUDA_VISIBLE_DEVICES: str save_dir: str image_log_on: bool + val_size: int class DatasetConfiguration(TypedDict): @@ -35,7 +36,6 @@ class ModelHPConfiguration(TypedDict): class OptimizerHPConfiguration(TypedDict): - start_iter: int lr: int betas: Tuple[float, float] eps: float @@ -44,8 +44,8 @@ class OptimizerHPConfiguration(TypedDict): class SchedulerHPConfiguration(TypedDict): - step_size: int - gamma: float + start_step: int + final_gamma: float class HyperparameterConfiguration(TypedDict): diff --git a/utils/dataset.py b/utils/dataset.py index 72cf050..41e2f1e 100644 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -111,9 +111,9 @@ class CASIAB(data.Dataset): # in Bag #2 condition from 90 degree angle classes, conditions, views = [], [], [] if selector: - selected_classes = selector.pop('classes', None) - selected_conditions = selector.pop('conditions', None) - selected_views = selector.pop('views', None) + selected_classes = selector.get('classes', None) + selected_conditions = selector.get('conditions', None) + selected_views = selector.get('views', None) class_regex = r'\d{3}' condition_regex = r'(nm|bg|cl)-0[0-6]' diff --git a/utils/sampler.py b/utils/sampler.py index 0977f94..581d7a2 100644 --- a/utils/sampler.py +++ b/utils/sampler.py @@ -15,7 +15,18 @@ class TripletSampler(data.Sampler): ): super().__init__(data_source) self.metadata_labels = data_source.metadata['labels'] + metadata_conditions = data_source.metadata['conditions'] + self.subsets = {} + for condition in metadata_conditions: + pre, _ = condition.split('-') + if self.subsets.get(pre, None) is None: + self.subsets[pre] = [] + self.subsets[pre].append(condition) + self.num_subsets = len(self.subsets) + self.num_seq = {pre: len(seq) for (pre, seq) in self.subsets.items()} + self.min_num_seq = min(self.num_seq.values()) self.labels = data_source.labels + self.conditions = data_source.conditions self.length = len(self.labels) self.indexes = np.arange(0, self.length) (self.pr, self.k) = batch_size @@ -26,15 +37,31 @@ class TripletSampler(data.Sampler): # Sample pr subjects by sampling labels appeared in dataset sampled_subjects = random.sample(self.metadata_labels, k=self.pr) for label in sampled_subjects: - clips_from_subject = self.indexes[self.labels == label].tolist() + mask = self.labels == label + # Fix unbalanced datasets + if self.num_subsets > 1: + condition_mask = np.zeros(self.conditions.shape, dtype=bool) + for num, conditions_ in zip( + self.num_seq.values(), self.subsets.values() + ): + if num > self.min_num_seq: + conditions = random.sample( + conditions_, self.min_num_seq + ) + else: + conditions = conditions_ + for condition in conditions: + condition_mask |= self.conditions == condition + mask &= condition_mask + clips = self.indexes[mask].tolist() # Sample k clips from the subject without replacement if # have enough clips, k more clips will sampled for # disentanglement k = self.k * 2 - if len(clips_from_subject) >= k: - _sampled_indexes = random.sample(clips_from_subject, k=k) + if len(clips) >= k: + _sampled_indexes = random.sample(clips, k=k) else: - _sampled_indexes = random.choices(clips_from_subject, k=k) + _sampled_indexes = random.choices(clips, k=k) sampled_indexes += _sampled_indexes yield sampled_indexes |