diff options
Diffstat (limited to 'models/rgb_part_net.py')
| -rw-r--r-- | models/rgb_part_net.py | 15 | 
1 files changed, 10 insertions, 5 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index 1c7a1a2..6be6b0a 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -10,6 +10,7 @@ from models.auto_encoder import AutoEncoder  class RGBPartNet(nn.Module):      def __init__(              self, +            num_class: int,              ae_in_channels: int = 3,              ae_in_size: Tuple[int, int] = (64, 48),              ae_feature_channels: int = 64, @@ -22,11 +23,15 @@ class RGBPartNet(nn.Module):          self.image_log_on = image_log_on          self.ae = AutoEncoder( -            ae_in_channels, ae_in_size, ae_feature_channels, f_a_c_p_dims +            num_class, +            ae_in_channels, +            ae_in_size, +            ae_feature_channels, +            f_a_c_p_dims          ) -    def forward(self, x_c1, x_c2=None): -        losses, features, images = self._disentangle(x_c1, x_c2) +    def forward(self, x_c1, x_c2=None, y=None): +        losses, features, images = self._disentangle(x_c1, x_c2, y)          if self.training:              losses = torch.stack(losses) @@ -34,11 +39,11 @@ class RGBPartNet(nn.Module):          else:              return features -    def _disentangle(self, x_c1_t2, x_c2_t2=None): +    def _disentangle(self, x_c1_t2, x_c2_t2=None, y=None):          n, t, c, h, w = x_c1_t2.size()          if self.training:              x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :] -            ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2) +            ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2, y)              f_a = f_a_.view(n, t, -1)              f_c = f_c_.view(n, t, -1)              f_p = f_p_.view(n, t, -1)  | 
