summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-12 20:30:25 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-12 20:34:00 +0800
commit30b475c0a27e0f848743abf0f909607defc6a3ee (patch)
treeaaab163d3d76a835c32ce5014ce62637550d0b0d
parent3d8fc322623ba61610fd206b9f52b406e85cae61 (diff)
parente83ae0bcb5c763636fd522c2712a3c8aef558f3c (diff)
Merge branch 'data_parallel' into data_parallel_py3.8
# Conflicts: # models/hpm.py # models/model.py # models/rgb_part_net.py # utils/configuration.py # utils/triplet_loss.py
-rw-r--r--config.py21
-rw-r--r--models/hpm.py25
-rw-r--r--models/layers.py9
-rw-r--r--models/model.py143
-rw-r--r--models/part_net.py18
-rw-r--r--models/rgb_part_net.py32
-rw-r--r--test/part_net.py2
-rw-r--r--utils/configuration.py20
-rw-r--r--utils/triplet_loss.py42
9 files changed, 139 insertions, 173 deletions
diff --git a/config.py b/config.py
index 2378282..fdd4328 100644
--- a/config.py
+++ b/config.py
@@ -50,19 +50,17 @@ config: Configuration = {
'ae_feature_channels': 64,
# Appearance, canonical and pose feature dimensions
'f_a_c_p_dims': (192, 192, 96),
- # Use 1x1 convolution in dimensionality reduction
- 'hpm_use_1x1conv': False,
# HPM pyramid scales, of which sum is number of parts
'hpm_scales': (1, 2, 4, 8),
# Global pooling method
'hpm_use_avg_pool': True,
'hpm_use_max_pool': True,
- # Attention squeeze ratio
- 'tfa_squeeze_ratio': 4,
# Number of parts after Part Net
'tfa_num_parts': 16,
- # Embedding dimension for each part
- 'embedding_dims': 256,
+ # Attention squeeze ratio
+ 'tfa_squeeze_ratio': 4,
+ # Embedding dimensions for each part
+ 'embedding_dims': (256, 256),
# Batch Hard or Batch All
'triplet_is_hard': True,
# Use non-zero mean or sum
@@ -91,9 +89,14 @@ config: Configuration = {
},
'scheduler': {
# Step start to decay
- 'start_step': 15_000,
+ 'start_step': 500,
# Multiplicative factor of decay in the end
- 'final_gamma': 0.001,
+ 'final_gamma': 0.01,
+
+ # Local parameters (override global ones)
+ 'hpm': {
+ 'final_gamma': 0.001
+ }
}
},
# Model metadata
@@ -107,6 +110,6 @@ config: Configuration = {
# Restoration iteration (multiple models, e.g. nm, bg and cl)
'restore_iters': (0, 0, 0),
# Total iteration for training (multiple models)
- 'total_iters': (25_000, 25_000, 25_000),
+ 'total_iters': (30_000, 40_000, 60_000),
},
}
diff --git a/models/hpm.py b/models/hpm.py
index b49be3a..8320569 100644
--- a/models/hpm.py
+++ b/models/hpm.py
@@ -11,32 +11,26 @@ class HorizontalPyramidMatching(nn.Module):
self,
in_channels: int,
out_channels: int = 128,
- use_1x1conv: bool = False,
scales: Tuple[int, ...] = (1, 2, 4),
use_avg_pool: bool = True,
use_max_pool: bool = False,
- **kwargs
):
super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.use_1x1conv = use_1x1conv
self.scales = scales
+ self.num_parts = sum(scales)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
self.pyramids = nn.ModuleList([
- self._make_pyramid(scale, **kwargs) for scale in self.scales
+ self._make_pyramid(scale) for scale in scales
])
+ self.fc_mat = nn.Parameter(
+ torch.empty(self.num_parts, in_channels, out_channels)
+ )
- def _make_pyramid(self, scale: int, **kwargs):
+ def _make_pyramid(self, scale: int):
pyramid = nn.ModuleList([
- HorizontalPyramidPooling(self.in_channels,
- self.out_channels,
- use_1x1conv=self.use_1x1conv,
- use_avg_pool=self.use_avg_pool,
- use_max_pool=self.use_max_pool,
- **kwargs)
+ HorizontalPyramidPooling(self.use_avg_pool, self.use_max_pool)
for _ in range(scale)
])
return pyramid
@@ -54,4 +48,9 @@ class HorizontalPyramidMatching(nn.Module):
x_slice = x_slice.view(n, -1)
feature.append(x_slice)
x = torch.stack(feature)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
+
return x
diff --git a/models/layers.py b/models/layers.py
index e30d0c4..a0933e8 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -167,17 +167,10 @@ class BasicConv1d(nn.Module):
class HorizontalPyramidPooling(nn.Module):
def __init__(
self,
- in_channels: int,
- out_channels: int,
- use_1x1conv: bool = False,
use_avg_pool: bool = True,
use_max_pool: bool = False,
- **kwargs
):
super().__init__()
- self.use_1x1conv = use_1x1conv
- if use_1x1conv:
- self.conv = BasicConv2d(in_channels, out_channels, 1, **kwargs)
self.use_avg_pool = use_avg_pool
self.use_max_pool = use_max_pool
assert use_avg_pool or use_max_pool, 'Pooling layer(s) required.'
@@ -191,6 +184,4 @@ class HorizontalPyramidPooling(nn.Module):
x = self.avg_pool(x)
elif not self.use_avg_pool and self.use_max_pool:
x = self.max_pool(x)
- if self.use_1x1conv:
- x = self.conv(x)
return x
diff --git a/models/model.py b/models/model.py
index 4335bc9..91d6651 100644
--- a/models/model.py
+++ b/models/model.py
@@ -13,13 +13,15 @@ from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
+from models.hpm import HorizontalPyramidMatching
+from models.part_net import PartNet
from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
-from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss
+from utils.triplet_loss import BatchTripletLoss
class Model:
@@ -69,7 +71,8 @@ class Model:
self._dataset_sig: str = 'undefined'
self.rgb_pn: Optional[RGBPartNet] = None
- self.triplet_loss: Optional[JointBatchTripletLoss] = None
+ self.triplet_loss_hpm: Optional[BatchTripletLoss] = None
+ self.triplet_loss_pn: Optional[BatchTripletLoss] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
@@ -149,26 +152,28 @@ class Model:
triplet_margins = model_hp.pop('triplet_margins', None)
optim_hp: Dict = self.hp.get('optimizer', {}).copy()
ae_optim_hp = optim_hp.pop('auto_encoder', {})
- pn_optim_hp = optim_hp.pop('part_net', {})
hpm_optim_hp = optim_hp.pop('hpm', {})
- fc_optim_hp = optim_hp.pop('fc', {})
+ pn_optim_hp = optim_hp.pop('part_net', {})
sched_hp = self.hp.get('scheduler', {})
+ ae_sched_hp = sched_hp.get('auto_encoder', {})
+ hpm_sched_hp = sched_hp.get('hpm', {})
+ pn_sched_hp = sched_hp.get('part_net', {})
+
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
# Hard margins
if triplet_margins:
- # Same margins
- if triplet_margins[0] == triplet_margins[1]:
- self.triplet_loss = BatchTripletLoss(
- triplet_is_hard, triplet_margins[0]
- )
- else: # Different margins
- self.triplet_loss = JointBatchTripletLoss(
- self.rgb_pn.module.hpm_num_parts,
- triplet_is_hard, triplet_is_mean, triplet_margins
- )
+ self.triplet_loss_hpm = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, triplet_margins[0]
+ )
+ self.triplet_loss_pn = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, triplet_margins[1]
+ )
else: # Soft margins
- self.triplet_loss = BatchTripletLoss(
+ self.triplet_loss_hpm = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, None
+ )
+ self.triplet_loss_pn = BatchTripletLoss(
triplet_is_hard, triplet_is_mean, None
)
@@ -179,26 +184,34 @@ class Model:
# Try to accelerate computation using CUDA or others
self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
- self.triplet_loss = nn.DataParallel(self.triplet_loss)
- self.triplet_loss = self.triplet_loss.to(self.device)
+ self.triplet_loss_hpm = nn.DataParallel(self.triplet_loss_hpm)
+ self.triplet_loss_hpm = self.triplet_loss_hpm.to(self.device)
+ self.triplet_loss_pn = nn.DataParallel(self.triplet_loss_pn)
+ self.triplet_loss_pn = self.triplet_loss_pn.to(self.device)
self.optimizer = optim.Adam([
{'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp},
- {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
{'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp},
- {'params': self.rgb_pn.module.fc_mat, **fc_optim_hp}
+ {'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
], **optim_hp)
- sched_final_gamma = sched_hp.get('final_gamma', 0.001)
- sched_start_step = sched_hp.get('start_step', 15_000)
- all_step = self.total_iter - sched_start_step
-
- def lr_lambda(epoch):
- if epoch > sched_start_step:
- passed_step = epoch - sched_start_step
- return sched_final_gamma ** (passed_step / all_step)
- else:
- return 1
+
+ start_step = sched_hp.get('start_step', 15_000)
+ final_gamma = sched_hp.get('final_gamma', 0.001)
+ ae_start_step = ae_sched_hp.get('start_step', start_step)
+ ae_final_gamma = ae_sched_hp.get('final_gamma', final_gamma)
+ ae_all_step = self.total_iter - ae_start_step
+ hpm_start_step = hpm_sched_hp.get('start_step', start_step)
+ hpm_final_gamma = hpm_sched_hp.get('final_gamma', final_gamma)
+ hpm_all_step = self.total_iter - hpm_start_step
+ pn_start_step = pn_sched_hp.get('start_step', start_step)
+ pn_final_gamma = pn_sched_hp.get('final_gamma', final_gamma)
+ pn_all_step = self.total_iter - pn_start_step
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
- lr_lambda, lr_lambda, lr_lambda, lr_lambda
+ lambda t: ae_final_gamma ** ((t - ae_start_step) / ae_all_step)
+ if t > ae_start_step else 1,
+ lambda t: hpm_final_gamma ** ((t - hpm_start_step) / hpm_all_step)
+ if t > hpm_start_step else 1,
+ lambda t: pn_final_gamma ** ((t - pn_start_step) / pn_all_step)
+ if t > pn_start_step else 1,
])
self.writer = SummaryWriter(self._log_name)
@@ -223,7 +236,7 @@ class Model:
running_loss = torch.zeros(5, device=self.device)
print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}")
+ f"{'TripHPM':^8} {'TripPN':^8} {'LRs':^29}")
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -231,7 +244,7 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- embedding, images, feature_for_loss = self.rgb_pn(x_c1, x_c2)
+ embed_c, embed_p, images, feature_for_loss = self.rgb_pn(x_c1, x_c2)
x_c1_pred = feature_for_loss[0]
xrecon_loss = torch.stack([
F.mse_loss(x_c1_pred[:, i, :, :, :], x_c1[:, i, :, :, :])
@@ -249,13 +262,16 @@ class Model:
) * 10
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
- y = y.repeat(self.rgb_pn.module.num_total_parts, 1)
- embedding = embedding.transpose(0, 1)
- triplet_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
- hpm_loss = triplet_loss[:self.rgb_pn.module.hpm_num_parts].mean()
- pn_loss = triplet_loss[self.rgb_pn.module.hpm_num_parts:].mean()
+ y = y.repeat(self.rgb_pn.module.num_parts, 1)
+ trip_loss_hpm, hpm_dist, hpm_num_non_zero = self.triplet_loss_hpm(
+ embed_c.transpose(0, 1), y[:self.rgb_pn.module.hpm.num_parts]
+ )
+ trip_loss_pn, pn_dist, pn_num_non_zero = self.triplet_loss_pn(
+ embed_p.transpose(0, 1), y[self.rgb_pn.module.hpm.num_parts:]
+ )
losses = torch.stack((
- xrecon_loss, cano_cons_loss, pose_sim_loss, hpm_loss, pn_loss
+ xrecon_loss, cano_cons_loss, pose_sim_loss,
+ trip_loss_hpm.mean(), trip_loss_pn.mean()
))
loss = losses.sum()
loss.backward()
@@ -271,37 +287,34 @@ class Model:
'Pose similarity loss': pose_sim_loss
}, self.curr_iter)
self.writer.add_scalars('Loss/triplet loss', {
- 'HPM': hpm_loss, 'PartNet': pn_loss
+ 'HPM': losses[3],
+ 'PartNet': losses[4]
}, self.curr_iter)
# None-zero losses in batch
- if num_non_zero is not None:
+ if hpm_num_non_zero is not None and hpm_num_non_zero is not None:
self.writer.add_scalars('Loss/non-zero counts', {
- 'HPM': num_non_zero[
- :self.rgb_pn.module.hpm_num_parts].mean(),
- 'PartNet': num_non_zero[
- self.rgb_pn.module.hpm_num_parts:].mean()
+ 'HPM': hpm_num_non_zero.mean(),
+ 'PartNet': pn_num_non_zero.mean()
}, self.curr_iter)
# Embedding distance
- mean_hpm_dist = dist[:self.rgb_pn.module.hpm_num_parts].mean(0)
+ mean_hpm_dist = hpm_dist.mean(0)
self._add_ranked_scalars(
'Embedding/HPM distance', mean_hpm_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
- mean_pa_dist = dist[self.rgb_pn.module.hpm_num_parts:].mean(0)
+ mean_pa_dist = pn_dist.mean(0)
self._add_ranked_scalars(
'Embedding/ParNet distance', mean_pa_dist,
num_pos_pairs, num_pairs, self.curr_iter
)
# Embedding norm
- mean_hpm_embedding = embedding[
- :self.rgb_pn.module.hpm_num_parts].mean(0)
+ mean_hpm_embedding = embed_c.mean(0)
mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/HPM norm', mean_hpm_norm,
self.k, self.pr * self.k, self.curr_iter
)
- mean_pa_embedding = embedding[
- self.rgb_pn.module.hpm_num_parts:].mean(0)
+ mean_pa_embedding = embed_p.mean(0)
mean_pa_norm = mean_pa_embedding.norm(dim=-1)
self._add_ranked_scalars(
'Embedding/PartNet norm', mean_pa_norm,
@@ -309,10 +322,9 @@ class Model:
)
# Learning rate
lrs = self.scheduler.get_last_lr()
- # Write learning rates
- self.writer.add_scalar(
- 'Learning rate', lrs[0], self.curr_iter
- )
+ self.writer.add_scalars('Learning rate', dict(zip((
+ 'Auto-encoder', 'HPM', 'PartNet'
+ ), lrs)), self.curr_iter)
if self.curr_iter % 100 == 0:
# Write disentangled images
@@ -337,7 +349,7 @@ class Model:
print(f'{hour:02}:{minute:02}:{second:02}',
f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- f'{lrs[0]:.3e}')
+ '{:.3e} {:.3e} {:.3e}'.format(*lrs))
running_loss.zero_()
# Step scheduler
@@ -432,13 +444,16 @@ class Model:
unit='clips'):
gallery_samples_c.append(self._get_eval_sample(sample))
gallery_samples[condition] = default_collate(gallery_samples_c)
+ gallery_samples['meta'] = self._gallery_dataset_meta
# Probe
probe_samples_c = []
for sample in tqdm(probe_dataloader,
desc=f'Transforming probe {condition}',
unit='clips'):
probe_samples_c.append(self._get_eval_sample(sample))
- probe_samples[condition] = default_collate(probe_samples_c)
+ probe_samples_c = default_collate(probe_samples_c)
+ probe_samples_c['meta'] = self._probe_datasets_meta[condition]
+ probe_samples[condition] = probe_samples_c
return gallery_samples, probe_samples
@@ -453,15 +468,16 @@ class Model:
**{'feature': feature}
}
+ @staticmethod
def evaluate(
self,
- gallery_samples: Dict[str, Union[List[str], torch.Tensor]],
- probe_samples: Dict[str, Dict[str, Union[List[str], torch.Tensor]]],
+ gallery_samples: Dict[str, Dict[str, Union[List, torch.Tensor]]],
+ probe_samples: Dict[str, Dict[str, Union[List, torch.Tensor]]],
num_ranks: int = 5
) -> Dict[str, torch.Tensor]:
- conditions = gallery_samples.keys()
- gallery_views_meta = self._gallery_dataset_meta['views']
- probe_views_meta = list(self._probe_datasets_meta.values())[0]['views']
+ conditions = list(probe_samples.keys())
+ gallery_views_meta = gallery_samples['meta']['views']
+ probe_views_meta = probe_samples[conditions[0]]['meta']['views']
accuracy = {
condition: torch.empty(
len(gallery_views_meta), len(probe_views_meta), num_ranks
@@ -474,7 +490,7 @@ class Model:
(labels_g, _, views_g, features_g) = gallery_samples_c.values()
views_g = np.asarray(views_g)
probe_samples_c = probe_samples[condition]
- (labels_p, _, views_p, features_p) = probe_samples_c.values()
+ (labels_p, _, views_p, features_p, _) = probe_samples_c.values()
views_p = np.asarray(views_p)
accuracy_c = accuracy[condition]
for (v_g_i, view_g) in enumerate(gallery_views_meta):
@@ -499,7 +515,6 @@ class Model:
positive_counts = positive_mat.sum(0)
total_counts, _ = dist.size()
accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts
-
return accuracy
def _load_pretrained(
@@ -570,7 +585,7 @@ class Model:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
- elif isinstance(m, RGBPartNet):
+ elif isinstance(m, (HorizontalPyramidMatching, PartNet)):
nn.init.xavier_uniform_(m.fc_mat)
def _parse_dataset_config(
diff --git a/models/part_net.py b/models/part_net.py
index d9b954f..de19c8c 100644
--- a/models/part_net.py
+++ b/models/part_net.py
@@ -112,17 +112,21 @@ class PartNet(nn.Module):
def __init__(
self,
in_channels: int = 128,
+ embedding_dims: int = 256,
+ num_parts: int = 16,
squeeze_ratio: int = 4,
- num_part: int = 16
):
super().__init__()
- self.num_part = num_part
- self.tfa = TemporalFeatureAggregator(
- in_channels, squeeze_ratio, self.num_part
- )
+ self.num_part = num_parts
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
+ self.tfa = TemporalFeatureAggregator(
+ in_channels, squeeze_ratio, self.num_part
+ )
+ self.fc_mat = nn.Parameter(
+ torch.empty(num_parts, in_channels, embedding_dims)
+ )
def forward(self, x):
n, t, c, h, w = x.size()
@@ -139,4 +143,8 @@ class PartNet(nn.Module):
# p, n, t, c
x = self.tfa(x)
+
+ # p, n, c
+ x = x @ self.fc_mat
+ # p, n, d
return x
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 2853571..81f198e 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -15,39 +15,31 @@ class RGBPartNet(nn.Module):
ae_in_size: Tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
f_a_c_p_dims: Tuple[int, int, int] = (128, 128, 64),
- hpm_use_1x1conv: bool = False,
hpm_scales: Tuple[int, ...] = (1, 2, 4),
hpm_use_avg_pool: bool = True,
hpm_use_max_pool: bool = True,
tfa_squeeze_ratio: int = 4,
tfa_num_parts: int = 16,
- embedding_dims: int = 256,
+ embedding_dims: Tuple[int] = (256, 256),
image_log_on: bool = False
):
super().__init__()
self.h, self.w = ae_in_size
(self.f_a_dim, self.f_c_dim, self.f_p_dim) = f_a_c_p_dims
- self.hpm_num_parts = sum(hpm_scales)
self.image_log_on = image_log_on
self.ae = AutoEncoder(
ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
)
self.pn_in_channels = ae_feature_channels * 2
- self.pn = PartNet(
- self.pn_in_channels, tfa_squeeze_ratio, tfa_num_parts
- )
self.hpm = HorizontalPyramidMatching(
- ae_feature_channels * 2, self.pn_in_channels, hpm_use_1x1conv,
- hpm_scales, hpm_use_avg_pool, hpm_use_max_pool
+ self.pn_in_channels, embedding_dims[0], hpm_scales,
+ hpm_use_avg_pool, hpm_use_max_pool
)
- self.num_total_parts = self.hpm_num_parts + tfa_num_parts
- empty_fc = torch.empty(self.num_total_parts,
- self.pn_in_channels, embedding_dims)
- self.fc_mat = nn.Parameter(empty_fc)
+ self.pn = PartNet(self.pn_in_channels, embedding_dims[1],
+ tfa_num_parts, tfa_squeeze_ratio)
- def fc(self, x):
- return x @ self.fc_mat
+ self.num_parts = self.hpm.num_parts + tfa_num_parts
def forward(self, x_c1, x_c2=None):
# Step 1: Disentanglement
@@ -57,21 +49,17 @@ class RGBPartNet(nn.Module):
# Step 2.a: Static Gait Feature Aggregation & HPM
# n, c, h, w
x_c = self.hpm(x_c)
- # p, n, c
+ # p, n, d
# Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
# n, t, c, h, w
x_p = self.pn(x_p)
- # p, n, c
-
- # Step 3: Cat feature map together and fc
- x = torch.cat((x_c, x_p))
- x = self.fc(x)
+ # p, n, d
if self.training:
- return x.transpose(0, 1), images, f_loss
+ return x_c.transpose(0, 1), x_p.transpose(0, 1), images, f_loss
else:
- return x.unsqueeze(1).view(-1)
+ return torch.cat((x_c, x_p)).unsqueeze(1).view(-1)
def _disentangle(self, x_c1_t2, x_c2_t2=None):
n, t, c, h, w = x_c1_t2.size()
diff --git a/test/part_net.py b/test/part_net.py
index 25e92ae..fada2c4 100644
--- a/test/part_net.py
+++ b/test/part_net.py
@@ -64,7 +64,7 @@ def test_custom_part_net():
paddings=((2, 1), (1, 1), (1, 1), (1, 1)),
halving=(1, 1, 3, 3),
squeeze_ratio=8,
- num_part=8)
+ num_parts=8)
x = torch.rand(T, N, 1, H, W)
x = pa(x)
diff --git a/utils/configuration.py b/utils/configuration.py
index 376ae0f..8ee08f2 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -33,16 +33,11 @@ class ModelHPConfiguration(TypedDict):
ae_feature_channels: int
f_a_c_p_dims: Tuple[int, int, int]
hpm_scales: Tuple[int, ...]
- hpm_use_1x1conv: bool
hpm_use_avg_pool: bool
hpm_use_max_pool: bool
- fpfe_feature_channels: int
- fpfe_kernel_sizes: Tuple[Tuple, ...]
- fpfe_paddings: Tuple[Tuple, ...]
- fpfe_halving: Tuple[int, ...]
- tfa_squeeze_ratio: int
tfa_num_parts: int
- embedding_dims: int
+ tfa_squeeze_ratio: int
+ embedding_dims: Tuple[int]
triplet_is_hard: bool
triplet_is_mean: bool
triplet_margins: Tuple[float, float]
@@ -63,14 +58,21 @@ class OptimizerHPConfiguration(TypedDict):
weight_decay: float
amsgrad: bool
auto_encoder: SubOptimizerHPConfiguration
- part_net: SubOptimizerHPConfiguration
hpm: SubOptimizerHPConfiguration
- fc: SubOptimizerHPConfiguration
+ part_net: SubOptimizerHPConfiguration
+
+
+class SubSchedulerHPConfiguration(TypedDict):
+ start_step: int
+ final_gamma: float
class SchedulerHPConfiguration(TypedDict):
start_step: int
final_gamma: float
+ auto_encoder: SubSchedulerHPConfiguration
+ hpm: SubSchedulerHPConfiguration
+ part_net: SubSchedulerHPConfiguration
class HyperparameterConfiguration(TypedDict):
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index ae899ec..03fff21 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple
+from typing import Optional
import torch
import torch.nn as nn
@@ -85,43 +85,3 @@ class BatchTripletLoss(nn.Module):
non_zero_mean = losses.sum(1) / non_zero_counts
non_zero_mean[non_zero_counts == 0] = 0
return non_zero_mean
-
-
-class JointBatchTripletLoss(BatchTripletLoss):
- def __init__(
- self,
- hpm_num_parts: int,
- is_hard: bool = True,
- is_mean: bool = True,
- margins: Tuple[float, float] = (0.2, 0.2)
- ):
- super().__init__(is_hard, is_mean)
- self.hpm_num_parts = hpm_num_parts
- self.margin_hpm, self.margin_pn = margins
-
- def forward(self, x, y):
- p, n, c = x.size()
- dist = self._batch_distance(x)
- flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device)
- flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]]
-
- if self.is_hard:
- positive_negative_dist = self._hard_distance(dist, y, p, n)
- else: # is_all
- positive_negative_dist = self._all_distance(dist, y, p, n)
-
- hpm_part_loss = F.relu(
- self.margin_hpm + positive_negative_dist[:self.hpm_num_parts]
- )
- pn_part_loss = F.relu(
- self.margin_pn + positive_negative_dist[self.hpm_num_parts:]
- )
- losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1)
-
- non_zero_counts = (losses != 0).sum(1).float()
- if self.is_mean:
- loss_metric = self._none_zero_mean(losses, non_zero_counts)
- else: # is_sum
- loss_metric = losses.sum(1)
-
- return loss_metric, flat_dist, non_zero_counts