Skip to content

Add mxfp8 support for online quantization, Triton dense linear, and CUTLASS MoE#17449

Merged
ispobock merged 40 commits intosgl-project:mainfrom
zianglih:mxfp8-no-dg
Jan 29, 2026
Merged

Add mxfp8 support for online quantization, Triton dense linear, and CUTLASS MoE#17449
ispobock merged 40 commits intosgl-project:mainfrom
zianglih:mxfp8-no-dg

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Jan 21, 2026

Motivation

@HumansAnd

#17093
This PR adds mxfp8 quantization support to SGLang, using Triton for dense linear layer, and existing mxfp8 CUTLASS groped GEMM kernel in sgl-kernel from #13731 for MoE.

Online mxfp8 quantization from bf16 checkpoints and serving mxfp8 checkpoints directly are both supported.

Modifications

Accuracy Tests


# Eval:
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum

# bf16:
python -m sglang.launch_server --tp 2 --model Qwen/Qwen3-30B-A3B-Instruct-2507 --fp8-gemm-backend triton --moe-runner-backend cutlass  &
# Trial 1
Accuracy: 0.964
Invalid: 0.000
Latency: 14.543 s
Output throughput: 11779.580 token/s
# Trial 2
Accuracy: 0.964
Invalid: 0.000
Latency: 13.427 s
Output throughput: 12778.580 token/s
# Trial 3
Accuracy: 0.964
Invalid: 0.000
Latency: 13.049 s
Output throughput: 13128.875 token/s
# Online mxfp8:
python -m sglang.launch_server --tp 2 --model Qwen/Qwen3-30B-A3B-Instruct-2507 --quantization mxfp8 --fp8-gemm-backend triton --moe-runner-backend cutlass  &
# Trial 1
Accuracy: 0.964
Invalid: 0.000
Latency: 23.825 s
Output throughput: 7197.115 token/s
# Trial 2
Accuracy: 0.966
Invalid: 0.000
Latency: 18.390 s
Output throughput: 9310.458 token/s
# Trial 3
Accuracy: 0.966
Invalid: 0.000
Latency: 18.178 s
Output throughput: 9419.016 token/s
# Offline mxfp8:
python -m sglang.launch_server --tp 1 --model /data/models/Qwen3-30B-A3B-Instruct-2507-MXFP8 --fp8-gemm-backend triton --moe-runner-backend cutlass  &
# Trial 1
Accuracy: 0.963
Invalid: 0.000
Latency: 18.698 s
Output throughput: 9075.207 token/s
# Trial 2
Accuracy: 0.963
Invalid: 0.000
Latency: 18.297 s
Output throughput: 9273.994 token/s
# Trial 3
Accuracy: 0.963
Invalid: 0.000
Latency: 17.752 s
Output throughput: 9558.566 token/s

Benchmarking and Profiling

Next Steps

  • Latest DeepGEMM includes mxfp8 & mxfp4 kernels by passing in recipe = (1, 1, 32). Add DeepGEMM backend for mxfp8 when sgl-kernel version is bumped.
  • Improve Triton mxfp8 kernel performance.
  • Blackwell mxfp8 RL integration.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

/sgl-workspace/sglang# python python/sglang/test/test_block_fp8.py -k TestMXFP8DenseLinear
test_mxfp8_dense_linear (__main__.TestMXFP8DenseLinear.test_mxfp8_dense_linear) ... [CI Test Method] TestMXFP8DenseLinear.test_mxfp8_dense_linear
ok

----------------------------------------------------------------------
Ran 1 test in 1.012s

OK
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
BEFORE
Accuracy: 0.817
Invalid: 0.000
Latency: 25.296 s
Output throughput: 9041.457 token/s

AFTER
Accuracy: 0.849
Invalid: 0.000
Latency: 14.757 s
Output throughput: 14630.467 token/s
@zianglih zianglih changed the title Add mxfp8 support for online quantization, Triton linear, and CUTLASS MoE Add mxfp8 support for online quantization, Triton dense linear, and CUTLASS MoE Jan 21, 2026
@ispobock
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@zianglih
Copy link
Copy Markdown
Contributor Author

zianglih commented Jan 26, 2026

Implemented a few minor fixes for /update_weights_from_disk.
Test for online mxfp8 quantization:

python -m sglang.launch_server --tp 2 --model Qwen/Qwen3-30B-A3B-Instruct-2507 --quantization mxfp8 --fp8-gemm-backend triton --moe-runner-backend cutlass  &
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.964
Invalid: 0.000
Latency: 20.583 s
Output throughput: 8330.881 token/s

Test for Qwen3-4B-Instruct-2507-MXFP8 (dense), with /update_weights_from_disk:

python -m sglang.launch_server --tp 1 --model /data/models/Qwen3-4B-Instruct-2507-MXFP8 --fp8-gemm-backend triton --moe-runner-backend cutlass  &
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.849
Invalid: 0.000
Latency: 16.115 s
Output throughput: 13397.779 token/s
curl -sS http://localhost:30000/update_weights_from_disk \
  -H 'Content-Type: application/json' \
  -d '{
    "model_path": "/data/models/Qwen3-4B-Instruct-2507-MXFP8",
    "flush_cache": true,
    "abort_all_requests": false
  }'
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.849
Invalid: 0.000
Latency: 15.149 s
Output throughput: 14251.889 token/s

Test for Qwen3-30B-A3B-Instruct-2507-MXFP8 (MoE), with /update_weights_from_disk:


python -m sglang.launch_server --tp 1 --model /data/models/Qwen3-30B-A3B-Instruct-2507-MXFP8 --fp8-gemm-backend triton --moe-runner-backend cutlass  &
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.963
Invalid: 0.000
Latency: 18.686 s
Output throughput: 9081.105 token/s
curl -sS http://localhost:30000/update_weights_from_disk \
  -H 'Content-Type: application/json' \
  -d '{
    "model_path": "/data/models/Qwen3-30B-A3B-Instruct-2507-MXFP8",
    "flush_cache": true,
    "abort_all_requests": false
  }'
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum
Accuracy: 0.963
Invalid: 0.000
Latency: 19.230 s
Output throughput: 8823.926 token/s

Implementation is now fully complete.

@ispobock ispobock merged commit 3c9cc44 into sgl-project:main Jan 29, 2026
205 of 223 checks passed
Comment on lines +344 to +347
if self.use_mxfp8 and not self.is_checkpoint_fp8_serialized:
raise ValueError(
"MXFP8 requires fp8-serialized checkpoint for linear layers."
)
Copy link
Copy Markdown
Contributor

@fxmarty-amd fxmarty-amd Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is surprising given the snippet below?

elif self.use_mxfp8:
if not self.is_checkpoint_fp8_serialized:
self._quantize_mxfp8_weights(layer)
return

Should this rather say that it is simply untested? Or should this error be removed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mxfp8 online quantization from bf16 is tested as in the Accuracy Test section. This raise is never triggered so can be removed.

