From fd77ad26d5c4ede79e3406e736fcdaa29eb1c7c9 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Thu, 18 Feb 2021 16:01:26 +0800 Subject: Decode mean appearance feature --- models/rgb_part_net.py | 4 ++-- preprocess.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py index bf52efe..c3954bc 100644 --- a/models/rgb_part_net.py +++ b/models/rgb_part_net.py @@ -120,8 +120,8 @@ class RGBPartNet(nn.Module): f_a = f_a_.view(n, t, -1) x_a = self.ae.decoder( f_a.mean(1), - torch.zeros((n * t, self.f_c_dim), device=device), - torch.zeros((n * t, self.f_p_dim), device=device) + torch.zeros((n, self.f_c_dim), device=device), + torch.zeros((n, self.f_p_dim), device=device) ) return x_a diff --git a/preprocess.py b/preprocess.py index 8af58f2..91fa8c2 100644 --- a/preprocess.py +++ b/preprocess.py @@ -51,7 +51,7 @@ def result_handler(frame_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: mask_x1 = (mask_xc + mask_half_width).ceil().int() # Skip incomplete frames - if (height < 64 or width < 64 / MASK_BOX_RATIO) \ + if (height < 64 or width < 64 / BOX_RATIO_THRESHOLD[1]) \ or (mask_x0 < 0 or mask_x1 > 320) \ or (mask_y0 < 0 or mask_y1 > 240): continue -- cgit v1.2.3