Skip to content

Conversation

@IwakuraRein
Copy link
Contributor

@IwakuraRein IwakuraRein commented Dec 1, 2025

Purpose

  • Choose ReduceScatterSum over ReduceSum when sizes are the same

    image image
  • Avoid calling fp8 quant kernel twice when using fp8 attention by creating a stage buffer for decode_ql_nope and decode_q_pe. This also eliminates the torch.cat in the flashinfer mla backend.

    image

Test Plan

Launch command

  --dtype auto --kv-cache-dtype fp8 \
  --tensor-parallel-size 8 \
  --swap-space 16 --max-num-seqs 1024 --trust-remote-code --max-model-len 10240 --gpu-memory-utilization 0.95 \
  --max-num-batched-tokens 16384 --async-scheduling \
  --max-cudagraph-capture-size 1024 --compilation_config.cudagraph_mode FULL_DECODE_ONLY

Evaluation command

lm-eval --model local-completions --tasks gsm8k --model_args model=nvidia/DeepSeek-R1-0528-FP4-v2,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=64,max_retries=3,tokenized_requests=False --num_fewshot 20

Models

Test Result

  • Hopper, VLLM_ATTENTION_BACKEND=FLASHMLA

    |Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
    |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
    |gsm8k|      3|flexible-extract|    20|exact_match|↑  |0.9530|±  |0.0058|
    |     |       |strict-match    |    20|exact_match|↑  |0.9522|±  |0.0059|
    
  • Blackwell, VLLM_ATTENTION_BACKEND=FLASHINFER_MLA

    |Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
    |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
    |gsm8k|      3|flexible-extract|    20|exact_match|_  |0.9538|_  |0.0058|
    |     |       |strict-match    |    20|exact_match|_  |0.9522|_  |0.0059|
    

Perf gain: 5% projected improvement in decode when using flashinfer mla backend when serving DeepSeek R1 FP4 on 4 GB200 with DEP4


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@IwakuraRein IwakuraRein force-pushed the improve-attn-fp8-quant branch from 4c10b28 to b39aa3b Compare December 1, 2025 22:01
@IwakuraRein IwakuraRein changed the title [Perf] Improve attn fp8 quant; replace ReduceSum with ReduceScatterSum [Perf] Improve fp8 quant in mla; replace ReduceSum with ReduceScatterSum Dec 1, 2025
@IwakuraRein IwakuraRein marked this pull request as ready for review December 1, 2025 22:48
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@heheda12345
Copy link
Collaborator

CC @MatthewBonanni @LucasWilkinson

Copy link
Contributor

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this contribution! Can you include some benchmark results (or a snippet of a profile) to get a sense of the speedup?

Copy link
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, minor comments

@github-project-automation github-project-automation bot moved this to In review in NVIDIA Dec 4, 2025
@pavanimajety pavanimajety added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 5, 2025
@pavanimajety
Copy link
Collaborator

@IwakuraRein Could you please add in the PR description why ReduceScatterSum is better than ReduceSum when sizes are same? or post some comparison results?

@pavanimajety pavanimajety merged commit 1fb632f into vllm-project:main Dec 8, 2025
58 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in NVIDIA Dec 8, 2025
therealnaveenkamal added a commit to therealnaveenkamal/vllm that referenced this pull request Dec 9, 2025
Signed-off-by: Naveenraj Kamalakannan <[email protected]>
mayoohee pushed a commit to mayoohee/vllm that referenced this pull request Dec 9, 2025
@IwakuraRein IwakuraRein deleted the improve-attn-fp8-quant branch December 9, 2025 17:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants