diff --git a/modelopt/onnx/autocast/__main__.py b/modelopt/onnx/autocast/__main__.py index 57d3d3133..ceb6a44c9 100644 --- a/modelopt/onnx/autocast/__main__.py +++ b/modelopt/onnx/autocast/__main__.py @@ -162,6 +162,19 @@ def get_parser() -> argparse.ArgumentParser: "libraries are in the PATH or LD_LIBRARY_PATH variables." ), ) + parser.add_argument( + "--trt_plugins_precision", + type=str, + default=[], + nargs="+", + help=( + "A space-separated list indicating the precision for each custom op. " + "Each item should have the format : (all inputs and outputs have the same precision) " + "or :[,,...]:[,,...] " + "(inputs and outputs can have different precisions), where precision can be fp32 (default) or fp16." + "For example: op_type_1:fp16 op_type_2:[fp16,fp32]:[fp16]." + ), + ) return parser @@ -192,6 +205,7 @@ def main(argv=None): init_conversion_max_bytes=args.init_conversion_max_bytes, providers=args.providers, trt_plugins=args.trt_plugins, + trt_plugins_precision=args.trt_plugins_precision, max_depth_of_reduction=args.max_depth_of_reduction, ) diff --git a/modelopt/onnx/autocast/convert.py b/modelopt/onnx/autocast/convert.py index baf0dc60e..e3db4cf0c 100644 --- a/modelopt/onnx/autocast/convert.py +++ b/modelopt/onnx/autocast/convert.py @@ -58,6 +58,7 @@ def convert_to_mixed_precision( init_conversion_max_bytes: int | None = None, providers: list[str] = ["cpu"], trt_plugins: list[str] = [], + trt_plugins_precision: list[str] = [], max_depth_of_reduction: int | None = None, ) -> onnx.ModelProto: """Convert model to mixed precision. @@ -78,6 +79,7 @@ def convert_to_mixed_precision( runtime. providers: List of ORT execution providers. trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library). + trt_plugins_precision: List indicating the precision for each custom op. max_depth_of_reduction: Maximum depth of reduction for node classification. Returns: @@ -92,7 +94,11 @@ def convert_to_mixed_precision( # Otherwise, prefer to keep the original opset version unless it's very old min_opset = 22 if low_precision_type == "bf16" else 13 graph_sanitizer = GraphSanitizer( - model, min_opset, trt_plugins=trt_plugins, max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT + model, + min_opset, + trt_plugins=trt_plugins, + trt_plugins_precision=trt_plugins_precision, + max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT, ) graph_sanitizer.sanitize() model = graph_sanitizer.model @@ -118,6 +124,7 @@ def convert_to_mixed_precision( init_max=init_max, custom_rule=custom_rule, max_depth_of_reduction=max_depth_of_reduction, + custom_ops_low_precision_nodes=graph_sanitizer.custom_ops_low_precision_nodes or [], ) precision_converter = PrecisionConverter( diff --git a/modelopt/onnx/autocast/graphsanitizer.py b/modelopt/onnx/autocast/graphsanitizer.py index 9a4ee041b..d27379760 100644 --- a/modelopt/onnx/autocast/graphsanitizer.py +++ b/modelopt/onnx/autocast/graphsanitizer.py @@ -23,6 +23,8 @@ import modelopt.onnx.autocast.utils as utils import modelopt.onnx.utils as onnx_utils from modelopt.onnx.autocast.logging_config import logger +from modelopt.onnx.quantization.graph_utils import cast_custom_ops +from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag class GraphSanitizer: @@ -34,6 +36,7 @@ def __init__( min_opset: int = 13, max_ir_version: int | None = None, trt_plugins: list[str] | None = [], + trt_plugins_precision: list[str] | None = [], ) -> None: """Initialize GraphSanitizer. @@ -48,7 +51,9 @@ def __init__( self.max_ir_version = max_ir_version self.standard_ops = {schema.name for schema in onnx.defs.get_all_schemas()} self.custom_ops = None + self.custom_ops_low_precision_nodes = [] self.trt_plugins = trt_plugins + self.trt_plugins_precision = trt_plugins_precision or [] def sanitize(self) -> None: """Sanitize the model graph. @@ -67,6 +72,7 @@ def sanitize(self) -> None: self.cleanup_model() self.set_ir_version(self.max_ir_version) self.convert_fp64_to_fp32() + self.ensure_custom_ops_precision() def convert_fp64_to_fp32(self) -> None: """Convert FP64 initializers, I/O types, and specific nodes to FP32.""" @@ -88,6 +94,19 @@ def convert_fp64_to_fp32(self) -> None: logger.info("Converted FP64 initializers, I/O types, and nodes to FP32") self.model = onnx_utils.infer_shapes(self.model, strict_mode=True) + def ensure_custom_ops_precision(self) -> None: + """Ensure that custom ops run in the requested precision.""" + custom_ops_to_cast, _ = interpret_trt_plugins_precision_flag( + self.model, + self.trt_plugins_precision, + ) + if custom_ops_to_cast.get("fp16", {}): + self.model = cast_custom_ops(self.model, custom_ops_to_cast["fp16"]) + self.custom_ops_low_precision_nodes = [ + n.name for n in self.model.graph.node if n.op_type in custom_ops_to_cast["fp16"] + ] + logger.info("Ensured custom ops precision") + def find_custom_nodes(self) -> None: """Find custom nodes in the model. diff --git a/modelopt/onnx/autocast/nodeclassifier.py b/modelopt/onnx/autocast/nodeclassifier.py index dda4b36ed..0a7638429 100644 --- a/modelopt/onnx/autocast/nodeclassifier.py +++ b/modelopt/onnx/autocast/nodeclassifier.py @@ -360,6 +360,7 @@ def __init__( data_max: float | None = 1000.0, init_max: float | None = np.finfo(np.float16).max, max_depth_of_reduction: int | None = None, + custom_ops_low_precision_nodes: list[str] | None = None, ): """Initialize the node classifier. @@ -375,6 +376,7 @@ def __init__( data_max: Maximum absolute value allowed for node I/O. init_max: Maximum absolute value allowed for initializers. max_depth_of_reduction: Maximum depth of reduction allowed in low precision. + custom_ops_low_precision_nodes: List of custom op node names to convert to low precision. """ self.model = model self.node_to_init_map = node_to_init_map @@ -387,6 +389,7 @@ def __init__( self.data_max = data_max self.init_max = init_max self.max_depth_of_reduction = max_depth_of_reduction + self.custom_ops_low_precision_nodes = custom_ops_low_precision_nodes def _gen_exclude_node_rules(self, reference_data): """Generate list of rules for blocking nodes from precision conversion. @@ -446,12 +449,14 @@ def run(self, ref_outputs_dict=None): """ exclude_node_rules = self._gen_exclude_node_rules(ref_outputs_dict) include_node_rules = self._gen_include_node_rules() - low_precision_nodes = [] + low_precision_nodes = self.custom_ops_low_precision_nodes or [] high_precision_nodes = [] for node in self.model.graph.node: # If any condition is met - node will be executed in high precision - if any(rule.check(node) for rule in exclude_node_rules) and not any( - rule.check(node) for rule in include_node_rules + if ( + node.name not in low_precision_nodes + and any(rule.check(node) for rule in exclude_node_rules) + and not any(rule.check(node) for rule in include_node_rules) ): high_precision_nodes.append(node.name) else: diff --git a/modelopt/onnx/trt_utils.py b/modelopt/onnx/trt_utils.py index fe01d672f..a27f15230 100644 --- a/modelopt/onnx/trt_utils.py +++ b/modelopt/onnx/trt_utils.py @@ -349,7 +349,7 @@ def load_onnx_model( def interpret_trt_plugins_precision_flag( onnx_model: onnx.ModelProto, trt_plugins_precision: list[str], - quantize_mode: str, + quantize_mode: str = "int8", ) -> tuple[dict, dict]: """Convert custom ops precision flag to dictionaries with custom op and I/O indices to be cast/quantized.