Add mxfp8 support for online quantization, Triton dense linear, and CUTLASS MoE#17449
Add mxfp8 support for online quantization, Triton dense linear, and CUTLASS MoE#17449ispobock merged 40 commits intosgl-project:mainfrom
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
/sgl-workspace/sglang# python python/sglang/test/test_block_fp8.py -k TestMXFP8DenseLinear test_mxfp8_dense_linear (__main__.TestMXFP8DenseLinear.test_mxfp8_dense_linear) ... [CI Test Method] TestMXFP8DenseLinear.test_mxfp8_dense_linear ok ---------------------------------------------------------------------- Ran 1 test in 1.012s OK
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1209 --parallel 1209 --platinum BEFORE Accuracy: 0.817 Invalid: 0.000 Latency: 25.296 s Output throughput: 9041.457 token/s AFTER Accuracy: 0.849 Invalid: 0.000 Latency: 14.757 s Output throughput: 14630.467 token/s
|
/tag-and-rerun-ci |
|
Implemented a few minor fixes for Test for Qwen3-4B-Instruct-2507-MXFP8 (dense), with Test for Qwen3-30B-A3B-Instruct-2507-MXFP8 (MoE), with Implementation is now fully complete. |
| if self.use_mxfp8 and not self.is_checkpoint_fp8_serialized: | ||
| raise ValueError( | ||
| "MXFP8 requires fp8-serialized checkpoint for linear layers." | ||
| ) |
There was a problem hiding this comment.
This is surprising given the snippet below?
sglang/python/sglang/srt/layers/quantization/fp8.py
Lines 409 to 412 in 0769de9
Should this rather say that it is simply untested? Or should this error be removed?
There was a problem hiding this comment.
mxfp8 online quantization from bf16 is tested as in the Accuracy Test section. This raise is never triggered so can be removed.
The per-tensor FP8 MoE path in process_weights_after_loading() replaced Parameter objects with new ones via torch.nn.Parameter(), destroying custom attributes (weight_loader, quant_method) needed by EPLB weight hot-reload. The block-quant path was already fixed (PR sgl-project#17449) but the per-tensor path was missed. Additionally, update_weights_from_disk() unconditionally called process_weights_after_loading() on ALL modules even during partial reloads (e.g. EPLB expert rebalancing), causing non-expert layers like FP8 attention to be double-processed (double transpose -> shape mismatch). Changes in fp8.py: - Dynamic-quant path: use .data= for weight rebinding and .data[expert].fill_() for scale updates instead of Parameter replacement. - Checkpoint-FP8 path: use .data.fill_() for input_scale merging; fill both columns of the [E,2] w13_weight_scale in-place instead of replacing with a new [E] Parameter. - DeepGemm apply(): add ndim==2 guard to collapse w13_weight_scale from [E,2] to [E] via [:,0] before expanding to block shape. Changes in model_runner.py: - When weight_name_filter is set (EPLB expert rebalancing), split load_weights and process_weights_after_loading into two steps, only calling the latter on modules whose names match the filter. Signed-off-by: Socratesa <lihaode@zju.edu.cn>
The per-tensor FP8 MoE path in process_weights_after_loading() replaced Parameter objects with new ones via torch.nn.Parameter(), destroying custom attributes (weight_loader, quant_method) needed by EPLB weight hot-reload. The block-quant path was already fixed (PR sgl-project#17449) but the per-tensor path was missed. Additionally, update_weights_from_disk() unconditionally called process_weights_after_loading() on ALL modules even during partial reloads (e.g. EPLB expert rebalancing), causing non-expert layers like FP8 attention to be double-processed (double transpose -> shape mismatch). Changes in fp8.py: - Dynamic-quant path: use .data= for weight rebinding and .data[expert].fill_() for scale updates instead of Parameter replacement. - Checkpoint-FP8 path: use .data.fill_() for input_scale merging; fill both columns of the [E,2] w13_weight_scale in-place instead of replacing with a new [E] Parameter. - DeepGemm apply(): add ndim==2 guard to collapse w13_weight_scale from [E,2] to [E] via [:,0] before expanding to block shape. Changes in model_runner.py: - When weight_name_filter is set (EPLB expert rebalancing), split load_weights and process_weights_after_loading into two steps, only calling the latter on modules whose names match the filter. Signed-off-by: Socratesa <lihaode@zju.edu.cn>
The per-tensor FP8 MoE path in process_weights_after_loading() replaced Parameter objects with new ones via torch.nn.Parameter(), destroying custom attributes (weight_loader, quant_method) needed by EPLB weight hot-reload. The block-quant path was already fixed (PR sgl-project#17449) but the per-tensor path was missed. Additionally, update_weights_from_disk() unconditionally called process_weights_after_loading() on ALL modules even during partial reloads (e.g. EPLB expert rebalancing), causing non-expert layers like FP8 attention to be double-processed (double transpose -> shape mismatch). Changes in fp8.py: - Dynamic-quant path: use .data= for weight rebinding and .data[expert].fill_() for scale updates instead of Parameter replacement. - Checkpoint-FP8 path: use .data.fill_() for input_scale merging; fill both columns of the [E,2] w13_weight_scale in-place instead of replacing with a new [E] Parameter. - DeepGemm apply(): add ndim==2 guard to collapse w13_weight_scale from [E,2] to [E] via [:,0] before expanding to block shape. Changes in model_runner.py: - When weight_name_filter is set (EPLB expert rebalancing), split load_weights and process_weights_after_loading into two steps, only calling the latter on modules whose names match the filter. Signed-off-by: Socratesa <lihaode@zju.edu.cn>
Motivation
@HumansAnd
#17093
This PR adds mxfp8 quantization support to SGLang, using Triton for dense linear layer, and existing mxfp8 CUTLASS groped GEMM kernel in sgl-kernel from #13731 for MoE.
Online mxfp8 quantization from bf16 checkpoints and serving mxfp8 checkpoints directly are both supported.
Modifications
Accuracy Tests
Benchmarking and Profiling
Next Steps
recipe = (1, 1, 32). Add DeepGEMM backend for mxfp8 when sgl-kernel version is bumped.Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci