diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-03 20:48:13 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-03 20:48:13 +0800 |
commit | 65c1f1a5a88344b67df42ca4f646763a5b80a691 (patch) | |
tree | 4c6bb2dc4e4b0bef26f7899319327aaf0c6587da | |
parent | 89d677873eb1c99070bd9a33a36f4c6415396756 (diff) |
Separate last fc matrix from weight init function
Recursive apply will override other parameters too
-rw-r--r-- | models/model.py | 2 | ||||
-rw-r--r-- | models/rgb_part_net.py | 2 |
2 files changed, 1 insertions, 3 deletions
diff --git a/models/model.py b/models/model.py index ebd6aba..7dba949 100644 --- a/models/model.py +++ b/models/model.py @@ -87,8 +87,6 @@ class Model: nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) - elif isinstance(m, nn.Parameter): - nn.init.xavier_uniform_(m) def _parse_dataset_config( self, diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 5012765..ac76dbf 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -41,7 +41,7 @@ class RGBPartNet(nn.Module): ) total_parts = sum(hpm_scales) + tfa_num_parts empty_fc = torch.empty(total_parts, out_channels, embedding_dims) - self.fc_mat = nn.Parameter(empty_fc) + self.fc_mat = nn.init.xavier_uniform_(nn.Parameter(empty_fc)) def fc(self, x): return torch.matmul(x, self.fc_mat) |