diff --git a/modelopt/onnx/quantization/__main__.py b/modelopt/onnx/quantization/__main__.py index 55cca6ee5..1573eb57d 100644 --- a/modelopt/onnx/quantization/__main__.py +++ b/modelopt/onnx/quantization/__main__.py @@ -41,7 +41,7 @@ def get_parser() -> argparse.ArgumentParser: argparser.add_argument( "--calibration_method", type=str, - choices=["max", "entropy", "awq_clip", "rtn_dq"], + choices=["max", "entropy", "awq_clip", "rtn_dq", "awq_full", "awq_lite", "rtn"], help=( "Calibration method choices for int8/fp8: {entropy (default), max}, " "int4: {awq_clip (default), rtn_dq}." @@ -255,6 +255,22 @@ def get_parser() -> argparse.ArgumentParser: "The currently supported precisions are {fp16, int8, fp8}." ), ) + argparser.add_argument( + "--kv_quant_mode", + type=str, + choices=["NONE", "PER_TENSOR", "PER_CHANNEL"], + default="NONE", + help=( + "Quantization mode for kv cache in GQA. NONE (default) means no quantization for kv cache, " + ), + ) + argparser.add_argument( + "--kv_cache_type", + type=str, + choices=["fp8", "int8"], + default="NONE", + help=("Quantization type for kv cache in GQA. fp8 is default."), + ) return argparser @@ -298,6 +314,8 @@ def main(): simplify=args.simplify, calibrate_per_node=args.calibrate_per_node, direct_io_types=args.direct_io_types, + kv_quant_mode=args.kv_quant_mode, + kv_cache_type=args.kv_cache_type, ) diff --git a/modelopt/onnx/quantization/fp8.py b/modelopt/onnx/quantization/fp8.py index ce7d56a26..8db397ce5 100755 --- a/modelopt/onnx/quantization/fp8.py +++ b/modelopt/onnx/quantization/fp8.py @@ -181,6 +181,7 @@ def quantize( calibrate_per_node: bool = False, custom_ops_to_quantize: list[str] = [], direct_io_types: bool = False, + kv_quant_mode: str = "NONE", **kwargs, ) -> onnx.ModelProto: """Applies FP8 GEMM only quantization to an ONNX file. @@ -295,6 +296,8 @@ def quantize( # With ActivationSymmetric as True, MinMax calibration is equivalent to max calibration else CalibrationMethod.MinMax ), + intermediate_generated_files=intermediate_generated_files, + kv_quant_mode=kv_quant_mode, ) intermediate_generated_files.append(tmp_onnx_path) if use_external_data_format: diff --git a/modelopt/onnx/quantization/int4.py b/modelopt/onnx/quantization/int4.py index 1086b5a4d..a7b07f532 100644 --- a/modelopt/onnx/quantization/int4.py +++ b/modelopt/onnx/quantization/int4.py @@ -45,6 +45,10 @@ get_tensor_producer_nodes, ) from modelopt.onnx.quantization.gs_patching import patch_gs_modules +from modelopt.onnx.quantization.kv_cache import ( + save_kv_cache_calib_data, + save_kv_cache_calib_data_rtn, +) from modelopt.onnx.quantization.ort_utils import create_inference_session from modelopt.onnx.quantization.quant_utils import ( _pad, @@ -444,6 +448,8 @@ def _quantize_awq_clip( force_fp16: bool = False, nodes_to_exclude: list[str] = [], input_shapes_profile: Sequence[dict[str, str]] | None = None, + intermediate_generated_files: list[str] = [], + kv_quant_mode: str = "NONE", **kwargs: Any, ) -> onnx.ModelProto: """Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm.""" @@ -482,10 +488,19 @@ def _quantize_awq_clip( # Apply AWQ clip on selected weights t = time.time() alphas = {} + + if kv_quant_mode != "NONE": + save_kv_cache_calib_data( + onnx_model, + session=session, + inputs=inputs, + intermediate_generated_files=intermediate_generated_files, + ) + for i in tqdm(range(len(wa_pack)), desc="Running clip search..."): act_tensor, weight_tensor, do_transpose, gemm_io_type = wa_pack[i] - # First capture all the activation values after calibration data sweep + # First capture all the activation values after calibration data sweep output_dicts = {} for inp_d in inputs: np_inp_d = {name: numpy.asarray(tensor) for name, tensor in inp_d.items()} @@ -968,6 +983,8 @@ def _quantize_awq_lite( use_zero_point: bool = False, nodes_to_exclude: list[str] = [], input_shapes_profile: Sequence[dict[str, str]] | None = None, + intermediate_generated_files: list[str] = [], + kv_quant_mode: str = "NONE", **kwargs: Any, ) -> onnx.ModelProto: """Quantizes `onnx_model` using the Activation aware quantization a.k.a AWQ algorithm.""" @@ -1025,6 +1042,14 @@ def _quantize_awq_lite( gc.collect() + if kv_quant_mode != "NONE": + save_kv_cache_calib_data( + onnx_model, + session=session, + inputs=inputs, + intermediate_generated_files=intermediate_generated_files, + ) + output_data = [] if enable_fast_path_using_high_sysram: @@ -1328,6 +1353,8 @@ def quantize( nodes_to_exclude: list[str] | None = [r"/lm_head"], log_level: str = "INFO", input_shapes_profile: Sequence[dict[str, str]] | None = None, + intermediate_generated_files: list[str] = [], + kv_quant_mode: str = "NONE", **kwargs: Any, ) -> onnx.ModelProto: """Applies INT4 Weight-Only-Quantization (WoQ) to an ONNX model. @@ -1421,6 +1448,16 @@ def quantize( qdq.use_trt_qdq_ops() if calibration_method in ["rtn", "rtn_dq", "rtn_trt", "rtn_trt_dq"]: + # Save kv-cache calibration data if kv_quant_mode is not NONE + if kv_quant_mode != "NONE": + save_kv_cache_calib_data_rtn( + onnx_model, + intermediate_generated_files=intermediate_generated_files, + data_reader=calibration_data_reader, + calibration_eps=calibration_eps, + input_shapes_profile=input_shapes_profile, + use_external_data_format=use_external_data_format, + ) onnx_model = quantize_rtn( onnx_model, block_size, @@ -1445,6 +1482,8 @@ def quantize( use_zero_point=use_zero_point, enable_weight_clipping=do_weight_clipping, input_shapes_profile=input_shapes_profile, + kv_quant_mode=kv_quant_mode, + intermediate_generated_files=intermediate_generated_files, **kwargs, ) elif calibration_method in ["awq_clip", "awq_clip_trt"]: @@ -1456,6 +1495,8 @@ def quantize( block_size, nodes_to_exclude=nodes_to_exclude, input_shapes_profile=input_shapes_profile, + kv_quant_mode=kv_quant_mode, + intermediate_generated_files=intermediate_generated_files, **kwargs, ) else: diff --git a/modelopt/onnx/quantization/int8.py b/modelopt/onnx/quantization/int8.py index baf5a4383..6f4f9ef44 100755 --- a/modelopt/onnx/quantization/int8.py +++ b/modelopt/onnx/quantization/int8.py @@ -131,6 +131,7 @@ def quantize( calibrate_per_node: bool = False, custom_ops_to_quantize: list[str] = [], direct_io_types: bool = False, + kv_quant_mode: str = "NONE", **kwargs, ) -> onnx.ModelProto: """Applies INT8 quantization to an ONNX file using the compiler friendly heuristics. @@ -257,6 +258,8 @@ def quantize( # With ActivationSymmetric as True, MinMax calibration is equivalent to max calibration else CalibrationMethod.MinMax ), + intermediate_generated_files=intermediate_generated_files, + kv_quant_mode=kv_quant_mode, ) intermediate_generated_files.append(tmp_onnx_path) diff --git a/modelopt/onnx/quantization/kv_cache.py b/modelopt/onnx/quantization/kv_cache.py new file mode 100644 index 000000000..6d7219d6b --- /dev/null +++ b/modelopt/onnx/quantization/kv_cache.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Performs kv cache quantization, and returns the ONNX ModelProto.""" + +import copy +import os +import pickle +import tempfile +from collections.abc import Sequence +from pathlib import Path + +import numpy as np +import onnx +import onnxruntime as ort +from onnxruntime.quantization.calibrate import CalibrationDataReader +from tqdm import tqdm + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.ort_utils import create_inference_session +from modelopt.onnx.utils import save_onnx + + +# using fp8 as default quantization mode +def kv_cache_quantize( + onnx_model: onnx.ModelProto, + kv_cache_type: str = "fp8", # only support fp8 and int8 for now, will support fp16, uint8 later + kv_quant_mode: str = "NONE", # NONE / PER_TENSOR / PER_CHANNEL + kv_cache_bit_width: int = 0, # only used for uint8, available options: 2, 4, 8 + intermediate_generated_files: list[str] = [], + calibration_method: str | None = None, # "awq_clip", "awq_lite", "rtn_dq" +) -> onnx.ModelProto: + """Perform kv cache quantization on the given ONNX model.""" + if kv_cache_type == "NONE": + kv_cache_type = "fp8" + + logger.info( + f"Start kv cache quantization with kv_cache_type {kv_cache_type}, " + f"kv_quant_mode {kv_quant_mode}, calibration_method {calibration_method}" + ) + + logger.info(f"intermediate_generated_files: {intermediate_generated_files}") + + kv_tensor_names_list = [] + + # replace each tensor starting with past_key_values + for input in onnx_model.graph.input: + if "past_key_values" in input.name: + if kv_cache_type == "fp8": + input.type.tensor_type.elem_type = onnx.TensorProto.FLOAT8E4M3FN + elif kv_cache_type == "int8": + input.type.tensor_type.elem_type = onnx.TensorProto.INT8 + else: + raise ValueError( + f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization" + ) + + # Update graph output similarly, at the sametime add all output names, + # it will be used to collect calibration data later + for output in onnx_model.graph.output: + if "present" in output.name: + kv_tensor_names_list.append(output.name) + if kv_cache_type == "fp8": + output.type.tensor_type.elem_type = onnx.TensorProto.FLOAT8E4M3FN + elif kv_cache_type == "int8": + output.type.tensor_type.elem_type = onnx.TensorProto.INT8 + else: + raise ValueError( + f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization" + ) + + kv_tensor_names_list.sort() + + # loop through all nodes and find GQA node + group_query_attention_nodes = [ + node for node in onnx_model.graph.node if node.op_type == "GroupQueryAttention" + ] + + tensor_range = [] + # both scale is a TensorData, scale's shape depends on kv_quant_mode + k_scale = None + v_scale = None + + # look for file named tmp_calib_data.json in intermediate_generated_files + for intermediate_file in intermediate_generated_files: + # if a file end with .calib_data, use it as calibration data + if intermediate_file.endswith("tmp_calib_data.json"): + # load calibration data from file + with open(intermediate_file, "rb") as f: + tensor_range = pickle.load(f) + logger.info( + f"Using calibration data from {intermediate_file} for kv cache quantization" + ) + break + + # parse tensor_range + for node in group_query_attention_nodes: + # calculate k_scale based on input and output range + k_max = 0 + v_max = 0 + for output in node.output: + if "key" in output: + index = kv_tensor_names_list.index(output) + for data_range in tensor_range: + k_max = max(k_max, np.abs(np.asarray(data_range[index]).max())) + if "value" in output: + index = kv_tensor_names_list.index(output) + for data_range in tensor_range: + v_max = max(v_max, np.abs(np.asarray(data_range[index]).max())) + if kv_quant_mode == "PER_TENSOR": + tmax = 0 + if kv_cache_type == "fp8": + tmax = 448 # max fp value for E4M3 + elif kv_cache_type == "int8": + tmax = 127 # max int8 value + else: + raise ValueError( + f"Unsupported kv_cache_type {kv_cache_type} for kv cache quantization" + ) + # create onnx tensor data as fp16 and assign to k_scale and v_scale + k_scale = onnx.helper.make_tensor( + name=node.name + "_k_scale", + data_type=onnx.TensorProto.FLOAT16, + dims=[1], + vals=[k_max / tmax] if k_max != 0 else [1.0], + ) + v_scale = onnx.helper.make_tensor( + name=node.name + "_v_scale", + data_type=onnx.TensorProto.FLOAT16, + dims=[1], + vals=[v_max / tmax] if v_max != 0 else [1.0], + ) + onnx_model.graph.initializer.append(k_scale) + onnx_model.graph.initializer.append(v_scale) + # add scale to input, use empty string to pad the input to 12, + # insert k_scale at index 12 and v_scale at index 13 + while len(node.input) < 12: + node.input.append("") + node.input.append(k_scale.name) + node.input.append(v_scale.name) + + # add attributes to GQA node + for node in group_query_attention_nodes: + # add attribute for quantization type + node.attribute.append(onnx.helper.make_attribute("k_quant_type", kv_quant_mode)) + node.attribute.append(onnx.helper.make_attribute("v_quant_type", kv_quant_mode)) + # set bit width attribute, only used for uint8, not supported currently + if kv_cache_type == "uint8": + node.attribute.append( + onnx.helper.make_attribute("kv_cache_bit_width", kv_cache_bit_width) + ) + logger.info("kv cache quantization done") + + return onnx_model + + +def save_kv_cache_calib_data( + onnx_model: str | Path | onnx.ModelProto, + session: ort.InferenceSession | None = None, + inputs: list[dict] = [], + intermediate_generated_files: list[str] = [], +): + """Save kv cache calibration data for quantization.""" + kv_tensor_data = [] + + if not isinstance(onnx_model, onnx.ModelProto): + onnx_model = onnx.load(onnx_model) + + kv_tensor_names_list = [ + output.name for output in onnx_model.graph.output if "present" in output.name + ] + + kv_tensor_names_list.sort() + + for i in tqdm(range(len(inputs)), desc="Caching activations..."): + inp_d = inputs[i] + np_inp_d = {name: np.asarray(tensor) for name, tensor in inp_d.items()} + output = session.run(kv_tensor_names_list, np_inp_d) + kv_tensor_data.append(output) + + # save to tmp file named tmp_calib_data.json + tmp_dir = tempfile.mkdtemp() + calib_data_path = Path(tmp_dir).joinpath("tmp_calib_data.json").as_posix() + # call to_dict and save to json + with open(calib_data_path, "wb") as f: + pickle.dump(kv_tensor_data, f) + intermediate_generated_files.append(calib_data_path) + + +def save_kv_cache_calib_data_rtn( + onnx_model: onnx.ModelProto, + data_reader: CalibrationDataReader | None = None, + calibration_eps: list[str] = [], + input_shapes_profile: Sequence[dict[str, str]] | None = None, + intermediate_generated_files: list[str] = [], + use_external_data_format: bool = False, +): + """Save kv cache calibration data for RTN quantization. Create inference session internally.""" + augmented_model = copy.deepcopy(onnx_model) + + # save model in augmented_onnx_path for creating inference session + augmented_onnx_file, augmented_onnx_path = tempfile.mkstemp(suffix=".onnx") + os.close(augmented_onnx_file) + + save_onnx(augmented_model, augmented_onnx_path, use_external_data_format) + + # Creating inference session and preparing inputs for calibration + session = create_inference_session(augmented_onnx_path, calibration_eps, input_shapes_profile) + inputs = [] + for inp_d in data_reader: + inputs.append(inp_d) + assert isinstance(inp_d, dict) + save_kv_cache_calib_data( + onnx_model, + session=session, + inputs=inputs, + intermediate_generated_files=intermediate_generated_files, + ) + + logger.info("Saved kv-cache calibration data for RTN quantization") + + del session + + try: + os.remove(augmented_onnx_path) + if use_external_data_format: + os.remove(augmented_onnx_path + "_data") + except OSError: + logger.warn("Augmented ONNX model or external data file was not found") diff --git a/modelopt/onnx/quantization/ort_patching.py b/modelopt/onnx/quantization/ort_patching.py index e177b5804..2c29fec2e 100755 --- a/modelopt/onnx/quantization/ort_patching.py +++ b/modelopt/onnx/quantization/ort_patching.py @@ -1559,6 +1559,8 @@ def _quantize_static( use_external_data_format=False, calibrate_method=CalibrationMethod.MinMax, extra_options=None, + intermediate_generated_files: list[str] = [], + kv_quant_mode: str = "NONE", ): """Modification: enables TRT custom ops in the calibrator via 'TrtExtraPluginLibraryPaths' in extra_options. @@ -1653,6 +1655,17 @@ def _quantize_static( raise TypeError( f"Unexpected type {type(tensors_range)} for tensors_range and calibrator={type(calibrator)}." ) + + if kv_quant_mode != "NONE": + from modelopt.onnx.quantization.kv_cache import save_kv_cache_calib_data + + save_kv_cache_calib_data( + Path(model_input), + session=calibrator.infer_session, + inputs=list(calibration_data_reader), + intermediate_generated_files=intermediate_generated_files, + ) + del calibrator check_static_quant_arguments(quant_format, activation_type, weight_type) diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 2d23b875a..a5bb16807 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -60,6 +60,7 @@ ) from modelopt.onnx.quantization.int4 import quantize as quantize_int4 from modelopt.onnx.quantization.int8 import quantize as quantize_int8 +from modelopt.onnx.quantization.kv_cache import kv_cache_quantize from modelopt.onnx.quantization.ort_utils import update_trt_ep_support from modelopt.onnx.quantization.qdq_utils import ( qdq_to_dq, @@ -238,6 +239,8 @@ def quantize( calibrate_per_node: bool = False, input_shapes_profile: Sequence[dict[str, str]] | None = None, direct_io_types: bool = False, + kv_quant_mode: str = "NONE", + kv_cache_type: str = "fp8", **kwargs: Any, ) -> None: """Quantizes the provided ONNX model. @@ -497,6 +500,7 @@ def quantize( calibrate_per_node=calibrate_per_node, custom_ops_to_quantize=list(custom_ops_to_quantize.keys()), direct_io_types=direct_io_types, + kv_quant_mode=kv_quant_mode, **kwargs, ) elif "int4" in quantize_mode: @@ -511,12 +515,26 @@ def quantize( use_zero_point=use_zero_point, log_level=log_level, input_shapes_profile=input_shapes_profile, + intermediate_generated_files=intermediate_generated_files, + kv_quant_mode=kv_quant_mode, **kwargs, ) else: raise RuntimeError(f"Invalid quantization mode choice: {quantize_mode}") if onnx_model: + if kv_quant_mode != "NONE": + logger.info( + f"Quantization mode for KV cache: {kv_quant_mode}, kv_cache_type: {kv_cache_type}" + ) + onnx_model = kv_cache_quantize( + onnx_model, + kv_quant_mode=kv_quant_mode, + kv_cache_type=kv_cache_type, + intermediate_generated_files=intermediate_generated_files, + calibration_method=calibration_method, + ) + # Fuse Q nodes for INT8/FP8 mode if quantize_mode in ["int8", "fp8"]: if dq_only: