summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/rgb_part_net.py4
-rw-r--r--preprocess.py2
2 files changed, 3 insertions, 3 deletions
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index bf52efe..c3954bc 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -120,8 +120,8 @@ class RGBPartNet(nn.Module):
f_a = f_a_.view(n, t, -1)
x_a = self.ae.decoder(
f_a.mean(1),
- torch.zeros((n * t, self.f_c_dim), device=device),
- torch.zeros((n * t, self.f_p_dim), device=device)
+ torch.zeros((n, self.f_c_dim), device=device),
+ torch.zeros((n, self.f_p_dim), device=device)
)
return x_a
diff --git a/preprocess.py b/preprocess.py
index 8af58f2..91fa8c2 100644
--- a/preprocess.py
+++ b/preprocess.py
@@ -51,7 +51,7 @@ def result_handler(frame_: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
mask_x1 = (mask_xc + mask_half_width).ceil().int()
# Skip incomplete frames
- if (height < 64 or width < 64 / MASK_BOX_RATIO) \
+ if (height < 64 or width < 64 / BOX_RATIO_THRESHOLD[1]) \
or (mask_x0 < 0 or mask_x1 > 320) \
or (mask_y0 < 0 or mask_y1 > 240):
continue