summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-14 23:43:29 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-14 23:43:29 +0800
commit6ffc1c06f66277d37877fc13fb1ffa585598d6d7 (patch)
tree3c5c12c2d2a695a3dba015bb2b09db3ffab061ee /models/model.py
parentdb5a58b1db9875afbc2a4c7e6e5d190b6c28ee34 (diff)
Enable optimizer fine tuning
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/models/model.py b/models/model.py
index 740cdf3..8797636 100644
--- a/models/model.py
+++ b/models/model.py
@@ -133,12 +133,21 @@ class Model:
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
# Prepare for model, optimizer and scheduler
model_hp = self.hp.get('model', {})
- optim_hp = self.hp.get('optimizer', {})
+ optim_hp: dict = self.hp.get('optimizer', {}).copy()
+ ae_optim_hp = optim_hp.pop('auto_encoder', {})
+ pn_optim_hp = optim_hp.pop('part_net', {})
+ hpm_optim_hp = optim_hp.pop('hpm', {})
+ fc_optim_hp = optim_hp.pop('fc', {})
sched_hp = self.hp.get('scheduler', {})
self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **model_hp)
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
- self.optimizer = optim.Adam(self.rgb_pn.parameters(), **optim_hp)
+ self.optimizer = optim.Adam([
+ {'params': self.rgb_pn.ae.parameters(), **ae_optim_hp},
+ {'params': self.rgb_pn.pn.parameters(), **pn_optim_hp},
+ {'params': self.rgb_pn.hpm.parameters(), **hpm_optim_hp},
+ {'params': self.rgb_pn.fc_mat, **fc_optim_hp},
+ ], **optim_hp)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, **sched_hp)
self.writer = SummaryWriter(self._log_name)