Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
from .padding_remover import remove_padding
from .quantizer_compressed_tensors import quantize_params_compressed_tensors
from .quantizer_fp8 import quantize_params_fp8
from .quantizer_mxfp8 import quantize_params_mxfp8

__all__ = ["remove_padding", "quantize_param", "quantize_params_fp8", "quantize_params_compressed_tensors"]
__all__ = [
"remove_padding",
"quantize_param",
"quantize_params_fp8",
"quantize_params_mxfp8",
"quantize_params_compressed_tensors",
]


def quantize_params(args, megatron_name, converted_named_params, quantization_config):
if quantization_config is None:
return converted_named_params
elif quantization_config["quant_method"] == "fp8":
return quantize_params_fp8(args, megatron_name, converted_named_params, quantization_config)
elif quantization_config["quant_method"] == "mxfp8":
return quantize_params_mxfp8(args, megatron_name, converted_named_params, quantization_config)
elif quantization_config["quant_method"] == "compressed-tensors":
# only int4 at the moment.
return quantize_params_compressed_tensors(converted_named_params, quantization_config)
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import re

from sglang.srt.layers.quantization.fp8_utils import mxfp8_group_quantize


def quantize_params_mxfp8(args, megatron_name, converted_named_params, quantization_config):
assert quantization_config["quant_method"] == "mxfp8"

decoder_layers_pattern = r"decoder\.layers\.(\d+)\.(.+)"
match = re.search(decoder_layers_pattern, megatron_name)

if not match:
# check mtp layers
mtp_layer_pattern = r"mtp\.layers\.(\d+)\.(.+)"
match = re.search(mtp_layer_pattern, megatron_name)
if not match:
return converted_named_params
layer_idx, rest = match.groups()
rest = rest.replace("transformer_layer.", "")
else:
layer_idx, rest = match.groups()

# experts
expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)"
match = re.match(expert_pattern, rest)
if match:
rest, expert_idx = match.groups()
if rest in [
"linear_fc1",
"linear_fc2",
]:
quantize_named_params = []
for converted_name, param in converted_named_params:
# skip bf16 weight_scale and input_scale
# TODO: find a clearer way.
if converted_name.endswith("_scale"):
continue
quantize_named_params.extend(_quantize_param(converted_name, param))

return quantize_named_params

# shared expert
shared_expert_pattern = r"mlp.shared_experts\.(.+)"
match = re.match(shared_expert_pattern, rest)
if match:
rest = match.groups()[0]
if rest in [
"linear_fc1.weight",
"linear_fc2.weight",
]:
quantize_named_params = []
for converted_name, param in converted_named_params:
quantize_named_params.extend(_quantize_param(converted_name, param))

return quantize_named_params

if rest in [
"self_attention.linear_proj.weight",
"self_attention.linear_qkv.weight",
"mlp.linear_fc1.weight",
"mlp.linear_fc2.weight",
# mla
"self_attention.linear_q_proj.weight",
"self_attention.linear_q_down_proj.weight",
"self_attention.linear_q_up_proj.weight",
"self_attention.linear_kv_down_proj.weight",
"self_attention.linear_kv_up_proj.weight",
]:
quantize_named_params = []
for converted_name, param in converted_named_params:
quantize_named_params.extend(_quantize_param(converted_name, param))

return quantize_named_params

# for other parameters, we just return the original converted_named_params
return converted_named_params


def _quantize_param(name, weight):
if mxfp8_group_quantize is None:
raise RuntimeError("MXFP8 quantization requires sglang fp8_utils.mxfp8_group_quantize.")
assert name.endswith(".weight"), f"Expected weight parameter, got {name}"
weight = weight.contiguous()
k = weight.shape[-1]
if k % 32 != 0:
raise ValueError(f"Last dim {k} must be divisible by 32 for MXFP8.")
weight_flat = weight.view(-1, k).contiguous()
qweight, scale = mxfp8_group_quantize(weight_flat)
qweight = qweight.view_as(weight)
scale = scale.view(*weight.shape[:-1], k // 32).contiguous()
scale_name = name.replace(".weight", ".weight_scale_inv")
return [(name, qweight), (scale_name, scale)]
8 changes: 7 additions & 1 deletion miles/backends/megatron_utils/sglang.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# the file to manage all sglang deps in the megatron actor
try:
from sglang.srt.layers.quantization.fp8_utils import quant_weight_ue8m0, transform_scale_ue8m0
from sglang.srt.layers.quantization.fp8_utils import (
mxfp8_group_quantize,
quant_weight_ue8m0,
transform_scale_ue8m0,
)
from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0
except ImportError:
mxfp8_group_quantize = None
quant_weight_ue8m0 = None
transform_scale_ue8m0 = None
should_deepgemm_weight_requant_ue8m0 = None
Expand All @@ -22,6 +27,7 @@
from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import]

