Skip to content

Qualcomm AI Engine Direct - Delegate mutable buffer and fix the mutable buffer issue #11782

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .annotate_quant_attrs import AnnotateQuantAttrs
from .annotate_stack import AnnotateStack
from .annotate_unbind import AnnotateUnbind
from .convert_bmm_to_matmul import ConvertBmmToMatmul
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
from .convert_square_to_pow import ConvertSquareToPow
from .decompose_any import DecomposeAny
Expand Down Expand Up @@ -35,7 +36,6 @@
from .remove_0d_tensor import Remove0DTensor
from .remove_redundancy import RemoveRedundancy
from .replace_arange_args import ReplaceArangeArgs
from .replace_index_put_input import ReplaceIndexPutInput
from .replace_inf_values import ReplaceInfValues
from .tag_quant_io import TagQuantIO

Expand All @@ -45,6 +45,7 @@
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
ConvertSquareToPow,
DecomposeAny,
Expand Down Expand Up @@ -72,7 +73,6 @@
Remove0DTensor,
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceIndexPutInput,
ReplaceInfValues,
TagQuantIO,
]
76 changes: 76 additions & 0 deletions backends/qualcomm/_passes/convert_bmm_to_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import operator
from collections import Counter
from typing import List

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


class ConvertBmmToMatmul(ExportPass):
"""
Replace bmm to matmul, because bmm is eqaul to matmul in QNN.
Handle missing quantization tag for bmm op.
"""

view_copy = exir_ops.edge.aten.view_copy.default
expand_copy = exir_ops.edge.aten.expand_copy.default
clone = exir_ops.edge.aten.clone.default
bmm = exir_ops.edge.aten.bmm.default
matmul = exir_ops.edge.aten.matmul.default
patterns = [
{expand_copy: 2, view_copy: 3, bmm: 1},
{expand_copy: 2, view_copy: 3, bmm: 1, clone: 1},
{bmm: 1},
]

def __init__(self):
super(ConvertBmmToMatmul, self).__init__()

def _get_ordered_inputs(
self, inputs: List[torch.fx.Node], output: torch.fx.Node
) -> List[torch.fx.Node]:
bmm_inputs = []
for arg in output.args:
while arg not in inputs:
arg = arg.args[0]
bmm_inputs.append(arg)
return bmm_inputs

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
partitions = get_source_partitions(
graph,
[operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default],
)
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
op_cnt = Counter([n.target for n in src_partition.nodes])
if op_cnt not in self.patterns:
raise AssertionError(
"Found a new pattern needed be converted to linear op"
)

inputs = src_partition.input_nodes
bmm_node = [n for n in src_partition.nodes if n.target == self.bmm][0]
output = src_partition.output_nodes[0]
# the order of src_partition.inputs is not guaranteed.
lhs, rhs = self._get_ordered_inputs(inputs, bmm_node)
with graph_module.graph.inserting_before(output):
# replace bmm to matmul, because bmm is eqaul to matmul in qnn.
matmul_node = graph.create_node(
"call_function", self.matmul, (lhs, rhs)
)
matmul_node.meta = output.meta
for user in output.users.copy():
user.replace_input_with(output, matmul_node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
10 changes: 8 additions & 2 deletions backends/qualcomm/_passes/insert_io_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from executorch.backends.qualcomm.builders.node_visitor import q_ops

from executorch.backends.qualcomm.builders.utils import is_parameter
from executorch.backends.qualcomm.builders.utils import (
is_mutable_buffer_input,
is_parameter,
)
from executorch.backends.qualcomm.utils.constants import (
QCOM_ENCODING,
QCOM_QUANT_ATTRS,
Expand Down Expand Up @@ -124,7 +127,10 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
if (
n.op == "placeholder"
and n.meta.get(QCOM_QUANT_ATTRS)
and not is_parameter(n, self.edge_program)
and (
not is_parameter(n, self.edge_program)
or is_mutable_buffer_input(n, self.edge_program)
)
):
self._insert_quant_node(
graph_module, n, n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]
Expand Down
13 changes: 10 additions & 3 deletions backends/qualcomm/_passes/qnn_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
ConvertSquareToPow,
DecomposeAny,
Expand Down Expand Up @@ -40,7 +41,6 @@
Remove0DTensor,
RemoveRedundancy,
ReplaceArangeArgs,
ReplaceIndexPutInput,
ReplaceInfValues,
TagQuantIO,
)
Expand Down Expand Up @@ -80,6 +80,7 @@ def get_capture_program_passes():
(AnnotateQuantAttrs, True),
(AnnotateStack, True),
(AnnotateUnbind, True),
(ConvertBmmToMatmul, False),
(ConvertConv1dToConv2d, True),
(DecomposeAny, True),
(DecomposeColIm, True),
Expand All @@ -92,7 +93,6 @@ def get_capture_program_passes():
(RecomposeRmsNorm, False),
(Remove0DTensor, True),
(RemoveRedundancy, True),
(ReplaceIndexPutInput, True),
(TagQuantIO, False),
]

Expand Down Expand Up @@ -224,4 +224,11 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
self.add_pass(LayoutTransform(exported_program, insert_permute=True))
self.add_pass(FuseConsecutiveCast())
self.add_pass(FuseConsecutiveTranspose())
return self._transform(exported_program.graph_module)
self._transform(exported_program.graph_module)
# Update inputs_to_buffers and buffers_to_mutate in graph signature for mutable buffer
# Since I/O will be inserted Q/DQ, it results in failed to mapping output node names and buffer
exported_program._graph_signature = _get_updated_graph_signature(
exported_program.graph_signature,
exported_program.graph_module,
)
return exported_program.graph_module
54 changes: 0 additions & 54 deletions backends/qualcomm/_passes/replace_index_put_input.py

