diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-18 16:01:26 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2021-02-18 17:13:34 +0800 |
commit | fd77ad26d5c4ede79e3406e736fcdaa29eb1c7c9 (patch) | |
tree | 163889734f97e4c2c94e607e9dd2312d1b8fc43d | |
parent | 2a9507204ae2dd14504556ab5885c4f39bddd89a (diff) |
Decode mean appearance feature
-rw-r--r-- | models/rgb_part_net.py | 4 | ||||
-rw-r--r-- | preprocess.py | 2 |
2 files changed, 3 insertions, 3 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index bf52efe..c3954bc 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -120,8 +120,8 @@ class RGBPartNet(nn.Module): f_a = f_a_.view(n, t, -1) x_a = self.ae.decoder( f_a.mean(1), - torch.zeros((n * t, self.f_c_dim), device=device), - torch.zeros((n * t, self.f_p_dim), device=device) + torch.zeros((n, self.f_c_dim), device=device), + torch.zeros((n, self.f_p_dim), device=device) ) return x_a diff --git a/preprocess.py b/preprocess.py index 8af58f2..91fa8c2 100644 --- a/preprocess.py +++ b/preprocess.py @@ -51,7 +51,7 @@ def result_handler(frame_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: mask_x1 = (mask_xc + mask_half_width).ceil().int() # Skip incomplete frames - if (height < 64 or width < 64 / MASK_BOX_RATIO) \ + if (height < 64 or width < 64 / BOX_RATIO_THRESHOLD[1]) \ or (mask_x0 < 0 or mask_x1 > 320) \ or (mask_y0 < 0 or mask_y1 > 240): continue |