diff options
-rw-r--r-- | config.py | 21 | ||||
-rw-r--r-- | models/hpm.py | 25 | ||||
-rw-r--r-- | models/layers.py | 9 | ||||
-rw-r--r-- | models/model.py | 144 | ||||
-rw-r--r-- | models/part_net.py | 18 | ||||
-rw-r--r-- | models/rgb_part_net.py | 32 | ||||
-rw-r--r-- | test/part_net.py | 2 | ||||
-rw-r--r-- | utils/configuration.py | 20 | ||||
-rw-r--r-- | utils/triplet_loss.py | 40 |
9 files changed, 138 insertions, 173 deletions
@@ -50,19 +50,17 @@ config: Configuration = { 'ae_feature_channels': 64, # Appearance, canonical and pose feature dimensions 'f_a_c_p_dims': (192, 192, 96), - # Use 1x1 convolution in dimensionality reduction - 'hpm_use_1x1conv': False, # HPM pyramid scales, of which sum is number of parts 'hpm_scales': (1, 2, 4, 8), # Global pooling method 'hpm_use_avg_pool': True, 'hpm_use_max_pool': True, - # Attention squeeze ratio - 'tfa_squeeze_ratio': 4, # Number of parts after Part Net 'tfa_num_parts': 16, - # Embedding dimension for each part - 'embedding_dims': 256, + # Attention squeeze ratio + 'tfa_squeeze_ratio': 4, + # Embedding dimensions for each part + 'embedding_dims': (256, 256), # Batch Hard or Batch All 'triplet_is_hard': True, # Use non-zero mean or sum @@ -91,9 +89,14 @@ config: Configuration = { }, 'scheduler': { # Step start to decay - 'start_step': 15_000, + 'start_step': 500, # Multiplicative factor of decay in the end - 'final_gamma': 0.001, + 'final_gamma': 0.01, + + # Local parameters (override global ones) + 'hpm': { + 'final_gamma': 0.001 + } } }, # Model metadata @@ -107,6 +110,6 @@ config: Configuration = { # Restoration iteration (multiple models, e.g. nm, bg and cl) 'restore_iters': (0, 0, 0), # Total iteration for training (multiple models) - 'total_iters': (25_000, 25_000, 25_000), + 'total_iters': (30_000, 40_000, 60_000), }, } diff --git a/models/hpm.py b/models/hpm.py index 9879cfb..8186b20 100644 --- a/models/hpm.py +++ b/models/hpm.py @@ -9,32 +9,26 @@ class HorizontalPyramidMatching(nn.Module): self, in_channels: int, out_channels: int = 128, - use_1x1conv: bool = False, scales: tuple[int, ...] = (1, 2, 4), use_avg_pool: bool = True, use_max_pool: bool = False, - **kwargs ): super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.use_1x1conv = use_1x1conv self.scales = scales + self.num_parts = sum(scales) self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool self.pyramids = nn.ModuleList([ - self._make_pyramid(scale, **kwargs) for scale in self.scales + self._make_pyramid(scale) for scale in scales ]) + self.fc_mat = nn.Parameter( + torch.empty(self.num_parts, in_channels, out_channels) + ) - def _make_pyramid(self, scale: int, **kwargs): + def _make_pyramid(self, scale: int): pyramid = nn.ModuleList([ - HorizontalPyramidPooling(self.in_channels, - self.out_channels, - use_1x1conv=self.use_1x1conv, - use_avg_pool=self.use_avg_pool, - use_max_pool=self.use_max_pool, - **kwargs) + HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool) for _ in range(scale) ]) return pyramid @@ -52,4 +46,9 @@ class HorizontalPyramidMatching(nn.Module): x_slice = x_slice.view(n, -1) feature.append(x_slice) x = torch.stack(feature) + + # p, n, c + x = x @ self.fc_mat + # p, n, d + return x diff --git a/models/layers.py b/models/layers.py index f1d72b6..c609698 100644 --- a/models/layers.py +++ b/models/layers.py @@ -167,17 +167,10 @@ class BasicConv1d(nn.Module): class HorizontalPyramidPooling(nn.Module): def __init__( self, - in_channels: int, - out_channels: int, - use_1x1conv: bool = False, use_avg_pool: bool = True, use_max_pool: bool = False, - **kwargs ): super().__init__() - self.use_1x1conv = use_1x1conv - if use_1x1conv: - self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs) self.use_avg_pool = use_avg_pool self.use_max_pool = use_max_pool assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.' @@ -191,6 +184,4 @@ class HorizontalPyramidPooling(nn.Module): x = self.avg_pool(x) elif not self.use_avg_pool and self.use_max_pool: x = self.max_pool(x) - if self.use_1x1conv: - x = self.conv(x) return x diff --git a/models/model.py b/models/model.py index f515e05..bf79564 100644 --- a/models/model.py +++ b/models/model.py @@ -13,13 +13,15 @@ from torch.utils.data.dataloader import default_collate from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm +from models.hpm import HorizontalPyramidMatching +from models.part_net import PartNet from models.rgb_part_net import RGBPartNet from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses from utils.sampler import TripletSampler -from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss +from utils.triplet_loss import BatchTripletLoss class Model: @@ -69,7 +71,8 @@ class Model: self._dataset_sig: str = 'undefined' self.rgb_pn: Optional[RGBPartNet] = None - self.triplet_loss: Optional[JointBatchTripletLoss] = None + self.triplet_loss_hpm: Optional[BatchTripletLoss] = None + self.triplet_loss_pn: Optional[BatchTripletLoss] = None self.optimizer: Optional[optim.Adam] = None self.scheduler: Optional[optim.lr_scheduler.StepLR] = None self.writer: Optional[SummaryWriter] = None @@ -149,26 +152,28 @@ class Model: triplet_margins = model_hp.pop('triplet_margins', None) optim_hp: dict = self.hp.get('optimizer', {}).copy() ae_optim_hp = optim_hp.pop('auto_encoder', {}) - pn_optim_hp = optim_hp.pop('part_net', {}) hpm_optim_hp = optim_hp.pop('hpm', {}) - fc_optim_hp = optim_hp.pop('fc', {}) + pn_optim_hp = optim_hp.pop('part_net', {}) sched_hp = self.hp.get('scheduler', {}) + ae_sched_hp = sched_hp.get('auto_encoder', {}) + hpm_sched_hp = sched_hp.get('hpm', {}) + pn_sched_hp = sched_hp.get('part_net', {}) + self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp, image_log_on=self.image_log_on) # Hard margins if triplet_margins: - # Same margins - if triplet_margins[0] == triplet_margins[1]: - self.triplet_loss = BatchTripletLoss( - triplet_is_hard, triplet_margins[0] - ) - else: # Different margins - self.triplet_loss = JointBatchTripletLoss( - self.rgb_pn.module.hpm_num_parts, - triplet_is_hard, triplet_is_mean, triplet_margins - ) + self.triplet_loss_hpm = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, triplet_margins[0] + ) + self.triplet_loss_pn = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, triplet_margins[1] + ) else: # Soft margins - self.triplet_loss = BatchTripletLoss( + self.triplet_loss_hpm = BatchTripletLoss( + triplet_is_hard, triplet_is_mean, None + ) + self.triplet_loss_pn = BatchTripletLoss( triplet_is_hard, triplet_is_mean, None ) @@ -179,26 +184,34 @@ class Model: # Try to accelerate computation using CUDA or others self.rgb_pn = nn.DataParallel(self.rgb_pn) self.rgb_pn = self.rgb_pn.to(self.device) - self.triplet_loss = nn.DataParallel(self.triplet_loss) - self.triplet_loss = self.triplet_loss.to(self.device) + self.triplet_loss_hpm = nn.DataParallel(self.triplet_loss_hpm) + self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device) + self.triplet_loss_pn = nn.DataParallel(self.triplet_loss_pn) + self.triplet_loss_pn = self.triplet_loss_pn.to(self.device) self.optimizer = optim.Adam([ {'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp}, - {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp}, {'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp}, - {'params': self.rgb_pn.module.fc_mat, **fc_optim_hp} + {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp}, ], **optim_hp) - sched_final_gamma = sched_hp.get('final_gamma', 0.001) - sched_start_step = sched_hp.get('start_step', 15_000) - all_step = self.total_iter - sched_start_step - - def lr_lambda(epoch): - if epoch > sched_start_step: - passed_step = epoch - sched_start_step - return sched_final_gamma ** (passed_step / all_step) - else: - return 1 + + start_step = sched_hp.get('start_step', 15_000) + final_gamma = sched_hp.get('final_gamma', 0.001) + ae_start_step = ae_sched_hp.get('start_step', start_step) + ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma) + ae_all_step = self.total_iter - ae_start_step + hpm_start_step = hpm_sched_hp.get('start_step', start_step) + hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma) + hpm_all_step = self.total_iter - hpm_start_step + pn_start_step = pn_sched_hp.get('start_step', start_step) + pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma) + pn_all_step = self.total_iter - pn_start_step self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[ - lr_lambda, lr_lambda, lr_lambda, lr_lambda + lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step) + if t > ae_start_step else 1, + lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step) + if t > hpm_start_step else 1, + lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step) + if t > pn_start_step else 1, ]) self.writer = SummaryWriter(self._log_name) @@ -223,7 +236,7 @@ class Model: running_loss = torch.zeros(5, device=self.device) print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}", f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}", - f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}") + f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}") for (batch_c1, batch_c2) in dataloader: self.curr_iter += 1 # Zero the parameter gradients @@ -231,7 +244,7 @@ class Model: # forward + backward + optimize x_c1 = batch_c1['clip'].to(self.device) x_c2 = batch_c2['clip'].to(self.device) - embedding, images, feature_for_loss = self.rgb_pn(x_c1, x_c2) + embed_c, embed_p, images, feature_for_loss = self.rgb_pn(x_c1, x_c2) x_c1_pred = feature_for_loss[0] xrecon_loss = torch.stack([ F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :]) @@ -249,13 +262,16 @@ class Model: ) * 10 y = batch_c1['label'].to(self.device) # Duplicate labels for each part - y = y.repeat(self.rgb_pn.module.num_total_parts, 1) - embedding = embedding.transpose(0, 1) - triplet_loss, dist, num_non_zero = self.triplet_loss(embedding, y) - hpm_loss = triplet_loss[:self.rgb_pn.module.hpm_num_parts].mean() - pn_loss = triplet_loss[self.rgb_pn.module.hpm_num_parts:].mean() + y = y.repeat(self.rgb_pn.module.num_parts, 1) + trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm( + embed_c.transpose(0, 1), y[:self.rgb_pn.module.hpm.num_parts] + ) + trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn( + embed_p.transpose(0, 1), y[self.rgb_pn.module.hpm.num_parts:] + ) losses = torch.stack(( - xrecon_loss, cano_cons_loss, pose_sim_loss, hpm_loss, pn_loss + xrecon_loss, cano_cons_loss, pose_sim_loss, + trip_loss_hpm.mean(), trip_loss_pn.mean() )) loss = losses.sum() loss.backward() @@ -271,37 +287,34 @@ class Model: 'Pose similarity loss': pose_sim_loss }, self.curr_iter) self.writer.add_scalars('Loss/triplet loss', { - 'HPM': hpm_loss, 'PartNet': pn_loss + 'HPM': losses[3], + 'PartNet': losses[4] }, self.curr_iter) # None-zero losses in batch - if num_non_zero is not None: + if hpm_num_non_zero is not None and hpm_num_non_zero is not None: self.writer.add_scalars('Loss/non-zero counts', { - 'HPM': num_non_zero[ - :self.rgb_pn.module.hpm_num_parts].mean(), - 'PartNet': num_non_zero[ - self.rgb_pn.module.hpm_num_parts:].mean() + 'HPM': hpm_num_non_zero.mean(), + 'PartNet': pn_num_non_zero.mean() }, self.curr_iter) # Embedding distance - mean_hpm_dist = dist[:self.rgb_pn.module.hpm_num_parts].mean(0) + mean_hpm_dist = hpm_dist.mean(0) self._add_ranked_scalars( 'Embedding/HPM distance', mean_hpm_dist, num_pos_pairs, num_pairs, self.curr_iter ) - mean_pa_dist = dist[self.rgb_pn.module.hpm_num_parts:].mean(0) + mean_pa_dist = pn_dist.mean(0) self._add_ranked_scalars( 'Embedding/ParNet distance', mean_pa_dist, num_pos_pairs, num_pairs, self.curr_iter ) # Embedding norm - mean_hpm_embedding = embedding[ - :self.rgb_pn.module.hpm_num_parts].mean(0) + mean_hpm_embedding = embed_c.mean(0) mean_hpm_norm = mean_hpm_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/HPM norm', mean_hpm_norm, self.k, self.pr * self.k, self.curr_iter ) - mean_pa_embedding = embedding[ - self.rgb_pn.module.hpm_num_parts:].mean(0) + mean_pa_embedding = embed_p.mean(0) mean_pa_norm = mean_pa_embedding.norm(dim=-1) self._add_ranked_scalars( 'Embedding/PartNet norm', mean_pa_norm, @@ -309,10 +322,9 @@ class Model: ) # Learning rate lrs = self.scheduler.get_last_lr() - # Write learning rates - self.writer.add_scalar( - 'Learning rate', lrs[0], self.curr_iter - ) + self.writer.add_scalars('Learning rate', dict(zip(( + 'Auto-encoder', 'HPM', 'PartNet' + ), lrs)), self.curr_iter) if self.curr_iter % 100 == 0: # Write disentangled images @@ -337,7 +349,7 @@ class Model: print(f'{hour:02}:{minute:02}:{second:02}', f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}', '{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100), - f'{lrs[0]:.3e}') + '{:.3e} {:.3e} {:.3e}'.format(*lrs)) running_loss.zero_() # Step scheduler @@ -432,13 +444,16 @@ class Model: unit='clips'): gallery_samples_c.append(self._get_eval_sample(sample)) gallery_samples[condition] = default_collate(gallery_samples_c) + gallery_samples['meta'] = self._gallery_dataset_meta # Probe probe_samples_c = [] for sample in tqdm(probe_dataloader, desc=f'Transforming probe {condition}', unit='clips'): probe_samples_c.append(self._get_eval_sample(sample)) - probe_samples[condition] = default_collate(probe_samples_c) + probe_samples_c = default_collate(probe_samples_c) + probe_samples_c['meta'] = self._probe_datasets_meta[condition] + probe_samples[condition] = probe_samples_c return gallery_samples, probe_samples @@ -453,15 +468,15 @@ class Model: **{'feature': feature} } + @staticmethod def evaluate( - self, - gallery_samples: dict[str, Union[list[str], torch.Tensor]], - probe_samples: dict[str, dict[str, Union[list[str], torch.Tensor]]], + gallery_samples: dict[str, dict[str, Union[list, torch.Tensor]]], + probe_samples: dict[str, dict[str, Union[list, torch.Tensor]]], num_ranks: int = 5 ) -> dict[str, torch.Tensor]: - conditions = gallery_samples.keys() - gallery_views_meta = self._gallery_dataset_meta['views'] - probe_views_meta = list(self._probe_datasets_meta.values())[0]['views'] + conditions = list(probe_samples.keys()) + gallery_views_meta = gallery_samples['meta']['views'] + probe_views_meta = probe_samples[conditions[0]]['meta']['views'] accuracy = { condition: torch.empty( len(gallery_views_meta), len(probe_views_meta), num_ranks @@ -474,7 +489,7 @@ class Model: (labels_g, _, views_g, features_g) = gallery_samples_c.values() views_g = np.asarray(views_g) probe_samples_c = probe_samples[condition] - (labels_p, _, views_p, features_p) = probe_samples_c.values() + (labels_p, _, views_p, features_p, _) = probe_samples_c.values() views_p = np.asarray(views_p) accuracy_c = accuracy[condition] for (v_g_i, view_g) in enumerate(gallery_views_meta): @@ -499,7 +514,6 @@ class Model: positive_counts = positive_mat.sum(0) total_counts, _ = dist.size() accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts - return accuracy def _load_pretrained( @@ -570,7 +584,7 @@ class Model: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) - elif isinstance(m, RGBPartNet): + elif isinstance(m, (HorizontalPyramidMatching, PartNet)): nn.init.xavier_uniform_(m.fc_mat) def _parse_dataset_config( diff --git a/models/part_net.py b/models/part_net.py index 29cf9cd..f2236bf 100644 --- a/models/part_net.py +++ b/models/part_net.py @@ -111,17 +111,21 @@ class PartNet(nn.Module): def __init__( self, in_channels: int = 128, + embedding_dims: int = 256, + num_parts: int = 16, squeeze_ratio: int = 4, - num_part: int = 16 ): super().__init__() - self.num_part = num_part - self.tfa = TemporalFeatureAggregator( - in_channels, squeeze_ratio, self.num_part - ) + self.num_part = num_parts self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1) + self.tfa = TemporalFeatureAggregator( + in_channels, squeeze_ratio, self.num_part + ) + self.fc_mat = nn.Parameter( + torch.empty(num_parts, in_channels, embedding_dims) + ) def forward(self, x): n, t, c, h, w = x.size() @@ -138,4 +142,8 @@ class PartNet(nn.Module): # p, n, t, c x = self.tfa(x) + + # p, n, c + x = x @ self.fc_mat + # p, n, d return x diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 7785bb7..fdeed17 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -13,39 +13,31 @@ class RGBPartNet(nn.Module): ae_in_size: tuple[int, int] = (64, 48), ae_feature_channels: int = 64, f_a_c_p_dims: tuple[int, int, int] = (128, 128, 64), - hpm_use_1x1conv: bool = False, hpm_scales: tuple[int, ...] = (1, 2, 4), hpm_use_avg_pool: bool = True, hpm_use_max_pool: bool = True, tfa_squeeze_ratio: int = 4, tfa_num_parts: int = 16, - embedding_dims: int = 256, + embedding_dims: tuple[int] = (256, 256), image_log_on: bool = False ): super().__init__() self.h, self.w = ae_in_size (self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims - self.hpm_num_parts = sum(hpm_scales) self.image_log_on = image_log_on self.ae = AutoEncoder( ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims ) self.pn_in_channels = ae_feature_channels * 2 - self.pn = PartNet( - self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts - ) self.hpm = HorizontalPyramidMatching( - ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv, - hpm_scales, hpm_use_avg_pool, hpm_use_max_pool + self.pn_in_channels, embedding_dims[0], hpm_scales, + hpm_use_avg_pool, hpm_use_max_pool ) - self.num_total_parts = self.hpm_num_parts + tfa_num_parts - empty_fc = torch.empty(self.num_total_parts, - self.pn_in_channels, embedding_dims) - self.fc_mat = nn.Parameter(empty_fc) + self.pn = PartNet(self.pn_in_channels, embedding_dims[1], + tfa_num_parts, tfa_squeeze_ratio) - def fc(self, x): - return x @ self.fc_mat + self.num_parts = self.hpm.num_parts + tfa_num_parts def forward(self, x_c1, x_c2=None): # Step 1: Disentanglement @@ -55,21 +47,17 @@ class RGBPartNet(nn.Module): # Step 2.a: Static Gait Feature Aggregation & HPM # n, c, h, w x_c = self.hpm(x_c) - # p, n, c + # p, n, d # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation) # n, t, c, h, w x_p = self.pn(x_p) - # p, n, c - - # Step 3: Cat feature map together and fc - x = torch.cat((x_c, x_p)) - x = self.fc(x) + # p, n, d if self.training: - return x.transpose(0, 1), images, f_loss + return x_c.transpose(0, 1), x_p.transpose(0, 1), images, f_loss else: - return x.unsqueeze(1).view(-1) + return torch.cat((x_c, x_p)).unsqueeze(1).view(-1) def _disentangle(self, x_c1_t2, x_c2_t2=None): n, t, c, h, w = x_c1_t2.size() diff --git a/test/part_net.py b/test/part_net.py index 25e92ae..fada2c4 100644 --- a/test/part_net.py +++ b/test/part_net.py @@ -64,7 +64,7 @@ def test_custom_part_net(): paddings=((2, 1), (1, 1), (1, 1), (1, 1)), halving=(1, 1, 3, 3), squeeze_ratio=8, - num_part=8) + num_parts=8) x = torch.rand(T, N, 1, H, W) x = pa(x) diff --git a/utils/configuration.py b/utils/configuration.py index 0f8d9ff..f6ac182 100644 --- a/utils/configuration.py +++ b/utils/configuration.py @@ -33,16 +33,11 @@ class ModelHPConfiguration(TypedDict): ae_feature_channels: int f_a_c_p_dims: tuple[int, int, int] hpm_scales: tuple[int, ...] - hpm_use_1x1conv: bool hpm_use_avg_pool: bool hpm_use_max_pool: bool - fpfe_feature_channels: int - fpfe_kernel_sizes: tuple[tuple, ...] - fpfe_paddings: tuple[tuple, ...] - fpfe_halving: tuple[int, ...] - tfa_squeeze_ratio: int tfa_num_parts: int - embedding_dims: int + tfa_squeeze_ratio: int + embedding_dims: tuple[int] triplet_is_hard: bool triplet_is_mean: bool triplet_margins: tuple[float, float] @@ -63,14 +58,21 @@ class OptimizerHPConfiguration(TypedDict): weight_decay: float amsgrad: bool auto_encoder: SubOptimizerHPConfiguration - part_net: SubOptimizerHPConfiguration hpm: SubOptimizerHPConfiguration - fc: SubOptimizerHPConfiguration + part_net: SubOptimizerHPConfiguration + + +class SubSchedulerHPConfiguration(TypedDict): + start_step: int + final_gamma: float class SchedulerHPConfiguration(TypedDict): start_step: int final_gamma: float + auto_encoder: SubSchedulerHPConfiguration + hpm: SubSchedulerHPConfiguration + part_net: SubSchedulerHPConfiguration class HyperparameterConfiguration(TypedDict): diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index e05b69d..03fff21 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -85,43 +85,3 @@ class BatchTripletLoss(nn.Module): non_zero_mean = losses.sum(1) / non_zero_counts non_zero_mean[non_zero_counts == 0] = 0 return non_zero_mean - - -class JointBatchTripletLoss(BatchTripletLoss): - def __init__( - self, - hpm_num_parts: int, - is_hard: bool = True, - is_mean: bool = True, - margins: tuple[float, float] = (0.2, 0.2) - ): - super().__init__(is_hard, is_mean) - self.hpm_num_parts = hpm_num_parts - self.margin_hpm, self.margin_pn = margins - - def forward(self, x, y): - p, n, c = x.size() - dist = self._batch_distance(x) - flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device) - flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]] - - if self.is_hard: - positive_negative_dist = self._hard_distance(dist, y, p, n) - else: # is_all - positive_negative_dist = self._all_distance(dist, y, p, n) - - hpm_part_loss = F.relu( - self.margin_hpm + positive_negative_dist[:self.hpm_num_parts] - ) - pn_part_loss = F.relu( - self.margin_pn + positive_negative_dist[self.hpm_num_parts:] - ) - losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1) - - non_zero_counts = (losses != 0).sum(1).float() - if self.is_mean: - loss_metric = self._none_zero_mean(losses, non_zero_counts) - else: # is_sum - loss_metric = losses.sum(1) - - return loss_metric, flat_dist, non_zero_counts |