summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-12 11:29:02 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-12 11:29:02 +0800
commit966d4431c037b0c4641aa2a5fc22f05be064b331 (patch)
tree0239ba89d31857a7f086acf627fc1bbf167855a9
parent7825f978f198e56958703f0d08f7ccbd8cef49ca (diff)
parent36cf502afe9b93efe31c244030270b0a62e644b8 (diff)
Merge branch 'master' into python3.8
# Conflicts: # models/model.py
-rw-r--r--eval.py30
-rw-r--r--models/auto_encoder.py2
-rw-r--r--models/model.py249
-rw-r--r--models/rgb_part_net.py15
-rw-r--r--requirements.txt2
-rw-r--r--test/model.py4
-rw-r--r--train.py26
-rw-r--r--utils/misc.py10
-rw-r--r--utils/triplet_loss.py9
9 files changed, 295 insertions, 52 deletions
diff --git a/eval.py b/eval.py
new file mode 100644
index 0000000..7b68220
--- /dev/null
+++ b/eval.py
@@ -0,0 +1,30 @@
+import numpy as np
+
+from config import config
+from models import Model
+from utils.dataset import ClipConditions
+from utils.misc import set_visible_cuda
+
+set_visible_cuda(config['system'])
+model = Model(config['system'], config['model'], config['hyperparameter'])
+
+dataset_selectors = {
+ 'nm': {'conditions': ClipConditions({r'nm-0\d'})},
+ 'bg': {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
+ 'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
+}
+
+accuracy = model.predict_all(config['model']['total_iter'], config['dataset'],
+ dataset_selectors, config['dataloader'])
+rank = 5
+np.set_printoptions(formatter={'float': '{:5.2f}'.format})
+for n in range(rank):
+ print(f'===Rank-{n + 1} Accuracy===')
+ for (condition, accuracy_c) in accuracy.items():
+ acc_excl_identical_view = accuracy_c[:, :, n].fill_diagonal_(0)
+ num_gallery_views = (acc_excl_identical_view != 0).sum()
+ acc_each_angle = acc_excl_identical_view.sum(0) / num_gallery_views
+ print('{0}: {1} mean: {2:5.2f}'.format(
+ condition, acc_each_angle.cpu().numpy() * 100,
+ acc_each_angle.mean() * 100)
+ )
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 1e7c323..64c52e3 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -130,7 +130,7 @@ class AutoEncoder(nn.Module):
BasicLinear(f_c_dim, num_class)
)
- def forward(self, x_c1_t1, x_c1_t2, x_c2_t2, y=None):
+ def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None, y=None):
# x_c1_t2 is the frame for later module
(f_a_c1_t2, f_c_c1_t2, f_p_c1_t2) = self.encoder(x_c1_t2)
diff --git a/models/model.py b/models/model.py
index 80d4499..1154d7f 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,19 +1,22 @@
import os
+from datetime import datetime
from typing import Union, Optional, Tuple, List, Dict, Set
import numpy as np
import torch
import torch.nn as nn
+import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
-from utils.dataset import CASIAB
+from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
@@ -52,25 +55,52 @@ class Model:
self.pr: Optional[int] = None
self.k: Optional[int] = None
+ self._gallery_dataset_meta: Optional[Dict[str, List]] = None
+ self._probe_datasets_meta: Optional[Dict[str, Dict[str, List]]] = None
+
self._model_sig: str = self._make_signature(self.meta, ['restore_iter'])
self._hp_sig: str = self._make_signature(self.hp)
self._dataset_sig: str = 'undefined'
self._log_sig: str = '_'.join((self._model_sig, self._hp_sig))
- self.log_name: str = os.path.join(self.log_dir, self._log_sig)
+ self._log_name: str = os.path.join(self.log_dir, self._log_sig)
self.rgb_pn: Optional[RGBPartNet] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
+ self.CASIAB_GALLERY_SELECTOR = {
+ 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
+ }
+ self.CASIAB_PROBE_SELECTORS = {
+ 'nm': {'selector': {'conditions': ClipConditions({r'nm-0[5-6]'})}},
+ 'bg': {'selector': {'conditions': ClipConditions({r'bg-0[1-2]'})}},
+ 'cl': {'selector': {'conditions': ClipConditions({r'cl-0[1-2]'})}},
+ }
+
@property
- def signature(self) -> str:
+ def _signature(self) -> str:
return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig,
self._dataset_sig, str(self.pr), str(self.k)))
@property
- def checkpoint_name(self) -> str:
- return os.path.join(self.checkpoint_dir, self.signature)
+ def _checkpoint_name(self) -> str:
+ return os.path.join(self.checkpoint_dir, self._signature)
+
+ def fit_all(
+ self,
+ dataset_config: DatasetConfiguration,
+ dataset_selectors: Dict[
+ str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
+ ],
+ dataloader_config: DataloaderConfiguration,
+ ):
+ for (condition, selector) in dataset_selectors.items():
+ print(f'Training model {condition} ...')
+ self.fit(
+ dict(**dataset_config, **{'selector': selector}),
+ dataloader_config
+ )
def fit(
self,
@@ -86,24 +116,23 @@ class Model:
self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp)
self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
- self.writer = SummaryWriter(self.log_name)
-
- if not self.disable_acc:
- if torch.cuda.device_count() > 1:
- self.rgb_pn = nn.DataParallel(self.rgb_pn)
- self.rgb_pn = self.rgb_pn.to(self.device)
+ self.writer = SummaryWriter(self._log_name)
+ # Try to accelerate computation using CUDA or others
+ self._accelerate()
self.rgb_pn.train()
# Init weights at first iter
if self.curr_iter == 0:
self.rgb_pn.apply(self.init_weights)
else: # Load saved state dicts
- checkpoint = torch.load(self.checkpoint_name)
+ checkpoint = torch.load(self._checkpoint_name)
iter_, loss = checkpoint['iter'], checkpoint['loss']
print('{0:5d} loss: {1:.3f}'.format(iter_, loss))
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
+ # Training start
+ start_time = datetime.now()
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -126,7 +155,7 @@ class Model:
], metrics)), self.curr_iter)
if self.curr_iter % 100 == 0:
- print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss),
+ print('{0:5d} loss: {1:6.3f}'.format(self.curr_iter, loss),
'(xrecon = {:f}, pose_sim = {:f},'
' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics),
'lr:', self.scheduler.get_last_lr()[0])
@@ -137,12 +166,190 @@ class Model:
'model_state_dict': self.rgb_pn.state_dict(),
'optim_state_dict': self.optimizer.state_dict(),
'loss': loss,
- }, self.checkpoint_name)
+ }, self._checkpoint_name)
+ print(datetime.now() - start_time, 'used')
+ start_time = datetime.now()
if self.curr_iter == self.total_iter:
+ self.curr_iter = 0
self.writer.close()
break
+ def _accelerate(self):
+ if not self.disable_acc:
+ if torch.cuda.device_count() > 1:
+ self.rgb_pn = nn.DataParallel(self.rgb_pn)
+ self.rgb_pn = self.rgb_pn.to(self.device)
+
+ def predict_all(
+ self,
+ iter_: int,
+ dataset_config: DatasetConfiguration,
+ dataset_selectors: Dict[
+ str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
+ ],
+ dataloader_config: DataloaderConfiguration,
+ ) -> Dict[str, torch.Tensor]:
+ self.is_train = False
+ # Split gallery and probe dataset
+ gallery_dataloader, probe_dataloaders = self._split_gallery_probe(
+ dataset_config, dataloader_config
+ )
+ # Get pretrained models at iter_
+ checkpoints = self._load_pretrained(
+ iter_, dataset_config, dataset_selectors
+ )
+ # Init models
+ hp = self.hp.copy()
+ hp.pop('lr'), hp.pop('betas')
+ self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **hp)
+ # Try to accelerate computation using CUDA or others
+ self._accelerate()
+
+ self.rgb_pn.eval()
+ gallery_samples, probe_samples = [], {}
+
+ # Gallery
+ checkpoint = torch.load(list(checkpoints.values())[0])
+ self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
+ for sample in tqdm(gallery_dataloader,
+ desc='Transforming gallery', unit='clips'):
+ label = sample.pop('label').item()
+ clip = sample.pop('clip').to(self.device)
+ feature = self.rgb_pn(clip).detach()
+ gallery_samples.append({
+ **{'label': label},
+ **sample,
+ **{'feature': feature}
+ })
+ gallery_samples = default_collate(gallery_samples)
+
+ # Probe
+ for (condition, dataloader) in probe_dataloaders.items():
+ checkpoint = torch.load(checkpoints[condition])
+ self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
+ probe_samples[condition] = []
+ for sample in tqdm(dataloader,
+ desc=f'Transforming probe {condition}',
+ unit='clips'):
+ label = sample.pop('label').item()
+ clip = sample.pop('clip').to(self.device)
+ feature = self.rgb_pn(clip).detach()
+ probe_samples[condition].append({
+ **{'label': label},
+ **sample,
+ **{'feature': feature}
+ })
+ for (k, v) in probe_samples.items():
+ probe_samples[k] = default_collate(v)
+
+ return self._evaluate(gallery_samples, probe_samples)
+
+ def _evaluate(
+ self,
+ gallery_samples: Dict[str, Union[List[str], torch.Tensor]],
+ probe_samples: Dict[str, Dict[str, Union[List[str], torch.Tensor]]],
+ num_ranks: int = 5
+ ) -> Dict[str, torch.Tensor]:
+ probe_conditions = self._probe_datasets_meta.keys()
+ gallery_views_meta = self._gallery_dataset_meta['views']
+ probe_views_meta = list(self._probe_datasets_meta.values())[0]['views']
+ accuracy = {
+ condition: torch.empty(
+ len(gallery_views_meta), len(probe_views_meta), num_ranks
+ )
+ for condition in self._probe_datasets_meta.keys()
+ }
+
+ (labels_g, _, views_g, features_g) = gallery_samples.values()
+ views_g = np.asarray(views_g)
+ for (v_g_i, view_g) in enumerate(gallery_views_meta):
+ gallery_view_mask = (views_g == view_g)
+ f_g = features_g[gallery_view_mask]
+ y_g = labels_g[gallery_view_mask]
+ for condition in probe_conditions:
+ probe_samples_c = probe_samples[condition]
+ accuracy_c = accuracy[condition]
+ (labels_p, _, views_p, features_p) = probe_samples_c.values()
+ views_p = np.asarray(views_p)
+ for (v_p_i, view_p) in enumerate(probe_views_meta):
+ probe_view_mask = (views_p == view_p)
+ f_p = features_p[probe_view_mask]
+ y_p = labels_p[probe_view_mask]
+ # Euclidean distance
+ f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1)
+ f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0)
+ f_p_times_f_g_sum = f_p @ f_g.T
+ dist = torch.sqrt(F.relu(
+ f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum
+ ))
+ # Ranked accuracy
+ rank_mask = dist.argsort(1)[:, :num_ranks]
+ positive_mat = torch.eq(y_p.unsqueeze(1),
+ y_g[rank_mask]).cumsum(1).gt(0)
+ positive_counts = positive_mat.sum(0)
+ total_counts, _ = dist.size()
+ accuracy_c[v_g_i, v_p_i, :] = positive_counts / total_counts
+
+ return accuracy
+
+ def _load_pretrained(
+ self,
+ iter_: int,
+ dataset_config: DatasetConfiguration,
+ dataset_selectors: Dict[
+ str, Dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
+ ]
+ ) -> Dict[str, str]:
+ checkpoints = {}
+ self.curr_iter = iter_
+ for (k, v) in dataset_selectors.items():
+ self._dataset_sig = self._make_signature(
+ dict(**dataset_config, **v),
+ popped_keys=['root_dir', 'cache_on']
+ )
+ checkpoints[k] = self._checkpoint_name
+ return checkpoints
+
+ def _split_gallery_probe(
+ self,
+ dataset_config: DatasetConfiguration,
+ dataloader_config: DataloaderConfiguration,
+ ) -> Tuple[DataLoader, Dict[str: DataLoader]]:
+ dataset_name = dataset_config.get('name', 'CASIA-B')
+ if dataset_name == 'CASIA-B':
+ gallery_dataset = self._parse_dataset_config(
+ dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR)
+ )
+ self._gallery_dataset_meta = gallery_dataset.metadata
+ gallery_dataloader = self._parse_dataloader_config(
+ gallery_dataset, dataloader_config
+ )
+ probe_datasets = {
+ condition: self._parse_dataset_config(
+ dict(**dataset_config, **selector)
+ )
+ for (condition, selector) in self.CASIAB_PROBE_SELECTORS.items()
+ }
+ self._probe_datasets_meta = {
+ condition: dataset.metadata
+ for (condition, dataset) in probe_datasets.items()
+ }
+ probe_dataloaders = {
+ condtion: self._parse_dataloader_config(
+ dataset, dataloader_config
+ )
+ for (condtion, dataset) in probe_datasets.items()
+ }
+ elif dataset_name == 'FVG':
+ # TODO
+ gallery_dataloader = None
+ probe_dataloaders = None
+ else:
+ raise ValueError('Invalid dataset: {0}'.format(dataset_name))
+
+ return gallery_dataloader, probe_dataloaders
+
@staticmethod
def init_weights(m):
if isinstance(m, nn.modules.conv._ConvNd):
@@ -165,9 +372,9 @@ class Model:
dataset_config,
popped_keys=['root_dir', 'cache_on']
)
- self.log_name = '_'.join((self.log_name, self._dataset_sig))
+ self._log_name = '_'.join((self._log_name, self._dataset_sig))
config: Dict = dataset_config.copy()
- name = config.pop('name')
+ name = config.pop('name', 'CASIA-B')
if name == 'CASIA-B':
return CASIAB(**config, is_train=self.is_train)
elif name == 'FVG':
@@ -181,16 +388,16 @@ class Model:
dataloader_config: DataloaderConfiguration
) -> DataLoader:
config: Dict = dataloader_config.copy()
+ (self.pr, self.k) = config.pop('batch_size')
if self.is_train:
- (self.pr, self.k) = config.pop('batch_size')
- self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k)))
+ self._log_name = '_'.join(
+ (self._log_name, str(self.pr), str(self.k)))
triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
return DataLoader(dataset,
batch_sampler=triplet_sampler,
collate_fn=self._batch_splitter,
**config)
else: # is_test
- config.pop('batch_size')
return DataLoader(dataset, **config)
def _batch_splitter(
@@ -225,8 +432,10 @@ class Model:
for v in values:
if isinstance(v, str):
strings.append(v)
- elif isinstance(v, (Tuple, List, Set)):
+ elif isinstance(v, (Tuple, List)):
strings.append(self._gen_sig(v))
+ elif isinstance(v, Set):
+ strings.append(self._gen_sig(sorted(list(v))))
elif isinstance(v, Dict):
strings.append(self._gen_sig(list(v.values())))
else:
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 39cbed6..95a3f2e 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -52,10 +52,12 @@ class RGBPartNet(nn.Module):
def fc(self, x):
return x @ self.fc_mat
- def forward(self, x_c1, x_c2, y=None):
+ def forward(self, x_c1, x_c2=None, y=None):
# Step 0: Swap batch_size and time dimensions for next step
# n, t, c, h, w
- x_c1, x_c2 = x_c1.transpose(0, 1), x_c2.transpose(0, 1)
+ x_c1 = x_c1.transpose(0, 1)
+ if self.training:
+ x_c2 = x_c2.transpose(0, 1)
# Step 1: Disentanglement
# t, n, c, h, w
@@ -83,9 +85,9 @@ class RGBPartNet(nn.Module):
loss = torch.sum(torch.stack(losses))
return loss, [loss.item() for loss in losses]
else:
- return x
+ return x.unsqueeze(1).view(-1)
- def _disentangle(self, x_c1, x_c2, y):
+ def _disentangle(self, x_c1, x_c2=None, y=None):
num_frames = len(x_c1)
# Decoded canonical features and Pose images
x_c_c1, x_p_c1 = [], []
@@ -95,7 +97,7 @@ class RGBPartNet(nn.Module):
xrecon_loss, cano_cons_loss = [], []
for t2 in range(num_frames):
t1 = random.randrange(num_frames)
- output = self.ae(x_c1[t1], x_c1[t2], x_c2[t2], y)
+ output = self.ae(x_c1[t2], x_c1[t1], x_c2[t2], y)
(x_c1_t2, f_p_t2, losses) = output
# Decoded features or image
@@ -128,8 +130,7 @@ class RGBPartNet(nn.Module):
else: # evaluating
for t2 in range(num_frames):
- t1 = random.randrange(num_frames)
- x_c1_t2 = self.ae(x_c1[t1], x_c1[t2], x_c2[t2])
+ x_c1_t2 = self.ae(x_c1[t2])
# Decoded features or image
(x_c_c1_t2, x_p_c1_t2) = x_c1_t2
# Canonical Features for HPM
diff --git a/requirements.txt b/requirements.txt
index bcc6597..deaa0d4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,6 @@
torch~=1.7.1
torchvision~=0.8.0a0+ecf4e9c
numpy~=1.19.4
-tqdm~=4.55.0
+tqdm~=4.56.0
Pillow~=8.1.0
scikit-learn~=0.23.2 \ No newline at end of file
diff --git a/test/model.py b/test/model.py
index 472c210..f679908 100644
--- a/test/model.py
+++ b/test/model.py
@@ -11,11 +11,11 @@ def test_default_signature():
model = Model(conf['system'], conf['model'], conf['hyperparameter'])
casiab = model._parse_dataset_config(conf['dataset'])
model._parse_dataloader_config(casiab, conf['dataloader'])
- assert model.log_name == os.path.join(
+ assert model._log_name == os.path.join(
'runs', 'logs', 'RGB-GaitPart_80000_64_128_128_64_1_2_4_True_True_32_5_'
'3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_256_0.2_0.0001_0.9_'
'0.999_CASIA-B_74_30_15_3_64_32_8_16')
- assert model.signature == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_'
+ assert model._signature == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_'
'True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_256_'
'0.2_0.0001_0.9_0.999_CASIA-B_74_30_15_3_64_32_'
'8_16')
diff --git a/train.py b/train.py
index d921839..d91dcd0 100644
--- a/train.py
+++ b/train.py
@@ -1,24 +1,16 @@
-import os
-
from config import config
from models import Model
from utils.dataset import ClipConditions
+from utils.misc import set_visible_cuda
-# Set environment variable CUDA device(s)
-CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None)
-if CUDA_VISIBLE_DEVICES:
- os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
-
+set_visible_cuda(config['system'])
model = Model(config['system'], config['model'], config['hyperparameter'])
# 3 models for different conditions
-dataset_selectors = [
- {'conditions': ClipConditions({r'nm-0\d'})},
- {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
- {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
-]
-for selector in dataset_selectors:
- model.fit(
- dict(**config['dataset'], **{'selector': selector}),
- config['dataloader']
- )
+dataset_selectors = {
+ 'nm': {'conditions': ClipConditions({r'nm-0\d'})},
+ 'bg': {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
+ 'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
+}
+
+model.fit_all(config['dataset'], dataset_selectors, config['dataloader'])
diff --git a/utils/misc.py b/utils/misc.py
new file mode 100644
index 0000000..b850830
--- /dev/null
+++ b/utils/misc.py
@@ -0,0 +1,10 @@
+import os
+
+from utils.configuration import SystemConfiguration
+
+
+def set_visible_cuda(config: SystemConfiguration):
+ """Set environment variable CUDA device(s)"""
+ CUDA_VISIBLE_DEVICES = config.get('CUDA_VISIBLE_DEVICES', None)
+ if CUDA_VISIBLE_DEVICES:
+ os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 1d63a0e..8c143d6 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -15,8 +15,8 @@ class BatchAllTripletLoss(nn.Module):
# Euclidean distance p x n x n
x_squared_sum = torch.sum(x ** 2, dim=2)
- x1_squared_sum = x_squared_sum.unsqueeze(1)
- x2_squared_sum = x_squared_sum.unsqueeze(2)
+ x1_squared_sum = x_squared_sum.unsqueeze(2)
+ x2_squared_sum = x_squared_sum.unsqueeze(1)
x1_times_x2_sum = x @ x.transpose(1, 2)
dist = torch.sqrt(
F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum)
@@ -30,8 +30,9 @@ class BatchAllTripletLoss(nn.Module):
all_loss = F.relu(self.margin + positive_negative_dist).view(p, -1)
# Non-zero parted mean
- parted_loss_mean = all_loss.sum(1) / (all_loss != 0).sum(1)
- parted_loss_mean[parted_loss_mean == float('Inf')] = 0
+ non_zero_counts = (all_loss != 0).sum(1)
+ parted_loss_mean = all_loss.sum(1) / non_zero_counts
+ parted_loss_mean[non_zero_counts == 0] = 0
loss = parted_loss_mean.sum()
return loss