__all__ = [
"mxfp8_group_quantize",
"quant_weight_ue8m0",
"transform_scale_ue8m0",
"should_deepgemm_weight_requant_ue8m0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ def update_weights(self) -> None:

if dist.get_rank() == 0:
ray.get([engine.continue_generation.remote() for engine in self.rollout_engines])
# int4/fp4 post_process
if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]:
# int4/fp4 post_process, mxfp8 post-process (swizzle MoE scales).
if self.quantization_config and self.quantization_config["quant_method"] in [
"compressed-tensors",
"mxfp8",
]:
post_process_weights(
restore_weights_before_load=False,
post_process_quantization=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,12 @@ def update_weights(self) -> None:
ray.get(refs)
del long_lived_tensors

# int4/fp4 post_process
# int4/fp4 post_process, mxfp8 post-process (swizzle MoE scales).
if rank == 0:
if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]:
if self.quantization_config and self.quantization_config["quant_method"] in [
"compressed-tensors",
"mxfp8",
]:
post_process_weights(
restore_weights_before_load=False,
post_process_quantization=True,
Expand Down
65 changes: 48 additions & 17 deletions scripts/run_qwen3_30b_a3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,28 @@ class ScriptArgs(U.ExecuteTrainConfig):
model_name: str = "Qwen3-30B-A3B"
megatron_model_type: str = "qwen3-30B-A3B"
num_gpus_per_node: int | None = None
hardware: Literal["H100", "GB200", "GB300"] = "H100"
hardware: Literal["H100", "B200", "B300", "GB200", "GB300"] = "H100"
enable_eval: bool = True
extra_args: str = ""
rollout_fp8: bool = False
rollout_mxfp8: bool = False
rollout_attn_fp8: bool = False
train_fp8: bool = False
train_mxfp8: bool = False
enable_megatron_bridge: bool = False
enable_mis: bool = False
# TODO improve, should be able to override more easily
tis_use_rs: bool = True

def __post_init__(self):
self.num_gpus_per_node = self.num_gpus_per_node or U.NUM_GPUS_OF_HARDWARE[self.hardware]
if self.rollout_mxfp8:
assert not self.rollout_fp8, "rollout_mxfp8 and rollout_fp8 cannot be enabled at the same time"
assert self.hardware in ("B200", "B300", "GB200", "GB300"), "rollout_mxfp8 only supports Blackwell GPUs"
if self.train_mxfp8:
assert not self.train_fp8, "train_mxfp8 and train_fp8 cannot be enabled at the same time"
assert self.hardware in ("B200", "B300", "GB200", "GB300"), "train_mxfp8 only supports Blackwell GPUs"
assert self.rollout_mxfp8, "train_mxfp8 requires rollout_mxfp8 to be enabled"


def prepare(args: ScriptArgs):
Expand All @@ -39,6 +48,11 @@ def prepare(args: ScriptArgs):
f"huggingface-cli download Qwen/{args.model_name}-FP8 --local-dir /root/models/{args.model_name}-FP8"
)

if args.rollout_mxfp8:
U.exec_command(
f"python tools/convert_hf_to_mxfp8.py --model-dir /root/models/{args.model_name} --save-dir /root/models/{args.model_name}-MXFP8"
)

if not args.enable_megatron_bridge:
U.convert_checkpoint(
model_name=args.model_name,
Expand All @@ -57,8 +71,15 @@ def execute(args: ScriptArgs):
else f"/root/models/{args.model_name}_torch_dist"
)
load_save_path = f"/root/shared_data/{args.run_id}/checkpoints"

