From e922f14584fb50219db44e52f03ee29cf70f1380 Mon Sep 17 00:00:00 2001 From: shewu-quic <quic_shewu@quicinc.com> Date: Thu, 12 Jun 2025 15:06:04 +0800 Subject: [PATCH 1/3] Qualcomm AI Engine Direct - Delegate mutable buffer and fix the mutable buffer issue Summary: - Add a parameter to support mutable buffer delegation in QNN Backend - Set the same memory address for I/O of mutable buffer at runtime - Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process. - Deprecated use_legacy_export in executorch llama --- backends/qualcomm/_passes/__init__.py | 2 - backends/qualcomm/_passes/insert_io_qdq.py | 10 +- backends/qualcomm/_passes/qnn_pass_manager.py | 11 +- .../_passes/replace_index_put_input.py | 54 ------- backends/qualcomm/_passes/utils.py | 4 +- backends/qualcomm/builders/node_visitor.py | 44 ++++-- .../qualcomm/builders/node_visitor_manager.py | 6 +- backends/qualcomm/builders/op_index_put.py | 7 +- backends/qualcomm/builders/utils.py | 35 ++++- .../qualcomm/partition/qnn_partitioner.py | 17 ++- backends/qualcomm/quantizer/annotators.py | 24 ++- .../qualcomm/runtime/QnnExecuTorchBackend.cpp | 38 ++--- backends/qualcomm/runtime/QnnManager.cpp | 33 +++++ backends/qualcomm/tests/models.py | 22 ++- backends/qualcomm/tests/test_qnn_delegate.py | 139 ++++++++++++++++-- backends/qualcomm/tests/utils.py | 2 + backends/qualcomm/utils/utils.py | 4 + examples/models/llama/export_llama_lib.py | 1 - .../llama/source_transformation/attention.py | 2 - .../llama/source_transformation/sdpa.py | 2 - extension/llm/export/builder.py | 76 +++------- extension/llm/export/partitioner_lib.py | 1 + 22 files changed, 357 insertions(+), 177 deletions(-) delete mode 100644 backends/qualcomm/_passes/replace_index_put_input.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 681ea2ee534..2b743454bf9 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -35,7 +35,6 @@ from .remove_0d_tensor import Remove0DTensor from .remove_redundancy import RemoveRedundancy from .replace_arange_args import ReplaceArangeArgs -from .replace_index_put_input import ReplaceIndexPutInput from .replace_inf_values import ReplaceInfValues from .tag_quant_io import TagQuantIO @@ -72,7 +71,6 @@ Remove0DTensor, RemoveRedundancy, ReplaceArangeArgs, - ReplaceIndexPutInput, ReplaceInfValues, TagQuantIO, ] diff --git a/backends/qualcomm/_passes/insert_io_qdq.py b/backends/qualcomm/_passes/insert_io_qdq.py index e5b15f2d12c..caecae64fa8 100644 --- a/backends/qualcomm/_passes/insert_io_qdq.py +++ b/backends/qualcomm/_passes/insert_io_qdq.py @@ -9,7 +9,10 @@ from executorch.backends.qualcomm.builders.node_visitor import q_ops -from executorch.backends.qualcomm.builders.utils import is_parameter +from executorch.backends.qualcomm.builders.utils import ( + is_mutable_buffer_input, + is_parameter, +) from executorch.backends.qualcomm.utils.constants import ( QCOM_ENCODING, QCOM_QUANT_ATTRS, @@ -124,7 +127,10 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: if ( n.op == "placeholder" and n.meta.get(QCOM_QUANT_ATTRS) - and not is_parameter(n, self.edge_program) + and ( + not is_parameter(n, self.edge_program) + or is_mutable_buffer_input(n, self.edge_program) + ) ): self._insert_quant_node( graph_module, n, n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 2b68f9b6c09..9d2726eaaed 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -40,7 +40,6 @@ Remove0DTensor, RemoveRedundancy, ReplaceArangeArgs, - ReplaceIndexPutInput, ReplaceInfValues, TagQuantIO, ) @@ -92,7 +91,6 @@ def get_capture_program_passes(): (RecomposeRmsNorm, False), (Remove0DTensor, True), (RemoveRedundancy, True), - (ReplaceIndexPutInput, True), (TagQuantIO, False), ] @@ -224,4 +222,11 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram): self.add_pass(LayoutTransform(exported_program, insert_permute=True)) self.add_pass(FuseConsecutiveCast()) self.add_pass(FuseConsecutiveTranspose()) - return self._transform(exported_program.graph_module) + self._transform(exported_program.graph_module) + # Update inputs_to_buffers and buffers_to_mutate in graph signature for mutable buffer + # Since I/O will be inserted Q/DQ, it results in failed to mapping output node names and buffer + exported_program._graph_signature = _get_updated_graph_signature( + exported_program.graph_signature, + exported_program.graph_module, + ) + return exported_program.graph_module diff --git a/backends/qualcomm/_passes/replace_index_put_input.py b/backends/qualcomm/_passes/replace_index_put_input.py deleted file mode 100644 index 93ee21bfc7c..00000000000 --- a/backends/qualcomm/_passes/replace_index_put_input.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import torch -from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING, QCOM_QUANT_ATTRS -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - - -class ReplaceIndexPutInput(ExportPass): - """ - Index put input workaround for quantized module - """ - - dq_q_map = { - # per tensor - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, - # per channel - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - } - - def __init__(self): - super(ReplaceIndexPutInput, self).__init__() - - def call(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - for node in graph.nodes: - if node.target == exir_ops.edge.aten.index_put.default: - if ( - copy_node := list(node.users)[0] - ) and copy_node.target == exir_ops.edge.aten.copy.default: - m_buffer_node = copy_node.args[0] - dq_node = node.args[0] - bad_frozen_node = dq_node.args[0] - if QCOM_QUANT_ATTRS in bad_frozen_node.meta: - m_buffer_node.meta[QCOM_QUANT_ATTRS] = bad_frozen_node.meta[ - QCOM_QUANT_ATTRS - ] - m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] = ( - self.dq_q_map[ - m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] - ] - ) - with graph.inserting_after(dq_node): - node.replace_input_with(dq_node, m_buffer_node) - else: - continue - - graph.eliminate_dead_code() - graph_module.recompile() - return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index ef52d2c190a..b23efc8714f 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -76,7 +76,6 @@ def get_passes_dependency_for_capture_program(): RecomposePixelUnshuffle, RecomposeRmsNorm, RemoveRedundancy, - ReplaceIndexPutInput, TagQuantIO, ) @@ -103,8 +102,7 @@ def get_passes_dependency_for_capture_program(): ], RecomposePixelUnshuffle: [RemoveRedundancy], RecomposeRmsNorm: [RemoveRedundancy], - ReplaceIndexPutInput: [LayoutTransform], - TagQuantIO: [ReplaceIndexPutInput], + TagQuantIO: [LayoutTransform], } diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 37fe3615268..8d77a5f47aa 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -41,6 +41,8 @@ get_parameter, is_graph_input, is_graph_output, + is_mutable_buffer_input, + is_mutable_buffer_output, is_parameter, ) @@ -307,7 +309,9 @@ def get_tensor_type( node: torch.fx.Node, tensor_type: PyQnnWrapper.Qnn_TensorType_t, ) -> PyQnnWrapper.Qnn_TensorType_t: - is_input = is_graph_input(node, self.edge_program) + is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input( + node, self.edge_program + ) is_output = is_graph_output(node) # handle logic for input/output tensors if is_input or is_output: @@ -352,6 +356,33 @@ def get_dynamic_dimension(self, dims): return dynamic_dims if any(dynamic_dims) else [], nominal_dims + def get_tensor_name( + self, + node: torch.fx.Node, + wrapper_idx: int = 0, + ): + tensor_name = f"{node.name}_{wrapper_idx}" + # The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess, + # the input order between QNN and the original graph’s forward function may differ. + # The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime. + # The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump. + if is_mutable_buffer_input(node, self.edge_program): + fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target] + position_index = list( + self.edge_program.graph_signature.buffers_to_mutate.values() + ).index(fqn) + tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}" + elif is_graph_input(node, self.edge_program): + tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}" + elif is_mutable_buffer_output(node, self.edge_program): + position_index = list( + self.edge_program.graph_signature.buffers_to_mutate.keys() + ).index(node.name) + tensor_name = f"output_mutbuf_{position_index}_{tensor_name}" + elif is_graph_output(node): + tensor_name = f"output_{tensor_name}" + return tensor_name + def define_custom_tensor_wrapper( self, node_name: str, @@ -413,16 +444,7 @@ def define_tensor( if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): return cached - tensor_name = f"{tensor_source_node.name}_{wrapper_idx}" - if is_graph_input(tensor_source_node, self.edge_program): - tensor_name = ( - "input_" - + str(self.external_ids[tensor_source_node]) - + "_" - + tensor_name - ) - if is_graph_output(tensor_source_node): - tensor_name = "output_" + tensor_name + tensor_name = self.get_tensor_name(tensor_source_node, wrapper_idx) dims = torch.Size([1]) if len(tensor.size()) == 0 else tensor.size() dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims) tensor_type = self.get_tensor_type(tensor_source_node, tensor_type) diff --git a/backends/qualcomm/builders/node_visitor_manager.py b/backends/qualcomm/builders/node_visitor_manager.py index fa9d51db1ad..8c1733fcec3 100644 --- a/backends/qualcomm/builders/node_visitor_manager.py +++ b/backends/qualcomm/builders/node_visitor_manager.py @@ -13,7 +13,7 @@ from .node_visitor import NodeVisitor from .op_custom_op import CustomOp -from .utils import is_graph_input, is_graph_output +from .utils import is_graph_input, is_graph_output, is_mutable_buffer_input # This will hold mapping of all node names to the visitor class @@ -39,7 +39,9 @@ def generate_node_to_external_map( # The order in which we visit the placeholder node is same as the *args # order for the forward(*args) signature for this gm. Using the order of # the nodes as external_id to extract the right arg from *args at runtime - if is_graph_input(node, edge_program): + if is_graph_input(node, edge_program) or is_mutable_buffer_input( + node, edge_program + ): node_to_external_map[node] = len(node_to_external_map) for node in edge_program.graph_module.graph.nodes: if is_graph_output(node): diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py index a58075bf06c..de59b1a0489 100644 --- a/backends/qualcomm/builders/op_index_put.py +++ b/backends/qualcomm/builders/op_index_put.py @@ -1,9 +1,10 @@ from typing import Dict import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper - import torch +from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS + from .node_visitor import NodeVisitor from .node_visitor_manager import register_node_visitor from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -22,6 +23,10 @@ def define_node( nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: input_node = self.get_node(node.args[0]) + # Because the args[0] of index_put op doesn't annotate, need to fill in the quant_attr with the node here. + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + quant_attrs = quant_attrs.copy() + input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index c82ebaf1bb3..3345f2e1fc9 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -75,6 +75,20 @@ def is_graph_input( return tensor.op == "placeholder" and not is_parameter(tensor, edge_program) +def is_mutable_buffer_input( + tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram +) -> bool: + """ + Check if the given tensor is a mutable buffer input + Args: + tensor: EdgeIR Tensor that is being checked for mutable buffer input + """ + if tensor.op == "placeholder" and is_buffer(edge_program, tensor): + fqn = edge_program.graph_signature.inputs_to_buffers[tensor.target] + # if the buffer is mutated then record that + return fqn in edge_program.graph_signature.buffers_to_mutate.values() + + def is_graph_output(node: torch.fx.Node) -> bool: """ Check if the given tensor is used as a graph output @@ -83,7 +97,7 @@ def is_graph_output(node: torch.fx.Node) -> bool: tensor: EdgeIR Tensor that is being checked for graph input """ for user in node.users.keys(): - # getitem node is skiped, check the op_skip_ops.py + # getitem node is skipped, check the op_skip_ops.py if user.op == "output" or ( user.target.__name__ == "getitem" and is_graph_output(user) ): @@ -91,6 +105,25 @@ def is_graph_output(node: torch.fx.Node) -> bool: return False +def is_mutable_buffer_output( + tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram +) -> bool: + """ + Check if the given tensor is a mutable buffer output + Args: + tensor: EdgeIR Tensor that is being checked for mutable buffer output + """ + return ( + any( + user.op == "output" + or user.target.__name__ == "getitem" + and is_graph_output(user) + for user in tensor.users.keys() + ) + and tensor.name in edge_program.graph_signature.buffers_to_mutate.keys() + ) + + def is_constant( tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram ) -> bool: diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 776923a1493..9a8ce92e739 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy +import logging from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Tuple @@ -29,7 +30,7 @@ Partitioner, PartitionResult, ) -from executorch.exir.backend.utils import tag_constant_data +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import Partition from torch.fx.passes.operator_support import OperatorSupportBase @@ -42,6 +43,9 @@ ) from .utils import filter_fn, generate_qnn_executorch_option, get_skip_decomp_table +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + class QnnOperatorSupport(OperatorSupportBase): def __init__( @@ -124,6 +128,7 @@ def __init__( compiler_specs: List[CompileSpec], skip_node_id_set: set = None, skip_node_op_set: set = None, + skip_mutable_buffer: bool = False, ): self.compiler_specs_snapshot = copy.deepcopy(compiler_specs) @@ -133,6 +138,7 @@ def __init__( self.partition_tags: Dict[str, DelegationSpec] = {} self.skip_node_id_set = set() if skip_node_id_set is None else skip_node_id_set self.skip_node_op_set = set() if skip_node_op_set is None else skip_node_op_set + self.skip_mutable_buffer = skip_mutable_buffer def generate_partitions( self, edge_program: torch.export.ExportedProgram @@ -178,6 +184,15 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu if len(partitions) != 0: self.tag_nodes(partitions, edge_program) tag_constant_data(edge_program) + if not self.skip_mutable_buffer: + logger.info( + "Qnn partitioner will delegate torch mutable buffer with the same I/O address during the runtime, " + "so if your model contains mutable buffer, " + "then you can get the better performance with skip_mutable_buffer=False. " + "If you encounter accuracy issue during the runtime, " + "then please set `skip_mutable_buffer=True` and try again." + ) + tag_mutated_buffer(edge_program) for node in edge_program.graph_module.graph.nodes: if hasattr(node, "meta"): # pop certain keys in meta for not affecting the passes in compilation diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index eca889a1610..e1e2ca6dff6 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -817,16 +817,32 @@ def annotate_index(node: Node, quantization_config: QuantizationConfig) -> None: [torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default] ) def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None: - input = node.args[0] + # Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process. value = node.args[2] input_qspec_map = {} - input_qspec_map[input] = quantization_config.input_activation - input_qspec_map[value] = SharedQuantizationSpec((input, node)) + input_qspec_map[value] = quantization_config.input_activation node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=SharedQuantizationSpec((input, node)), + output_qspec=SharedQuantizationSpec((value, node)), + _annotated=True, + ) + + +@register_annotator( + [torch.ops.aten.index_copy.default, torch.ops.aten.index_copy_.default] +) +def annotate_index_copy(node: Node, quantization_config: QuantizationConfig) -> None: + # Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process. + value = node.args[3] + + input_qspec_map = {} + input_qspec_map[value] = quantization_config.input_activation + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=SharedQuantizationSpec((value, node)), _annotated=True, ) diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index ab038404582..01bf13603d6 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -129,33 +129,37 @@ Error QnnExecuTorchBackend::execute( std::vector<Qnn_Tensor_t> input_tensor_structs; std::vector<Qnn_Tensor_t> output_tensor_structs; + int args_index = 0; input_tensor_structs.reserve(input_tensors.size()); - for (int i = 0; i < input_tensors.size(); ++i) { - if (qnn_manager->RegisterMem( - args[i]->toTensor().mutable_data_ptr(), input_tensors[i]) != - Error::Ok) { - // update data ptr only should be fine - input_tensors[i]->FillDataBuffer( - args[i]->toTensor().const_data_ptr(), false /* copy_data */); + for (const auto& input_tensor : input_tensors) { + if (input_tensor->GetName().find("mutbuf_") == std::string::npos) { + if (qnn_manager->RegisterMem( + args[args_index]->toTensor().mutable_data_ptr(), input_tensor) != + Error::Ok) { + // update data ptr only should be fine + input_tensor->FillDataBuffer( + args[args_index]->toTensor().const_data_ptr(), + false /* copy_data */); + // use the real input shape instead of nominal one to make sure + // dynamic shape is functional + auto dims = args[args_index]->toTensor().sizes(); + input_tensor->SetDims(dims.data(), dims.size()); + } + args_index++; } - // use the real input shape instead of nominal one to make sure - // dynamic shape is functional - auto dims = args[i]->toTensor().sizes(); - input_tensors[i]->SetDims(dims.data(), dims.size()); - input_tensor_structs.emplace_back(input_tensors[i]->CloneTensorStruct()); + input_tensor_structs.emplace_back(input_tensor->CloneTensorStruct()); } - int output_index = input_tensors.size(); for (const auto& output_tensor : output_tensors) { // pos=0 limits the search to the prefix - if (output_tensor->GetName().rfind("output_", 0) == 0) { - void* mutable_data_ptr = - args[output_index]->toTensor().mutable_data_ptr(); + if (output_tensor->GetName().rfind("output_", 0) == 0 && + output_tensor->GetName().find("mutbuf_") == std::string::npos) { + void* mutable_data_ptr = args[args_index]->toTensor().mutable_data_ptr(); if (qnn_manager->RegisterMem(mutable_data_ptr, output_tensor) != Error::Ok) { output_tensor->FillDataBuffer(mutable_data_ptr, false /* copy_data */); } - output_index++; + args_index++; } output_tensor_structs.push_back(output_tensor->CloneTensorStruct()); } diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index 0f64e8b9cce..0dd0470a2b0 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -18,6 +18,7 @@ #include <cstring> #include <fstream> #include <string> +#include <unordered_map> namespace executorch { namespace backends { @@ -35,6 +36,16 @@ bool CompareExportedInput( return numA < numB; } +int ExtractMutableBufferNumber(const std::string& name) { + std::string prefix = "mutbuf_"; + size_t startPos = name.find(prefix); + if (startPos != std::string::npos) { + startPos += prefix.length(); + return std::stoi(name.substr(startPos)); + } + return -1; +} + QnnManager::~QnnManager() { Destroy(); } @@ -363,9 +374,21 @@ Error QnnManager::AllocateTensor(const std::string& graph_name) { std::vector<Qnn_Tensor_t> output_tensors = backend_params_ptr_->qnn_context_ptr_->GetGraphOutputs(graph_name); + // Mapping memory address for the input and output of mutable buffer + std::unordered_map<int, const void*> mutable_buffer_id_to_memory_map; + for (auto& tensor : input_tensors) { std::shared_ptr<TensorWrapper> tensor_wrapper = CreateTensorWrapper(tensor); tensor_wrapper->UpdateQnnTensorMeta(tensor); + + int mutable_buffer_id = + ExtractMutableBufferNumber(tensor_wrapper->GetName()); + if (mutable_buffer_id != -1) { + // Delegate maintains the memory for mutable buffer + tensor_wrapper->AllocateDataBuffer(); + mutable_buffer_id_to_memory_map[mutable_buffer_id] = + tensor_wrapper->GetStaticTensorData(); + } input_tensors_[graph_name].emplace_back(std::move(tensor_wrapper)); } if (!options_->is_from_context_binary()) { @@ -388,6 +411,16 @@ Error QnnManager::AllocateTensor(const std::string& graph_name) { if (IsTensorDump()) { tensor_wrapper->AllocateDataBuffer(); } + int mutable_buffer_id = + ExtractMutableBufferNumber(tensor_wrapper->GetName()); + if (mutable_buffer_id != -1 && + mutable_buffer_id_to_memory_map.find(mutable_buffer_id) != + mutable_buffer_id_to_memory_map.end()) { + // Fill the same memory for I/O of mutable buffer + tensor_wrapper->FillDataBuffer( + mutable_buffer_id_to_memory_map[mutable_buffer_id], + false /* copy_data */); + } output_tensors_[graph_name].emplace_back(std::move(tensor_wrapper)); } return Error::Ok; diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 091c2d94cd0..8be05d46688 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -909,17 +909,35 @@ def forward(self, x): return self.dispatcher[self.axis](x) +class IndexCopy(torch.nn.Module): + def __init__(self, skip_mutable_buffer=False): + super().__init__() + self.skip_mutable_buffer = skip_mutable_buffer + self.register_buffer( + "k_cache", + torch.zeros((1, 1024, 12, 64), dtype=torch.float32), + persistent=True, + ) + + def forward(self, input_pos, k_val): + k_out = self.k_cache + k_out.index_copy_(1, input_pos, k_val) + return k_out + 0 + + class IndexPut(torch.nn.Module): - def __init__(self): + def __init__(self, skip_mutable_buffer=False): super().__init__() + self.skip_mutable_buffer = skip_mutable_buffer self.register_buffer( "k_cache", torch.zeros((1, 1024, 12, 64), dtype=torch.float32), + persistent=True, ) def forward(self, input_pos, k_val): k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val) - return k_out + return k_out + 0 class InstanceNorm2d(torch.nn.Module): diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 7666e36b985..747a6804957 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -618,13 +618,55 @@ def test_qnn_backend_index(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_index_copy(self): + test_comb = [ + { + QCOM_MODULE: IndexCopy(skip_mutable_buffer=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int64), + torch.randn([1, 1, 12, 64]), + ), + }, + { + QCOM_MODULE: IndexCopy(skip_mutable_buffer=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int64), + torch.randn([1, 1, 12, 64]), + ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], + test[QCOM_SAMPLE_INPUTS], + skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, + ) + def test_qnn_backend_index_put(self): - module = IndexPut() # noqa: F405 - sample_input = ( - torch.tensor([2], dtype=torch.int32), - torch.randn([1, 1, 12, 64]), - ) - self.lower_module_and_test_output(module, sample_input) + test_comb = [ + { + QCOM_MODULE: IndexPut(skip_mutable_buffer=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), + ), + }, + { + QCOM_MODULE: IndexPut(skip_mutable_buffer=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), + ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + self.lower_module_and_test_output( + test[QCOM_MODULE], + test[QCOM_SAMPLE_INPUTS], + skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, + ) def test_qnn_backend_instance_norm_2d(self): modules = [InstanceNorm2d(32), InstanceNorm2d(32, affine=False)] # noqa: F405 @@ -1860,14 +1902,61 @@ def test_qnn_backend_index(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_index_copy(self): + test_comb = [ + { + QCOM_MODULE: IndexCopy(skip_mutable_buffer=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int64), + torch.randn([1, 1, 12, 64]), + ), + }, + { + QCOM_MODULE: IndexCopy(skip_mutable_buffer=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int64), + torch.randn([1, 1, 12, 64]), + ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output( + module, + test[QCOM_SAMPLE_INPUTS], + skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, + ) + def test_qnn_backend_index_put(self): - module = IndexPut() # noqa: F405 - sample_input = ( - torch.tensor([2], dtype=torch.int32), - torch.randn([1, 1, 12, 64]), - ) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + test_comb = [ + { + QCOM_MODULE: IndexPut(skip_mutable_buffer=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), + ), + }, + { + QCOM_MODULE: IndexPut(skip_mutable_buffer=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), + ), + }, + ] + for i, test in enumerate(test_comb): + with self.subTest(i=i): + module = self.get_qdq_module( + test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output( + module, + test[QCOM_SAMPLE_INPUTS], + skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, + ) def test_qnn_backend_instance_norm_2d(self): modules = [InstanceNorm2d(32), InstanceNorm2d(32, affine=False)] # noqa: F405 @@ -3030,7 +3119,17 @@ def test_qnn_backend_generate_optrace(self): for _, (optrace, qhas) in binaries_trace.items(): with open(optrace, "r") as optrace_file: optrace_data = json.load(optrace_file) - for row in optrace_data: + # { + # header: + # { + # 'header_version': {'major': x, 'minor': y, 'patch': z}, + # 'version': {'major': x, 'minor': y, 'patch': z}, + # 'artifact_type': 'OP_TRACE' + # } + # traceEvents: + # {...} + # } + for row in optrace_data["traceEvents"]: self.assertIn("pid", row) with open(qhas, "r") as qhas_file: qhas_data = json.load(qhas_file) @@ -3726,7 +3825,17 @@ def test_qnn_backend_generate_optrace(self): for _, (optrace, qhas) in binaries_trace.items(): with open(optrace, "r") as optrace_file: optrace_data = json.load(optrace_file) - for row in optrace_data: + # { + # header: + # { + # 'header_version': {'major': x, 'minor': y, 'patch': z}, + # 'version': {'major': x, 'minor': y, 'patch': z}, + # 'artifact_type': 'OP_TRACE' + # } + # traceEvents: + # {...} + # } + for row in optrace_data["traceEvents"]: self.assertIn("pid", row) with open(qhas, "r") as qhas_file: qhas_data = json.load(qhas_file) diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 2968086d7a5..2e923b92250 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -462,6 +462,7 @@ def lower_module_and_test_output( passes_job: Optional[OrderedDict] = None, skip_node_id_set: set = None, skip_node_op_set: set = None, + skip_mutable_buffer: bool = False, dynamic_shapes: Dict = None, ): delegated_program = to_edge_transform_and_lower_to_qnn( @@ -472,6 +473,7 @@ def lower_module_and_test_output( passes_job=passes_job, skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, + skip_mutable_buffer=skip_mutable_buffer, ) # this is needed for the ETRecord as lowering modifies the graph in-place diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 45f08c8f2c1..3471b0155bd 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -332,6 +332,7 @@ def to_edge_transform_and_lower_to_qnn( passes_job: Optional[Union[OrderedDict, Dict[str, OrderedDict]]] = None, skip_node_id_set: Optional[set] = None, skip_node_op_set: Optional[set] = None, + skip_mutable_buffer: bool = False, ) -> EdgeProgramManager: """ Transforms and lowers a given PyTorch module to the QNN backend. @@ -356,6 +357,8 @@ def to_edge_transform_and_lower_to_qnn( Set of node IDs to skip during partitioning. skip_node_op_set (Optional[set]): Set of node operations to skip during partitioning. + skip_mutable_buffer (Optional[set]): + Whether to skip delegating the mutable buffer in QNN backend. Returns: EdgeProgramManager: @@ -407,6 +410,7 @@ def ensure_graph_specific_dict(value, graph_names): compiler_specs[graph_name], skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, + skip_mutable_buffer=skip_mutable_buffer, ) ] for graph_name in graph_names diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 1f055d65822..f79a25b6077 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1187,7 +1187,6 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": calibration_seq_length=llm_config.quantization.calibration_seq_length, calibration_data=llm_config.quantization.calibration_data, tokenizer_path=llm_config.base.tokenizer_path, - use_legacy_export=llm_config.backend.qnn.enabled, save_exported_program=llm_config.export.export_only, verbose=llm_config.debug.verbose, metadata=_load_llama_model_metadata( diff --git a/examples/models/llama/source_transformation/attention.py b/examples/models/llama/source_transformation/attention.py index d5f065550d2..3c37cddee69 100644 --- a/examples/models/llama/source_transformation/attention.py +++ b/examples/models/llama/source_transformation/attention.py @@ -45,12 +45,10 @@ def __init__( self.register_buffer( f"past_k_caches_{i}", torch.zeros(cache_shape, dtype=dtype, device="cpu"), - persistent=False, ) self.register_buffer( f"past_v_caches_{i}", torch.zeros(cache_shape, dtype=dtype, device="cpu"), - persistent=False, ) def update( diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 1fb3d97a9c7..59823b533a3 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -493,12 +493,10 @@ def __init__( self.register_buffer( "past_k_caches", torch.zeros(cache_shape, dtype=dtype, device="cpu"), - persistent=False, ) self.register_buffer( "past_v_caches", torch.zeros(cache_shape, dtype=dtype, device="cpu"), - persistent=False, ) def update( diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 4128bfd8198..079d89848dd 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -10,11 +10,9 @@ # pyre-unsafe -import contextlib import logging from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple -from unittest.mock import patch import torch from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( @@ -96,7 +94,6 @@ def __init__( verbose: bool = False, metadata: Optional[dict] = None, dynamic_shapes: Optional[Any] = None, - use_legacy_export: bool = False, save_exported_program: bool = False, ): # Store necessary constructor arguments. @@ -117,7 +114,6 @@ def __init__( self.verbose = verbose self.metadata = metadata self.dynamic_shapes = dynamic_shapes - self.use_legacy_export = use_legacy_export self.save_exported_program = save_exported_program # Note: treat this as the source of truth for the result of @@ -229,39 +225,20 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - if self.use_legacy_export: - # TODO: for use cases such as qnn, which does not work with new, non-functional export IR. - # See issue: https://github.com/pytorch/executorch/issues/7373 - - with patch.object( - torch._utils_internal, - "export_training_ir_rollout_check", - return_value=False, - ): - # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a - # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details - exported_module = torch.export.export( - self.model if not module else module, - self.example_inputs, - self.example_kwarg_inputs, - dynamic_shapes=dynamic_shape, - strict=True, - ) + if module: + logging.info("Re-exporting with:") else: - if module: - logging.info("Re-exporting with:") - else: - logging.info("Exporting with:") - logging.info(f"inputs: {self.example_inputs}") - logging.info(f"kwargs: {self.example_kwarg_inputs}") - logging.info(f"dynamic shapes: {dynamic_shape}") - exported_module = export_for_training( - self.model if not module else module, - self.example_inputs, - kwargs=self.example_kwarg_inputs, - dynamic_shapes=dynamic_shape, - strict=True, - ) + logging.info("Exporting with:") + logging.info(f"inputs: {self.example_inputs}") + logging.info(f"kwargs: {self.example_kwarg_inputs}") + logging.info(f"dynamic shapes: {dynamic_shape}") + exported_module = export_for_training( + self.model if not module else module, + self.example_inputs, + kwargs=self.example_kwarg_inputs, + dynamic_shapes=dynamic_shape, + strict=True, + ) return exported_module def export(self) -> "LLMEdgeManager": @@ -446,24 +423,15 @@ def export_to_edge(self) -> "LLMEdgeManager": # Run export() if it didn't run self.export() - override_export_behaviour = contextlib.nullcontext() - if self.use_legacy_export: - override_export_behaviour = patch.object( - torch._utils_internal, - "export_training_ir_rollout_check", - return_value=False, - ) - - with override_export_behaviour: - self.edge_manager = export_to_edge( - self.pre_autograd_graph_module, # pyre-fixme[6] - self.example_inputs, - example_kwarg_inputs=self.example_kwarg_inputs, - dynamic_shapes=dynamic_shape, - edge_constant_methods=self.metadata, - edge_compile_config=edge_config, - verbose=self.verbose, - ) + self.edge_manager = export_to_edge( + self.pre_autograd_graph_module, # pyre-fixme[6] + self.example_inputs, + example_kwarg_inputs=self.example_kwarg_inputs, + dynamic_shapes=dynamic_shape, + edge_constant_methods=self.metadata, + edge_compile_config=edge_config, + verbose=self.verbose, + ) return self def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager": diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index 20604bbf635..e35358f56b7 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -216,4 +216,5 @@ def get_qnn_partitioner( ), skip_node_id_set={}, skip_node_op_set=skip_node_op_set, + skip_mutable_buffer=False, ) From 19c5aa13cb232eca0b0d385b41c93c5ebe4cbad5 Mon Sep 17 00:00:00 2001 From: shewu-quic <quic_shewu@quicinc.com> Date: Thu, 19 Jun 2025 12:14:53 +0800 Subject: [PATCH 2/3] Fixed the CI for meta's llama --- backends/qualcomm/_passes/__init__.py | 2 + .../qualcomm/_passes/convert_bmm_to_matmul.py | 76 +++++++++++++++++++ backends/qualcomm/_passes/qnn_pass_manager.py | 2 + backends/qualcomm/_passes/utils.py | 3 + .../qualcomm/quantizer/custom_annotation.py | 9 ++- examples/models/llama/export_llama_lib.py | 2 + 6 files changed, 90 insertions(+), 4 deletions(-) create mode 100644 backends/qualcomm/_passes/convert_bmm_to_matmul.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 2b743454bf9..01710aa8d80 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -8,6 +8,7 @@ from .annotate_quant_attrs import AnnotateQuantAttrs from .annotate_stack import AnnotateStack from .annotate_unbind import AnnotateUnbind +from .convert_bmm_to_matmul import ConvertBmmToMatmul from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d from .convert_square_to_pow import ConvertSquareToPow from .decompose_any import DecomposeAny @@ -44,6 +45,7 @@ AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, + ConvertBmmToMatmul, ConvertConv1dToConv2d, ConvertSquareToPow, DecomposeAny, diff --git a/backends/qualcomm/_passes/convert_bmm_to_matmul.py b/backends/qualcomm/_passes/convert_bmm_to_matmul.py new file mode 100644 index 00000000000..84e1ff26aa1 --- /dev/null +++ b/backends/qualcomm/_passes/convert_bmm_to_matmul.py @@ -0,0 +1,76 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import operator +from collections import Counter +from typing import List + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + + +class ConvertBmmToMatmul(ExportPass): + """ + Replace bmm to matmul, because bmm is eqaul to matmul in QNN. + Handle missing quantization tag for bmm op. + """ + + view_copy = exir_ops.edge.aten.view_copy.default + expand_copy = exir_ops.edge.aten.expand_copy.default + clone = exir_ops.edge.aten.clone.default + bmm = exir_ops.edge.aten.bmm.default + matmul = exir_ops.edge.aten.matmul.default + patterns = [ + {expand_copy: 2, view_copy: 3, bmm: 1}, + {expand_copy: 2, view_copy: 3, bmm: 1, clone: 1}, + {bmm: 1}, + ] + + def __init__(self): + super(ConvertBmmToMatmul, self).__init__() + + def _get_ordered_inputs( + self, inputs: List[torch.fx.Node], output: torch.fx.Node + ) -> List[torch.fx.Node]: + bmm_inputs = [] + for arg in output.args: + while arg not in inputs: + arg = arg.args[0] + bmm_inputs.append(arg) + return bmm_inputs + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + partitions = get_source_partitions( + graph, + [operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default], + ) + for _, src_partitions in partitions.items(): + for src_partition in src_partitions: + op_cnt = Counter([n.target for n in src_partition.nodes]) + if op_cnt not in self.patterns: + raise AssertionError( + "Found a new pattern needed be converted to linear op" + ) + + inputs = src_partition.input_nodes + bmm_node = [n for n in src_partition.nodes if n.target == self.bmm][0] + output = src_partition.output_nodes[0] + # the order of src_partition.inputs is not guaranteed. + lhs, rhs = self._get_ordered_inputs(inputs, bmm_node) + with graph_module.graph.inserting_before(output): + # replace bmm to matmul, because bmm is eqaul to matmul in qnn. + matmul_node = graph.create_node( + "call_function", self.matmul, (lhs, rhs) + ) + matmul_node.meta = output.meta + for user in output.users.copy(): + user.replace_input_with(output, matmul_node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 9d2726eaaed..8340fa6209e 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -13,6 +13,7 @@ AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, + ConvertBmmToMatmul, ConvertConv1dToConv2d, ConvertSquareToPow, DecomposeAny, @@ -79,6 +80,7 @@ def get_capture_program_passes(): (AnnotateQuantAttrs, True), (AnnotateStack, True), (AnnotateUnbind, True), + (ConvertBmmToMatmul, False), (ConvertConv1dToConv2d, True), (DecomposeAny, True), (DecomposeColIm, True), diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index b23efc8714f..ae11ba7b325 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -64,6 +64,7 @@ def get_passes_dependency_for_capture_program(): AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind, + ConvertBmmToMatmul, ConvertConv1dToConv2d, DecomposeAny, DecomposeColIm, @@ -82,11 +83,13 @@ def get_passes_dependency_for_capture_program(): return { AnnotateAdaptiveAvgPool1D: [RemoveRedundancy], AnnotateQuantAttrs: [ + ConvertBmmToMatmul, RecomposePixelUnshuffle, RemoveRedundancy, ], AnnotateStack: [RemoveRedundancy], AnnotateUnbind: [RemoveRedundancy], + ConvertBmmToMatmul: [RecomposePixelUnshuffle], DecomposeAny: [RemoveRedundancy], DecomposeColIm: [FoldQDQ], DecomposeLinalgVectorNorm: [RemoveRedundancy], diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index b5531cd492f..9e9370ef697 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -292,14 +292,15 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig): ) def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None: - input = node.args[0] + # Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process. value = node.args[2] + input_qspec_map = {} - input_qspec_map[input] = quantization_config.input_activation - input_qspec_map[value] = SharedQuantizationSpec((input, node)) + input_qspec_map[value] = quantization_config.input_activation + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( input_qspec_map=input_qspec_map, - output_qspec=SharedQuantizationSpec((input, node)), + output_qspec=SharedQuantizationSpec((value, node)), _annotated=True, ) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index f79a25b6077..db2b781825f 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -914,6 +914,7 @@ def _to_edge_and_lower_llama( # noqa: C901 # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes` from executorch.backends.qualcomm._passes import ( AnnotateStack, + ConvertBmmToMatmul, FoldQDQ, RecomposeRmsNorm, TagQuantIO, @@ -956,6 +957,7 @@ def _to_edge_and_lower_llama( # noqa: C901 passes_job = get_capture_program_passes() dep_table = get_passes_dependency_for_capture_program() passes_job[AnnotateStack][QCOM_PASS_ACTIVATE_KEY] = True + passes_job[ConvertBmmToMatmul][QCOM_PASS_ACTIVATE_KEY] = True passes_job[RecomposeRmsNorm][QCOM_PASS_ACTIVATE_KEY] = True passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ From 5fde193fa730e0a9fa7d20aa10c037108da4f1ae Mon Sep 17 00:00:00 2001 From: shewu-quic <quic_shewu@quicinc.com> Date: Fri, 20 Jun 2025 10:11:52 +0800 Subject: [PATCH 3/3] revert change --- examples/models/llama/export_llama_lib.py | 1 + extension/llm/export/builder.py | 76 ++++++++++++++++------- extension/llm/export/partitioner_lib.py | 3 +- 3 files changed, 57 insertions(+), 23 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index db2b781825f..9d5fcfdba25 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1189,6 +1189,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": calibration_seq_length=llm_config.quantization.calibration_seq_length, calibration_data=llm_config.quantization.calibration_data, tokenizer_path=llm_config.base.tokenizer_path, + use_legacy_export=llm_config.backend.qnn.enabled, save_exported_program=llm_config.export.export_only, verbose=llm_config.debug.verbose, metadata=_load_llama_model_metadata( diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 079d89848dd..4128bfd8198 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -10,9 +10,11 @@ # pyre-unsafe +import contextlib import logging from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple +from unittest.mock import patch import torch from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( @@ -94,6 +96,7 @@ def __init__( verbose: bool = False, metadata: Optional[dict] = None, dynamic_shapes: Optional[Any] = None, + use_legacy_export: bool = False, save_exported_program: bool = False, ): # Store necessary constructor arguments. @@ -114,6 +117,7 @@ def __init__( self.verbose = verbose self.metadata = metadata self.dynamic_shapes = dynamic_shapes + self.use_legacy_export = use_legacy_export self.save_exported_program = save_exported_program # Note: treat this as the source of truth for the result of @@ -225,20 +229,39 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram: # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - if module: - logging.info("Re-exporting with:") + if self.use_legacy_export: + # TODO: for use cases such as qnn, which does not work with new, non-functional export IR. + # See issue: https://github.com/pytorch/executorch/issues/7373 + + with patch.object( + torch._utils_internal, + "export_training_ir_rollout_check", + return_value=False, + ): + # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a + # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details + exported_module = torch.export.export( + self.model if not module else module, + self.example_inputs, + self.example_kwarg_inputs, + dynamic_shapes=dynamic_shape, + strict=True, + ) else: - logging.info("Exporting with:") - logging.info(f"inputs: {self.example_inputs}") - logging.info(f"kwargs: {self.example_kwarg_inputs}") - logging.info(f"dynamic shapes: {dynamic_shape}") - exported_module = export_for_training( - self.model if not module else module, - self.example_inputs, - kwargs=self.example_kwarg_inputs, - dynamic_shapes=dynamic_shape, - strict=True, - ) + if module: + logging.info("Re-exporting with:") + else: + logging.info("Exporting with:") + logging.info(f"inputs: {self.example_inputs}") + logging.info(f"kwargs: {self.example_kwarg_inputs}") + logging.info(f"dynamic shapes: {dynamic_shape}") + exported_module = export_for_training( + self.model if not module else module, + self.example_inputs, + kwargs=self.example_kwarg_inputs, + dynamic_shapes=dynamic_shape, + strict=True, + ) return exported_module def export(self) -> "LLMEdgeManager": @@ -423,15 +446,24 @@ def export_to_edge(self) -> "LLMEdgeManager": # Run export() if it didn't run self.export() - self.edge_manager = export_to_edge( - self.pre_autograd_graph_module, # pyre-fixme[6] - self.example_inputs, - example_kwarg_inputs=self.example_kwarg_inputs, - dynamic_shapes=dynamic_shape, - edge_constant_methods=self.metadata, - edge_compile_config=edge_config, - verbose=self.verbose, - ) + override_export_behaviour = contextlib.nullcontext() + if self.use_legacy_export: + override_export_behaviour = patch.object( + torch._utils_internal, + "export_training_ir_rollout_check", + return_value=False, + ) + + with override_export_behaviour: + self.edge_manager = export_to_edge( + self.pre_autograd_graph_module, # pyre-fixme[6] + self.example_inputs, + example_kwarg_inputs=self.example_kwarg_inputs, + dynamic_shapes=dynamic_shape, + edge_constant_methods=self.metadata, + edge_compile_config=edge_config, + verbose=self.verbose, + ) return self def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager": diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index e35358f56b7..7b093a7f1a3 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -216,5 +216,6 @@ def get_qnn_partitioner( ), skip_node_id_set={}, skip_node_op_set=skip_node_op_set, - skip_mutable_buffer=False, + # TODO: if deprecated legacy export, skip_mutable_buffer can be set False + skip_mutable_buffer=True, )