Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,11 +1019,15 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False
module.weight_quantizer.amax = weight_amax


def get_quant_config(named_modules: nn.Module | dict[str, nn.Module]) -> dict[str, Any]:
"""Generate quantization config for a torch model.
def get_quant_config(
named_modules: Generator[tuple[str, nn.Module]] | dict[str, nn.Module],
) -> dict[str, Any]:
"""Generate quantization config for a set of named modules.

It should be the name_modules of a model or a subset of it.

Args:
model: The PyTorch model to analyze
named_modules: The set of PyTorch named modules

Returns:
Dictionary containing the quantization configuration
Expand Down
33 changes: 7 additions & 26 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import torch
import torch.nn as nn
from accelerate import Accelerator
from safetensors.torch import save_file
from torch.distributed.fsdp import FSDPModule

Expand Down Expand Up @@ -74,8 +75,6 @@

__all__ = ["export_hf_checkpoint"]

SPECULATIVE_DECODING_MODULE_NAMES = ["medusa_heads", "eagle_module", "drafter"]


def _is_enabled_quantizer(quantizer):
if hasattr(quantizer, "is_enabled") and quantizer.is_enabled:
Expand Down Expand Up @@ -367,16 +366,14 @@ def _export_quantized_weight(


def _export_hf_checkpoint(
model: nn.Module,
dtype: torch.dtype | None = None,
**kwargs,
model: nn.Module, dtype: torch.dtype | None = None, accelerator: Accelerator | None = None
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Exports the torch model to the packed checkpoint with original HF naming.

The packed checkpoint will be consumed by the TensorRT-LLM unified converter.

Args:
model: the torch model.
model: the full torch model to export. The actual quantized model may be a submodule.
dtype: the weights data type to export the unquantized layers or the default model data type if None.
accelerator: the accelerator instance in case of distributed export setup.

Expand All @@ -392,17 +389,8 @@ def _export_hf_checkpoint(
f"({dtype}), which may lead to numerical errors."
)

accelerator = kwargs.get("accelerator")

# Create a model layer pool
# If `model.model` exists use that, otherwise use `model` itself, e.g., Nemotron-H
root = getattr(model, "model", model)
# If that has a `.layers`, use it, otherwise fall back to the object itself
root = getattr(root, "layers", root)
layer_pool = {f"model.layers.{name}": sub_module for name, sub_module in root.named_modules()}

# Handle input quantizers of experts that are not calibrated
for name, sub_module in model.named_modules():
for _, sub_module in model.named_modules():
if is_moe(sub_module) and hasattr(sub_module, "experts"):
expert_linear_names = get_expert_linear_names(sub_module)
for linear_name in expert_linear_names:
Expand Down Expand Up @@ -455,13 +443,6 @@ def _export_hf_checkpoint(
f"Please file an issue or add support for this model architecture."
)

# NOTE: Speculative decoding models have extra modules that may be quantized
# Need to add these modules to the layer_pool
for key in SPECULATIVE_DECODING_MODULE_NAMES:
if hasattr(model, key):
for name, sub_module in getattr(model, key).named_modules():
layer_pool.update({f"{key}.{name}": sub_module})

# Resmooth and requantize fused layers
# TODO: Handle mixed precision
requantize_resmooth_fused_llm_layers(model)
Expand All @@ -474,7 +455,7 @@ def _export_hf_checkpoint(
except ImportError:
warnings.warn("accelerate is not installed, hooks will not be removed")

quant_config = get_quant_config(layer_pool)
quant_config = get_quant_config(model.named_modules())

kv_cache_max_bound = 0
kv_cache_format = quant_config["quantization"]["kv_cache_quant_algo"]
Expand All @@ -493,7 +474,7 @@ def _export_hf_checkpoint(
has_quantized_layers = False
fsdp_module_to_reshard = None

for name, sub_module in layer_pool.items():
for _, sub_module in model.named_modules():
# Optimization to perform resharding only once per decoder layer to avoid extra communication overhead
if isinstance(sub_module, FSDPModule):
# Every time we encounter a new FSDPModule, the previous decoder layer is fully processed.
Expand Down Expand Up @@ -555,7 +536,7 @@ def export_hf_checkpoint(
"""Exports the torch model to unified checkpoint and saves to export_dir.

Args:
model: the torch model.
model: the full torch model to export. The actual quantized model may be a submodule.
dtype: the weights data type to export the unquantized layers or the default model data type if None.
export_dir: the target export path.
save_modelopt_state: whether to save the modelopt state_dict.
Expand Down
Loading