Skip to content

Commit 6116377

Browse files
Incoperated all subfunction changes
1 parent 19a9b91 commit 6116377

File tree

6 files changed

+145
-13
lines changed

6 files changed

+145
-13
lines changed

QEfficient/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
# TODO: Find a better way to do this, this is temp. fix.
2828
apply_torch_patches()
2929

30+
3031
def check_qaic_sdk():
3132
"""Check if QAIC SDK is installed"""
3233
try:

QEfficient/base/modeling_qeff.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def _export(
251251
CustomOpTransform.register_custom_op("CtxGatherFunc", CtxGatherFunc, CtxGather)
252252
decoder_layer_classes = get_decoder_layer_classes_for_export(self.model)
253253
export_kwargs = {} if export_kwargs is None else export_kwargs
254-
254+
255255
torch.onnx.export(
256256
self.model,
257257
(example_inputs,),
@@ -269,10 +269,11 @@ def _export(
269269

270270
_ = self._offload_model_weights(offload_pt_weights)
271271
model = onnx.load(tmp_onnx_path, load_external_data=False)
272-
model,transformed = rename_function_outputs(model)
273-
272+
model, transformed = rename_function_outputs(model)
273+
274274
transform_kwargs = {
275275
"onnx_base_dir": str(tmp_onnx_dir),
276+
"temp_onnx_path": tmp_onnx_path,
276277
"model_name": self.model_name,
277278
}
278279
if onnx_transform_kwargs is not None:

QEfficient/base/onnx_transforms.py

Lines changed: 129 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
#
66
# ----------------------------------------------------------------------------
77

8-
from typing import Optional, Tuple
8+
from typing import Any, Dict, List, Optional, Tuple
99

1010
import numpy as np
11+
import onnx
12+
import onnxslim
13+
import torch
1114
from onnx import ModelProto, external_data_helper, numpy_helper
1215

1316

@@ -100,9 +103,130 @@ def apply(
100103
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
101104
return model, transformed
102105

106+
107+
class OnnxSlimTransform(OnnxTransform):
108+
"""
109+
Applies onnx-slim transformations on the given ONNX graph.
110+
"""
111+
112+
@classmethod
113+
def apply(
114+
cls,
115+
model: ModelProto,
116+
*,
117+
onnx_base_dir: Optional[str] = None,
118+
**kwargs,
119+
) -> Tuple[ModelProto, bool]:
120+
"""
121+
:param enable_onnx_slim_transform: If True, applies onnx-slim transformations.
122+
:param temp_onnx_path: Path to save the slimmed ONNX model.
123+
"""
124+
transformed = False
125+
onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False)
126+
temp_onnx_path = kwargs.get("temp_onnx_path", None)
127+
if not temp_onnx_path:
128+
err_str = "temp_onnx_path is required for onnx-slim transform."
129+
raise RuntimeError(err_str)
130+
if onnx_slim_transform:
131+
transformed = True
132+
slimmed_model = onnxslim.slim(model)
133+
onnx.save(slimmed_model, temp_onnx_path)
134+
return slimmed_model, transformed
135+
return model, transformed
136+
137+
138+
class CustomOpTransform(OnnxTransform):
139+
"""
140+
Transform to register custom operations and add their function protos to the ONNX model.
141+
"""
142+
143+
# Registry of custom operations
144+
_custom_ops: Dict[str, Tuple[Any, Any]] = {} # op_name -> (func_class, onnxscript_func)
145+
146+
@classmethod
147+
def register_custom_op(cls, op_name: str, func_class: Any, onnxscript_func: Any):
148+
"""Register a custom operation."""
149+
cls._custom_ops[op_name] = (func_class, onnxscript_func)
150+
151+
@classmethod
152+
def apply(cls, model: ModelProto, *, opset_version: int = 17, **kwargs) -> Tuple[ModelProto, bool]:
153+
"""
154+
Apply custom op registration and add function protos to the model.
155+
156+
:param model: The ONNX model to transform
157+
:param opset_version: ONNX opset version for symbolic registration
158+
:returns: Transformed model and success flag
159+
"""
160+
transformed = False
161+
162+
# Register all custom op symbolic functions with torch.onnx
163+
for op_name, (func_class, _) in cls._custom_ops.items():
164+
if hasattr(func_class, "symbolic"):
165+
torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, opset_version)
166+
167+
# Add function protos for custom ops that are used in the model
168+
used_protos = cls._get_function_protos_for_model(model)
169+
170+
for proto in used_protos:
171+
# Check if proto already exists to avoid duplicates
172+
proto_name = proto.name
173+
if not any(func.name == proto_name for func in model.functions):
174+
model.functions.append(proto)
175+
transformed = True
176+
177+
return model, transformed
178+
179+
@classmethod
180+
def _get_function_protos_for_model(cls, model: ModelProto) -> List[Any]:
181+
"""Get function protos for custom ops that are actually used in the model."""
182+
used_protos = []
183+
184+
# Get all node op_types in the model
185+
used_op_types = set()
186+
for node in model.graph.node:
187+
used_op_types.add(node.op_type)
188+
189+
# Also check function calls
190+
for func in model.functions:
191+
for node in func.node:
192+
used_op_types.add(node.op_type)
193+
194+
# Check which custom ops are actually used
195+
for op_name, (func_class, onnxscript_func) in cls._custom_ops.items():
196+
# Check if the custom op is referenced in the model
197+
if cls._is_custom_op_used(model, op_name, used_op_types):
198+
proto = onnxscript_func.to_function_proto()
199+
used_protos.append(proto)
200+
201+
return used_protos
202+
203+
@classmethod
204+
def _is_custom_op_used(cls, model: ModelProto, op_name: str, used_op_types: set) -> bool:
205+
"""Check if a custom op is used in the model."""
206+
# Check if the op_name appears in node op_types
207+
if op_name in used_op_types:
208+
return True
209+
210+
# Check for domain-specific ops (e.g., "com.qti.aisw.onnx::CustomRMSNorm")
211+
custom_op_pattern = f"com.qti.aisw.onnx::{op_name.replace('Func', '')}"
212+
if custom_op_pattern in used_op_types:
213+
return True
214+
215+
# Heuristic checks based on op type
216+
if "RMSNorm" in op_name:
217+
# Check if any RMSNorm-related ops are present
218+
return any("RMSNorm" in op_type for op_type in used_op_types)
219+
220+
if "Ctx" in op_name:
221+
# Check if Gather/Scatter operations are present (indicating KV cache usage)
222+
return any(op_type in ["Gather", "GatherND", "Scatter", "ScatterND"] for op_type in used_op_types)
223+
224+
return False
225+
226+
103227
def rename_function_outputs(model):
104228
graph = model.graph
105-
op_type_to_func_map = {func.name:func for func in model.functions}
229+
op_type_to_func_map = {func.name: func for func in model.functions}
106230
decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"]
107231
transformed = False
108232
model_graph_outputs = [val.name for val in model.graph.output]
@@ -117,11 +241,11 @@ def rename_function_outputs(model):
117241
if "key" in func.output[i]:
118242
new_name = f"past_key.{node_count}_RetainedState"
119243
elif "value" in func.output[i]:
120-
new_name= f"past_value.{node_count}_RetainedState"
244+
new_name = f"past_value.{node_count}_RetainedState"
121245
else:
122246
raise NotImplementedError()
123247
print(f"renaming {node.output[i]} to {new_name}")
124248
node.output[i] = new_name
125249
model.graph.output[model_graph_outputs.index(tmp)].name = new_name
126-
node_count+=1
127-
return model, transformed
250+
node_count += 1
251+
return model, transformed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626

2727
import QEfficient
2828
from QEfficient.base.modeling_qeff import QEFFBaseModel
29-
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
29+
from QEfficient.base.onnx_transforms import (
30+
CustomOpTransform,
31+
FP16ClipTransform,
32+
OnnxSlimTransform,
33+
SplitTensorsTransform,
34+
)
3035
from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform
3136
from QEfficient.generation.cloud_infer import QAICInferenceSession
3237
from QEfficient.generation.text_generation_inference import (
@@ -347,7 +352,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
347352
dynamic_axes,
348353
export_dir=export_dir,
349354
)
350-
355+
351356
def compile(
352357
self,
353358
onnx_path: Optional[str] = None,
@@ -2037,7 +2042,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
20372042
SplitGateUpWeightsTransform,
20382043
KVCacheExternalModuleMapperTransform,
20392044
]
2040-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
2045+
_onnx_transforms = [FP16ClipTransform, CustomOpTransform, OnnxSlimTransform, SplitTensorsTransform]
20412046

20422047
def __init__(
20432048
self,

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Modu
789789
warnings.warn("Pooling is applied to the model.")
790790
return model, transformed
791791

792+
792793
def get_decoder_layer_classes_for_export(model: nn.Module) -> set:
793794
"""
794795
Dynamically determine which DecoderLayer classes should be exported as functions
@@ -812,4 +813,4 @@ def get_decoder_layer_classes_for_export(model: nn.Module) -> set:
812813
if module.__class__ in decoder_layer_classes:
813814
model_decoder_classes.add(module.__class__)
814815

815-
return model_decoder_classes
816+
return model_decoder_classes

QEfficient/utils/patches.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,4 @@ def apply_torch_patches():
117117

118118
def is_patched():
119119
"""Check if patches have been applied."""
120-
return onnx_utils._setup_trace_module_map == _setup_trace_module_map_patched
120+
return onnx_utils._setup_trace_module_map == _setup_trace_module_map_patched

0 commit comments

Comments
 (0)