From 544ccd8f183e66200aff00bc4492cfe94d0817a6 Mon Sep 17 00:00:00 2001 From: zhanghaoc <2272055687@qq.com> Date: Mon, 20 Oct 2025 10:26:36 -0700 Subject: [PATCH 1/5] add basic support for kv cache quantization Signed-off-by: zhanghaoc <2272055687@qq.com> --- modelopt/onnx/quantization/__main__.py | 22 ++- modelopt/onnx/quantization/int4.py | 43 +++++- modelopt/onnx/quantization/kv_cache.py | 204 +++++++++++++++++++++++++ modelopt/onnx/quantization/quantize.py | 15 ++ 4 files changed, 282 insertions(+), 2 deletions(-) create mode 100644 modelopt/onnx/quantization/kv_cache.py diff --git a/modelopt/onnx/quantization/__main__.py b/modelopt/onnx/quantization/__main__.py index 55cca6ee5..d118af60c 100644 --- a/modelopt/onnx/quantization/__main__.py +++ b/modelopt/onnx/quantization/__main__.py @@ -41,7 +41,7 @@ def get_parser() -> argparse.ArgumentParser: argparser.add_argument( "--calibration_method", type=str, - choices=["max", "entropy", "awq_clip", "rtn_dq"], + choices=["max", "entropy", "awq_clip", "rtn_dq", "awq_full", "awq_lite"], help=( "Calibration method choices for int8/fp8: {entropy (default), max}, " "int4: {awq_clip (default), rtn_dq}." @@ -255,6 +255,24 @@ def get_parser() -> argparse.ArgumentParser: "The currently supported precisions are {fp16, int8, fp8}." ), ) + argparser.add_argument( + "--kv_quant_mode", + type=str, + choices=["NONE", "PER_TENSOR", "PER_CHANNEL"], + default="NONE", + help=( + "Quantization mode for kv cache in GQA. NONE (default) means no quantization for kv cache, " + ), + ) + argparser.add_argument( + "--kv_cache_type", + type=str, + choices=["fp8", "int8"], + default="NONE", + help=( + "Quantization type for kv cache in GQA. fp8 is default." + ), + ) return argparser @@ -298,6 +316,8 @@ def main(): simplify=args.simplify, calibrate_per_node=args.calibrate_per_node, direct_io_types=args.direct_io_types, + kv_quant_mode=args.kv_quant_mode, + kv_cache_type=args.kv_cache_type, ) diff --git a/modelopt/onnx/quantization/int4.py b/modelopt/onnx/quantization/int4.py index 1086b5a4d..6e9a3a558 100644 --- a/modelopt/onnx/quantization/int4.py +++ b/modelopt/onnx/quantization/int4.py @@ -57,6 +57,10 @@ update_block_size, ) from modelopt.onnx.utils import save_onnx +from modelopt.onnx.quantization.kv_cache import ( + save_kv_cache_calib_data, + save_kv_cache_calib_data_rtn, +) __all__ = ["quantize"] @@ -444,6 +448,8 @@ def _quantize_awq_clip( force_fp16: bool = False, nodes_to_exclude: list[str] = [], input_shapes_profile: Sequence[dict[str, str]] | None = None, + intermediate_generated_files: list[str] = [], + kv_quant_mode: str = "NONE", **kwargs: Any, ) -> onnx.ModelProto: """Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm.""" @@ -482,10 +488,19 @@ def _quantize_awq_clip( # Apply AWQ clip on selected weights t = time.time() alphas = {} + + if kv_quant_mode != "NONE": + save_kv_cache_calib_data( + onnx_model, + session=session, + inputs=inputs, + intermediate_generated_files=intermediate_generated_files, + ) + for i in tqdm(range(len(wa_pack)), desc="Running clip search..."): act_tensor, weight_tensor, do_transpose, gemm_io_type = wa_pack[i] - # First capture all the activation values after calibration data sweep + # First capture all the activation values after calibration data sweep output_dicts = {} for inp_d in inputs: np_inp_d = {name: numpy.asarray(tensor) for name, tensor in inp_d.items()} @@ -968,6 +983,8 @@ def _quantize_awq_lite( use_zero_point: bool = False, nodes_to_exclude: list[str] = [], input_shapes_profile: Sequence[dict[str, str]] | None = None, + intermediate_generated_files: list[str] = [], + kv_quant_mode: str = "NONE", **kwargs: Any, ) -> onnx.ModelProto: """Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm.""" @@ -1025,6 +1042,14 @@ def _quantize_awq_lite( gc.collect() + if kv_quant_mode != "NONE": + save_kv_cache_calib_data( + onnx_model, + session=session, + inputs=inputs, + intermediate_generated_files=intermediate_generated_files, + ) + output_data = [] if enable_fast_path_using_high_sysram: @@ -1328,6 +1353,8 @@ def quantize( nodes_to_exclude: list[str] | None = [r"/lm_head"], log_level: str = "INFO", input_shapes_profile: Sequence[dict[str, str]] | None = None, + intermediate_generated_files: list[str] = [], + kv_quant_mode: str = "NONE", **kwargs: Any, ) -> onnx.ModelProto: """Applies INT4 Weight-Only-Quantization (WoQ) to an ONNX model. @@ -1421,6 +1448,16 @@ def quantize( qdq.use_trt_qdq_ops() if calibration_method in ["rtn", "rtn_dq", "rtn_trt", "rtn_trt_dq"]: + # Save kv-cache calibration data if kv_quant_mode is not NONE + if kv_quant_mode != "NONE": + save_kv_cache_calib_data_rtn( + onnx_model, + intermediate_generated_files=intermediate_generated_files, + data_reader=calibration_data_reader, + calibration_eps=calibration_eps, + input_shapes_profile=input_shapes_profile, + use_external_data_format=use_external_data_format, + ) onnx_model = quantize_rtn( onnx_model, block_size, @@ -1445,6 +1482,8 @@ def quantize( use_zero_point=use_zero_point, enable_weight_clipping=do_weight_clipping, input_shapes_profile=input_shapes_profile, + kv_quant_mode=kv_quant_mode, + intermediate_generated_files=intermediate_generated_files, **kwargs, ) elif calibration_method in ["awq_clip", "awq_clip_trt"]: @@ -1456,6 +1495,8 @@ def quantize( block_size, nodes_to_exclude=nodes_to_exclude, input_shapes_profile=input_shapes_profile, + kv_quant_mode=kv_quant_mode, + intermediate_generated_files=intermediate_generated_files, **kwargs, ) else: diff --git a/modelopt/onnx/quantization/kv_cache.py b/modelopt/onnx/quantization/kv_cache.py new file mode 100644 index 000000000..3b60a25db --- /dev/null +++ b/modelopt/onnx/quantization/kv_cache.py @@ -0,0 +1,204 @@ +import onnx +import pickle +import numpy as np +import onnxruntime as ort +import tempfile +import os +import copy +from collections.abc import Sequence +from tqdm import tqdm +from pathlib import Path +from modelopt.onnx.logging_config import logger +from onnxruntime.quantization.calibrate import CalibrationDataReader +from modelopt.onnx.quantization.ort_utils import create_inference_session +from modelopt.onnx.utils import save_onnx + + +# using fp8 as default quantization mode +def kv_cache_quantize( + onnx_model: onnx.ModelProto, + kv_cache_type: str = "fp8", # only support fp8 and int8 for now, will support fp16, uint8 later + kv_quant_mode: str = "NONE", # NONE / PER_TENSOR / PER_CHANNEL + kv_cache_bit_width: int = 0, # only used for uint8, available options: 2, 4, 8 + intermediate_generated_files: list[str] = [], + calibration_method: str | None = None, # "awq_clip", "awq_lite", "rtn_dq" +) -> onnx.ModelProto: + + logger.info(f"Start kv cache quantization with kv_cache_type {kv_cache_type}, " + f"kv_quant_mode {kv_quant_mode}, calibration_method {calibration_method}") + + logger.info(f"intermediate_generated_files: {intermediate_generated_files}") + + kv_tensor_names_list = [] + + # replace each tensor starting with past_key_values + for input in onnx_model.graph.input: + if "past_key_values" in input.name: + if kv_cache_type == "fp8": + input.type.tensor_type.elem_type = onnx.TensorProto.FLOAT8E4M3FN + elif kv_cache_type == "int8": + input.type.tensor_type.elem_type = onnx.TensorProto.INT8 + else: + raise ValueError(f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization") + + # Update graph output similarly, at the sametime add all output names, + # it will be used to collect calibration data later + for output in onnx_model.graph.output: + if "present" in output.name: + kv_tensor_names_list.append(output.name) + if kv_cache_type == "fp8": + output.type.tensor_type.elem_type = onnx.TensorProto.FLOAT8E4M3FN + elif kv_cache_type == "int8": + output.type.tensor_type.elem_type = onnx.TensorProto.INT8 + else: + raise ValueError(f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization") + + kv_tensor_names_list.sort() + + # loop through all nodes and find GQA node + group_query_attention_nodes = [ + node for node in onnx_model.graph.node if node.op_type == "GroupQueryAttention" + ] + + tensor_range = None + # both scale is a TensorData, scale's shape depends on kv_quant_mode + k_scale = None + v_scale = None + + # look for file named tmp_calib_data.json in intermediate_generated_files + for intermediate_file in intermediate_generated_files: + # if a file end with .calib_data, use it as calibration data + if intermediate_file.endswith("tmp_calib_data.json"): + # load calibration data from file + with open(intermediate_file, "rb") as f: + tensor_range = pickle.load(f) + logger.info(f"Using calibration data from {intermediate_file} for kv cache quantization") + break + + if calibration_method in ["awq_clip", "awq_lite", "rtn_dq"]: + logger.info(f"Using {calibration_method} calibration method for kv cache quantization") + # parse tensor_range + for node in group_query_attention_nodes: + # calculate k_scale based on input and output range + k_max = 0 + v_max = 0 + for output in node.output: + if "key" in output: + index = kv_tensor_names_list.index(output) + for data_range in tensor_range: + k_max = max(k_max, np.abs(np.asarray(data_range[index]).max())) + if "value" in output: + index = kv_tensor_names_list.index(output) + for data_range in tensor_range: + v_max = max(v_max, np.abs(np.asarray(data_range[index]).max())) + if kv_quant_mode == "PER_TENSOR": + Qmax = 0 + if kv_cache_type == "fp8": + Qmax = 448 # max fp value for E4M3 + elif kv_cache_type == "int8": + Qmax = 127 # max int8 value + else: + raise ValueError(f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization") + # create onnx tensor data as fp16 and assign to k_scale and v_scale + k_scale = onnx.helper.make_tensor( + name=node.name + "_k_scale", + data_type=onnx.TensorProto.FLOAT16, + dims=[1], + vals=[k_max / Qmax] if k_max != 0 else [1.0], + ) + v_scale = onnx.helper.make_tensor( + name=node.name + "_v_scale", + data_type=onnx.TensorProto.FLOAT16, + dims=[1], + vals=[v_max / Qmax] if v_max != 0 else [1.0], + ) + onnx_model.graph.initializer.append(k_scale) + onnx_model.graph.initializer.append(v_scale) + # add scale to input, use empty string to pad the input to 12, + # insert k_scale at index 12 and v_scale at index 13 + while len(node.input) < 12: + node.input.append("") + node.input.append(k_scale.name) + node.input.append(v_scale.name) + + # add attributes to GQA node + for node in group_query_attention_nodes: + # add attribute for quantization type + node.attribute.append(onnx.helper.make_attribute("k_quant_type", kv_quant_mode)) + node.attribute.append(onnx.helper.make_attribute("v_quant_type", kv_quant_mode)) + # set bit width attribute, only used for uint8, not supported currently + if kv_cache_type == "uint8": + node.attribute.append(onnx.helper.make_attribute("kv_cache_bit_width", kv_cache_bit_width)) + logger.info(f"kv cache quantization done") + + return onnx_model + +def save_kv_cache_calib_data( + onnx_model: onnx.ModelProto, + session: ort.InferenceSession | None = None, + inputs: list[dict] = [], + intermediate_generated_files: list[str] = [], +): + kv_tensor_data = [] + kv_tensor_names_list = [] + for output in onnx_model.graph.output: + if "present" in output.name: + kv_tensor_names_list.append(output.name) + + kv_tensor_names_list.sort() + + for i in tqdm(range(len(inputs)), desc="Caching activations..."): + inp_d = inputs[i] + np_inp_d = {name: np.asarray(tensor) for name, tensor in inp_d.items()} + output = session.run(kv_tensor_names_list, np_inp_d) + kv_tensor_data.append(output) + + # save to tmp file named tmp_calib_data.json + tmp_dir = tempfile.mkdtemp() + calib_data_path = Path(tmp_dir).joinpath("tmp_calib_data.json").as_posix() + # call to_dict and save to json + with open(calib_data_path, "wb") as f: + pickle.dump(kv_tensor_data, f) + intermediate_generated_files.append(calib_data_path) + + +def save_kv_cache_calib_data_rtn( + onnx_model: onnx.ModelProto, + data_reader: CalibrationDataReader | None = None, + calibration_eps: list[str] = [], + input_shapes_profile: Sequence[dict[str, str]] | None = None, + intermediate_generated_files: list[str] = [], + use_external_data_format: bool = False, +): + augmented_model = copy.deepcopy(onnx_model) + + # save model in augmented_onnx_path for creating inference session + augmented_onnx_file, augmented_onnx_path = tempfile.mkstemp(suffix=".onnx") + os.close(augmented_onnx_file) + + save_onnx(augmented_model, augmented_onnx_path, use_external_data_format) + + # Creating inference session and preparing inputs for calibration + session = create_inference_session(augmented_onnx_path, calibration_eps, input_shapes_profile) + inputs = [] + for inp_d in data_reader: + inputs.append(inp_d) + assert isinstance(inp_d, dict) + save_kv_cache_calib_data( + onnx_model, + session=session, + inputs=inputs, + intermediate_generated_files=intermediate_generated_files, + ) + + logger.info("Saved kv-cache calibration data for RTN quantization") + + del session + + try: + os.remove(augmented_onnx_path) + if use_external_data_format: + os.remove(augmented_onnx_path + "_data") + except OSError: + logger.warn("Augmented ONNX model or external data file was not found") + \ No newline at end of file diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 2d23b875a..66f36bfea 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -68,6 +68,7 @@ ) from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model from modelopt.onnx.utils import duplicate_shared_constants, name_onnx_nodes, save_onnx +from modelopt.onnx.quantization.kv_cache import kv_cache_quantize __all__ = ["quantize"] @@ -238,6 +239,8 @@ def quantize( calibrate_per_node: bool = False, input_shapes_profile: Sequence[dict[str, str]] | None = None, direct_io_types: bool = False, + kv_quant_mode: str = "NONE", + kv_cache_type: str = "fp8", **kwargs: Any, ) -> None: """Quantizes the provided ONNX model. @@ -511,12 +514,24 @@ def quantize( use_zero_point=use_zero_point, log_level=log_level, input_shapes_profile=input_shapes_profile, + intermediate_generated_files=intermediate_generated_files, + kv_quant_mode=kv_quant_mode, **kwargs, ) else: raise RuntimeError(f"Invalid quantization mode choice: {quantize_mode}") if onnx_model: + if quantize_mode == "int4" and kv_quant_mode != "NONE" and calibration_method in ["awq_clip", "awq_lite", "rtn_dq"]: + logger.info(f"Quantization mode for KV cache: {kv_quant_mode}, kv_cache_type: {kv_cache_type}") + onnx_model = kv_cache_quantize( + onnx_model, + kv_quant_mode=kv_quant_mode, + kv_cache_type=kv_cache_type, + intermediate_generated_files=intermediate_generated_files, + calibration_method=calibration_method, + ) + # Fuse Q nodes for INT8/FP8 mode if quantize_mode in ["int8", "fp8"]: if dq_only: From c062c087a45c33ad04a5e58945ecceac9b38bea8 Mon Sep 17 00:00:00 2001 From: zhanghaoc <2272055687@qq.com> Date: Wed, 22 Oct 2025 10:13:37 -0700 Subject: [PATCH 2/5] Add kv cache support Signed-off-by: zhanghaoc <2272055687@qq.com> --- modelopt/onnx/quantization/kv_cache.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modelopt/onnx/quantization/kv_cache.py b/modelopt/onnx/quantization/kv_cache.py index 3b60a25db..4affa9754 100644 --- a/modelopt/onnx/quantization/kv_cache.py +++ b/modelopt/onnx/quantization/kv_cache.py @@ -23,6 +23,8 @@ def kv_cache_quantize( intermediate_generated_files: list[str] = [], calibration_method: str | None = None, # "awq_clip", "awq_lite", "rtn_dq" ) -> onnx.ModelProto: + if kv_cache_type == "NONE": + kv_cache_type = "fp8" logger.info(f"Start kv cache quantization with kv_cache_type {kv_cache_type}, " f"kv_quant_mode {kv_quant_mode}, calibration_method {calibration_method}") From 7428db485b1f053c2c90e409afb7ae025d2e1511 Mon Sep 17 00:00:00 2001 From: zhanghaoc <2272055687@qq.com> Date: Thu, 30 Oct 2025 16:08:11 -0700 Subject: [PATCH 3/5] Add kv cache support for int8/fp8 Signed-off-by: zhanghaoc <2272055687@qq.com> --- modelopt/onnx/quantization/fp8.py | 3 + modelopt/onnx/quantization/int8.py | 3 + modelopt/onnx/quantization/kv_cache.py | 94 +++++++++++----------- modelopt/onnx/quantization/ort_patching.py | 13 +++ modelopt/onnx/quantization/quantize.py | 1 + 5 files changed, 68 insertions(+), 46 deletions(-) diff --git a/modelopt/onnx/quantization/fp8.py b/modelopt/onnx/quantization/fp8.py index ce7d56a26..8db397ce5 100755 --- a/modelopt/onnx/quantization/fp8.py +++ b/modelopt/onnx/quantization/fp8.py @@ -181,6 +181,7 @@ def quantize( calibrate_per_node: bool = False, custom_ops_to_quantize: list[str] = [], direct_io_types: bool = False, + kv_quant_mode: str = "NONE", **kwargs, ) -> onnx.ModelProto: """Applies FP8 GEMM only quantization to an ONNX file. @@ -295,6 +296,8 @@ def quantize( # With ActivationSymmetric as True, MinMax calibration is equivalent to max calibration else CalibrationMethod.MinMax ), + intermediate_generated_files=intermediate_generated_files, + kv_quant_mode=kv_quant_mode, ) intermediate_generated_files.append(tmp_onnx_path) if use_external_data_format: diff --git a/modelopt/onnx/quantization/int8.py b/modelopt/onnx/quantization/int8.py index baf5a4383..6f4f9ef44 100755 --- a/modelopt/onnx/quantization/int8.py +++ b/modelopt/onnx/quantization/int8.py @@ -131,6 +131,7 @@ def quantize( calibrate_per_node: bool = False, custom_ops_to_quantize: list[str] = [], direct_io_types: bool = False, + kv_quant_mode: str = "NONE", **kwargs, ) -> onnx.ModelProto: """Applies INT8 quantization to an ONNX file using the compiler friendly heuristics. @@ -257,6 +258,8 @@ def quantize( # With ActivationSymmetric as True, MinMax calibration is equivalent to max calibration else CalibrationMethod.MinMax ), + intermediate_generated_files=intermediate_generated_files, + kv_quant_mode=kv_quant_mode, ) intermediate_generated_files.append(tmp_onnx_path) diff --git a/modelopt/onnx/quantization/kv_cache.py b/modelopt/onnx/quantization/kv_cache.py index 4affa9754..af1e0613a 100644 --- a/modelopt/onnx/quantization/kv_cache.py +++ b/modelopt/onnx/quantization/kv_cache.py @@ -77,51 +77,49 @@ def kv_cache_quantize( logger.info(f"Using calibration data from {intermediate_file} for kv cache quantization") break - if calibration_method in ["awq_clip", "awq_lite", "rtn_dq"]: - logger.info(f"Using {calibration_method} calibration method for kv cache quantization") - # parse tensor_range - for node in group_query_attention_nodes: - # calculate k_scale based on input and output range - k_max = 0 - v_max = 0 - for output in node.output: - if "key" in output: - index = kv_tensor_names_list.index(output) - for data_range in tensor_range: - k_max = max(k_max, np.abs(np.asarray(data_range[index]).max())) - if "value" in output: - index = kv_tensor_names_list.index(output) - for data_range in tensor_range: - v_max = max(v_max, np.abs(np.asarray(data_range[index]).max())) - if kv_quant_mode == "PER_TENSOR": - Qmax = 0 - if kv_cache_type == "fp8": - Qmax = 448 # max fp value for E4M3 - elif kv_cache_type == "int8": - Qmax = 127 # max int8 value - else: - raise ValueError(f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization") - # create onnx tensor data as fp16 and assign to k_scale and v_scale - k_scale = onnx.helper.make_tensor( - name=node.name + "_k_scale", - data_type=onnx.TensorProto.FLOAT16, - dims=[1], - vals=[k_max / Qmax] if k_max != 0 else [1.0], - ) - v_scale = onnx.helper.make_tensor( - name=node.name + "_v_scale", - data_type=onnx.TensorProto.FLOAT16, - dims=[1], - vals=[v_max / Qmax] if v_max != 0 else [1.0], - ) - onnx_model.graph.initializer.append(k_scale) - onnx_model.graph.initializer.append(v_scale) - # add scale to input, use empty string to pad the input to 12, - # insert k_scale at index 12 and v_scale at index 13 - while len(node.input) < 12: - node.input.append("") - node.input.append(k_scale.name) - node.input.append(v_scale.name) + # parse tensor_range + for node in group_query_attention_nodes: + # calculate k_scale based on input and output range + k_max = 0 + v_max = 0 + for output in node.output: + if "key" in output: + index = kv_tensor_names_list.index(output) + for data_range in tensor_range: + k_max = max(k_max, np.abs(np.asarray(data_range[index]).max())) + if "value" in output: + index = kv_tensor_names_list.index(output) + for data_range in tensor_range: + v_max = max(v_max, np.abs(np.asarray(data_range[index]).max())) + if kv_quant_mode == "PER_TENSOR": + Qmax = 0 + if kv_cache_type == "fp8": + Qmax = 448 # max fp value for E4M3 + elif kv_cache_type == "int8": + Qmax = 127 # max int8 value + else: + raise ValueError(f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization") + # create onnx tensor data as fp16 and assign to k_scale and v_scale + k_scale = onnx.helper.make_tensor( + name=node.name + "_k_scale", + data_type=onnx.TensorProto.FLOAT16, + dims=[1], + vals=[k_max / Qmax] if k_max != 0 else [1.0], + ) + v_scale = onnx.helper.make_tensor( + name=node.name + "_v_scale", + data_type=onnx.TensorProto.FLOAT16, + dims=[1], + vals=[v_max / Qmax] if v_max != 0 else [1.0], + ) + onnx_model.graph.initializer.append(k_scale) + onnx_model.graph.initializer.append(v_scale) + # add scale to input, use empty string to pad the input to 12, + # insert k_scale at index 12 and v_scale at index 13 + while len(node.input) < 12: + node.input.append("") + node.input.append(k_scale.name) + node.input.append(v_scale.name) # add attributes to GQA node for node in group_query_attention_nodes: @@ -136,13 +134,17 @@ def kv_cache_quantize( return onnx_model def save_kv_cache_calib_data( - onnx_model: onnx.ModelProto, + onnx_model: str | Path | onnx.ModelProto, session: ort.InferenceSession | None = None, inputs: list[dict] = [], intermediate_generated_files: list[str] = [], ): kv_tensor_data = [] kv_tensor_names_list = [] + + if not isinstance(onnx_model, onnx.ModelProto): + onnx_model = onnx.load(onnx_model) + for output in onnx_model.graph.output: if "present" in output.name: kv_tensor_names_list.append(output.name) diff --git a/modelopt/onnx/quantization/ort_patching.py b/modelopt/onnx/quantization/ort_patching.py index e177b5804..0d231a690 100755 --- a/modelopt/onnx/quantization/ort_patching.py +++ b/modelopt/onnx/quantization/ort_patching.py @@ -1559,6 +1559,8 @@ def _quantize_static( use_external_data_format=False, calibrate_method=CalibrationMethod.MinMax, extra_options=None, + intermediate_generated_files: list[str] = [], + kv_quant_mode: str = "NONE", ): """Modification: enables TRT custom ops in the calibrator via 'TrtExtraPluginLibraryPaths' in extra_options. @@ -1653,6 +1655,17 @@ def _quantize_static( raise TypeError( f"Unexpected type {type(tensors_range)} for tensors_range and calibrator={type(calibrator)}." ) + + if kv_quant_mode != "NONE": + from modelopt.onnx.quantization.kv_cache import save_kv_cache_calib_data + + save_kv_cache_calib_data( + Path(model_input), + session=calibrator.infer_session, + inputs=[inp_d for inp_d in calibration_data_reader], + intermediate_generated_files=intermediate_generated_files, + ) + del calibrator check_static_quant_arguments(quant_format, activation_type, weight_type) diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 66f36bfea..6a4623506 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -500,6 +500,7 @@ def quantize( calibrate_per_node=calibrate_per_node, custom_ops_to_quantize=list(custom_ops_to_quantize.keys()), direct_io_types=direct_io_types, + kv_quant_mode=kv_quant_mode, **kwargs, ) elif "int4" in quantize_mode: From 91154e689f45e812ba20c521fac7c601b6978da4 Mon Sep 17 00:00:00 2001 From: zhanghaoc <2272055687@qq.com> Date: Thu, 30 Oct 2025 16:25:30 -0700 Subject: [PATCH 4/5] Fix small bug for kv cache Signed-off-by: zhanghaoc <2272055687@qq.com> --- modelopt/onnx/quantization/quantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 6a4623506..87b7a8002 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -523,7 +523,7 @@ def quantize( raise RuntimeError(f"Invalid quantization mode choice: {quantize_mode}") if onnx_model: - if quantize_mode == "int4" and kv_quant_mode != "NONE" and calibration_method in ["awq_clip", "awq_lite", "rtn_dq"]: + if kv_quant_mode != "NONE": logger.info(f"Quantization mode for KV cache: {kv_quant_mode}, kv_cache_type: {kv_cache_type}") onnx_model = kv_cache_quantize( onnx_model, From a877d022827a2acd185841603339e79fd18d4359 Mon Sep 17 00:00:00 2001 From: zhanghaoc <2272055687@qq.com> Date: Thu, 30 Oct 2025 17:48:26 -0700 Subject: [PATCH 5/5] Style fix Signed-off-by: zhanghaoc <2272055687@qq.com> --- modelopt/onnx/quantization/__main__.py | 6 +- modelopt/onnx/quantization/int4.py | 8 +- modelopt/onnx/quantization/kv_cache.py | 125 +++++++++++++-------- modelopt/onnx/quantization/ort_patching.py | 4 +- modelopt/onnx/quantization/quantize.py | 10 +- 5 files changed, 93 insertions(+), 60 deletions(-) diff --git a/modelopt/onnx/quantization/__main__.py b/modelopt/onnx/quantization/__main__.py index d118af60c..1573eb57d 100644 --- a/modelopt/onnx/quantization/__main__.py +++ b/modelopt/onnx/quantization/__main__.py @@ -41,7 +41,7 @@ def get_parser() -> argparse.ArgumentParser: argparser.add_argument( "--calibration_method", type=str, - choices=["max", "entropy", "awq_clip", "rtn_dq", "awq_full", "awq_lite"], + choices=["max", "entropy", "awq_clip", "rtn_dq", "awq_full", "awq_lite", "rtn"], help=( "Calibration method choices for int8/fp8: {entropy (default), max}, " "int4: {awq_clip (default), rtn_dq}." @@ -269,9 +269,7 @@ def get_parser() -> argparse.ArgumentParser: type=str, choices=["fp8", "int8"], default="NONE", - help=( - "Quantization type for kv cache in GQA. fp8 is default." - ), + help=("Quantization type for kv cache in GQA. fp8 is default."), ) return argparser diff --git a/modelopt/onnx/quantization/int4.py b/modelopt/onnx/quantization/int4.py index 6e9a3a558..a7b07f532 100644 --- a/modelopt/onnx/quantization/int4.py +++ b/modelopt/onnx/quantization/int4.py @@ -45,6 +45,10 @@ get_tensor_producer_nodes, ) from modelopt.onnx.quantization.gs_patching import patch_gs_modules +from modelopt.onnx.quantization.kv_cache import ( + save_kv_cache_calib_data, + save_kv_cache_calib_data_rtn, +) from modelopt.onnx.quantization.ort_utils import create_inference_session from modelopt.onnx.quantization.quant_utils import ( _pad, @@ -57,10 +61,6 @@ update_block_size, ) from modelopt.onnx.utils import save_onnx -from modelopt.onnx.quantization.kv_cache import ( - save_kv_cache_calib_data, - save_kv_cache_calib_data_rtn, -) __all__ = ["quantize"] diff --git a/modelopt/onnx/quantization/kv_cache.py b/modelopt/onnx/quantization/kv_cache.py index af1e0613a..6d7219d6b 100644 --- a/modelopt/onnx/quantization/kv_cache.py +++ b/modelopt/onnx/quantization/kv_cache.py @@ -1,38 +1,60 @@ -import onnx +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Performs kv cache quantization, and returns the ONNX ModelProto.""" + +import copy +import os import pickle -import numpy as np -import onnxruntime as ort import tempfile -import os -import copy from collections.abc import Sequence -from tqdm import tqdm from pathlib import Path -from modelopt.onnx.logging_config import logger + +import numpy as np +import onnx +import onnxruntime as ort from onnxruntime.quantization.calibrate import CalibrationDataReader +from tqdm import tqdm + +from modelopt.onnx.logging_config import logger from modelopt.onnx.quantization.ort_utils import create_inference_session from modelopt.onnx.utils import save_onnx # using fp8 as default quantization mode def kv_cache_quantize( - onnx_model: onnx.ModelProto, - kv_cache_type: str = "fp8", # only support fp8 and int8 for now, will support fp16, uint8 later - kv_quant_mode: str = "NONE", # NONE / PER_TENSOR / PER_CHANNEL - kv_cache_bit_width: int = 0, # only used for uint8, available options: 2, 4, 8 + onnx_model: onnx.ModelProto, + kv_cache_type: str = "fp8", # only support fp8 and int8 for now, will support fp16, uint8 later + kv_quant_mode: str = "NONE", # NONE / PER_TENSOR / PER_CHANNEL + kv_cache_bit_width: int = 0, # only used for uint8, available options: 2, 4, 8 intermediate_generated_files: list[str] = [], - calibration_method: str | None = None, # "awq_clip", "awq_lite", "rtn_dq" + calibration_method: str | None = None, # "awq_clip", "awq_lite", "rtn_dq" ) -> onnx.ModelProto: + """Perform kv cache quantization on the given ONNX model.""" if kv_cache_type == "NONE": kv_cache_type = "fp8" - - logger.info(f"Start kv cache quantization with kv_cache_type {kv_cache_type}, " - f"kv_quant_mode {kv_quant_mode}, calibration_method {calibration_method}") - + + logger.info( + f"Start kv cache quantization with kv_cache_type {kv_cache_type}, " + f"kv_quant_mode {kv_quant_mode}, calibration_method {calibration_method}" + ) + logger.info(f"intermediate_generated_files: {intermediate_generated_files}") - + kv_tensor_names_list = [] - + # replace each tensor starting with past_key_values for input in onnx_model.graph.input: if "past_key_values" in input.name: @@ -41,19 +63,23 @@ def kv_cache_quantize( elif kv_cache_type == "int8": input.type.tensor_type.elem_type = onnx.TensorProto.INT8 else: - raise ValueError(f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization") - + raise ValueError( + f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization" + ) + # Update graph output similarly, at the sametime add all output names, # it will be used to collect calibration data later for output in onnx_model.graph.output: if "present" in output.name: - kv_tensor_names_list.append(output.name) + kv_tensor_names_list.append(output.name) if kv_cache_type == "fp8": output.type.tensor_type.elem_type = onnx.TensorProto.FLOAT8E4M3FN elif kv_cache_type == "int8": output.type.tensor_type.elem_type = onnx.TensorProto.INT8 else: - raise ValueError(f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization") + raise ValueError( + f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization" + ) kv_tensor_names_list.sort() @@ -62,11 +88,11 @@ def kv_cache_quantize( node for node in onnx_model.graph.node if node.op_type == "GroupQueryAttention" ] - tensor_range = None + tensor_range = [] # both scale is a TensorData, scale's shape depends on kv_quant_mode k_scale = None v_scale = None - + # look for file named tmp_calib_data.json in intermediate_generated_files for intermediate_file in intermediate_generated_files: # if a file end with .calib_data, use it as calibration data @@ -74,17 +100,19 @@ def kv_cache_quantize( # load calibration data from file with open(intermediate_file, "rb") as f: tensor_range = pickle.load(f) - logger.info(f"Using calibration data from {intermediate_file} for kv cache quantization") + logger.info( + f"Using calibration data from {intermediate_file} for kv cache quantization" + ) break - + # parse tensor_range for node in group_query_attention_nodes: # calculate k_scale based on input and output range k_max = 0 v_max = 0 for output in node.output: - if "key" in output: - index = kv_tensor_names_list.index(output) + if "key" in output: + index = kv_tensor_names_list.index(output) for data_range in tensor_range: k_max = max(k_max, np.abs(np.asarray(data_range[index]).max())) if "value" in output: @@ -92,25 +120,27 @@ def kv_cache_quantize( for data_range in tensor_range: v_max = max(v_max, np.abs(np.asarray(data_range[index]).max())) if kv_quant_mode == "PER_TENSOR": - Qmax = 0 + tmax = 0 if kv_cache_type == "fp8": - Qmax = 448 # max fp value for E4M3 + tmax = 448 # max fp value for E4M3 elif kv_cache_type == "int8": - Qmax = 127 # max int8 value + tmax = 127 # max int8 value else: - raise ValueError(f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization") + raise ValueError( + f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization" + ) # create onnx tensor data as fp16 and assign to k_scale and v_scale k_scale = onnx.helper.make_tensor( name=node.name + "_k_scale", data_type=onnx.TensorProto.FLOAT16, dims=[1], - vals=[k_max / Qmax] if k_max != 0 else [1.0], + vals=[k_max / tmax] if k_max != 0 else [1.0], ) v_scale = onnx.helper.make_tensor( name=node.name + "_v_scale", data_type=onnx.TensorProto.FLOAT16, dims=[1], - vals=[v_max / Qmax] if v_max != 0 else [1.0], + vals=[v_max / tmax] if v_max != 0 else [1.0], ) onnx_model.graph.initializer.append(k_scale) onnx_model.graph.initializer.append(v_scale) @@ -120,35 +150,38 @@ def kv_cache_quantize( node.input.append("") node.input.append(k_scale.name) node.input.append(v_scale.name) - + # add attributes to GQA node - for node in group_query_attention_nodes: + for node in group_query_attention_nodes: # add attribute for quantization type node.attribute.append(onnx.helper.make_attribute("k_quant_type", kv_quant_mode)) node.attribute.append(onnx.helper.make_attribute("v_quant_type", kv_quant_mode)) # set bit width attribute, only used for uint8, not supported currently if kv_cache_type == "uint8": - node.attribute.append(onnx.helper.make_attribute("kv_cache_bit_width", kv_cache_bit_width)) - logger.info(f"kv cache quantization done") + node.attribute.append( + onnx.helper.make_attribute("kv_cache_bit_width", kv_cache_bit_width) + ) + logger.info("kv cache quantization done") return onnx_model - + + def save_kv_cache_calib_data( onnx_model: str | Path | onnx.ModelProto, session: ort.InferenceSession | None = None, inputs: list[dict] = [], intermediate_generated_files: list[str] = [], ): + """Save kv cache calibration data for quantization.""" kv_tensor_data = [] - kv_tensor_names_list = [] if not isinstance(onnx_model, onnx.ModelProto): onnx_model = onnx.load(onnx_model) - for output in onnx_model.graph.output: - if "present" in output.name: - kv_tensor_names_list.append(output.name) - + kv_tensor_names_list = [ + output.name for output in onnx_model.graph.output if "present" in output.name + ] + kv_tensor_names_list.sort() for i in tqdm(range(len(inputs)), desc="Caching activations..."): @@ -174,8 +207,9 @@ def save_kv_cache_calib_data_rtn( intermediate_generated_files: list[str] = [], use_external_data_format: bool = False, ): + """Save kv cache calibration data for RTN quantization. Create inference session internally.""" augmented_model = copy.deepcopy(onnx_model) - + # save model in augmented_onnx_path for creating inference session augmented_onnx_file, augmented_onnx_path = tempfile.mkstemp(suffix=".onnx") os.close(augmented_onnx_file) @@ -205,4 +239,3 @@ def save_kv_cache_calib_data_rtn( os.remove(augmented_onnx_path + "_data") except OSError: logger.warn("Augmented ONNX model or external data file was not found") - \ No newline at end of file diff --git a/modelopt/onnx/quantization/ort_patching.py b/modelopt/onnx/quantization/ort_patching.py index 0d231a690..2c29fec2e 100755 --- a/modelopt/onnx/quantization/ort_patching.py +++ b/modelopt/onnx/quantization/ort_patching.py @@ -1655,14 +1655,14 @@ def _quantize_static( raise TypeError( f"Unexpected type {type(tensors_range)} for tensors_range and calibrator={type(calibrator)}." ) - + if kv_quant_mode != "NONE": from modelopt.onnx.quantization.kv_cache import save_kv_cache_calib_data save_kv_cache_calib_data( Path(model_input), session=calibrator.infer_session, - inputs=[inp_d for inp_d in calibration_data_reader], + inputs=list(calibration_data_reader), intermediate_generated_files=intermediate_generated_files, ) diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 87b7a8002..a5bb16807 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -60,6 +60,7 @@ ) from modelopt.onnx.quantization.int4 import quantize as quantize_int4 from modelopt.onnx.quantization.int8 import quantize as quantize_int8 +from modelopt.onnx.quantization.kv_cache import kv_cache_quantize from modelopt.onnx.quantization.ort_utils import update_trt_ep_support from modelopt.onnx.quantization.qdq_utils import ( qdq_to_dq, @@ -68,7 +69,6 @@ ) from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model from modelopt.onnx.utils import duplicate_shared_constants, name_onnx_nodes, save_onnx -from modelopt.onnx.quantization.kv_cache import kv_cache_quantize __all__ = ["quantize"] @@ -239,7 +239,7 @@ def quantize( calibrate_per_node: bool = False, input_shapes_profile: Sequence[dict[str, str]] | None = None, direct_io_types: bool = False, - kv_quant_mode: str = "NONE", + kv_quant_mode: str = "NONE", kv_cache_type: str = "fp8", **kwargs: Any, ) -> None: @@ -516,7 +516,7 @@ def quantize( log_level=log_level, input_shapes_profile=input_shapes_profile, intermediate_generated_files=intermediate_generated_files, - kv_quant_mode=kv_quant_mode, + kv_quant_mode=kv_quant_mode, **kwargs, ) else: @@ -524,7 +524,9 @@ def quantize( if onnx_model: if kv_quant_mode != "NONE": - logger.info(f"Quantization mode for KV cache: {kv_quant_mode}, kv_cache_type: {kv_cache_type}") + logger.info( + f"Quantization mode for KV cache: {kv_quant_mode}, kv_cache_type: {kv_cache_type}" + ) onnx_model = kv_cache_quantize( onnx_model, kv_quant_mode=kv_quant_mode,