summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/models/model.py b/models/model.py
index 740cdf3..8797636 100644
--- a/models/model.py
+++ b/models/model.py
@@ -133,12 +133,21 @@ class Model:
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
# Prepare for model, optimizer and scheduler
model_hp = self.hp.get('model', {})
- optim_hp = self.hp.get('optimizer', {})
+ optim_hp: dict = self.hp.get('optimizer', {}).copy()
+ ae_optim_hp = optim_hp.pop('auto_encoder', {})
+ pn_optim_hp = optim_hp.pop('part_net', {})
+ hpm_optim_hp = optim_hp.pop('hpm', {})
+ fc_optim_hp = optim_hp.pop('fc', {})
sched_hp = self.hp.get('scheduler', {})
self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **model_hp)
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
- self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp)
+ self.optimizer = optim.Adam([
+ {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
+ {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
+ {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
+ {'params': self.rgb_pn.fc_mat, **fc_optim_hp},
+ ], **optim_hp)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, **sched_hp)
self.writer = SummaryWriter(self._log_name)