charlesHsuGG pushed a commit to charlesHsuGG/sglang that referenced this pull request Jan 30, 2026
Chen-0210 pushed a commit to Chen-0210/sglang that referenced this pull request Jan 30, 2026
sfiisf pushed a commit to sfiisf/sglang that referenced this pull request Feb 5, 2026
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Socratesa added a commit to Socratesa/sglang that referenced this pull request Feb 27, 2026
The per-tensor FP8 MoE path in process_weights_after_loading() replaced
Parameter objects with new ones via torch.nn.Parameter(), destroying
custom attributes (weight_loader, quant_method) needed by EPLB weight
hot-reload.  The block-quant path was already fixed (PR sgl-project#17449) but the
per-tensor path was missed.

Additionally, update_weights_from_disk() unconditionally called
process_weights_after_loading() on ALL modules even during partial
reloads (e.g. EPLB expert rebalancing), causing non-expert layers like
FP8 attention to be double-processed (double transpose -> shape mismatch).

Changes in fp8.py:
- Dynamic-quant path: use .data= for weight rebinding and
  .data[expert].fill_() for scale updates instead of Parameter replacement.
- Checkpoint-FP8 path: use .data.fill_() for input_scale merging;
  fill both columns of the [E,2] w13_weight_scale in-place instead of
  replacing with a new [E] Parameter.
- DeepGemm apply(): add ndim==2 guard to collapse w13_weight_scale
  from [E,2] to [E] via [:,0] before expanding to block shape.

Changes in model_runner.py:
- When weight_name_filter is set (EPLB expert rebalancing), split
  load_weights and process_weights_after_loading into two steps,
  only calling the latter on modules whose names match the filter.

Signed-off-by: Socratesa <lihaode@zju.edu.cn>
Socratesa added a commit to Socratesa/sglang that referenced this pull request Feb 27, 2026
The per-tensor FP8 MoE path in process_weights_after_loading() replaced
Parameter objects with new ones via torch.nn.Parameter(), destroying
custom attributes (weight_loader, quant_method) needed by EPLB weight
hot-reload.  The block-quant path was already fixed (PR sgl-project#17449) but the
per-tensor path was missed.

Additionally, update_weights_from_disk() unconditionally called
process_weights_after_loading() on ALL modules even during partial
reloads (e.g. EPLB expert rebalancing), causing non-expert layers like
FP8 attention to be double-processed (double transpose -> shape mismatch).

Changes in fp8.py:
- Dynamic-quant path: use .data= for weight rebinding and
  .data[expert].fill_() for scale updates instead of Parameter replacement.
- Checkpoint-FP8 path: use .data.fill_() for input_scale merging;
  fill both columns of the [E,2] w13_weight_scale in-place instead of
  replacing with a new [E] Parameter.
- DeepGemm apply(): add ndim==2 guard to collapse w13_weight_scale
  from [E,2] to [E] via [:,0] before expanding to block shape.

Changes in model_runner.py:
- When weight_name_filter is set (EPLB expert rebalancing), split
  load_weights and process_weights_after_loading into two steps,
  only calling the latter on modules whose names match the filter.

Signed-off-by: Socratesa <lihaode@zju.edu.cn>
Socratesa added a commit to Socratesa/sglang that referenced this pull request Feb 27, 2026
The per-tensor FP8 MoE path in process_weights_after_loading() replaced
Parameter objects with new ones via torch.nn.Parameter(), destroying
custom attributes (weight_loader, quant_method) needed by EPLB weight
hot-reload.  The block-quant path was already fixed (PR sgl-project#17449) but the
per-tensor path was missed.

Additionally, update_weights_from_disk() unconditionally called
process_weights_after_loading() on ALL modules even during partial
reloads (e.g. EPLB expert rebalancing), causing non-expert layers like
FP8 attention to be double-processed (double transpose -> shape mismatch).

Changes in fp8.py:
- Dynamic-quant path: use .data= for weight rebinding and
  .data[expert].fill_() for scale updates instead of Parameter replacement.
- Checkpoint-FP8 path: use .data.fill_() for input_scale merging;
  fill both columns of the [E,2] w13_weight_scale in-place instead of
  replacing with a new [E] Parameter.
- DeepGemm apply(): add ndim==2 guard to collapse w13_weight_scale
  from [E,2] to [E] via [:,0] before expanding to block shape.

Changes in model_runner.py:
- When weight_name_filter is set (EPLB expert rebalancing), split
  load_weights and process_weights_after_loading into two steps,
  only calling the latter on modules whose names match the filter.

Signed-off-by: Socratesa <lihaode@zju.edu.cn>
@zianglih zianglih deleted the mxfp8-no-dg branch April 6, 2026 08:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants