summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-03-01 18:28:35 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-03-01 18:28:35 +0800
commitd01458560bc03c59852f6e3b45c6408a90ba9b6f (patch)
treec7187695c8fd16faf047cd93b66ed5ce9b7b0033
parent8745ffeb51b29cfe7d3eee2d845910086ba8b046 (diff)
parentd88e40217f56d96e568335ccee1f14ff3ea5a696 (diff)
Merge branch 'data_parallel' into data_parallel_py3.8
# Conflicts: # models/model.py # utils/configuration.py # utils/triplet_loss.py
-rw-r--r--config.py34
-rw-r--r--models/layers.py4
-rw-r--r--models/model.py129
-rw-r--r--models/rgb_part_net.py6
-rw-r--r--requirements.txt2
-rw-r--r--utils/configuration.py7
-rw-r--r--utils/triplet_loss.py117
7 files changed, 211 insertions, 88 deletions
diff --git a/config.py b/config.py
index 03f2f0d..e1ee2fb 100644
--- a/config.py
+++ b/config.py
@@ -5,7 +5,7 @@ config: Configuration = {
# Disable accelerator
'disable_acc': False,
# GPU(s) used in training or testing if available
- 'CUDA_VISIBLE_DEVICES': '0',
+ 'CUDA_VISIBLE_DEVICES': '0,1',
# Directory used in training or testing for temporary storage
'save_dir': 'runs',
# Recorde disentangled image or not
@@ -30,14 +30,14 @@ config: Configuration = {
# Resolution after resize, can be divided 16
'frame_size': (64, 48),
# Cache dataset or not
- 'cache_on': False,
+ 'cache_on': True,
},
# Dataloader settings
'dataloader': {
# Batch size (pr, k)
# `pr` denotes number of persons
# `k` denotes number of sequences per person
- 'batch_size': (4, 6),
+ 'batch_size': (6, 8),
# Number of workers of Dataloader
'num_workers': 4,
# Faster data transfer from RAM to GPU if enabled
@@ -53,7 +53,7 @@ config: Configuration = {
# Use 1x1 convolution in dimensionality reduction
'hpm_use_1x1conv': False,
# HPM pyramid scales, of which sum is number of parts
- 'hpm_scales': (1, 2, 4),
+ 'hpm_scales': (1, 2, 4, 8),
# Global pooling method
'hpm_use_avg_pool': True,
'hpm_use_max_pool': True,
@@ -63,13 +63,15 @@ config: Configuration = {
'tfa_num_parts': 16,
# Embedding dimension for each part
'embedding_dims': 256,
- # Triplet loss margins for HPM and PartNet
- 'triplet_margins': (1.5, 1.5),
+ # Batch Hard or Batch All
+ 'triplet_is_hard': True,
+ # Use non-zero mean or sum
+ 'triplet_is_mean': True,
+ # Triplet loss margins for HPM and PartNet, None for soft margin
+ 'triplet_margins': None,
},
'optimizer': {
# Global parameters
- # Iteration start to optimize non-disentangling parts
- # 'start_iter': 0,
# Initial learning rate of Adam Optimizer
'lr': 1e-4,
# Coefficients used for computing running averages of
@@ -83,15 +85,15 @@ config: Configuration = {
# 'amsgrad': False,
# Local parameters (override global ones)
- # 'auto_encoder': {
- # 'weight_decay': 0.001
- # },
+ 'auto_encoder': {
+ 'weight_decay': 0.001
+ },
},
'scheduler': {
- # Period of learning rate decay
- 'step_size': 500,
- # Multiplicative factor of decay
- 'gamma': 1,
+ # Step start to decay
+ 'start_step': 15_000,
+ # Multiplicative factor of decay in the end
+ 'final_gamma': 0.001,
}
},
# Model metadata
@@ -105,6 +107,6 @@ config: Configuration = {
# Restoration iteration (multiple models, e.g. nm, bg and cl)
'restore_iters': (0, 0, 0),
# Total iteration for training (multiple models)
- 'total_iters': (80_000, 80_000, 80_000),
+ 'total_iters': (25_000, 25_000, 25_000),
},
}
diff --git a/models/layers.py b/models/layers.py
index ae61583..e30d0c4 100644
--- a/models/layers.py
+++ b/models/layers.py
@@ -80,7 +80,9 @@ class DCGANConvTranspose2d(BasicConvTranspose2d):
if self.is_last_layer:
return self.trans_conv(x)
else:
- return super().forward(x)
+ x = self.trans_conv(x)
+ x = self.bn(x)
+ return F.leaky_relu(x, 0.2, inplace=True)
class BasicLinear(nn.Module):
diff --git a/models/model.py b/models/model.py
index 2c72270..46d7c4c 100644
--- a/models/model.py
+++ b/models/model.py
@@ -18,7 +18,7 @@ from utils.configuration import DataloaderConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
-from utils.triplet_loss import JointBatchAllTripletLoss
+from utils.triplet_loss import JointBatchTripletLoss, BatchTripletLoss
class Model:
@@ -68,7 +68,7 @@ class Model:
self._dataset_sig: str = 'undefined'
self.rgb_pn: Optional[RGBPartNet] = None
- self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None
+ self.triplet_loss: Optional[JointBatchTripletLoss] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
@@ -143,9 +143,10 @@ class Model:
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
# Prepare for model, optimizer and scheduler
model_hp: Dict = self.hp.get('model', {}).copy()
- triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2))
+ triplet_is_hard = model_hp.pop('triplet_is_hard', True)
+ triplet_is_mean = model_hp.pop('triplet_is_mean', True)
+ triplet_margins = model_hp.pop('triplet_margins', None)
optim_hp: Dict = self.hp.get('optimizer', {}).copy()
- start_iter = optim_hp.pop('start_iter', 0)
ae_optim_hp = optim_hp.pop('auto_encoder', {})
pn_optim_hp = optim_hp.pop('part_net', {})
hpm_optim_hp = optim_hp.pop('hpm', {})
@@ -153,28 +154,48 @@ class Model:
sched_hp = self.hp.get('scheduler', {})
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
- self.ba_triplet_loss = JointBatchAllTripletLoss(
- self.rgb_pn.hpm_num_parts, triplet_margins
- )
+ # Hard margins
+ if triplet_margins:
+ # Same margins
+ if triplet_margins[0] == triplet_margins[1]:
+ self.triplet_loss = BatchTripletLoss(
+ triplet_is_hard, triplet_margins[0]
+ )
+ else: # Different margins
+ self.triplet_loss = JointBatchTripletLoss(
+ self.rgb_pn.hpm_num_parts,
+ triplet_is_hard, triplet_is_mean, triplet_margins
+ )
+ else: # Soft margins
+ self.triplet_loss = BatchTripletLoss(
+ triplet_is_hard, triplet_is_mean, None
+ )
+
+ num_pairs = (self.pr*self.k-1) * (self.pr*self.k) // 2
+ num_pos_pairs = (self.k*(self.k-1)//2) * self.pr
+
# Try to accelerate computation using CUDA or others
self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
- self.ba_triplet_loss = nn.DataParallel(self.ba_triplet_loss)
- self.ba_triplet_loss = self.ba_triplet_loss.to(self.device)
+ self.triplet_loss = nn.DataParallel(self.triplet_loss)
+ self.triplet_loss = self.triplet_loss.to(self.device)
self.optimizer = optim.Adam([
{'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp},
{'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
{'params': self.rgb_pn.module.hpm.parameters(), **hpm_optim_hp},
{'params': self.rgb_pn.module.fc_mat, **fc_optim_hp}
], **optim_hp)
- sched_gamma = sched_hp.get('gamma', 0.9)
- sched_step_size = sched_hp.get('step_size', 500)
+ sched_final_gamma = sched_hp.get('final_gamma', 0.001)
+ sched_start_step = sched_hp.get('start_step', 15_000)
+
+ def lr_lambda(epoch):
+ passed_step = epoch - sched_start_step
+ all_step = self.total_iter - sched_start_step
+ return sched_final_gamma ** (passed_step / all_step)
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=[
- lambda epoch: sched_gamma ** (epoch // sched_step_size),
- lambda epoch: 0 if epoch < start_iter else 1,
- lambda epoch: 0 if epoch < start_iter else 1,
- lambda epoch: 0 if epoch < start_iter else 1,
+ lr_lambda, lr_lambda, lr_lambda, lr_lambda
])
+
self.writer = SummaryWriter(self._log_name)
self.rgb_pn.train()
@@ -194,7 +215,7 @@ class Model:
running_loss = torch.zeros(5, device=self.device)
print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
- f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
+ f"{'BATripH':^8} {'BATripP':^8} {'LR':^9}")
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -202,16 +223,16 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- feature, ae_losses, images = self.rgb_pn(x_c1, x_c2)
+ embedding, ae_losses, images = self.rgb_pn(x_c1, x_c2)
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
y = y.repeat(self.rgb_pn.num_total_parts, 1)
- triplet_loss = self.ba_triplet_loss(feature, y)
+ trip_loss, dist, num_non_zero = self.triplet_loss(embedding, y)
losses = torch.cat((
ae_losses.mean(0),
torch.stack((
- triplet_loss[:self.rgb_pn.hpm_num_parts].mean(),
- triplet_loss[self.rgb_pn.hpm_num_parts:].mean()
+ trip_loss[:self.rgb_pn.hpm_num_parts].mean(),
+ trip_loss[self.rgb_pn.hpm_num_parts:].mean()
))
))
loss = losses.sum()
@@ -222,20 +243,50 @@ class Model:
running_loss += losses.detach()
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss, self.curr_iter)
- self.writer.add_scalars('Loss/details', dict(zip([
+ self.writer.add_scalars('Loss/disentanglement', dict(zip((
'Cross reconstruction loss', 'Canonical consistency loss',
- 'Pose similarity loss', 'Batch All triplet loss (HPM)',
- 'Batch All triplet loss (PartNet)'
- ], losses)), self.curr_iter)
+ 'Pose similarity loss'
+ ), ae_losses)), self.curr_iter)
+ self.writer.add_scalars('Loss/triplet loss', {
+ 'HPM': losses[3],
+ 'PartNet': losses[4]
+ }, self.curr_iter)
+ # None-zero losses in batch
+ if num_non_zero is not None:
+ self.writer.add_scalars('Loss/non-zero counts', {
+ 'HPM': num_non_zero[:self.rgb_pn.hpm_num_parts].mean(),
+ 'PartNet': num_non_zero[self.rgb_pn.hpm_num_parts:].mean()
+ }, self.curr_iter)
+ # Embedding distance
+ mean_hpm_dist = dist[:self.rgb_pn.hpm_num_parts].mean(0)
+ self._add_ranked_scalars(
+ 'Embedding/HPM distance', mean_hpm_dist,
+ num_pos_pairs, num_pairs, self.curr_iter
+ )
+ mean_pa_dist = dist[self.rgb_pn.hpm_num_parts:].mean(0)
+ self._add_ranked_scalars(
+ 'Embedding/ParNet distance', mean_pa_dist,
+ num_pos_pairs, num_pairs, self.curr_iter
+ )
+ # Embedding norm
+ mean_hpm_embedding = embedding[:self.rgb_pn.hpm_num_parts].mean(0)
+ mean_hpm_norm = mean_hpm_embedding.norm(dim=-1)
+ self._add_ranked_scalars(
+ 'Embedding/HPM norm', mean_hpm_norm,
+ self.k, self.pr * self.k, self.curr_iter
+ )
+ mean_pa_embedding = embedding[self.rgb_pn.hpm_num_parts:].mean(0)
+ mean_pa_norm = mean_pa_embedding.norm(dim=-1)
+ self._add_ranked_scalars(
+ 'Embedding/PartNet norm', mean_pa_norm,
+ self.k, self.pr * self.k, self.curr_iter
+ )
if self.curr_iter % 100 == 0:
lrs = self.scheduler.get_last_lr()
# Write learning rates
self.writer.add_scalar(
- 'Learning rate/Auto-encoder', lrs[0], self.curr_iter
- )
- self.writer.add_scalar(
- 'Learning rate/Others', lrs[1], self.curr_iter
+ 'Learning rate', lrs[0], self.curr_iter
)
# Write disentangled images
if self.image_log_on:
@@ -259,7 +310,7 @@ class Model:
print(f'{hour:02}:{minute:02}:{second:02}',
f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
- '{:.3e} {:.3e}'.format(lrs[0], lrs[1]))
+ f'{lrs[0]:.3e}')
running_loss.zero_()
# Step scheduler
@@ -278,6 +329,24 @@ class Model:
self.writer.close()
break
+ def _add_ranked_scalars(
+ self,
+ main_tag: str,
+ metric: torch.Tensor,
+ num_pos: int,
+ num_all: int,
+ global_step: int
+ ):
+ rank = metric.argsort()
+ pos_ile = 100 - (num_pos - 1) * 100 // num_all
+ self.writer.add_scalars(main_tag, {
+ '0%-ile': metric[rank[-1]],
+ f'{100 - pos_ile}%-ile': metric[rank[-num_pos]],
+ '50%-ile': metric[rank[num_all // 2 - 1]],
+ f'{pos_ile}%-ile': metric[rank[num_pos - 1]],
+ '100%-ile': metric[rank[0]]
+ }, global_step)
+
def predict_all(
self,
iters: Tuple[int],
@@ -317,6 +386,8 @@ class Model:
# Init models
model_hp: Dict = self.hp.get('model', {}).copy()
+ model_hp.pop('triplet_is_hard', True)
+ model_hp.pop('triplet_is_mean', True)
model_hp.pop('triplet_margins', None)
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
# Try to accelerate computation using CUDA or others
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 4d7ba7f..310ef25 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -81,7 +81,7 @@ class RGBPartNet(nn.Module):
((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
# Decode features
x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ x_p_ = self._decode_pose_feature(f_p_, n, t, device)
x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4)
i_a, i_c, i_p = None, None, None
@@ -100,7 +100,7 @@ class RGBPartNet(nn.Module):
else: # evaluating
f_c_, f_p_ = self.ae(x_c1_t2)
x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p_ = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ x_p_ = self._decode_pose_feature(f_p_, n, t, device)
x_p = x_p_.view(n, t, self.pn_in_channels, self.h // 4, self.w // 4)
return (x_c, x_p), None, None
@@ -125,7 +125,7 @@ class RGBPartNet(nn.Module):
)
return x_c
- def _decode_pose_feature(self, f_p_, n, t, c, h, w, device):
+ def _decode_pose_feature(self, f_p_, n, t, device):
# Decode pose features to images
x_p_ = self.ae.decoder(
torch.zeros((n * t, self.f_a_dim), device=device),
diff --git a/requirements.txt b/requirements.txt
index 4d30e17..926a587 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
torch~=1.7.1
torchvision~=0.8.0a0+ecf4e9c
numpy~=1.19.4
-tqdm~=4.57.0
+tqdm~=4.58.0
Pillow~=8.1.0
scikit-learn~=0.24.0 \ No newline at end of file
diff --git a/utils/configuration.py b/utils/configuration.py
index b9e6d92..376ae0f 100644
--- a/utils/configuration.py
+++ b/utils/configuration.py
@@ -43,6 +43,8 @@ class ModelHPConfiguration(TypedDict):
tfa_squeeze_ratio: int
tfa_num_parts: int
embedding_dims: int
+ triplet_is_hard: bool
+ triplet_is_mean: bool
triplet_margins: Tuple[float, float]
@@ -55,7 +57,6 @@ class SubOptimizerHPConfiguration(TypedDict):
class OptimizerHPConfiguration(TypedDict):
- start_iter: int
lr: int
betas: Tuple[float, float]
eps: float
@@ -68,8 +69,8 @@ class OptimizerHPConfiguration(TypedDict):
class SchedulerHPConfiguration(TypedDict):
- step_size: int
- gamma: float
+ start_step: int
+ final_gamma: float
class HyperparameterConfiguration(TypedDict):
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 6025bd3..ae899ec 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -1,34 +1,48 @@
-from typing import Tuple
+from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
-class BatchAllTripletLoss(nn.Module):
- def __init__(self, margin: float = 0.2):
+class BatchTripletLoss(nn.Module):
+ def __init__(
+ self,
+ is_hard: bool = True,
+ is_mean: bool = True,
+ margin: Optional[float] = 0.2,
+ ):
super().__init__()
+ self.is_hard = is_hard
+ self.is_mean = is_mean
self.margin = margin
def forward(self, x, y):
p, n, c = x.size()
-
dist = self._batch_distance(x)
- positive_negative_dist = self._hard_distance(dist, y, p, n)
- all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1)
- parted_loss_mean = self._none_zero_parted_mean(all_loss)
-
- return parted_loss_mean
-
- @staticmethod
- def _hard_distance(dist, y, p, n):
- hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
- hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
- all_hard_positive = dist[hard_positive_mask].view(p, n, -1, 1)
- all_hard_negative = dist[hard_negative_mask].view(p, n, 1, -1)
- positive_negative_dist = all_hard_positive - all_hard_negative
-
- return positive_negative_dist
+ flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device)
+ flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]]
+
+ if self.is_hard:
+ positive_negative_dist = self._hard_distance(dist, y, p, n)
+ else: # is_all
+ positive_negative_dist = self._all_distance(dist, y, p, n)
+
+ if self.margin:
+ losses = F.relu(self.margin + positive_negative_dist).view(p, -1)
+ non_zero_counts = (losses != 0).sum(1).float()
+ if self.is_mean:
+ loss_metric = self._none_zero_mean(losses, non_zero_counts)
+ else: # is_sum
+ loss_metric = losses.sum(1)
+ return loss_metric, flat_dist, non_zero_counts
+ else: # Soft margin
+ losses = F.softplus(positive_negative_dist).view(p, -1)
+ if self.is_mean:
+ loss_metric = losses.mean(1)
+ else: # is_sum
+ loss_metric = losses.sum(1)
+ return loss_metric, flat_dist, None
@staticmethod
def _batch_distance(x):
@@ -40,41 +54,74 @@ class BatchAllTripletLoss(nn.Module):
dist = torch.sqrt(
F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum)
)
-
return dist
@staticmethod
- def _none_zero_parted_mean(all_loss):
- # Non-zero parted mean
- non_zero_counts = (all_loss != 0).sum(1)
- parted_loss_mean = all_loss.sum(1) / non_zero_counts
- parted_loss_mean[non_zero_counts == 0] = 0
+ def _hard_distance(dist, y, p, n):
+ positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
+ negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
+ hard_positive = dist[positive_mask].view(p, n, -1).max(-1).values
+ hard_negative = dist[negative_mask].view(p, n, -1).min(-1).values
+ positive_negative_dist = hard_positive - hard_negative
+
+ return positive_negative_dist
+
+ @staticmethod
+ def _all_distance(dist, y, p, n):
+ # Unmask identical samples
+ positive_mask = torch.eye(
+ n, dtype=torch.bool, device=y.device
+ ) ^ (y.unsqueeze(1) == y.unsqueeze(2))
+ negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
+ all_positive = dist[positive_mask].view(p, n, -1, 1)
+ all_negative = dist[negative_mask].view(p, n, 1, -1)
+ positive_negative_dist = all_positive - all_negative
- return parted_loss_mean
+ return positive_negative_dist
+
+ @staticmethod
+ def _none_zero_mean(losses, non_zero_counts):
+ # Non-zero parted mean
+ non_zero_mean = losses.sum(1) / non_zero_counts
+ non_zero_mean[non_zero_counts == 0] = 0
+ return non_zero_mean
-class JointBatchAllTripletLoss(BatchAllTripletLoss):
+class JointBatchTripletLoss(BatchTripletLoss):
def __init__(
self,
hpm_num_parts: int,
+ is_hard: bool = True,
+ is_mean: bool = True,
margins: Tuple[float, float] = (0.2, 0.2)
):
- super().__init__()
+ super().__init__(is_hard, is_mean)
self.hpm_num_parts = hpm_num_parts
self.margin_hpm, self.margin_pn = margins
def forward(self, x, y):
p, n, c = x.size()
-
dist = self._batch_distance(x)
- positive_negative_dist = self._hard_distance(dist, y, p, n)
+ flat_dist_mask = torch.tril_indices(n, n, offset=-1, device=dist.device)
+ flat_dist = dist[:, flat_dist_mask[0], flat_dist_mask[1]]
+
+ if self.is_hard:
+ positive_negative_dist = self._hard_distance(dist, y, p, n)
+ else: # is_all
+ positive_negative_dist = self._all_distance(dist, y, p, n)
+
hpm_part_loss = F.relu(
self.margin_hpm + positive_negative_dist[:self.hpm_num_parts]
- ).view(self.hpm_num_parts, -1)
+ )
pn_part_loss = F.relu(
self.margin_pn + positive_negative_dist[self.hpm_num_parts:]
- ).view(p - self.hpm_num_parts, -1)
- all_loss = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1)
- parted_loss_mean = self._none_zero_parted_mean(all_loss)
+ )
+ losses = torch.cat((hpm_part_loss, pn_part_loss)).view(p, -1)
+
+ non_zero_counts = (losses != 0).sum(1).float()
+ if self.is_mean:
+ loss_metric = self._none_zero_mean(losses, non_zero_counts)
+ else: # is_sum
+ loss_metric = losses.sum(1)
- return parted_loss_mean
+ return loss_metric, flat_dist, non_zero_counts