summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-07 18:37:43 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-07 18:37:43 +0800
commit4a284084c253b9114fc02e1782962556ff113761 (patch)
treed6ceff8da68b224186d84772ee6153353675bcfe
parenta27af5dfd58e7b48cf3bd063fa2b4b51ed1e0277 (diff)
Add typical training script and some bug fixes
1. Resolve deprecated scheduler stepping issue 2. Make losses in the same scale(replace mean with sum in separate triplet loss, enlarge pose similarity loss 10x) 3. Add ReLU when compute distance in triplet loss 4. Remove classes except Model from `models` package init
-rw-r--r--.gitignore3
-rw-r--r--main.py0
-rw-r--r--models/__init__.py5
-rw-r--r--models/model.py6
-rw-r--r--models/rgb_part_net.py6
-rw-r--r--test/cuda.py2
-rw-r--r--test/hpm.py2
-rw-r--r--test/rgb_part_net.py2
-rw-r--r--train.py12
-rw-r--r--utils/triplet_loss.py6
10 files changed, 29 insertions, 15 deletions
diff --git a/.gitignore b/.gitignore
index 08f887c..f4b5746 100644
--- a/.gitignore
+++ b/.gitignore
@@ -146,4 +146,5 @@ dmypy.json
# Dataset
data/
-
+# Runtime
+runs/
diff --git a/main.py b/main.py
deleted file mode 100644
index e69de29..0000000
--- a/main.py
+++ /dev/null
diff --git a/models/__init__.py b/models/__init__.py
index c1b9fe8..7040c63 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -1,4 +1 @@
-from .auto_encoder import AutoEncoder
-from .hpm import HorizontalPyramidMatching
-from .part_net import PartNet
-from .rgb_part_net import RGBPartNet
+from .model import Model \ No newline at end of file
diff --git a/models/model.py b/models/model.py
index bf8b5fb..1dc0f23 100644
--- a/models/model.py
+++ b/models/model.py
@@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
-from models import RGBPartNet
+from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
@@ -116,7 +116,7 @@ class Model:
loss.backward()
self.optimizer.step()
# Step scheduler
- self.scheduler.step(self.curr_iter)
+ self.scheduler.step()
# Write losses to TensorBoard
self.writer.add_scalar('Loss/all', loss.item(), self.curr_iter)
@@ -129,7 +129,7 @@ class Model:
print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss),
'(xrecon = {:f}, pose_sim = {:f},'
' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics),
- 'lr:', self.scheduler.get_last_lr())
+ 'lr:', self.scheduler.get_last_lr()[0])
if self.curr_iter % 1000 == 0:
torch.save({
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 73d5952..3037da0 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -4,7 +4,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
-from models import AutoEncoder, HorizontalPyramidMatching, PartNet
+from models.auto_encoder import AutoEncoder
+from models.hpm import HorizontalPyramidMatching
+from models.part_net import PartNet
from utils.triplet_loss import BatchAllTripletLoss
@@ -117,7 +119,7 @@ class RGBPartNet(nn.Module):
# Losses
xrecon_loss = torch.sum(torch.stack(xrecon_loss))
- pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2)
+ pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) * 10
cano_cons_loss = torch.mean(torch.stack(cano_cons_loss))
return ((x_c_c1, x_p_c1),
diff --git a/test/cuda.py b/test/cuda.py
index ef0ea36..b1418c4 100644
--- a/test/cuda.py
+++ b/test/cuda.py
@@ -1,6 +1,6 @@
import torch
-from models import RGBPartNet
+from models.rgb_part_net import RGBPartNet
P, K = 2, 4
N, T, C, H, W = P * K, 10, 3, 64, 32
diff --git a/test/hpm.py b/test/hpm.py
index a68337d..0aefbb8 100644
--- a/test/hpm.py
+++ b/test/hpm.py
@@ -1,6 +1,6 @@
import torch
-from models import HorizontalPyramidMatching
+from models.hpm import HorizontalPyramidMatching
T, N, C, H, W = 15, 4, 256, 32, 16
diff --git a/test/rgb_part_net.py b/test/rgb_part_net.py
index 1d754a0..d0d4e91 100644
--- a/test/rgb_part_net.py
+++ b/test/rgb_part_net.py
@@ -1,6 +1,6 @@
import torch
-from models import RGBPartNet
+from models.rgb_part_net import RGBPartNet
P, K = 2, 4
N, T, C, H, W = P * K, 10, 3, 64, 32
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..17cd0f6
--- /dev/null
+++ b/train.py
@@ -0,0 +1,12 @@
+import os
+
+from config import config
+from models import Model
+
+# Set environment variable CUDA device(s)
+CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None)
+if CUDA_VISIBLE_DEVICES:
+ os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
+
+model = Model(config['system'], config['model'], config['hyperparameter'])
+model.fit(config['dataset'], config['dataloader'])
diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py
index 242be45..1d63a0e 100644
--- a/utils/triplet_loss.py
+++ b/utils/triplet_loss.py
@@ -18,7 +18,9 @@ class BatchAllTripletLoss(nn.Module):
x1_squared_sum = x_squared_sum.unsqueeze(1)
x2_squared_sum = x_squared_sum.unsqueeze(2)
x1_times_x2_sum = x @ x.transpose(1, 2)
- dist = torch.sqrt(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum)
+ dist = torch.sqrt(
+ F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum)
+ )
hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2)
hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2)
@@ -31,5 +33,5 @@ class BatchAllTripletLoss(nn.Module):
parted_loss_mean = all_loss.sum(1) / (all_loss != 0).sum(1)
parted_loss_mean[parted_loss_mean == float('Inf')] = 0
- loss = parted_loss_mean.mean()
+ loss = parted_loss_mean.sum()
return loss