summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-05 13:21:16 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-05 13:21:16 +0800
commit7158a18a5f2789b2b2902c4b918ed80002970249 (patch)
tree5f9f70fddd82e3d2b6930ab4c2f81b29cc2bdc5a /models/model.py
parent65c1f1a5a88344b67df42ca4f646763a5b80a691 (diff)
Change and improve weight initialization
1. Change initial weights for Conv layers 2. Find a way to init last fc in init_weights
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/models/model.py b/models/model.py
index 7dba949..4354b35 100644
--- a/models/model.py
+++ b/models/model.py
@@ -81,12 +81,14 @@ class Model:
@staticmethod
def init_weights(m):
if isinstance(m, nn.modules.conv._ConvNd):
- nn.init.xavier_uniform_(m.weight)
+ nn.init.normal_(m.weight, 0.0, 0.01)
elif isinstance(m, nn.modules.batchnorm._NormBase):
nn.init.normal_(m.weight, 1.0, 0.01)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
+ elif isinstance(m, RGBPartNet):
+ nn.init.xavier_uniform_(m.fc_mat)
def _parse_dataset_config(
self,