summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore3
-rw-r--r--main.py0
-rw-r--r--models/__init__.py5
-rw-r--r--models/model.py6
-rw-r--r--models/rgb_part_net.py6
-rw-r--r--test/cuda.py2
-rw-r--r--test/hpm.py2
-rw-r--r--test/rgb_part_net.py2
-rw-r--r--train.py12
-rw-r--r--utils/triplet_loss.py6
10 files changed, 29 insertions, 15 deletions
diff --git a/.gitignore b/.gitignore
index 08f887c..f4b5746 100644
--- a/.gitignore
+++ b/.gitignore
@@ -146,4 +146,5 @@ dmypy.json
# Dataset
data/
-
+# Runtime
+runs/
diff --git a/main.py b/main.py
deleted file mode 100644
index e69de29..0000000
--- a/main.py
+++ /dev/null
diff --git a/models/__init__.py b/models/__init__.py
index c1b9fe8..7040c63 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -1,4 +1 @@
-from .auto_encoder import AutoEncoder
-from .hpm import HorizontalPyramidMatching
-from .part_net import PartNet
-from .rgb_part_net import RGBPartNet
+from .model import Model \ No newline at end of file
diff --git a/models/model.py b/models/model.py
index bf8b5fb..1dc0f23 100644
--- a/models/model.py
+++ b/models/model.py
@@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
-from models import RGBPartNet
+from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
@@ -116,7 +116,7 @@ class Model:
loss.backward()
self.optimizer.step()
# Step scheduler
- self.scheduler.step(self.curr_iter)
+ self.scheduler.step()
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss.item(), self.curr_iter)
@@ -129,7 +129,7 @@ class Model:
print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss),
'(xrecon = {:f}, pose_sim = {:f},'
' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics),
- 'lr:', self.scheduler.get_last_lr())
+ 'lr:', self.scheduler.get_last_lr()[0])
if self.curr_iter % 1000 == 0:
torch.save({
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 73d5952..3037da0 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -4,7 +4,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
-from models import AutoEncoder, HorizontalPyramidMatching, PartNet
+from models.auto_encoder import AutoEncoder
+from models.hpm import HorizontalPyramidMatching
+from models.part_net import PartNet
from utils.triplet_loss import BatchAllTripletLoss
@@ -117,7 +119,7 @@ class RGBPartNet(nn.Module):
# Losses
xrecon_loss = torch.sum(torch.stack(xrecon_loss))
- pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2)
+ pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) * 10
cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
return ((x_c_c1, x_p_c1),
diff --git a/test/cuda.py b/test/cuda.py
index ef0ea36..b1418c4 100644
--- a/test/cuda.py
+++ b/test/cuda.py
@@ -1,6 +1,6 @@
import torch
-from models import RGBPartNet
+from models.rgb_part_net import RGBPartNet
P, K = 2, 4
N, T, C, H, W = P * K, 10, 3, 64, 32
diff --git a/test/hpm.py b/test/hpm.py
index a68337d..0aefbb8 100644
--- a/test/hpm.py
+++ b/test/hpm.py
@@ -1,6 +1,6 @@
import torch
-from models import HorizontalPyramidMatching
+from models.hpm import HorizontalPyramidMatching
T, N, C, H, W = 15, 4, 256, 32, 16
diff --git a/test/rgb_part_net.py b/test/rgb_part_net.py
index 1d754a0..d0d4e91 100644
--- a/test/rgb_part_net.py
+++ b/test/rgb_part_net.py
@@ -1,6 +1,6 @@
import torch
-from models import RGBPartNet
+from models.rgb_part_net import RGBPartNet
P, K = 2, 4
N, T, C, H, W = P * K, 10, 3, 64, 32
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..17cd0f6
--- /dev/null
+++ b/train.py
@@ -0,0 +1,12 @@
+import os
+
+from config import config
+from models import Model
+
+# Set environment variable CUDA device(s)
+CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None)
+if CUDA_VISIBLE_DEVICES:
+ os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
+
+model = Model(config['system'], config['model'], config['hyperparameter'])
+model.fit(config['dataset'], config['dataloader'])
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 242be45..1d63a0e 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -18,7 +18,9 @@ class BatchAllTripletLoss(nn.Module):
x1_squared_sum = x_squared_sum.unsqueeze(1)
x2_squared_sum = x_squared_sum.unsqueeze(2)
x1_times_x2_sum = x @ x.transpose(1, 2)
- dist = torch.sqrt(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum)
+ dist = torch.sqrt(
+ F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum)
+ )
hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
@@ -31,5 +33,5 @@ class BatchAllTripletLoss(nn.Module):
parted_loss_mean = all_loss.sum(1) / (all_loss != 0).sum(1)
parted_loss_mean[parted_loss_mean == float('Inf')] = 0
- loss = parted_loss_mean.mean()
+ loss = parted_loss_mean.sum()
return loss