diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index f93bf891b16..721c5c3dc75 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -55,7 +55,7 @@ jobs: strategy: fail-fast: false matrix: - example: [llm_ptq, vlm_ptq] + example: [llm_ptq] uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: @@ -69,7 +69,7 @@ jobs: strategy: fail-fast: false matrix: - example: [llm_autodeploy, llm_eval, llm_ptq, vlm_ptq] + example: [llm_autodeploy, llm_eval, llm_ptq] uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d3d0ec160ec..43098b04668 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,11 @@ Changelog 0.46 (2026-xx-xx) ^^^^^^^^^^^^^^^^^ +**Deprecations** + +- Consolidated ``examples/vlm_ptq`` into ``examples/llm_ptq``. Vision-language model PTQ now shares the ``hf_ptq.py`` entry point and ``scripts/huggingface_example.sh``; pass ``--vlm`` to run the TensorRT-LLM multimodal quickstart smoke test. The ``examples/vlm_ptq/scripts/huggingface_example.sh`` entry point is deprecated: it now prints a warning and forwards to the ``llm_ptq`` script with ``--vlm``, and will be removed in a future release. See `examples/llm_ptq/README.md `__. +- Dropped VILA / NVILA vision-language model support in ``examples/llm_ptq``. VILA's modeling code requires ``transformers<=4.50.0``, which conflicts with ModelOpt's minimum supported ``transformers`` version. The VILA-specific bootstrap (repo clone, ``requirements-vila.txt``) and loading paths in ``example_utils.py`` have been removed. + **New Features** - Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred. diff --git a/README.md b/README.md index 487cb9e535e..259bf75421b 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ more fine-grained control on installed dependencies or for alternative docker im | Model Type | Support Matrix | |------------|----------------| | LLM Quantization | [View Support Matrix](./examples/llm_ptq/README.md#support-matrix) | -| VLM Quantization | [View Support Matrix](./examples/vlm_ptq/README.md#support-matrix) | +| VLM Quantization | [View Support Matrix](./examples/llm_ptq/README.md#hugging-face-supported-models) | | Diffusers Quantization | [View Support Matrix](./examples/diffusers/README.md#support-matrix) | | ONNX Quantization | [View Support Matrix](./examples/torch_onnx/README.md#onnx-export-supported-llm-models) | | Windows Quantization | [View Support Matrix](./examples/windows/README.md#support-matrix) | diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 64ef6deaa01..0c22d0b45fb 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -118,6 +118,11 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http | T5 | ✅ | ✅ | ✅ | ✅ | - | | Whisper9 | ✅ | ❌ | ❌ | ❌ | - | | Nemotron-3 | ✅ | ❌ | ❌ | ❌ | ✅ | +| Llava (VLM)11 | ✅ | ✅12 | ✅ | ✅ | - | +| Phi-3-vision, Phi-4-multimodal (VLM)11 | ✅ | ✅12 | ✅ | ✅ | ✅ | +| Qwen2, 2.5-VL (VLM)11 | ✅ | ✅12 | ✅ | ✅ | ✅ | +| Gemma 3 (VLM)11 | ✅ | - | - | - | - | +| Nemotron VL (VLM)11,13 | ✅ | - | - | - | ✅ | > *This is a subset of the models supported. For the full list please check the [TensorRT-LLM support matrix](https://nvidia.github.io/TensorRT-LLM/reference/precision.html#support-matrix)* @@ -130,12 +135,21 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http > *7.[PTQ for DeepSeek](../deepseek/README.md)* \ > *8.GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.* \ > *9.Running Whisper model with transformers>=5.0 requires [torchcodec](https://github.com/meta-pytorch/torchcodec?tab=readme-ov-file#installing-cuda-enabled-torchcodec) and other system packages (e.g. ffmpeg).* \ -> *10.GPT-OSS ships with native MXFP4 weights; NVFP4 export is produced via the closed-form `--cast_mxfp4_to_nvfp4` cast (see [MXFP4 → NVFP4 cast](#mxfp4--nvfp4-cast-for-gpt-oss)).* +> *10.GPT-OSS ships with native MXFP4 weights; NVFP4 export is produced via the closed-form `--cast_mxfp4_to_nvfp4` cast (see [MXFP4 → NVFP4 cast](#mxfp4--nvfp4-cast-for-gpt-oss)).* \ +> *11.Vision-language model (VLM): only the language model is quantized while the vision encoder is kept in high precision. Pass `--vlm` to the shell script (see [VLM quantization](#vlm-quantization)).* \ +> *12.For VLMs, `int8_sq` only supports TensorRT-LLM checkpoint export and is not compatible with the TensorRT-LLM torch backend.* \ +> *13.Nemotron VL automatically calibrates with image-text pairs; see [VLM calibration with image-text pairs](#vlm-calibration-with-image-text-pairs-eg-nemotron-vl).* > *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead. For NVFP4 quantization specifically, we recommend `nvfp4_mlp_only`, `nvfp4_experts_only`, or `nvfp4_omlp_only` to achieve higher accuracy by restricting quantization to the MLP/expert layers (and optionally the `o_proj` layer) while keeping the attention QKV projections unquantized.* > You can also create your own custom config using [this](https://nvidia.github.io/Model-Optimizer/guides/_pytorch_quantization.html#custom-calibration-algorithm) guide. +> *Vision-language models (VLMs) are listed in the support matrix above (rows marked `(VLM)`). PTQ for +> VLMs is handled by the same `hf_ptq.py` entry point and shell script as LLMs — the language model is +> quantized while the vision encoder is kept in high precision. Pass `--vlm` to the shell script (see +> [VLM quantization](#vlm-quantization)). For detailed TensorRT-LLM torch backend multimodal support, +> please refer to [this doc](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/models/supported-models.md#multimodal-feature-support-matrix-pytorch-backend).* + ## Framework Scripts ### Hugging Face Example [Script](./scripts/huggingface_example.sh) @@ -243,6 +257,20 @@ The cast pins each NVFP4 block's `scale_2 = 2^(k_max - 8)` and `_amax = 6 * 2^k_ [PTQ for DeepSeek](../deepseek/README.md) shows how to quantize the DeepSeek model with FP4 and export to TensorRT-LLM. +#### VLM quantization + +Vision-language models are quantized through the same script. Add `--vlm` so the script runs the +TensorRT-LLM multimodal quickstart as the deploy smoke test instead of the text-only one: + +```bash +scripts/huggingface_example.sh --model --quant fp8 --vlm +``` + +Supported `--quant` values for VLMs are `fp8`, `nvfp4`, `int8_sq`, `int4_awq`, and `w4a8_awq` (see +the `(VLM)` rows in the [Support Matrix](#hugging-face-supported-models)). + +> *This consolidates the former `examples/vlm_ptq` example, which now forwards here.* + #### VLM calibration with image-text pairs (e.g., Nemotron VL) For vision-language models, calibration quality can likely improve by using image-text pairs instead of text-only data, especially on visual understanding tasks: @@ -257,6 +285,12 @@ python hf_ptq.py \ --calib_size 512 ``` +The same flag is exposed by the shell script: + +```bash +scripts/huggingface_example.sh --model --quant nvfp4 --vlm --calib_with_images --trust_remote_code +``` + > Note: when `--calib_with_images` is set, `--calib_size` must be a single value, and the calibration dataset is nvidia/nemotron_vlm_dataset_v2. This functionality is currently in beta and has been tested on `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16`. diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index d36754a8d42..9fc1a9e68ab 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -21,7 +21,6 @@ import logging import os import shutil -import sys import warnings from collections.abc import Callable, Iterable from pathlib import Path @@ -298,9 +297,6 @@ def is_speculative(hf_config): def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs) -> PreTrainedTokenizerBase: print(f"Initializing tokenizer from {ckpt_path}") - if "vila" in ckpt_path.lower(): - ckpt_path += "/llm" - tokenizer = AutoTokenizer.from_pretrained( ckpt_path, trust_remote_code=trust_remote_code, **kwargs ) @@ -616,13 +612,6 @@ def get_model( if device == "cpu": device_map = "cpu" - # Add VILA to sys.path before loading config if needed - if "vila" in ckpt_path.lower(): - vila_path = os.path.join(ckpt_path, "..", "VILA") - if vila_path not in sys.path: - sys.path.append(vila_path) - from llava.model import LlavaLlamaConfig, LlavaLlamaModel # noqa: F401 - # Prepare config kwargs for loading config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {} @@ -644,147 +633,133 @@ def get_model( # Note: Forcibly converting the model precision between bf16 and fp16 may introduce accuracy drop model_kwargs = config_kwargs.copy() - # Don't set torch_dtype for VILA models as they handle it explicitly in their builder - if "vila" not in ckpt_path.lower(): - model_kwargs.setdefault("dtype", "auto") - - if "vila" in ckpt_path.lower(): - hf_vila = AutoModel.from_pretrained( + model_kwargs.setdefault("dtype", "auto") + + if use_seq_device_map: + device_map = "sequential" + # If we use sequential, set max_memory limit to ensure that the model does not occupy the full GPU + max_memory = get_max_memory() + max_memory = {key: value * gpu_mem_percentage for key, value in max_memory.items()} + model_kwargs["max_memory"] = max_memory + + if hf_config.model_type == "bart": + # device_map "auto" and "cuda" triggers error regarding meta tensor from safetensors + device_map = None + + if hf_config.model_type == "t5": + # device_map "auto" can naively shard T5's tied encoder/decoder embeddings and + # position-bias buffers across GPUs, which non-deterministically produces NaN + # activations during calibration on multi-GPU machines (see HF transformers #21093). + device_map = None + + # Helper function to check if model has pack-quantized config + def has_pack_quantized_config(config): + # Check top-level quantization_config + if hasattr(config, "quantization_config"): + if config.quantization_config.get("format", None) == "pack-quantized": + return True + # Check nested text_config.quantization_config (for multi-modal models like kimi k2.5) + if hasattr(config, "text_config") and hasattr(config.text_config, "quantization_config"): + if config.text_config.quantization_config.get("format", None) == "pack-quantized": + return True + return False + + if is_speculative(hf_config): + model = AutoModelForCausalLM.from_pretrained( ckpt_path, device_map=device_map, **model_kwargs, ) - model = hf_vila.llm - else: - if use_seq_device_map: - device_map = "sequential" - # If we use sequential, set max_memory limit to ensure that the model does not occupy the full GPU - max_memory = get_max_memory() - max_memory = {key: value * gpu_mem_percentage for key, value in max_memory.items()} - model_kwargs["max_memory"] = max_memory - - if hf_config.model_type == "bart": - # device_map "auto" and "cuda" triggers error regarding meta tensor from safetensors - device_map = None - - if hf_config.model_type == "t5": - # device_map "auto" can naively shard T5's tied encoder/decoder embeddings and - # position-bias buffers across GPUs, which non-deterministically produces NaN - # activations during calibration on multi-GPU machines (see HF transformers #21093). - device_map = None + elif has_pack_quantized_config(hf_config): + from modelopt.torch.quantization.plugins.huggingface import patch_compressed_linear_loading - # Helper function to check if model has pack-quantized config - def has_pack_quantized_config(config): - # Check top-level quantization_config - if hasattr(config, "quantization_config"): - if config.quantization_config.get("format", None) == "pack-quantized": - return True - # Check nested text_config.quantization_config (for multi-modal models like kimi k2.5) - if hasattr(config, "text_config") and hasattr( - config.text_config, "quantization_config" - ): - if config.text_config.quantization_config.get("format", None) == "pack-quantized": - return True - return False - - if is_speculative(hf_config): + with patch_compressed_linear_loading(): model = AutoModelForCausalLM.from_pretrained( ckpt_path, - device_map=device_map, - **model_kwargs, - ) - elif has_pack_quantized_config(hf_config): - from modelopt.torch.quantization.plugins.huggingface import ( - patch_compressed_linear_loading, + device_map="auto", + trust_remote_code=trust_remote_code, + dtype="auto", ) + elif get_original_hf_quant_method(hf_config) == "mxfp4": + # Native MXFP4 checkpoints (e.g. openai/gpt-oss-*) must be dequantized to + # plain BF16 experts (``GptOssExperts``) so ModelOpt can insert and export + # quantizers: the packed-kernel experts wrapper (``Mxfp4GptOssExperts``, + # used when the optional ``kernels`` package is present) is not supported by + # the unified HF export. Force dequantization regardless of whether + # ``kernels`` is installed. + # Local import: ``Mxfp4Config`` only exists in newer Transformers (gpt-oss support); + # importing it at module scope would break example_utils for users on older + # Transformers running unrelated (non-MXFP4) models. + from transformers import Mxfp4Config + + # Load with a *sequential* device map (not "auto"): the MXFP4->BF16 dequant + # runs inside Transformers' threaded weight loader, and an "auto"/balanced + # split across multiple GPUs trips a CUDA illegal-memory access during dequant + # materialization. Sequential keeps each shard's dequant on a single device + # (the whole model lands on one GPU when it fits there). + model_kwargs["quantization_config"] = Mxfp4Config(dequantize=True) + model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + device_map="cpu" if device == "cpu" else "sequential", + **model_kwargs, + ) + else: + architecture = hf_config.architectures[0] - with patch_compressed_linear_loading(): - model = AutoModelForCausalLM.from_pretrained( - ckpt_path, - device_map="auto", - trust_remote_code=trust_remote_code, - dtype="auto", + if not hasattr(transformers, architecture) or "Deepseek" in architecture: + if not hasattr(transformers, architecture): + warnings.warn( + f"Architecture {architecture} not found in transformers: {transformers.__version__}. " + "Falling back to AutoModelForCausalLM (or AutoModel for non-causal architectures)." ) - elif get_original_hf_quant_method(hf_config) == "mxfp4": - # Native MXFP4 checkpoints (e.g. openai/gpt-oss-*) must be dequantized to - # plain BF16 experts (``GptOssExperts``) so ModelOpt can insert and export - # quantizers: the packed-kernel experts wrapper (``Mxfp4GptOssExperts``, - # used when the optional ``kernels`` package is present) is not supported by - # the unified HF export. Force dequantization regardless of whether - # ``kernels`` is installed. - # Local import: ``Mxfp4Config`` only exists in newer Transformers (gpt-oss support); - # importing it at module scope would break example_utils for users on older - # Transformers running unrelated (non-MXFP4) models. - from transformers import Mxfp4Config - - # Load with a *sequential* device map (not "auto"): the MXFP4->BF16 dequant - # runs inside Transformers' threaded weight loader, and an "auto"/balanced - # split across multiple GPUs trips a CUDA illegal-memory access during dequant - # materialization. Sequential keeps each shard's dequant on a single device - # (the whole model lands on one GPU when it fits there). - model_kwargs["quantization_config"] = Mxfp4Config(dequantize=True) - model = AutoModelForCausalLM.from_pretrained( - ckpt_path, - device_map="cpu" if device == "cpu" else "sequential", - **model_kwargs, + assert trust_remote_code, ( + "Please set trust_remote_code to True if you want to use this architecture" ) - else: - architecture = hf_config.architectures[0] - - if not hasattr(transformers, architecture) or "Deepseek" in architecture: - if not hasattr(transformers, architecture): - warnings.warn( - f"Architecture {architecture} not found in transformers: {transformers.__version__}. " - "Falling back to AutoModelForCausalLM (or AutoModel for non-causal architectures)." - ) - assert trust_remote_code, ( - "Please set trust_remote_code to True if you want to use this architecture" - ) - # Use AutoModelForCausalLM for causal LMs, AutoModel for encoder-decoder models - if getattr(hf_config, "is_encoder_decoder", False): - auto_model_module = AutoModel - else: - auto_model_module = AutoModelForCausalLM - from_config = auto_model_module.from_config + # Use AutoModelForCausalLM for causal LMs, AutoModel for encoder-decoder models + if getattr(hf_config, "is_encoder_decoder", False): + auto_model_module = AutoModel else: - auto_model_module = getattr(transformers, architecture) - from_config = auto_model_module._from_config - - with init_empty_weights(include_buffers=True): - # When computing the device_map, assuming bfloat16 precision by default, - # unless specified by the hf_config. - torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) - model_kwargs2 = model_kwargs.copy() - if auto_model_module not in [AutoModelForCausalLM, AutoModel]: - model_kwargs2.pop("trust_remote_code", None) - model_kwargs2["dtype"] = torch_dtype - model_kwargs2.pop("max_memory", None) - model = from_config(hf_config, **model_kwargs2) - - max_memory = get_max_memory() - inferred_device_map = infer_auto_device_map(model, max_memory=max_memory) - - on_cpu = "cpu" in inferred_device_map.values() - - if on_cpu: - for _device in max_memory: - if isinstance(_device, int): - max_memory[_device] *= gpu_mem_percentage - - print( - "Model does not fit to the GPU mem. " - f"We apply the following memory limit for calibration: \n{max_memory}\n" - "If you hit GPU OOM issue, please adjust `gpu_mem_percentage` or " - "reduce the calibration `batch_size` manually." - ) - model_kwargs["max_memory"] = max_memory + auto_model_module = AutoModelForCausalLM + from_config = auto_model_module.from_config + else: + auto_model_module = getattr(transformers, architecture) + from_config = auto_model_module._from_config + + with init_empty_weights(include_buffers=True): + # When computing the device_map, assuming bfloat16 precision by default, + # unless specified by the hf_config. + torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) + model_kwargs2 = model_kwargs.copy() + if auto_model_module not in [AutoModelForCausalLM, AutoModel]: + model_kwargs2.pop("trust_remote_code", None) + model_kwargs2["dtype"] = torch_dtype + model_kwargs2.pop("max_memory", None) + model = from_config(hf_config, **model_kwargs2) + + max_memory = get_max_memory() + inferred_device_map = infer_auto_device_map(model, max_memory=max_memory) + + on_cpu = "cpu" in inferred_device_map.values() + + if on_cpu: + for _device in max_memory: + if isinstance(_device, int): + max_memory[_device] *= gpu_mem_percentage - model = auto_model_module.from_pretrained( - ckpt_path, - device_map=device_map, - **model_kwargs, + print( + "Model does not fit to the GPU mem. " + f"We apply the following memory limit for calibration: \n{max_memory}\n" + "If you hit GPU OOM issue, please adjust `gpu_mem_percentage` or " + "reduce the calibration `batch_size` manually." ) + model_kwargs["max_memory"] = max_memory + + model = auto_model_module.from_pretrained( + ckpt_path, + device_map=device_map, + **model_kwargs, + ) model.eval() if has_pack_quantized_config(hf_config): _unpack_compressed_linear_weights(model, ckpt_path) diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 3f51e5b73f3..a073fb7fecb 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -86,6 +86,10 @@ fi PTQ_ARGS="" +if $CALIB_WITH_IMAGES; then + PTQ_ARGS+=" --calib_with_images " +fi + if [ "$LOW_MEMORY_MODE" = "true" ]; then PTQ_ARGS+=" --low_memory_mode " fi @@ -223,7 +227,21 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH # Only run the deploy+generate smoke test when "quant" is explicitly requested. Eval tasks # (lm_eval/mmlu/simple_eval) deploy the checkpoint themselves, so it is redundant there. if [[ $TASKS =~ "quant" ]]; then - python run_tensorrt_llm.py --checkpoint_dir=$SAVE_PATH $RUN_ARGS + if $VLM; then + # VLMs use the TRT-LLM multimodal quickstart for the deploy smoke test. + if [ -z "$TRT_LLM_CODE_PATH" ]; then + TRT_LLM_CODE_PATH=/app/tensorrt_llm # default path for the TRT-LLM release docker image + echo "Setting default TRT_LLM_CODE_PATH to $TRT_LLM_CODE_PATH." + fi + QUICK_START_MULTIMODAL=$TRT_LLM_CODE_PATH/examples/llm-api/quickstart_multimodal.py + if [ -f "$QUICK_START_MULTIMODAL" ]; then + python3 "$QUICK_START_MULTIMODAL" --model_dir "$SAVE_PATH" --modality image + else + echo "Warning: $QUICK_START_MULTIMODAL cannot be found. Please set TRT_LLM_CODE_PATH to the TRT-LLM code path or test the quantized checkpoint $SAVE_PATH with the TRT-LLM repo directly." + fi + else + python run_tensorrt_llm.py --checkpoint_dir="$SAVE_PATH" $RUN_ARGS + fi fi fi diff --git a/examples/llm_ptq/scripts/parser.sh b/examples/llm_ptq/scripts/parser.sh index 3efed91bc32..06b440e5731 100644 --- a/examples/llm_ptq/scripts/parser.sh +++ b/examples/llm_ptq/scripts/parser.sh @@ -37,9 +37,11 @@ parse_options() { VERBOSE=true USE_SEQ_DEVICE_MAP=false CAST_MXFP4_TO_NVFP4=false + VLM=false + CALIB_WITH_IMAGES=false # Parse command-line options - ARGS=$(getopt -o "" -l "model:,quant:,recipe:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,simple_eval_limit:,mmlu_limit:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") + ARGS=$(getopt -o "" -l "model:,quant:,recipe:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,simple_eval_limit:,mmlu_limit:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4,vlm,calib_with_images" -n "$0" -- "$@") eval set -- "$ARGS" while true; do @@ -76,6 +78,8 @@ parse_options() { --auto_quantize_checkpoint ) AUTO_QUANTIZE_CHECKPOINT="$2"; shift 2;; --moe_calib_experts_ratio ) MOE_CALIB_EXPERTS_RATIO="$2"; shift 2;; --cast_mxfp4_to_nvfp4 ) CAST_MXFP4_TO_NVFP4=true; shift;; + --vlm ) VLM=true; shift;; + --calib_with_images ) CALIB_WITH_IMAGES=true; shift;; -- ) shift; break ;; * ) break ;; esac @@ -176,5 +180,7 @@ parse_options() { echo "auto_quantize_checkpoint: $AUTO_QUANTIZE_CHECKPOINT" echo "moe_calib_experts_ratio: $MOE_CALIB_EXPERTS_RATIO" echo "cast_mxfp4_to_nvfp4: $CAST_MXFP4_TO_NVFP4" + echo "vlm: $VLM" + echo "calib_with_images: $CALIB_WITH_IMAGES" echo "=================" } diff --git a/examples/vlm_ptq/README.md b/examples/vlm_ptq/README.md index 8b9c31aa429..b7f5a30f1b8 100644 --- a/examples/vlm_ptq/README.md +++ b/examples/vlm_ptq/README.md @@ -1,80 +1,31 @@ -# Post-training quantization (PTQ) for Vision Language Models +# [Deprecated] Post-training quantization (PTQ) for Vision Language Models -To learn more about the quantization feature, please refer to the [documentation](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html). +> **This example has been consolidated into [`examples/llm_ptq`](../llm_ptq/README.md) and is +> deprecated.** It will be removed in a future release. VLM PTQ now shares the same entry point +> (`hf_ptq.py`) and shell script as LLM PTQ. -Quantization is an effective model optimization technique that compresses your models. Quantization with Model Optimizer can compress model size by 2x-4x, speeding up inference while preserving model quality. \ -Model Optimizer enables highly performant quantization formats including NVFP4, FP8, INT8, INT4 and supports advanced algorithms such as SmoothQuant, AWQ, SVDQuant, and Double Quantization with easy-to-use Python APIs. +## Migration -This section focuses on Post-training quantization for VLM (Vision Language Models), a technique that reduces model precision after training to improve inference efficiency without requiring retraining. - -
- -| **Section** | **Description** | **Link** | **Docs** | -| :------------: | :------------: | :------------: | :------------: | -| Pre-Requisites | Required & optional packages to use this technique | \[[Link](#pre-requisites)\] | | -| Getting Started | Learn how to optimize your models using PTQ to reduce precision and improve inference efficiency | \[[Link](#getting-started)\] | \[[docs](https://nvidia.github.io/Model-Optimizer/guides/1_quantization.html)\] | -| Support Matrix | View the support matrix to see quantization compatibility and feature availability across different models | \[[Link](#support-matrix)\] | | -| Framework Scripts | Example scripts demonstrating quantization techniques for optimizing Hugging Face / Megatron-Bridge / Megatron-LM models | \[[Link](#framework-scripts)\] | | -| Pre-Quantized Checkpoints | Ready to deploy Hugging Face pre-quantized checkpoints | \[[Link](#pre-quantized-checkpoints)\] | | -| Resources | Extra links to relevant resources | \[[Link](#resources)\] | | - -
- -## Pre-Requisites - -Please refer to the [llm_ptq/README.md](../llm_ptq/README.md#pre-requisites) for the pre-requisites. - -## Getting Started - -Please refer to the [llm_ptq/README.md](../llm_ptq/README.md#getting-started) for the getting-started. - -## Support Matrix - -### Supported Models - -| Model | fp8 | int8_sq1 | int4_awq | w4a8_awq2 | nvfp43 | -| :---: | :---: | :---: | :---: | :---: | :---: | -| Llava | ✅ | ✅ | ✅ | ✅ | - | -| VILA | ✅ | ✅ | ✅ | ✅ | - | -| Phi-3-vision, Phi-4-multimodal | ✅ | ✅ | ✅ | ✅ | ✅ | -| Qwen2, 2.5-VL | ✅ | ✅ | ✅ | ✅ | ✅ | -| Gemma3 | ✅ | - | - | - | - | - -> *1.Only TensorRT-LLM checkpoint export is supported. Not compatible with the TensorRT-LLM torch backend* \ -> *2.The w4a8_awq is an experimental quantization scheme that may result in a higher accuracy penalty.* \ -> *3.A selective set of the popular models are internally tested. The actual model support list may be longer. NVFP4 inference requires Blackwell GPUs and TensorRT-LLM v0.17 or later.* - -> *For detailed TensorRT-LLM torch backend multimodal support, please refer to [this doc](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/models/supported-models.md#multimodal-feature-support-matrix-pytorch-backend)* - -> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](../llm_ptq/hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead.* - -## Framework Scripts - -Please refer to the [llm_ptq/README.md](../llm_ptq/README.md) about the details of model quantization. - -The following scripts provide an all-in-one and step-by-step model quantization example for the supported Hugging Face multi-modal models. The quantization format and the number of GPUs will be supplied as inputs to these scripts. - -### Hugging Face Example [Script](./scripts/huggingface_example.sh) +Use the `llm_ptq` script with the `--vlm` flag: ```bash -scripts/huggingface_example.sh --model --quant [fp8|nvfp4|int8_sq|int4_awq|w4a8_awq] +cd examples/llm_ptq +scripts/huggingface_example.sh --model --quant [fp8|nvfp4|int8_sq|int4_awq|w4a8_awq] --vlm ``` -### Megatron-Bridge Example - -Please refer to the [examples/megatron_bridge/](../megatron_bridge/README.md) for example scripts for PTQ with Megatron-Bridge. +The previous `examples/vlm_ptq/scripts/huggingface_example.sh` entry point still works: it now +prints a deprecation warning and forwards to the command above. -## Pre-Quantized Checkpoints +## Where things moved -- Ready-to-deploy checkpoints \[[🤗 Hugging Face - Nvidia Model Optimizer Collection](https://huggingface.co/collections/nvidia/inference-optimized-checkpoints-with-model-optimizer)\] -- Deployable on [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang) -- More models coming soon! +| Topic | New location | +| :--- | :--- | +| Supported VLMs / support matrix | [llm_ptq/README.md#hugging-face-supported-models](../llm_ptq/README.md#hugging-face-supported-models) | +| VLM quantization workflow (`--vlm`) | [llm_ptq/README.md#vlm-quantization](../llm_ptq/README.md#vlm-quantization) | +| Image-text calibration (`--calib_with_images`) | [llm_ptq/README.md#vlm-calibration-with-image-text-pairs-eg-nemotron-vl](../llm_ptq/README.md#vlm-calibration-with-image-text-pairs-eg-nemotron-vl) | +| Megatron-Bridge VLM PTQ | [examples/megatron_bridge/](../megatron_bridge/README.md) | ## Resources -- 📅 [Roadmap](https://github.com/NVIDIA/Model-Optimizer/issues/146) - 📖 [Documentation](https://nvidia.github.io/Model-Optimizer) -- 🎯 [Benchmarks](../benchmark.md) - 💡 [Release Notes](https://nvidia.github.io/Model-Optimizer/reference/0_changelog.html) -- 🐛 [File a bug](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=1_bug_report.md) -- ✨ [File a Feature Request](https://github.com/NVIDIA/Model-Optimizer/issues/new?template=2_feature_request.md) diff --git a/examples/vlm_ptq/scripts/huggingface_example.sh b/examples/vlm_ptq/scripts/huggingface_example.sh index eada1c137f8..2e42a3c768d 100755 --- a/examples/vlm_ptq/scripts/huggingface_example.sh +++ b/examples/vlm_ptq/scripts/huggingface_example.sh @@ -14,130 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -set -e - -script_dir="$(dirname "$(readlink -f "$0")")" - -source $script_dir/../../llm_ptq/scripts/parser.sh -parse_options "$@" - -set -x - -# This will prevent the script from hanging on Selene/EOS due to the MPI support. -echo "********** unset all SLURM_, PMI_, PMIX_ Variables **********" -for i in $(env | grep ^SLURM_ | cut -d"=" -f 1); do unset -v $i; done -for i in $(env | grep ^PMI_ | cut -d"=" -f 1); do unset -v $i; done -for i in $(env | grep ^PMIX_ | cut -d"=" -f 1); do unset -v $i; done +# DEPRECATED: examples/vlm_ptq has been consolidated into examples/llm_ptq. +# This shim forwards all arguments to the llm_ptq script with the --vlm flag so existing +# commands keep working. Please migrate to: +# +# cd examples/llm_ptq +# scripts/huggingface_example.sh --model --quant --vlm +# +# See examples/llm_ptq/README.md#vlm-quantization for details. -if [ -z "$MODEL_PATH" ]; then - echo "Unsupported model argument: Expected a huggingface model path or model name" >&2 - exit 1 -fi +set -e -case $QFORMAT in - fp8|int8_sq|int4_awq|w4a8_awq|nvfp4) - ;; - *) - echo "Unknown quant argument: Expected one of: [fp8, int8_sq, int4_awq, w4a8_awq, nvfp4]" >&2 - exit 1 -esac +echo "WARNING: examples/vlm_ptq is deprecated and will be removed in a future release." >&2 +echo " Forwarding to examples/llm_ptq/scripts/huggingface_example.sh --vlm" >&2 +echo " See examples/llm_ptq/README.md#vlm-quantization" >&2 script_dir="$(dirname "$(readlink -f "$0")")" -pushd $script_dir/.. - -if [ -z "$ROOT_SAVE_PATH" ]; then - ROOT_SAVE_PATH=$(pwd) -fi - -MODEL_NAME=$(basename $MODEL_PATH | sed 's/[^0-9a-zA-Z\-]/_/g')_${QFORMAT}${KV_CACHE_QUANT:+_kv_${KV_CACHE_QUANT}} -SAVE_PATH=${ROOT_SAVE_PATH}/saved_models_${MODEL_NAME} - -MODEL_CONFIG=${SAVE_PATH}/config.json - -if [ "${REMOVE_EXISTING_MODEL_CONFIG,,}" = "true" ]; then - rm -f $MODEL_CONFIG -fi - -PTQ_ARGS="" - -if [ -n "$AUTO_QUANTIZE_BITS" ]; then - PTQ_ARGS+=" --auto_quantize_bits $AUTO_QUANTIZE_BITS " -fi - -if $TRUST_REMOTE_CODE; then - PTQ_ARGS+=" --trust_remote_code " -fi - -if [ -n "$KV_CACHE_QUANT" ]; then - PTQ_ARGS+=" --kv_cache_qformat=$KV_CACHE_QUANT " -fi - -if [[ "${MODEL_NAME,,}" == *"vila"* ]]; then - # Check transformers version - must be <= 4.50.0 - CURRENT_TRANSFORMERS_VERSION=$(pip show transformers | grep Version | cut -d' ' -f2) - if [ "$(printf '%s\n' "4.50.0" "$CURRENT_TRANSFORMERS_VERSION" | sort -V | head -n1)" = "4.50.0" ] && [ "$CURRENT_TRANSFORMERS_VERSION" != "4.50.0" ]; then - echo "ERROR: transformers version $CURRENT_TRANSFORMERS_VERSION is not supported." >&2 - echo "VILA requires transformers<=4.50.0" >&2 - echo "Please refer to examples/vlm_ptq/requirements-vila.txt for the supported versions." >&2 - echo "You also need to download VILA repository from https://github.com/Efficient-Large-Model/VILA.git and checkout ec7fb2c264920bf004fd9fa37f1ec36ea0942db5" >&2 - exit 1 - fi - - pip install -r ../vlm_ptq/requirements-vila.txt - # Clone original VILA repo - if [ ! -d "$(dirname "$MODEL_PATH")/VILA" ]; then - echo "VILA repository is needed until it is added to HF model zoo. Cloning the repository parallel to $MODEL_PATH..." - git clone https://github.com/Efficient-Large-Model/VILA.git "$(dirname "$MODEL_PATH")/VILA" && \ - cd "$(dirname "$MODEL_PATH")/VILA" && \ - git checkout ec7fb2c264920bf004fd9fa37f1ec36ea0942db5 && \ - cd "$script_dir/.." - fi -fi - -if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH) ]]; then - if ! [ -f $MODEL_CONFIG ]; then - echo "Quantizing original model..." - python ../llm_ptq/hf_ptq.py \ - --pyt_ckpt_path=$MODEL_PATH \ - --export_path=$SAVE_PATH \ - --qformat=$QFORMAT \ - --calib_size=$CALIB_SIZE \ - --batch_size=$CALIB_BATCH_SIZE \ - --inference_tensor_parallel=$TP \ - --inference_pipeline_parallel=$PP \ - $PTQ_ARGS - else - echo "Quantized model config $MODEL_CONFIG exists, skipping the quantization stage" - fi -fi - -if [[ "$QFORMAT" != "fp8" ]]; then - echo "For quant format $QFORMAT, please refer to the TensorRT-LLM documentation for deployment. Checkpoint saved to $SAVE_PATH." - exit 0 -fi - -if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]]; then - cuda_major=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader -i 0 | cut -d. -f1) - - if [ "$cuda_major" -lt 10 ]; then - echo "Please deploy the NVFP4 checkpoint on a Blackwell GPU. Checkpoint export_path: $SAVE_PATH" - exit 0 - fi -fi - -# Prepare datasets for TRT-LLM benchmark -if [ -z "$TRT_LLM_CODE_PATH" ]; then - TRT_LLM_CODE_PATH=/app/tensorrt_llm # default path for the TRT-LLM release docker image - echo "Setting default TRT_LLM_CODE_PATH to $TRT_LLM_CODE_PATH." -fi - -QUICK_START_MULTIMODAL=$TRT_LLM_CODE_PATH/examples/llm-api/quickstart_multimodal.py - -if [ -f "$QUICK_START_MULTIMODAL" ]; then - python3 $QUICK_START_MULTIMODAL --model_dir $SAVE_PATH --modality image -else - echo "Warning: $QUICK_START_MULTIMODAL cannot be found. Please set TRT_LLM_CODE_PATH to the TRT-LLM code path or test the quantized checkpoint $SAVE_PATH with the TRT-LLM repo directly." -fi - -popd +exec "$script_dir/../../llm_ptq/scripts/huggingface_example.sh" --vlm "$@" diff --git a/tests/_test_utils/examples/run_command.py b/tests/_test_utils/examples/run_command.py index 64e23bb15e5..d650edc6259 100644 --- a/tests/_test_utils/examples/run_command.py +++ b/tests/_test_utils/examples/run_command.py @@ -62,19 +62,14 @@ def run_command_in_background( return process -def run_llm_ptq_command(*, model: str, quant: str, **kwargs): +def run_llm_ptq_command(*, model: str, quant: str, vlm: bool = False, **kwargs): kwargs.update({"model": model, "quant": quant}) kwargs.setdefault("tasks", "quant") kwargs.setdefault("calib", 16) - cmd_parts = extend_cmd_parts(["scripts/huggingface_example.sh", "--no-verbose"], **kwargs) + cmd_parts = ["scripts/huggingface_example.sh", "--no-verbose"] + if vlm: + # VLM PTQ shares the llm_ptq entry point; --vlm runs the multimodal deploy smoke test. + cmd_parts.append("--vlm") + cmd_parts = extend_cmd_parts(cmd_parts, **kwargs) run_example_command(cmd_parts, "llm_ptq") - - -def run_vlm_ptq_command(*, model: str, quant: str, **kwargs): - kwargs.update({"model": model, "quant": quant}) - kwargs.setdefault("tasks", "quant") - kwargs.setdefault("calib", 16) - - cmd_parts = extend_cmd_parts(["scripts/huggingface_example.sh"], **kwargs) - run_example_command(cmd_parts, "vlm_ptq") diff --git a/tests/examples/vlm_ptq/test_qwen_vl.py b/tests/examples/llm_ptq/test_vlm_ptq.py similarity index 86% rename from tests/examples/vlm_ptq/test_qwen_vl.py rename to tests/examples/llm_ptq/test_vlm_ptq.py index 458d7563db0..4ccfe8f753d 100644 --- a/tests/examples/vlm_ptq/test_qwen_vl.py +++ b/tests/examples/llm_ptq/test_vlm_ptq.py @@ -13,12 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - import pytest from _test_utils.examples.models import QWEN_VL_PATH -from _test_utils.examples.run_command import run_vlm_ptq_command +from _test_utils.examples.run_command import run_llm_ptq_command @pytest.mark.parametrize("quant", ["fp8", "int8_sq", "nvfp4"]) def test_qwen_vl(quant): - run_vlm_ptq_command(model=QWEN_VL_PATH, quant=quant) + run_llm_ptq_command(model=QWEN_VL_PATH, quant=quant, vlm=True) diff --git a/tests/examples/vlm_ptq/_extensions/test_torch_extensions.py b/tests/examples/vlm_ptq/_extensions/test_torch_extensions.py deleted file mode 120000 index 9267cc06186..00000000000 --- a/tests/examples/vlm_ptq/_extensions/test_torch_extensions.py +++ /dev/null @@ -1 +0,0 @@ -../../../gpu/_extensions/test_torch_extensions.py \ No newline at end of file