if args.rollout_fp8:
hf_checkpoint = f"/root/models/{args.model_name}-FP8"
elif args.train_mxfp8:
hf_checkpoint = f"/root/models/{args.model_name}-MXFP8"
else:
hf_checkpoint = f"/root/models/{args.model_name}"
ckpt_args = (
f"--hf-checkpoint /root/models/{args.model_name}{'-FP8' if args.rollout_fp8 else ''}/ "
f"--hf-checkpoint {hf_checkpoint}/ "
f"--ref-load {ref_load_path} "
f"--load {load_save_path} "
f"--save {load_save_path} "
Expand Down Expand Up @@ -138,19 +159,16 @@ def execute(args: ScriptArgs):
)
misc_env_vars = {}

if args.train_fp8:
if args.train_fp8 or args.train_mxfp8:
match args.hardware:
case "GB200" | "GB300":
# It can run but accuracy is incorrect currently
raise NotImplementedError
# ref: Megatron-MoE-ModelZoo
case "B200" | "B300" | "GB200" | "GB300":
misc_args += (
"--transformer-impl transformer_engine "
"--bf16 "
"--fp8-format e4m3 "
"--fp8-recipe mxfp8 "
"--fp8-param-gather "
"--reuse-grad-buf-for-mxfp8-param-ag "
# "--fp8-param-gather "
# "--reuse-grad-buf-for-mxfp8-param-ag "
# --moe-router-padding-for-quantization
)
case "H100" | "H200":
Expand Down Expand Up @@ -187,34 +205,47 @@ def execute(args: ScriptArgs):
optimizer_args += (
"--optimizer-cpu-offload " "--overlap-cpu-optimizer-d2h-h2d " "--use-precision-aware-optimizer "
)
case ("GB200", 1) | ("GB300", 1) | ("GB200", 2) | ("GB300", 2) | ("GB200", 4) | ("GB300", 4):
case ("B200" | "B300" | "GB200" | "GB300", 1 | 2 | 4):
perf_args += (
"--tensor-model-parallel-size 4 "
"--sequence-parallel "
"--pipeline-model-parallel-size 1 "
"--context-parallel-size 1 "
"--expert-model-parallel-size 4 "
f"--expert-model-parallel-size {args.num_gpus_per_node if args.train_mxfp8 else 4} "
"--expert-tensor-parallel-size 1 "
)
sglang_args = (
f"--rollout-num-gpus-per-engine {2 if args.rollout_fp8 else 4} "
"--sglang-mem-fraction-static 0.7 "
"--sglang-attention-backend trtllm_mha "
)
sglang_args = "--sglang-mem-fraction-static 0.7 " "--sglang-attention-backend trtllm_mha "
if args.rollout_fp8:
sglang_world_size = 2
sglang_attn_tp_size = 2
sglang_decode_max_bs = 256
sglang_args += (
f"--rollout-num-gpus-per-engine 2 "
f"--sglang-ep-size {sglang_world_size} "
"--sglang-moe-runner-backend deep_gemm "
"--sglang-moe-a2a-backend deepep "
f"--sglang-max-running-requests {sglang_world_size * sglang_decode_max_bs // sglang_attn_tp_size} "
f"--sglang-chunked-prefill-size {sglang_world_size * sglang_decode_max_bs} "
f"--sglang-cuda-graph-max-bs {sglang_decode_max_bs} "
)
elif args.rollout_mxfp8:
sglang_world_size = 1
sglang_attn_tp_size = 1
sglang_decode_max_bs = 256
sglang_args += (
f"--rollout-num-gpus-per-engine 1 "
"--sglang-fp8-gemm-backend triton "
# Currently, only cutlass moe runner is supported in sglang for mxfp8, which does not support ep
# f"--sglang-ep-size {sglang_world_size} "
"--sglang-moe-runner-backend cutlass "
# TODO: mxfp8 deepep and deepgemm is not supported in sglang yet
# "--sglang-moe-a2a-backend deepep "
f"--sglang-max-running-requests {sglang_world_size * sglang_decode_max_bs // sglang_attn_tp_size} "
f"--sglang-chunked-prefill-size {sglang_world_size * sglang_decode_max_bs} "
f"--sglang-cuda-graph-max-bs {sglang_decode_max_bs} "
)
else:
sglang_args += "--sglang-cuda-graph-max-bs 512 "
sglang_args += "--rollout-num-gpus-per-engine 4 " "--sglang-cuda-graph-max-bs 512 "
case _:
raise NotImplementedError

Expand Down
Loading