summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 73d5952..3037da0 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -4,7 +4,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
-from models import AutoEncoder, HorizontalPyramidMatching, PartNet
+from models.auto_encoder import AutoEncoder
+from models.hpm import HorizontalPyramidMatching
+from models.part_net import PartNet
from utils.triplet_loss import BatchAllTripletLoss
@@ -117,7 +119,7 @@ class RGBPartNet(nn.Module):
# Losses
xrecon_loss = torch.sum(torch.stack(xrecon_loss))
- pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2)
+ pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) * 10
cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
return ((x_c_c1, x_p_c1),