diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-03 20:16:16 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-01-03 20:16:16 +0800 |
commit | df3a8021b528cc7d585dc17d3e1f3c18a20ed963 (patch) | |
tree | 77e60bba2adf5b425812338fb58ec1bbb7bbb6d9 | |
parent | ca7119e677e14b209b224fafe4de57780113499f (diff) |
Implement weight initialization
-rw-r--r-- | models/model.py | 14 | ||||
-rw-r--r-- | models/rgb_part_net.py | 4 |
2 files changed, 15 insertions, 3 deletions
diff --git a/models/model.py b/models/model.py index c407d6c..ebd6aba 100644 --- a/models/model.py +++ b/models/model.py @@ -2,6 +2,7 @@ from typing import Union, Optional import numpy as np import torch +import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate @@ -67,6 +68,7 @@ class Model: self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9) self.rbg_pn.train() + self.rbg_pn.apply(self.init_weights) for iter_i, (x_c1, x_c2) in enumerate(dataloader): loss = self.rbg_pn(x_c1['clip'], x_c2['clip'], x_c1['label']) loss.backward() @@ -76,6 +78,18 @@ class Model: if iter_i == self.meta['total_iter']: break + @staticmethod + def init_weights(m): + if isinstance(m, nn.modules.conv._ConvNd): + nn.init.xavier_uniform_(m.weight) + elif isinstance(m, nn.modules.batchnorm._NormBase): + nn.init.normal_(m.weight, 1.0, 0.01) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + elif isinstance(m, nn.Parameter): + nn.init.xavier_uniform_(m) + def _parse_dataset_config( self, dataset_config: DatasetConfiguration diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 0f3b4f4..5012765 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -41,9 +41,7 @@ class RGBPartNet(nn.Module): ) total_parts = sum(hpm_scales) + tfa_num_parts empty_fc = torch.empty(total_parts, out_channels, embedding_dims) - self.fc_mat = nn.Parameter(nn.init.xavier_uniform_(empty_fc)) - - # TODO Weight inti here + self.fc_mat = nn.Parameter(empty_fc) def fc(self, x): return torch.matmul(x, self.fc_mat) |