diff --git a/tests/test_helpers/jit_utils.py b/tests/test_helpers/jit_utils.py index 6a462bc366..c30a9c7964 100644 --- a/tests/test_helpers/jit_utils.py +++ b/tests/test_helpers/jit_utils.py @@ -160,6 +160,9 @@ def gen_prefill_attention_modules( dtype_q=q_dtype, dtype_kv=kv_dtype, ): + if q_dtype != kv_dtype: + continue # fa3 template do not support mixed precision + jit_specs.append( flashinfer.prefill.gen_single_prefill_module( "fa3",