diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-07 18:37:43 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-07 18:37:43 +0800 |
commit | 4a284084c253b9114fc02e1782962556ff113761 (patch) | |
tree | d6ceff8da68b224186d84772ee6153353675bcfe /models/rgb_part_net.py | |
parent | a27af5dfd58e7b48cf3bd063fa2b4b51ed1e0277 (diff) |
Add typical training script and some bug fixes
1. Resolve deprecated scheduler stepping issue
2. Make losses in the same scale(replace mean with sum in separate triplet loss, enlarge pose similarity loss 10x)
3. Add ReLU when compute distance in triplet loss
4. Remove classes except Model from `models` package init
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 73d5952..3037da0 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -4,7 +4,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models import AutoEncoder, HorizontalPyramidMatching, PartNet +from models.auto_encoder import AutoEncoder +from models.hpm import HorizontalPyramidMatching +from models.part_net import PartNet from utils.triplet_loss import BatchAllTripletLoss @@ -117,7 +119,7 @@ class RGBPartNet(nn.Module): # Losses xrecon_loss = torch.sum(torch.stack(xrecon_loss)) - pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) + pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) * 10 cano_cons_loss = torch.mean(torch.stack(cano_cons_loss)) return ((x_c_c1, x_p_c1), |