Skip to content

[Bug]: TRTLLM attention + full cudagraph produces incorrect output at long context length (>128k) on Blackwell #1968

@xinli-sw

Description

@xinli-sw

This is discovered in vLLM, thanks to @mgoin's findings.

The original issue was found because Qwen3-VL models completely lost accuracy (1% vs 86% on GSM8K) on B200 GPUs, the issue happens when

  • trtllm attention backend is used
  • model's original max_seq_len > 128k
  • full cudagraph is used (i.e. attention kernel is executed with CG)

How to reproduce:

On vllm main:

vLLM introduced a workaround, but has performance implications

First comment out https://github.com/vllm-project/vllm/blob/main/vllm/config/vllm.py#L379-L393

vllm serve gradientai/Llama-3-8B-Instruct-Gradient-1048k  -O.cudagraph_mode=FULL_AND_PIECEWISE

On a separate terminal

git clone https://github.com/vllm-project/vllm.git && cd vllm 
python3 tests/evals/gsm8k/gsm8k_eval.py

The results are broken

Results:
Accuracy: 0.004
Invalid responses: 0.867
Total latency: 55.738 s
Questions per second: 23.664

Launch without TRTLLM-gen attention

VLLM_USE_TRTLLM_ATTENTION=0 vllm serve gradientai/Llama-3-8B-Instruct-Gradient-1048k  -O.cudagraph_mode=FULL_AND_PIECEWISE 

On a separate terminal

python3 tests/evals/gsm8k/gsm8k_eval.py

Results are OK

Results:
Accuracy: 0.578
Invalid responses: 0.005
Total latency: 16.689 s
Questions per second: 79.036

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions