summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/model.py4
-rw-r--r--models/rgb_part_net.py2
2 files changed, 4 insertions, 2 deletions
diff --git a/models/model.py b/models/model.py
index 7dba949..4354b35 100644
--- a/models/model.py
+++ b/models/model.py
@@ -81,12 +81,14 @@ class Model:
@staticmethod
def init_weights(m):
if isinstance(m, nn.modules.conv._ConvNd):
- nn.init.xavier_uniform_(m.weight)
+ nn.init.normal_(m.weight, 0.0, 0.01)
elif isinstance(m, nn.modules.batchnorm._NormBase):
nn.init.normal_(m.weight, 1.0, 0.01)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
+ elif isinstance(m, RGBPartNet):
+ nn.init.xavier_uniform_(m.fc_mat)
def _parse_dataset_config(
self,
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index ac76dbf..5012765 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -41,7 +41,7 @@ class RGBPartNet(nn.Module):
)
total_parts = sum(hpm_scales) + tfa_num_parts
empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
- self.fc_mat = nn.init.xavier_uniform_(nn.Parameter(empty_fc))
+ self.fc_mat = nn.Parameter(empty_fc)
def fc(self, x):
return torch.matmul(x, self.fc_mat)