summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-12 20:32:21 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-12 20:32:21 +0800
commit544cd7d3408191a9cabb5e0f2e6e83e2a2a7782e (patch)
treeea47424bf7c87224c2a0e03d7104d1fdf2566d11 /models/model.py
parent4495e10e5159ee54814e73ae7bfb27db2754ae69 (diff)
Set default learning rate
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/models/model.py b/models/model.py
index aad0a99..b86a050 100644
--- a/models/model.py
+++ b/models/model.py
@@ -118,8 +118,8 @@ class Model:
# Prepare for optimizer and scheduler
optim_hp = self.hp.get('optimizer', {})
# Scale learning rate to world size
- if optim_hp['lr']:
- optim_hp['lr'] *= xm.xrt_world_size()
+ lr = optim_hp.get('lr', '1-e3')
+ optim_hp['lr'] = lr * xm.xrt_world_size()
sched_hp = self.hp.get('scheduler', {})
device = xm.xla_device()
rgb_pn = wrapped_rgb_pn.to(device)