From a2541296d05ca4f02f7ee371424eb9864d8f260c Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Tue, 28 Oct 2025 00:45:41 +0000 Subject: [PATCH] [OMNIML-2917] export layer config using actual prefix instead of hardcoded model.layers This is change set 1 from working on OMNIML-2917. When we export quantized model to hf unified format, we hard code modules with a "model.layers" prefix. This is something completely odd and unnecessary. The major problem is that we may output some quant config that have completely wrong prefix, such as in exclude_modules. For example, for the Qwen3-VL models, there are 2 transformer blocks: language_model and vision. Before this change, for language_model, we will output: model.layers.language_model.layers.0.xxx model.layers.language_model.layers.1.xxx The prefixes are completely wrong therefore when inference systems such as vllm try to read the quant config, it will fail. Fix it by simply use the prefixes from parsing the model itself. Signed-off-by: Shengliang Xu --- modelopt/torch/export/quant_utils.py | 10 +++++-- modelopt/torch/export/unified_export_hf.py | 33 +++++----------------- 2 files changed, 14 insertions(+), 29 deletions(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 73d3c44e6..99967c5f5 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -1019,11 +1019,15 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False module.weight_quantizer.amax = weight_amax -def get_quant_config(named_modules: nn.Module | dict[str, nn.Module]) -> dict[str, Any]: - """Generate quantization config for a torch model. +def get_quant_config( + named_modules: Generator[tuple[str, nn.Module]] | dict[str, nn.Module], +) -> dict[str, Any]: + """Generate quantization config for a set of named modules. + + It should be the name_modules of a model or a subset of it. Args: - model: The PyTorch model to analyze + named_modules: The set of PyTorch named modules Returns: Dictionary containing the quantization configuration diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 36520f9fc..250ce5f6e 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -27,6 +27,7 @@ import torch import torch.nn as nn +from accelerate import Accelerator from safetensors.torch import save_file from torch.distributed.fsdp import FSDPModule @@ -74,8 +75,6 @@ __all__ = ["export_hf_checkpoint"] -SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads", "eagle_module", "drafter"] - def _is_enabled_quantizer(quantizer): if hasattr(quantizer, "is_enabled") and quantizer.is_enabled: @@ -367,16 +366,14 @@ def _export_quantized_weight( def _export_hf_checkpoint( - model: nn.Module, - dtype: torch.dtype | None = None, - **kwargs, + model: nn.Module, dtype: torch.dtype | None = None, accelerator: Accelerator | None = None ) -> tuple[dict[str, Any], dict[str, Any]]: """Exports the torch model to the packed checkpoint with original HF naming. The packed checkpoint will be consumed by the TensorRT-LLM unified converter. Args: - model: the torch model. + model: the full torch model to export. The actual quantized model may be a submodule. dtype: the weights data type to export the unquantized layers or the default model data type if None. accelerator: the accelerator instance in case of distributed export setup. @@ -392,17 +389,8 @@ def _export_hf_checkpoint( f"({dtype}), which may lead to numerical errors." ) - accelerator = kwargs.get("accelerator") - - # Create a model layer pool - # If `model.model` exists use that, otherwise use `model` itself, e.g., Nemotron-H - root = getattr(model, "model", model) - # If that has a `.layers`, use it, otherwise fall back to the object itself - root = getattr(root, "layers", root) - layer_pool = {f"model.layers.{name}": sub_module for name, sub_module in root.named_modules()} - # Handle input quantizers of experts that are not calibrated - for name, sub_module in model.named_modules(): + for _, sub_module in model.named_modules(): if is_moe(sub_module) and hasattr(sub_module, "experts"): expert_linear_names = get_expert_linear_names(sub_module) for linear_name in expert_linear_names: @@ -455,13 +443,6 @@ def _export_hf_checkpoint( f"Please file an issue or add support for this model architecture." ) - # NOTE: Speculative decoding models have extra modules that may be quantized - # Need to add these modules to the layer_pool - for key in SPECULATIVE_DECODING_MODULE_NAMES: - if hasattr(model, key): - for name, sub_module in getattr(model, key).named_modules(): - layer_pool.update({f"{key}.{name}": sub_module}) - # Resmooth and requantize fused layers # TODO: Handle mixed precision requantize_resmooth_fused_llm_layers(model) @@ -474,7 +455,7 @@ def _export_hf_checkpoint( except ImportError: warnings.warn("accelerate is not installed, hooks will not be removed") - quant_config = get_quant_config(layer_pool) + quant_config = get_quant_config(model.named_modules()) kv_cache_max_bound = 0 kv_cache_format = quant_config["quantization"]["kv_cache_quant_algo"] @@ -493,7 +474,7 @@ def _export_hf_checkpoint( has_quantized_layers = False fsdp_module_to_reshard = None - for name, sub_module in layer_pool.items(): + for _, sub_module in model.named_modules(): # Optimization to perform resharding only once per decoder layer to avoid extra communication overhead if isinstance(sub_module, FSDPModule): # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed. @@ -555,7 +536,7 @@ def export_hf_checkpoint( """Exports the torch model to unified checkpoint and saves to export_dir. Args: - model: the torch model. + model: the full torch model to export. The actual quantized model may be a submodule. dtype: the weights data type to export the unquantized layers or the default model data type if None. export_dir: the target export path. save_modelopt_state: whether to save the modelopt state_dict.