summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-20 14:48:16 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-20 14:48:16 +0800
commit9b1828be1db7fd1be8731a7cec66162de9145285 (patch)
tree9efb5a37856f34e333457e9d7ab2aaa8ba811cf6 /models/model.py
parente33c22e556ed64e1c1fdb011d78a124d1489ad15 (diff)
parentc538919cb69e35a46811aef0b23baefe6a4c499c (diff)
Merge branch 'python3.8' into data_parallel_py3.8
# Conflicts: # models/model.py
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/models/model.py b/models/model.py
index 2ef3b80..e2de476 100644
--- a/models/model.py
+++ b/models/model.py
@@ -18,6 +18,7 @@ from utils.configuration import DataloaderConfiguration, \
SystemConfiguration
from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
+from utils.triplet_loss import JointBatchAllTripletLoss
class Model:
@@ -67,6 +68,7 @@ class Model:
self._dataset_sig: str = 'undefined'
self.rgb_pn: Optional[RGBPartNet] = None
+ self.ba_triplet_loss: Optional[JointBatchAllTripletLoss] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
@@ -140,7 +142,8 @@ class Model:
dataset = self._parse_dataset_config(dataset_config)
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
# Prepare for model, optimizer and scheduler
- model_hp = self.hp.get('model', {})
+ model_hp: Dict = self.hp.get('model', {}).copy()
+ triplet_margins = model_hp.pop('triplet_margins', (0.2, 0.2))
optim_hp: Dict = self.hp.get('optimizer', {}).copy()
start_iter = optim_hp.pop('start_iter', 0)
ae_optim_hp = optim_hp.pop('auto_encoder', {})
@@ -150,9 +153,14 @@ class Model:
sched_hp = self.hp.get('scheduler', {})
self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
image_log_on=self.image_log_on)
+ self.ba_triplet_loss = JointBatchAllTripletLoss(
+ self.rgb_pn.hpm_num_parts, triplet_margins
+ )
# Try to accelerate computation using CUDA or others
self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
+ self.ba_triplet_loss = nn.DataParallel(self.ba_triplet_loss)
+ self.ba_triplet_loss = self.ba_triplet_loss.to(self.device)
self.optimizer = optim.Adam([
{'params': self.rgb_pn.module.ae.parameters(), **ae_optim_hp},
{'params': self.rgb_pn.module.pn.parameters(), **pn_optim_hp},
@@ -194,12 +202,18 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
+ feature, ae_losses, images = self.rgb_pn(x_c1, x_c2)
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
- y = y.unsqueeze(1).repeat(1, self.rgb_pn.module.num_total_parts)
- losses, images = self.rgb_pn(x_c1, x_c2, y)
- # Combine losses from different data splits
- losses = losses.mean()
+ y = y.repeat(self.rgb_pn.num_total_parts, 1)
+ triplet_loss = self.ba_triplet_loss(feature, y)
+ losses = torch.cat((
+ ae_losses.mean(0),
+ torch.stack((
+ triplet_loss[:self.rgb_pn.hpm_num_parts].mean(),
+ triplet_loss[self.rgb_pn.hpm_num_parts:].mean()
+ ))
+ ))
loss = losses.sum()
loss.backward()
self.optimizer.step()