diff options
-rw-r--r-- | .gitignore | 3 | ||||
-rw-r--r-- | main.py | 0 | ||||
-rw-r--r-- | models/__init__.py | 5 | ||||
-rw-r--r-- | models/model.py | 6 | ||||
-rw-r--r-- | models/rgb_part_net.py | 6 | ||||
-rw-r--r-- | test/cuda.py | 2 | ||||
-rw-r--r-- | test/hpm.py | 2 | ||||
-rw-r--r-- | test/rgb_part_net.py | 2 | ||||
-rw-r--r-- | train.py | 12 | ||||
-rw-r--r-- | utils/triplet_loss.py | 6 |
10 files changed, 29 insertions, 15 deletions
@@ -146,4 +146,5 @@ dmypy.json # Dataset data/ - +# Runtime +runs/ diff --git a/main.py b/main.py deleted file mode 100644 index e69de29..0000000 --- a/main.py +++ /dev/null diff --git a/models/__init__.py b/models/__init__.py index c1b9fe8..7040c63 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,4 +1 @@ -from .auto_encoder import AutoEncoder -from .hpm import HorizontalPyramidMatching -from .part_net import PartNet -from .rgb_part_net import RGBPartNet +from .model import Model
\ No newline at end of file diff --git a/models/model.py b/models/model.py index bf8b5fb..1dc0f23 100644 --- a/models/model.py +++ b/models/model.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate from torch.utils.tensorboard import SummaryWriter -from models import RGBPartNet +from models.rgb_part_net import RGBPartNet from utils.configuration import DataloaderConfiguration, \ HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \ SystemConfiguration @@ -116,7 +116,7 @@ class Model: loss.backward() self.optimizer.step() # Step scheduler - self.scheduler.step(self.curr_iter) + self.scheduler.step() # Write losses to TensorBoard self.writer.add_scalar('Loss/all', loss.item(), self.curr_iter) @@ -129,7 +129,7 @@ class Model: print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss), '(xrecon = {:f}, pose_sim = {:f},' ' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics), - 'lr:', self.scheduler.get_last_lr()) + 'lr:', self.scheduler.get_last_lr()[0]) if self.curr_iter % 1000 == 0: torch.save({ diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 73d5952..3037da0 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -4,7 +4,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models import AutoEncoder, HorizontalPyramidMatching, PartNet +from models.auto_encoder import AutoEncoder +from models.hpm import HorizontalPyramidMatching +from models.part_net import PartNet from utils.triplet_loss import BatchAllTripletLoss @@ -117,7 +119,7 @@ class RGBPartNet(nn.Module): # Losses xrecon_loss = torch.sum(torch.stack(xrecon_loss)) - pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) + pose_sim_loss = self._pose_sim_loss(f_p_c1, f_p_c2) * 10 cano_cons_loss = torch.mean(torch.stack(cano_cons_loss)) return ((x_c_c1, x_p_c1), diff --git a/test/cuda.py b/test/cuda.py index ef0ea36..b1418c4 100644 --- a/test/cuda.py +++ b/test/cuda.py @@ -1,6 +1,6 @@ import torch -from models import RGBPartNet +from models.rgb_part_net import RGBPartNet P, K = 2, 4 N, T, C, H, W = P * K, 10, 3, 64, 32 diff --git a/test/hpm.py b/test/hpm.py index a68337d..0aefbb8 100644 --- a/test/hpm.py +++ b/test/hpm.py @@ -1,6 +1,6 @@ import torch -from models import HorizontalPyramidMatching +from models.hpm import HorizontalPyramidMatching T, N, C, H, W = 15, 4, 256, 32, 16 diff --git a/test/rgb_part_net.py b/test/rgb_part_net.py index 1d754a0..d0d4e91 100644 --- a/test/rgb_part_net.py +++ b/test/rgb_part_net.py @@ -1,6 +1,6 @@ import torch -from models import RGBPartNet +from models.rgb_part_net import RGBPartNet P, K = 2, 4 N, T, C, H, W = P * K, 10, 3, 64, 32 diff --git a/train.py b/train.py new file mode 100644 index 0000000..17cd0f6 --- /dev/null +++ b/train.py @@ -0,0 +1,12 @@ +import os + +from config import config +from models import Model + +# Set environment variable CUDA device(s) +CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None) +if CUDA_VISIBLE_DEVICES: + os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES + +model = Model(config['system'], config['model'], config['hyperparameter']) +model.fit(config['dataset'], config['dataloader']) diff --git a/utils/triplet_loss.py b/utils/triplet_loss.py index 242be45..1d63a0e 100644 --- a/utils/triplet_loss.py +++ b/utils/triplet_loss.py @@ -18,7 +18,9 @@ class BatchAllTripletLoss(nn.Module): x1_squared_sum = x_squared_sum.unsqueeze(1) x2_squared_sum = x_squared_sum.unsqueeze(2) x1_times_x2_sum = x @ x.transpose(1, 2) - dist = torch.sqrt(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) + dist = torch.sqrt( + F.relu(x1_squared_sum - 2 * x1_times_x2_sum + x2_squared_sum) + ) hard_positive_mask = y.unsqueeze(1) == y.unsqueeze(2) hard_negative_mask = y.unsqueeze(1) != y.unsqueeze(2) @@ -31,5 +33,5 @@ class BatchAllTripletLoss(nn.Module): parted_loss_mean = all_loss.sum(1) / (all_loss != 0).sum(1) parted_loss_mean[parted_loss_mean == float('Inf')] = 0 - loss = parted_loss_mean.mean() + loss = parted_loss_mean.sum() return loss |