55#
66# ----------------------------------------------------------------------------
77
8- from typing import Optional , Tuple
8+ from typing import Any , Dict , List , Optional , Tuple
99
1010import numpy as np
11+ import onnx
12+ import onnxslim
13+ import torch
1114from onnx import ModelProto , external_data_helper , numpy_helper
1215
1316
@@ -100,9 +103,130 @@ def apply(
100103 external_data_helper .set_external_data (tensor , f"{ model_name } _{ file_num } .onnx.data" )
101104 return model , transformed
102105
106+
107+ class OnnxSlimTransform (OnnxTransform ):
108+ """
109+ Applies onnx-slim transformations on the given ONNX graph.
110+ """
111+
112+ @classmethod
113+ def apply (
114+ cls ,
115+ model : ModelProto ,
116+ * ,
117+ onnx_base_dir : Optional [str ] = None ,
118+ ** kwargs ,
119+ ) -> Tuple [ModelProto , bool ]:
120+ """
121+ :param enable_onnx_slim_transform: If True, applies onnx-slim transformations.
122+ :param temp_onnx_path: Path to save the slimmed ONNX model.
123+ """
124+ transformed = False
125+ onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False)
126+ temp_onnx_path = kwargs .get ("temp_onnx_path" , None )
127+ if not temp_onnx_path :
128+ err_str = "temp_onnx_path is required for onnx-slim transform."
129+ raise RuntimeError (err_str )
130+ if onnx_slim_transform :
131+ transformed = True
132+ slimmed_model = onnxslim .slim (model )
133+ onnx .save (slimmed_model , temp_onnx_path )
134+ return slimmed_model , transformed
135+ return model , transformed
136+
137+
138+ class CustomOpTransform (OnnxTransform ):
139+ """
140+ Transform to register custom operations and add their function protos to the ONNX model.
141+ """
142+
143+ # Registry of custom operations
144+ _custom_ops : Dict [str , Tuple [Any , Any ]] = {} # op_name -> (func_class, onnxscript_func)
145+
146+ @classmethod
147+ def register_custom_op (cls , op_name : str , func_class : Any , onnxscript_func : Any ):
148+ """Register a custom operation."""
149+ cls ._custom_ops [op_name ] = (func_class , onnxscript_func )
150+
151+ @classmethod
152+ def apply (cls , model : ModelProto , * , opset_version : int = 17 , ** kwargs ) -> Tuple [ModelProto , bool ]:
153+ """
154+ Apply custom op registration and add function protos to the model.
155+
156+ :param model: The ONNX model to transform
157+ :param opset_version: ONNX opset version for symbolic registration
158+ :returns: Transformed model and success flag
159+ """
160+ transformed = False
161+
162+ # Register all custom op symbolic functions with torch.onnx
163+ for op_name , (func_class , _ ) in cls ._custom_ops .items ():
164+ if hasattr (func_class , "symbolic" ):
165+ torch .onnx .register_custom_op_symbolic (f"::{ op_name } " , func_class .symbolic , opset_version )
166+
167+ # Add function protos for custom ops that are used in the model
168+ used_protos = cls ._get_function_protos_for_model (model )
169+
170+ for proto in used_protos :
171+ # Check if proto already exists to avoid duplicates
172+ proto_name = proto .name
173+ if not any (func .name == proto_name for func in model .functions ):
174+ model .functions .append (proto )
175+ transformed = True
176+
177+ return model , transformed
178+
179+ @classmethod
180+ def _get_function_protos_for_model (cls , model : ModelProto ) -> List [Any ]:
181+ """Get function protos for custom ops that are actually used in the model."""
182+ used_protos = []
183+
184+ # Get all node op_types in the model
185+ used_op_types = set ()
186+ for node in model .graph .node :
187+ used_op_types .add (node .op_type )
188+
189+ # Also check function calls
190+ for func in model .functions :
191+ for node in func .node :
192+ used_op_types .add (node .op_type )
193+
194+ # Check which custom ops are actually used
195+ for op_name , (func_class , onnxscript_func ) in cls ._custom_ops .items ():
196+ # Check if the custom op is referenced in the model
197+ if cls ._is_custom_op_used (model , op_name , used_op_types ):
198+ proto = onnxscript_func .to_function_proto ()
199+ used_protos .append (proto )
200+
201+ return used_protos
202+
203+ @classmethod
204+ def _is_custom_op_used (cls , model : ModelProto , op_name : str , used_op_types : set ) -> bool :
205+ """Check if a custom op is used in the model."""
206+ # Check if the op_name appears in node op_types
207+ if op_name in used_op_types :
208+ return True
209+
210+ # Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm")
211+ custom_op_pattern = f"com.qti.aisw.onnx::{ op_name .replace ('Func' , '' )} "
212+ if custom_op_pattern in used_op_types :
213+ return True
214+
215+ # Heuristic checks based on op type
216+ if "RMSNorm" in op_name :
217+ # Check if any RMSNorm-related ops are present
218+ return any ("RMSNorm" in op_type for op_type in used_op_types )
219+
220+ if "Ctx" in op_name :
221+ # Check if Gather/Scatter operations are present (indicating KV cache usage)
222+ return any (op_type in ["Gather" , "GatherND" , "Scatter" , "ScatterND" ] for op_type in used_op_types )
223+
224+ return False
225+
226+
103227def rename_function_outputs (model ):
104228 graph = model .graph
105- op_type_to_func_map = {func .name :func for func in model .functions }
229+ op_type_to_func_map = {func .name : func for func in model .functions }
106230 decoder_layer_patterns = ["DecoderLayer" , "Block" , "Layer" ]
107231 transformed = False
108232 model_graph_outputs = [val .name for val in model .graph .output ]
@@ -117,11 +241,11 @@ def rename_function_outputs(model):
117241 if "key" in func .output [i ]:
118242 new_name = f"past_key.{ node_count } _RetainedState"
119243 elif "value" in func .output [i ]:
120- new_name = f"past_value.{ node_count } _RetainedState"
244+ new_name = f"past_value.{ node_count } _RetainedState"
121245 else :
122246 raise NotImplementedError ()
123247 print (f"renaming { node .output [i ]} to { new_name } " )
124248 node .output [i ] = new_name
125249 model .graph .output [model_graph_outputs .index (tmp )].name = new_name
126- node_count += 1
127- return model , transformed
250+ node_count += 1
251+ return model , transformed
0 commit comments