diff --git a/tests/kernels/test_grpo_loss.py b/tests/kernels/test_grpo_loss.py index c7cbd077..7cc4672a 100644 --- a/tests/kernels/test_grpo_loss.py +++ b/tests/kernels/test_grpo_loss.py @@ -94,7 +94,7 @@ def grpo_loss_kernel(kl_type="unbias"): T = 32 H = 256 V = 1024 - BLOCK_SIZE_T = 8 + BLOCK_SIZE_T = 4 torch.manual_seed(42) advantages = torch.randn((T,), dtype=torch.float32, device=device_, requires_grad=True)