summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-03 20:48:13 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-03 20:48:13 +0800
commit65c1f1a5a88344b67df42ca4f646763a5b80a691 (patch)
tree4c6bb2dc4e4b0bef26f7899319327aaf0c6587da
parent89d677873eb1c99070bd9a33a36f4c6415396756 (diff)
Separate last fc matrix from weight init function
Recursive apply will override other parameters too
-rw-r--r--models/model.py2
-rw-r--r--models/rgb_part_net.py2
2 files changed, 1 insertions, 3 deletions
diff --git a/models/model.py b/models/model.py
index ebd6aba..7dba949 100644
--- a/models/model.py
+++ b/models/model.py
@@ -87,8 +87,6 @@ class Model:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
- elif isinstance(m, nn.Parameter):
- nn.init.xavier_uniform_(m)
def _parse_dataset_config(
self,
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 5012765..ac76dbf 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -41,7 +41,7 @@ class RGBPartNet(nn.Module):
)
total_parts = sum(hpm_scales) + tfa_num_parts
empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
- self.fc_mat = nn.Parameter(empty_fc)
+ self.fc_mat = nn.init.xavier_uniform_(nn.Parameter(empty_fc))
def fc(self, x):
return torch.matmul(x, self.fc_mat)