This file was deleted.

7 changes: 4 additions & 3 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def get_passes_dependency_for_capture_program():
AnnotateQuantAttrs,
AnnotateStack,
AnnotateUnbind,
ConvertBmmToMatmul,
ConvertConv1dToConv2d,
DecomposeAny,
DecomposeColIm,
Expand All @@ -76,18 +77,19 @@ def get_passes_dependency_for_capture_program():
RecomposePixelUnshuffle,
RecomposeRmsNorm,
RemoveRedundancy,
ReplaceIndexPutInput,
TagQuantIO,
)

return {
AnnotateAdaptiveAvgPool1D: [RemoveRedundancy],
AnnotateQuantAttrs: [
ConvertBmmToMatmul,
RecomposePixelUnshuffle,
RemoveRedundancy,
],
AnnotateStack: [RemoveRedundancy],
AnnotateUnbind: [RemoveRedundancy],
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
DecomposeAny: [RemoveRedundancy],
DecomposeColIm: [FoldQDQ],
DecomposeLinalgVectorNorm: [RemoveRedundancy],
Expand All @@ -103,8 +105,7 @@ def get_passes_dependency_for_capture_program():
],
RecomposePixelUnshuffle: [RemoveRedundancy],
RecomposeRmsNorm: [RemoveRedundancy],
ReplaceIndexPutInput: [LayoutTransform],
TagQuantIO: [ReplaceIndexPutInput],
TagQuantIO: [LayoutTransform],
}


Expand Down
44 changes: 33 additions & 11 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
get_parameter,
is_graph_input,
is_graph_output,
is_mutable_buffer_input,
is_mutable_buffer_output,
is_parameter,
)

Expand Down Expand Up @@ -307,7 +309,9 @@ def get_tensor_type(
node: torch.fx.Node,
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
) -> PyQnnWrapper.Qnn_TensorType_t:
is_input = is_graph_input(node, self.edge_program)
is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input(
node, self.edge_program
)
is_output = is_graph_output(node)
# handle logic for input/output tensors
if is_input or is_output:
Expand Down Expand Up @@ -352,6 +356,33 @@ def get_dynamic_dimension(self, dims):

return dynamic_dims if any(dynamic_dims) else [], nominal_dims

def get_tensor_name(
self,
node: torch.fx.Node,
wrapper_idx: int = 0,
):
tensor_name = f"{node.name}_{wrapper_idx}"
# The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess,
# the input order between QNN and the original graph’s forward function may differ.
# The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime.
# The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump.
if is_mutable_buffer_input(node, self.edge_program):
fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target]
position_index = list(
self.edge_program.graph_signature.buffers_to_mutate.values()
).index(fqn)
tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}"
elif is_graph_input(node, self.edge_program):
tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}"
elif is_mutable_buffer_output(node, self.edge_program):
position_index = list(
self.edge_program.graph_signature.buffers_to_mutate.keys()
).index(node.name)
tensor_name = f"output_mutbuf_{position_index}_{tensor_name}"
elif is_graph_output(node):
tensor_name = f"output_{tensor_name}"
return tensor_name

def define_custom_tensor_wrapper(
self,
node_name: str,
Expand Down Expand Up @@ -413,16 +444,7 @@ def define_tensor(
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached

tensor_name = f"{tensor_source_node.name}_{wrapper_idx}"
if is_graph_input(tensor_source_node, self.edge_program):
tensor_name = (
"input_"
+ str(self.external_ids[tensor_source_node])
+ "_"
+ tensor_name
)
if is_graph_output(tensor_source_node):
tensor_name = "output_" + tensor_name
tensor_name = self.get_tensor_name(tensor_source_node, wrapper_idx)
dims = torch.Size([1]) if len(tensor.size()) == 0 else tensor.size()
dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims)
tensor_type = self.get_tensor_type(tensor_source_node, tensor_type)
Expand Down
6 changes: 4 additions & 2 deletions backends/qualcomm/builders/node_visitor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .node_visitor import NodeVisitor
from .op_custom_op import CustomOp
from .utils import is_graph_input, is_graph_output
from .utils import is_graph_input, is_graph_output, is_mutable_buffer_input


# This will hold mapping of all node names to the visitor class
Expand All @@ -39,7 +39,9 @@ def generate_node_to_external_map(
# The order in which we visit the placeholder node is same as the *args
# order for the forward(*args) signature for this gm. Using the order of
# the nodes as external_id to extract the right arg from *args at runtime
if is_graph_input(node, edge_program):
if is_graph_input(node, edge_program) or is_mutable_buffer_input(
node, edge_program
):
node_to_external_map[node] = len(node_to_external_map)
for node in edge_program.graph_module.graph.nodes:
if is_graph_output(node):
Expand Down
7 changes: 6 additions & 1 deletion backends/qualcomm/builders/op_index_put.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch

from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS

from .node_visitor import NodeVisitor
from .node_visitor_manager import register_node_visitor
from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW
Expand All @@ -22,6 +23,10 @@ def define_node(
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = self.get_node(node.args[0])
# Because the args[0] of index_put op doesn't annotate, need to fill in the quant_attr with the node here.
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
quant_attrs = quant_attrs.copy()
input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
Expand Down
Loading
Loading