diff options
Diffstat (limited to 'models/rgb_part_net.py')
-rw-r--r-- | models/rgb_part_net.py | 4 |
1 files changed, 1 insertions, 3 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 0f3b4f4..5012765 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -41,9 +41,7 @@ class RGBPartNet(nn.Module): ) total_parts = sum(hpm_scales) + tfa_num_parts empty_fc = torch.empty(total_parts, out_channels, embedding_dims) - self.fc_mat = nn.Parameter(nn.init.xavier_uniform_(empty_fc)) - - # TODO Weight inti here + self.fc_mat = nn.Parameter(empty_fc) def fc(self, x): return torch.matmul(x, self.fc_mat) |