summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/auto_encoder.py9
-rw-r--r--models/model.py24
-rw-r--r--models/rgb_part_net.py15
3 files changed, 36 insertions, 12 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 91071dd..9bfb365 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -110,7 +110,7 @@ class Decoder(nn.Module):
x = torch.cat((f_appearance, f_canonical, f_pose), dim=1)
x = self.fc(x)
x = x.view(-1, self.feature_channels * 8, self.h_0, self.w_0)
- x = F.relu(x, inplace=True)
+ x = F.leaky_relu(x, 0.2, inplace=True)
x = self.trans_conv1(x)
x = self.trans_conv2(x)
x = self.trans_conv3(x)
@@ -122,6 +122,7 @@ class Decoder(nn.Module):
class AutoEncoder(nn.Module):
def __init__(
self,
+ num_class: int,
channels: int = 3,
frame_size: tuple[int, int] = (64, 48),
feature_channels: int = 64,
@@ -132,8 +133,9 @@ class AutoEncoder(nn.Module):
feature_channels, embedding_dims)
self.decoder = Decoder(embedding_dims, feature_channels,
self.encoder.feature_size, channels)
+ self.classifier = BasicLinear(embedding_dims[1], num_class)
- def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None):
+ def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None, y=None):
n, t, c, h, w = x_c1_t2.size()
# x_c1_t2 is the frame for later module
x_c1_t2_ = x_c1_t2.view(n * t, c, h, w)
@@ -160,6 +162,9 @@ class AutoEncoder(nn.Module):
cano_cons_loss = torch.stack([
F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
+ F.mse_loss(f_c_c1_t2[:, i, :], f_c_c2_t2[:, i, :])
+ + F.cross_entropy(self.classifier(
+ F.leaky_relu(f_c_c1_t2[:, i, :], 0.2)
+ ), y)
for i in range(t)
]).mean()
diff --git a/models/model.py b/models/model.py
index 0829f33..d976f5a 100644
--- a/models/model.py
+++ b/models/model.py
@@ -54,6 +54,7 @@ class Model:
self.total_iters = self.meta.get('total_iters', (self.total_iter,))
self.is_train: bool = True
+ self.num_class: Optional[int] = None
self.in_channels: int = 3
self.in_size: tuple[int, int] = (64, 48)
self.batch_size: Optional[int] = None
@@ -160,8 +161,13 @@ class Model:
optim_hp: dict = self.hp.get('optimizer', {}).copy()
sched_hp = self.hp.get('scheduler', {})
- self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
- image_log_on=self.image_log_on)
+ self.rgb_pn = RGBPartNet(
+ self.num_class,
+ self.in_channels,
+ self.in_size,
+ **model_hp,
+ image_log_on=self.image_log_on
+ )
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
@@ -202,7 +208,8 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- losses, features, images = self.rgb_pn(x_c1, x_c2)
+ y = batch_c1['label'].to(self.device)
+ losses, features, images = self.rgb_pn(x_c1, x_c2, y)
loss = losses.sum()
loss.backward()
self.optimizer.step()
@@ -254,8 +261,9 @@ class Model:
batch_c1, batch_c2 = next(val_dataloader)
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
+ y = batch_c1['label'].to(self.device)
with torch.no_grad():
- losses, _, _ = self.rgb_pn(x_c1, x_c2)
+ losses, _, _ = self.rgb_pn(x_c1, x_c2, y)
loss = losses.sum()
self._write_stat('Val', loss, losses)
@@ -302,7 +310,12 @@ class Model:
# Init models
model_hp: dict = self.hp.get('model', {}).copy()
- self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
+ self.rgb_pn = RGBPartNet(
+ self.num_class,
+ self.in_channels,
+ self.in_size,
+ **model_hp
+ )
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
self.rgb_pn.eval()
@@ -419,6 +432,7 @@ class Model:
self,
dataset_config: DatasetConfiguration
) -> Union[CASIAB]:
+ self.num_class = dataset_config.get('train_size', 74)
self.in_channels = dataset_config.get('num_input_channels', 3)
self.in_size = dataset_config.get('frame_size', (64, 48))
self._dataset_sig = self._make_signature(
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index d3f8ade..d5eb619 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -8,6 +8,7 @@ from models.auto_encoder import AutoEncoder
class RGBPartNet(nn.Module):
def __init__(
self,
+ num_class: int,
ae_in_channels: int = 3,
ae_in_size: tuple[int, int] = (64, 48),
ae_feature_channels: int = 64,
@@ -20,11 +21,15 @@ class RGBPartNet(nn.Module):
self.image_log_on = image_log_on
self.ae = AutoEncoder(
- ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims
+ num_class,
+ ae_in_channels,
+ ae_in_size,
+ ae_feature_channels,
+ f_a_c_p_dims
)
- def forward(self, x_c1, x_c2=None):
- losses, features, images = self._disentangle(x_c1, x_c2)
+ def forward(self, x_c1, x_c2=None, y=None):
+ losses, features, images = self._disentangle(x_c1, x_c2, y)
if self.training:
losses = torch.stack(losses)
@@ -32,11 +37,11 @@ class RGBPartNet(nn.Module):
else:
return features
- def _disentangle(self, x_c1_t2, x_c2_t2=None):
+ def _disentangle(self, x_c1_t2, x_c2_t2=None, y=None):
n, t, c, h, w = x_c1_t2.size()
if self.training:
x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
- ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
+ ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2, y)
f_a = f_a_.view(n, t, -1)
f_c = f_c_.view(n, t, -1)
f_p = f_p_.view(n, t, -1)