summaryrefslogtreecommitdiff
path: root/models/rgb_part_net.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-03 20:16:16 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-03 20:16:16 +0800
commitdf3a8021b528cc7d585dc17d3e1f3c18a20ed963 (patch)
tree77e60bba2adf5b425812338fb58ec1bbb7bbb6d9 /models/rgb_part_net.py
parentca7119e677e14b209b224fafe4de57780113499f (diff)
Implement weight initialization
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r--models/rgb_part_net.py4
1 files changed, 1 insertions, 3 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 0f3b4f4..5012765 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -41,9 +41,7 @@ class RGBPartNet(nn.Module):
)
total_parts = sum(hpm_scales) + tfa_num_parts
empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
- self.fc_mat = nn.Parameter(nn.init.xavier_uniform_(empty_fc))
-
- # TODO Weight inti here
+ self.fc_mat = nn.Parameter(empty_fc)
def fc(self, x):
return torch.matmul(x, self.fc_mat)