diff --git a/miles/backends/megatron_utils/megatron_to_hf/processors/__init__.py b/miles/backends/megatron_utils/megatron_to_hf/processors/__init__.py index 0141c3548..fba2c4e16 100644 --- a/miles/backends/megatron_utils/megatron_to_hf/processors/__init__.py +++ b/miles/backends/megatron_utils/megatron_to_hf/processors/__init__.py @@ -1,8 +1,15 @@ from .padding_remover import remove_padding from .quantizer_compressed_tensors import quantize_params_compressed_tensors from .quantizer_fp8 import quantize_params_fp8 +from .quantizer_mxfp8 import quantize_params_mxfp8 -__all__ = ["remove_padding", "quantize_param", "quantize_params_fp8", "quantize_params_compressed_tensors"] +__all__ = [ + "remove_padding", + "quantize_param", + "quantize_params_fp8", + "quantize_params_mxfp8", + "quantize_params_compressed_tensors", +] def quantize_params(args, megatron_name, converted_named_params, quantization_config): @@ -10,6 +17,8 @@ def quantize_params(args, megatron_name, converted_named_params, quantization_co return converted_named_params elif quantization_config["quant_method"] == "fp8": return quantize_params_fp8(args, megatron_name, converted_named_params, quantization_config) + elif quantization_config["quant_method"] == "mxfp8": + return quantize_params_mxfp8(args, megatron_name, converted_named_params, quantization_config) elif quantization_config["quant_method"] == "compressed-tensors": # only int4 at the moment. return quantize_params_compressed_tensors(converted_named_params, quantization_config) diff --git a/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_mxfp8.py b/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_mxfp8.py new file mode 100644 index 000000000..94824dd1d --- /dev/null +++ b/miles/backends/megatron_utils/megatron_to_hf/processors/quantizer_mxfp8.py @@ -0,0 +1,92 @@ +import re + +from sglang.srt.layers.quantization.fp8_utils import mxfp8_group_quantize + + +def quantize_params_mxfp8(args, megatron_name, converted_named_params, quantization_config): + assert quantization_config["quant_method"] == "mxfp8" + + decoder_layers_pattern = r"decoder\.layers\.(\d+)\.(.+)" + match = re.search(decoder_layers_pattern, megatron_name) + + if not match: + # check mtp layers + mtp_layer_pattern = r"mtp\.layers\.(\d+)\.(.+)" + match = re.search(mtp_layer_pattern, megatron_name) + if not match: + return converted_named_params + layer_idx, rest = match.groups() + rest = rest.replace("transformer_layer.", "") + else: + layer_idx, rest = match.groups() + + # experts + expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)" + match = re.match(expert_pattern, rest) + if match: + rest, expert_idx = match.groups() + if rest in [ + "linear_fc1", + "linear_fc2", + ]: + quantize_named_params = [] + for converted_name, param in converted_named_params: + # skip bf16 weight_scale and input_scale + # TODO: find a clearer way. + if converted_name.endswith("_scale"): + continue + quantize_named_params.extend(_quantize_param(converted_name, param)) + + return quantize_named_params + + # shared expert + shared_expert_pattern = r"mlp.shared_experts\.(.+)" + match = re.match(shared_expert_pattern, rest) + if match: + rest = match.groups()[0] + if rest in [ + "linear_fc1.weight", + "linear_fc2.weight", + ]: + quantize_named_params = [] + for converted_name, param in converted_named_params: + quantize_named_params.extend(_quantize_param(converted_name, param)) + + return quantize_named_params + + if rest in [ + "self_attention.linear_proj.weight", + "self_attention.linear_qkv.weight", + "mlp.linear_fc1.weight", + "mlp.linear_fc2.weight", + # mla + "self_attention.linear_q_proj.weight", + "self_attention.linear_q_down_proj.weight", + "self_attention.linear_q_up_proj.weight", + "self_attention.linear_kv_down_proj.weight", + "self_attention.linear_kv_up_proj.weight", + ]: + quantize_named_params = [] + for converted_name, param in converted_named_params: + quantize_named_params.extend(_quantize_param(converted_name, param)) + + return quantize_named_params + + # for other parameters, we just return the original converted_named_params + return converted_named_params + + +def _quantize_param(name, weight): + if mxfp8_group_quantize is None: + raise RuntimeError("MXFP8 quantization requires sglang fp8_utils.mxfp8_group_quantize.") + assert name.endswith(".weight"), f"Expected weight parameter, got {name}" + weight = weight.contiguous() + k = weight.shape[-1] + if k % 32 != 0: + raise ValueError(f"Last dim {k} must be divisible by 32 for MXFP8.") + weight_flat = weight.view(-1, k).contiguous() + qweight, scale = mxfp8_group_quantize(weight_flat) + qweight = qweight.view_as(weight) + scale = scale.view(*weight.shape[:-1], k // 32).contiguous() + scale_name = name.replace(".weight", ".weight_scale_inv") + return [(name, qweight), (scale_name, scale)] diff --git a/miles/backends/megatron_utils/sglang.py b/miles/backends/megatron_utils/sglang.py index 97c82a31c..d6e7fbccc 100644 --- a/miles/backends/megatron_utils/sglang.py +++ b/miles/backends/megatron_utils/sglang.py @@ -1,8 +1,13 @@ # the file to manage all sglang deps in the megatron actor try: - from sglang.srt.layers.quantization.fp8_utils import quant_weight_ue8m0, transform_scale_ue8m0 + from sglang.srt.layers.quantization.fp8_utils import ( + mxfp8_group_quantize, + quant_weight_ue8m0, + transform_scale_ue8m0, + ) from sglang.srt.model_loader.utils import should_deepgemm_weight_requant_ue8m0 except ImportError: + mxfp8_group_quantize = None quant_weight_ue8m0 = None transform_scale_ue8m0 = None should_deepgemm_weight_requant_ue8m0 = None @@ -22,6 +27,7 @@ from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] __all__ = [ + "mxfp8_group_quantize", "quant_weight_ue8m0", "transform_scale_ue8m0", "should_deepgemm_weight_requant_ue8m0", diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py index f9c90bb1b..f3125781f 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -120,8 +120,11 @@ def update_weights(self) -> None: if dist.get_rank() == 0: ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) - # int4/fp4 post_process - if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: + # int4/fp4 post_process, mxfp8 post-process (swizzle MoE scales). + if self.quantization_config and self.quantization_config["quant_method"] in [ + "compressed-tensors", + "mxfp8", + ]: post_process_weights( restore_weights_before_load=False, post_process_quantization=True, diff --git a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py index 1acfabba3..5fdf7323f 100644 --- a/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py +++ b/miles/backends/megatron_utils/update_weight/update_weight_from_tensor.py @@ -128,9 +128,12 @@ def update_weights(self) -> None: ray.get(refs) del long_lived_tensors - # int4/fp4 post_process + # int4/fp4 post_process, mxfp8 post-process (swizzle MoE scales). if rank == 0: - if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: + if self.quantization_config and self.quantization_config["quant_method"] in [ + "compressed-tensors", + "mxfp8", + ]: post_process_weights( restore_weights_before_load=False, post_process_quantization=True, diff --git a/scripts/run_qwen3_30b_a3b.py b/scripts/run_qwen3_30b_a3b.py index 26a374aa0..4322d94d0 100644 --- a/scripts/run_qwen3_30b_a3b.py +++ b/scripts/run_qwen3_30b_a3b.py @@ -13,12 +13,14 @@ class ScriptArgs(U.ExecuteTrainConfig): model_name: str = "Qwen3-30B-A3B" megatron_model_type: str = "qwen3-30B-A3B" num_gpus_per_node: int | None = None - hardware: Literal["H100", "GB200", "GB300"] = "H100" + hardware: Literal["H100", "B200", "B300", "GB200", "GB300"] = "H100" enable_eval: bool = True extra_args: str = "" rollout_fp8: bool = False + rollout_mxfp8: bool = False rollout_attn_fp8: bool = False train_fp8: bool = False + train_mxfp8: bool = False enable_megatron_bridge: bool = False enable_mis: bool = False # TODO improve, should be able to override more easily @@ -26,6 +28,13 @@ class ScriptArgs(U.ExecuteTrainConfig): def __post_init__(self): self.num_gpus_per_node = self.num_gpus_per_node or U.NUM_GPUS_OF_HARDWARE[self.hardware] + if self.rollout_mxfp8: + assert not self.rollout_fp8, "rollout_mxfp8 and rollout_fp8 cannot be enabled at the same time" + assert self.hardware in ("B200", "B300", "GB200", "GB300"), "rollout_mxfp8 only supports Blackwell GPUs" + if self.train_mxfp8: + assert not self.train_fp8, "train_mxfp8 and train_fp8 cannot be enabled at the same time" + assert self.hardware in ("B200", "B300", "GB200", "GB300"), "train_mxfp8 only supports Blackwell GPUs" + assert self.rollout_mxfp8, "train_mxfp8 requires rollout_mxfp8 to be enabled" def prepare(args: ScriptArgs): @@ -39,6 +48,11 @@ def prepare(args: ScriptArgs): f"huggingface-cli download Qwen/{args.model_name}-FP8 --local-dir /root/models/{args.model_name}-FP8" ) + if args.rollout_mxfp8: + U.exec_command( + f"python tools/convert_hf_to_mxfp8.py --model-dir /root/models/{args.model_name} --save-dir /root/models/{args.model_name}-MXFP8" + ) + if not args.enable_megatron_bridge: U.convert_checkpoint( model_name=args.model_name, @@ -57,8 +71,15 @@ def execute(args: ScriptArgs): else f"/root/models/{args.model_name}_torch_dist" ) load_save_path = f"/root/shared_data/{args.run_id}/checkpoints" + + if args.rollout_fp8: + hf_checkpoint = f"/root/models/{args.model_name}-FP8" + elif args.train_mxfp8: + hf_checkpoint = f"/root/models/{args.model_name}-MXFP8" + else: + hf_checkpoint = f"/root/models/{args.model_name}" ckpt_args = ( - f"--hf-checkpoint /root/models/{args.model_name}{'-FP8' if args.rollout_fp8 else ''}/ " + f"--hf-checkpoint {hf_checkpoint}/ " f"--ref-load {ref_load_path} " f"--load {load_save_path} " f"--save {load_save_path} " @@ -138,19 +159,16 @@ def execute(args: ScriptArgs): ) misc_env_vars = {} - if args.train_fp8: + if args.train_fp8 or args.train_mxfp8: match args.hardware: - case "GB200" | "GB300": - # It can run but accuracy is incorrect currently - raise NotImplementedError - # ref: Megatron-MoE-ModelZoo + case "B200" | "B300" | "GB200" | "GB300": misc_args += ( "--transformer-impl transformer_engine " "--bf16 " "--fp8-format e4m3 " "--fp8-recipe mxfp8 " - "--fp8-param-gather " - "--reuse-grad-buf-for-mxfp8-param-ag " + # "--fp8-param-gather " + # "--reuse-grad-buf-for-mxfp8-param-ag " # --moe-router-padding-for-quantization ) case "H100" | "H200": @@ -187,25 +205,22 @@ def execute(args: ScriptArgs): optimizer_args += ( "--optimizer-cpu-offload " "--overlap-cpu-optimizer-d2h-h2d " "--use-precision-aware-optimizer " ) - case ("GB200", 1) | ("GB300", 1) | ("GB200", 2) | ("GB300", 2) | ("GB200", 4) | ("GB300", 4): + case ("B200" | "B300" | "GB200" | "GB300", 1 | 2 | 4): perf_args += ( "--tensor-model-parallel-size 4 " "--sequence-parallel " "--pipeline-model-parallel-size 1 " "--context-parallel-size 1 " - "--expert-model-parallel-size 4 " + f"--expert-model-parallel-size {args.num_gpus_per_node if args.train_mxfp8 else 4} " "--expert-tensor-parallel-size 1 " ) - sglang_args = ( - f"--rollout-num-gpus-per-engine {2 if args.rollout_fp8 else 4} " - "--sglang-mem-fraction-static 0.7 " - "--sglang-attention-backend trtllm_mha " - ) + sglang_args = "--sglang-mem-fraction-static 0.7 " "--sglang-attention-backend trtllm_mha " if args.rollout_fp8: sglang_world_size = 2 sglang_attn_tp_size = 2 sglang_decode_max_bs = 256 sglang_args += ( + f"--rollout-num-gpus-per-engine 2 " f"--sglang-ep-size {sglang_world_size} " "--sglang-moe-runner-backend deep_gemm " "--sglang-moe-a2a-backend deepep " @@ -213,8 +228,24 @@ def execute(args: ScriptArgs): f"--sglang-chunked-prefill-size {sglang_world_size * sglang_decode_max_bs} " f"--sglang-cuda-graph-max-bs {sglang_decode_max_bs} " ) + elif args.rollout_mxfp8: + sglang_world_size = 1 + sglang_attn_tp_size = 1 + sglang_decode_max_bs = 256 + sglang_args += ( + f"--rollout-num-gpus-per-engine 1 " + "--sglang-fp8-gemm-backend triton " + # Currently, only cutlass moe runner is supported in sglang for mxfp8, which does not support ep + # f"--sglang-ep-size {sglang_world_size} " + "--sglang-moe-runner-backend cutlass " + # TODO: mxfp8 deepep and deepgemm is not supported in sglang yet + # "--sglang-moe-a2a-backend deepep " + f"--sglang-max-running-requests {sglang_world_size * sglang_decode_max_bs // sglang_attn_tp_size} " + f"--sglang-chunked-prefill-size {sglang_world_size * sglang_decode_max_bs} " + f"--sglang-cuda-graph-max-bs {sglang_decode_max_bs} " + ) else: - sglang_args += "--sglang-cuda-graph-max-bs 512 " + sglang_args += "--rollout-num-gpus-per-engine 4 " "--sglang-cuda-graph-max-bs 512 " case _: raise NotImplementedError diff --git a/tools/convert_hf_to_mxfp8.py b/tools/convert_hf_to_mxfp8.py new file mode 100644 index 000000000..3e334469e --- /dev/null +++ b/tools/convert_hf_to_mxfp8.py @@ -0,0 +1,205 @@ +""" +python tools/convert_hf_to_mxfp8.py [-h] [--model-dir MODEL_DIR] [--save-dir SAVE_DIR] + +Convert a BF16/FP16 HF safetensors checkpoint to MXFP8 with UE8M0 scales. +The scale layout mirrors sglang _quantize_and_swizzle_with_triton_kernel, +but keeps the scales in unswizzled group layout for serialization. +""" + +import argparse +import gc +import json +import os +import shutil + +import safetensors +import safetensors.torch +import torch +from tqdm import tqdm + +try: + from sglang.srt.layers.quantization.fp8_utils import mxfp8_group_quantize +except ImportError as exc: + raise ImportError( + "Missing sglang dependency: mxfp8_group_quantize must be importable from sglang.srt.layers.quantization.fp8_utils." + ) from exc + + +SKIP_WEIGHT_SUBSTRINGS = ( + "layernorm", + "embed", + "router", + "mlp.gate.", + "norm", + "lm_head", + "eh_proj", + "weights_proj", +) + + +def should_quantize(name: str, weight: torch.Tensor) -> bool: + if not name.endswith(".weight"): + return False + if any(substr in name for substr in SKIP_WEIGHT_SUBSTRINGS): + return False + if weight.dtype not in (torch.float16, torch.bfloat16, torch.float32): + return False + if weight.dim() < 2: + return False + if weight.shape[-1] % 32 != 0: + return False + return True + + +def quantize_mxfp8(weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Mirror sglang _quantize_and_swizzle_with_triton_kernel but do not swizzle scales. + Returns: + qweight: same shape as input, dtype float8_e4m3fn + scale: shape = (*weight.shape[:-1], weight.shape[-1] // 32), dtype uint8 + """ + weight = weight.contiguous() + k = weight.shape[-1] + if k % 32 != 0: + raise ValueError(f"Last dim {k} must be divisible by 32 for MXFP8.") + + weight_flat = weight.view(-1, k).contiguous() + qweight, scale = mxfp8_group_quantize(weight_flat) + qweight = qweight.view_as(weight) + scale = scale.view(*weight.shape[:-1], k // 32).contiguous() + return qweight, scale + + +class ConversionResult: + def __init__(self) -> None: + self.weight_map: dict[str, str] = {} + self.total_size: int = 0 + self.modules_to_not_convert: list[str] = [] + + def add_result( + self, + filename: str, + q_weights: dict[str, torch.Tensor], + module_names: list[str], + ) -> None: + for key, tensor in q_weights.items(): + self.weight_map[key] = filename + self.total_size += tensor.numel() * tensor.element_size() + self.modules_to_not_convert.extend(module_names) + + +def process_file( + input_path: str, + output_path: str, + filename: str, + result_collector: ConversionResult, + device: str, +) -> None: + if not filename.endswith(".safetensors"): + return + + weights: dict[str, torch.Tensor] = {} + q_weights: dict[str, torch.Tensor] = {} + + with safetensors.safe_open(os.path.join(input_path, filename), framework="pt", device=device) as f: + for key in f.keys(): + weights[key] = f.get_tensor(key) + + modules_to_not_convert: list[str] = [] + for key, tensor in weights.items(): + if should_quantize(key, tensor): + qweight, scale = quantize_mxfp8(tensor) + q_weights[key] = qweight + q_weights[key.replace(".weight", ".weight_scale_inv")] = scale + else: + if key.endswith(".weight"): + modules_to_not_convert.append(key.replace(".weight", "")) + q_weights[key] = tensor + + safetensors.torch.save_file(q_weights, os.path.join(output_path, filename), metadata={"format": "pt"}) + result_collector.add_result(filename, q_weights, modules_to_not_convert) + + +def convert_mxfp8(model_dir: str, save_dir: str, device: str) -> None: + input_path = os.path.abspath(model_dir) + output_path = os.path.abspath(save_dir) + os.makedirs(output_path, exist_ok=True) + + for filename in os.listdir(input_path): + if not filename.endswith(".safetensors") and not os.path.isdir(os.path.join(input_path, filename)): + shutil.copyfile(os.path.join(input_path, filename), os.path.join(output_path, filename)) + + safetensors_files = [f for f in os.listdir(input_path) if f.endswith(".safetensors")] + + result_collector = ConversionResult() + for filename in tqdm(safetensors_files, desc="Processing files"): + process_file(input_path, output_path, filename, result_collector, device) + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + quantization_config = { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "mxfp8", + "weight_block_size": [1, 32], + "scale_fmt": "ue8m0", + } + if len(result_collector.modules_to_not_convert) > 0: + quantization_config["modules_to_not_convert"] = list(set(result_collector.modules_to_not_convert)) + + config_path = os.path.join(input_path, "config.json") + if os.path.exists(config_path): + cfg = json.load(open(config_path)) + cfg["quantization_config"] = quantization_config + json.dump(cfg, open(os.path.join(output_path, "config.json"), "w"), indent=2) + + index_dict = { + "weight_map": result_collector.weight_map, + "metadata": {"total_size": result_collector.total_size}, + } + json.dump(index_dict, open(os.path.join(output_path, "model.safetensors.index.json"), "w"), indent=2) + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--model-dir", type=str, required=True, help="Path to HF safetensors model.") + parser.add_argument("--save-dir", type=str, required=True, help="Path to save converted model.") + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Torch device to run quantization on (default: cuda).", + ) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available, cannot run MXFP8 quantization.") + + if isinstance(args.device, str) and args.device.isdigit(): + device = torch.device(f"cuda:{args.device}") + else: + device = torch.device(args.device) + + if device.type != "cuda": + raise RuntimeError("MXFP8 quantization requires a CUDA device.") + if device.index is None: + device = torch.device("cuda:0") + + torch.cuda.set_device(device) + + if not os.path.exists(args.save_dir): + print(f"Creating directory {args.save_dir}") + os.makedirs(args.save_dir) + elif not os.path.isdir(args.save_dir): + raise ValueError("The save_dir should be a directory.") + + convert_mxfp8(args.model_dir, args.save_dir, str(device)) + + +if __name__ == "__main__": + main()