summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/model.py14
-rw-r--r--models/rgb_part_net.py4
2 files changed, 15 insertions, 3 deletions
diff --git a/models/model.py b/models/model.py
index c407d6c..ebd6aba 100644
--- a/models/model.py
+++ b/models/model.py
@@ -2,6 +2,7 @@ from typing import Union, Optional
import numpy as np
import torch
+import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
@@ -67,6 +68,7 @@ class Model:
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
self.rbg_pn.train()
+ self.rbg_pn.apply(self.init_weights)
for iter_i, (x_c1, x_c2) in enumerate(dataloader):
loss = self.rbg_pn(x_c1['clip'], x_c2['clip'], x_c1['label'])
loss.backward()
@@ -76,6 +78,18 @@ class Model:
if iter_i == self.meta['total_iter']:
break
+ @staticmethod
+ def init_weights(m):
+ if isinstance(m, nn.modules.conv._ConvNd):
+ nn.init.xavier_uniform_(m.weight)
+ elif isinstance(m, nn.modules.batchnorm._NormBase):
+ nn.init.normal_(m.weight, 1.0, 0.01)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ elif isinstance(m, nn.Parameter):
+ nn.init.xavier_uniform_(m)
+
def _parse_dataset_config(
self,
dataset_config: DatasetConfiguration
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index 0f3b4f4..5012765 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -41,9 +41,7 @@ class RGBPartNet(nn.Module):
)
total_parts = sum(hpm_scales) + tfa_num_parts
empty_fc = torch.empty(total_parts, out_channels, embedding_dims)
- self.fc_mat = nn.Parameter(nn.init.xavier_uniform_(empty_fc))
-
- # TODO Weight inti here
+ self.fc_mat = nn.Parameter(empty_fc)
def fc(self, x):
return torch.matmul(x, self.fc_mat)