From e922f14584fb50219db44e52f03ee29cf70f1380 Mon Sep 17 00:00:00 2001
From: shewu-quic <quic_shewu@quicinc.com>
Date: Thu, 12 Jun 2025 15:06:04 +0800
Subject: [PATCH 1/3] Qualcomm AI Engine Direct - Delegate mutable buffer and
 fix the mutable buffer issue

Summary:
- Add a parameter to support mutable buffer delegation in QNN Backend
  - Set the same memory address for I/O of mutable buffer at runtime
- Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process.
- Deprecated use_legacy_export in executorch llama
---
 backends/qualcomm/_passes/__init__.py         |   2 -
 backends/qualcomm/_passes/insert_io_qdq.py    |  10 +-
 backends/qualcomm/_passes/qnn_pass_manager.py |  11 +-
 .../_passes/replace_index_put_input.py        |  54 -------
 backends/qualcomm/_passes/utils.py            |   4 +-
 backends/qualcomm/builders/node_visitor.py    |  44 ++++--
 .../qualcomm/builders/node_visitor_manager.py |   6 +-
 backends/qualcomm/builders/op_index_put.py    |   7 +-
 backends/qualcomm/builders/utils.py           |  35 ++++-
 .../qualcomm/partition/qnn_partitioner.py     |  17 ++-
 backends/qualcomm/quantizer/annotators.py     |  24 ++-
 .../qualcomm/runtime/QnnExecuTorchBackend.cpp |  38 ++---
 backends/qualcomm/runtime/QnnManager.cpp      |  33 +++++
 backends/qualcomm/tests/models.py             |  22 ++-
 backends/qualcomm/tests/test_qnn_delegate.py  | 139 ++++++++++++++++--
 backends/qualcomm/tests/utils.py              |   2 +
 backends/qualcomm/utils/utils.py              |   4 +
 examples/models/llama/export_llama_lib.py     |   1 -
 .../llama/source_transformation/attention.py  |   2 -
 .../llama/source_transformation/sdpa.py       |   2 -
 extension/llm/export/builder.py               |  76 +++-------
 extension/llm/export/partitioner_lib.py       |   1 +
 22 files changed, 357 insertions(+), 177 deletions(-)
 delete mode 100644 backends/qualcomm/_passes/replace_index_put_input.py

diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py
index 681ea2ee534..2b743454bf9 100644
--- a/backends/qualcomm/_passes/__init__.py
+++ b/backends/qualcomm/_passes/__init__.py
@@ -35,7 +35,6 @@
 from .remove_0d_tensor import Remove0DTensor
 from .remove_redundancy import RemoveRedundancy
 from .replace_arange_args import ReplaceArangeArgs
-from .replace_index_put_input import ReplaceIndexPutInput
 from .replace_inf_values import ReplaceInfValues
 from .tag_quant_io import TagQuantIO
 
@@ -72,7 +71,6 @@
     Remove0DTensor,
     RemoveRedundancy,
     ReplaceArangeArgs,
-    ReplaceIndexPutInput,
     ReplaceInfValues,
     TagQuantIO,
 ]
diff --git a/backends/qualcomm/_passes/insert_io_qdq.py b/backends/qualcomm/_passes/insert_io_qdq.py
index e5b15f2d12c..caecae64fa8 100644
--- a/backends/qualcomm/_passes/insert_io_qdq.py
+++ b/backends/qualcomm/_passes/insert_io_qdq.py
@@ -9,7 +9,10 @@
 
 from executorch.backends.qualcomm.builders.node_visitor import q_ops
 
-from executorch.backends.qualcomm.builders.utils import is_parameter
+from executorch.backends.qualcomm.builders.utils import (
+    is_mutable_buffer_input,
+    is_parameter,
+)
 from executorch.backends.qualcomm.utils.constants import (
     QCOM_ENCODING,
     QCOM_QUANT_ATTRS,
@@ -124,7 +127,10 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
             if (
                 n.op == "placeholder"
                 and n.meta.get(QCOM_QUANT_ATTRS)
-                and not is_parameter(n, self.edge_program)
+                and (
+                    not is_parameter(n, self.edge_program)
+                    or is_mutable_buffer_input(n, self.edge_program)
+                )
             ):
                 self._insert_quant_node(
                     graph_module, n, n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]
diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py
index 2b68f9b6c09..9d2726eaaed 100644
--- a/backends/qualcomm/_passes/qnn_pass_manager.py
+++ b/backends/qualcomm/_passes/qnn_pass_manager.py
@@ -40,7 +40,6 @@
     Remove0DTensor,
     RemoveRedundancy,
     ReplaceArangeArgs,
-    ReplaceIndexPutInput,
     ReplaceInfValues,
     TagQuantIO,
 )
@@ -92,7 +91,6 @@ def get_capture_program_passes():
         (RecomposeRmsNorm, False),
         (Remove0DTensor, True),
         (RemoveRedundancy, True),
-        (ReplaceIndexPutInput, True),
         (TagQuantIO, False),
     ]
 
@@ -224,4 +222,11 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram):
         self.add_pass(LayoutTransform(exported_program, insert_permute=True))
         self.add_pass(FuseConsecutiveCast())
         self.add_pass(FuseConsecutiveTranspose())
-        return self._transform(exported_program.graph_module)
+        self._transform(exported_program.graph_module)
+        # Update inputs_to_buffers and buffers_to_mutate in graph signature for mutable buffer
+        # Since I/O will be inserted Q/DQ, it results in failed to mapping output node names and buffer
+        exported_program._graph_signature = _get_updated_graph_signature(
+            exported_program.graph_signature,
+            exported_program.graph_module,
+        )
+        return exported_program.graph_module
diff --git a/backends/qualcomm/_passes/replace_index_put_input.py b/backends/qualcomm/_passes/replace_index_put_input.py
deleted file mode 100644
index 93ee21bfc7c..00000000000
--- a/backends/qualcomm/_passes/replace_index_put_input.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright (c) Qualcomm Innovation Center, Inc.
-# All rights reserved
-#
-# This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree.
-import torch
-from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING, QCOM_QUANT_ATTRS
-from executorch.exir.dialects._ops import ops as exir_ops
-from executorch.exir.pass_base import ExportPass, PassResult
-
-
-class ReplaceIndexPutInput(ExportPass):
-    """
-    Index put input workaround for quantized module
-    """
-
-    dq_q_map = {
-        # per tensor
-        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
-        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
-        # per channel
-        exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
-    }
-
-    def __init__(self):
-        super(ReplaceIndexPutInput, self).__init__()
-
-    def call(self, graph_module: torch.fx.GraphModule):
-        graph = graph_module.graph
-        for node in graph.nodes:
-            if node.target == exir_ops.edge.aten.index_put.default:
-                if (
-                    copy_node := list(node.users)[0]
-                ) and copy_node.target == exir_ops.edge.aten.copy.default:
-                    m_buffer_node = copy_node.args[0]
-                    dq_node = node.args[0]
-                    bad_frozen_node = dq_node.args[0]
-                    if QCOM_QUANT_ATTRS in bad_frozen_node.meta:
-                        m_buffer_node.meta[QCOM_QUANT_ATTRS] = bad_frozen_node.meta[
-                            QCOM_QUANT_ATTRS
-                        ]
-                        m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] = (
-                            self.dq_q_map[
-                                m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]
-                            ]
-                        )
-                    with graph.inserting_after(dq_node):
-                        node.replace_input_with(dq_node, m_buffer_node)
-                else:
-                    continue
-
-        graph.eliminate_dead_code()
-        graph_module.recompile()
-        return PassResult(graph_module, True)
diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py
index ef52d2c190a..b23efc8714f 100755
--- a/backends/qualcomm/_passes/utils.py
+++ b/backends/qualcomm/_passes/utils.py
@@ -76,7 +76,6 @@ def get_passes_dependency_for_capture_program():
         RecomposePixelUnshuffle,
         RecomposeRmsNorm,
         RemoveRedundancy,
-        ReplaceIndexPutInput,
         TagQuantIO,
     )
 
@@ -103,8 +102,7 @@ def get_passes_dependency_for_capture_program():
         ],
         RecomposePixelUnshuffle: [RemoveRedundancy],
         RecomposeRmsNorm: [RemoveRedundancy],
-        ReplaceIndexPutInput: [LayoutTransform],
-        TagQuantIO: [ReplaceIndexPutInput],
+        TagQuantIO: [LayoutTransform],
     }
 
 
diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py
index 37fe3615268..8d77a5f47aa 100644
--- a/backends/qualcomm/builders/node_visitor.py
+++ b/backends/qualcomm/builders/node_visitor.py
@@ -41,6 +41,8 @@
     get_parameter,
     is_graph_input,
     is_graph_output,
+    is_mutable_buffer_input,
+    is_mutable_buffer_output,
     is_parameter,
 )
 
@@ -307,7 +309,9 @@ def get_tensor_type(
         node: torch.fx.Node,
         tensor_type: PyQnnWrapper.Qnn_TensorType_t,
     ) -> PyQnnWrapper.Qnn_TensorType_t:
-        is_input = is_graph_input(node, self.edge_program)
+        is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input(
+            node, self.edge_program
+        )
         is_output = is_graph_output(node)
         # handle logic for input/output tensors
         if is_input or is_output:
@@ -352,6 +356,33 @@ def get_dynamic_dimension(self, dims):
 
         return dynamic_dims if any(dynamic_dims) else [], nominal_dims
 
+    def get_tensor_name(
+        self,
+        node: torch.fx.Node,
+        wrapper_idx: int = 0,
+    ):
+        tensor_name = f"{node.name}_{wrapper_idx}"
+        # The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess,
+        # the input order between QNN and the original graph’s forward function may differ.
+        # The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime.
+        # The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump.
+        if is_mutable_buffer_input(node, self.edge_program):
+            fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target]
+            position_index = list(
+                self.edge_program.graph_signature.buffers_to_mutate.values()
+            ).index(fqn)
+            tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}"
+        elif is_graph_input(node, self.edge_program):
+            tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}"
+        elif is_mutable_buffer_output(node, self.edge_program):
+            position_index = list(
+                self.edge_program.graph_signature.buffers_to_mutate.keys()
+            ).index(node.name)
+            tensor_name = f"output_mutbuf_{position_index}_{tensor_name}"
+        elif is_graph_output(node):
+            tensor_name = f"output_{tensor_name}"
+        return tensor_name
+
     def define_custom_tensor_wrapper(
         self,
         node_name: str,
@@ -413,16 +444,7 @@ def define_tensor(
         if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
             return cached
 
-        tensor_name = f"{tensor_source_node.name}_{wrapper_idx}"
-        if is_graph_input(tensor_source_node, self.edge_program):
-            tensor_name = (
-                "input_"
-                + str(self.external_ids[tensor_source_node])
-                + "_"
-                + tensor_name
-            )
-        if is_graph_output(tensor_source_node):
-            tensor_name = "output_" + tensor_name
+        tensor_name = self.get_tensor_name(tensor_source_node, wrapper_idx)
         dims = torch.Size([1]) if len(tensor.size()) == 0 else tensor.size()
         dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims)
         tensor_type = self.get_tensor_type(tensor_source_node, tensor_type)
diff --git a/backends/qualcomm/builders/node_visitor_manager.py b/backends/qualcomm/builders/node_visitor_manager.py
index fa9d51db1ad..8c1733fcec3 100644
--- a/backends/qualcomm/builders/node_visitor_manager.py
+++ b/backends/qualcomm/builders/node_visitor_manager.py
@@ -13,7 +13,7 @@
 
 from .node_visitor import NodeVisitor
 from .op_custom_op import CustomOp
-from .utils import is_graph_input, is_graph_output
+from .utils import is_graph_input, is_graph_output, is_mutable_buffer_input
 
 
 # This will hold mapping of all node names to the visitor class
@@ -39,7 +39,9 @@ def generate_node_to_external_map(
         # The order in which we visit the placeholder node is same as the *args
         # order for the forward(*args) signature for this gm. Using the order of
         # the nodes as external_id to extract the right arg from *args at runtime
-        if is_graph_input(node, edge_program):
+        if is_graph_input(node, edge_program) or is_mutable_buffer_input(
+            node, edge_program
+        ):
             node_to_external_map[node] = len(node_to_external_map)
     for node in edge_program.graph_module.graph.nodes:
         if is_graph_output(node):
diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py
index a58075bf06c..de59b1a0489 100644
--- a/backends/qualcomm/builders/op_index_put.py
+++ b/backends/qualcomm/builders/op_index_put.py
@@ -1,9 +1,10 @@
 from typing import Dict
 
 import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
-
 import torch
 
+from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
+
 from .node_visitor import NodeVisitor
 from .node_visitor_manager import register_node_visitor
 from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -22,6 +23,10 @@ def define_node(
         nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
     ) -> PyQnnWrapper.PyQnnOpWrapper:
         input_node = self.get_node(node.args[0])
+        # Because the args[0] of index_put op doesn't annotate, need to fill in the quant_attr with the node here.
+        if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
+            quant_attrs = quant_attrs.copy()
+            input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
         input_tensor = self.get_tensor(input_node, node)
         input_tensor_wrapper = self.define_tensor(
             input_node,
diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py
index c82ebaf1bb3..3345f2e1fc9 100755
--- a/backends/qualcomm/builders/utils.py
+++ b/backends/qualcomm/builders/utils.py
@@ -75,6 +75,20 @@ def is_graph_input(
     return tensor.op == "placeholder" and not is_parameter(tensor, edge_program)
 
 
+def is_mutable_buffer_input(
+    tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
+) -> bool:
+    """
+    Check if the given tensor is a mutable buffer input
+    Args:
+        tensor: EdgeIR Tensor that is being checked for mutable buffer input
+    """
+    if tensor.op == "placeholder" and is_buffer(edge_program, tensor):
+        fqn = edge_program.graph_signature.inputs_to_buffers[tensor.target]
+        # if the buffer is mutated then record that
+        return fqn in edge_program.graph_signature.buffers_to_mutate.values()
+
+
 def is_graph_output(node: torch.fx.Node) -> bool:
     """
     Check if the given tensor is used as a graph output
@@ -83,7 +97,7 @@ def is_graph_output(node: torch.fx.Node) -> bool:
         tensor: EdgeIR Tensor that is being checked for graph input
     """
     for user in node.users.keys():
-        # getitem node is skiped, check the op_skip_ops.py
+        # getitem node is skipped, check the op_skip_ops.py
         if user.op == "output" or (
             user.target.__name__ == "getitem" and is_graph_output(user)
         ):
@@ -91,6 +105,25 @@ def is_graph_output(node: torch.fx.Node) -> bool:
     return False
 
 
+def is_mutable_buffer_output(
+    tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
+) -> bool:
+    """
+    Check if the given tensor is a mutable buffer output
+    Args:
+        tensor: EdgeIR Tensor that is being checked for mutable buffer output
+    """
+    return (
+        any(
+            user.op == "output"
+            or user.target.__name__ == "getitem"
+            and is_graph_output(user)
+            for user in tensor.users.keys()
+        )
+        and tensor.name in edge_program.graph_signature.buffers_to_mutate.keys()
+    )
+
+
 def is_constant(
     tensor: torch.fx.Node, edge_program: torch.export.ExportedProgram
 ) -> bool:
diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py
index 776923a1493..9a8ce92e739 100644
--- a/backends/qualcomm/partition/qnn_partitioner.py
+++ b/backends/qualcomm/partition/qnn_partitioner.py
@@ -4,6 +4,7 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 import copy
+import logging
 from collections import defaultdict
 from typing import Any, Callable, Dict, List, Optional, Tuple
 
@@ -29,7 +30,7 @@
     Partitioner,
     PartitionResult,
 )
-from executorch.exir.backend.utils import tag_constant_data
+from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
 from torch.export.exported_program import ExportedProgram
 from torch.fx.passes.infra.partitioner import Partition
 from torch.fx.passes.operator_support import OperatorSupportBase
@@ -42,6 +43,9 @@
 )
 from .utils import filter_fn, generate_qnn_executorch_option, get_skip_decomp_table
 
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
 
 class QnnOperatorSupport(OperatorSupportBase):
     def __init__(
@@ -124,6 +128,7 @@ def __init__(
         compiler_specs: List[CompileSpec],
         skip_node_id_set: set = None,
         skip_node_op_set: set = None,
+        skip_mutable_buffer: bool = False,
     ):
         self.compiler_specs_snapshot = copy.deepcopy(compiler_specs)
 
@@ -133,6 +138,7 @@ def __init__(
         self.partition_tags: Dict[str, DelegationSpec] = {}
         self.skip_node_id_set = set() if skip_node_id_set is None else skip_node_id_set
         self.skip_node_op_set = set() if skip_node_op_set is None else skip_node_op_set
+        self.skip_mutable_buffer = skip_mutable_buffer
 
     def generate_partitions(
         self, edge_program: torch.export.ExportedProgram
@@ -178,6 +184,15 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu
         if len(partitions) != 0:
             self.tag_nodes(partitions, edge_program)
             tag_constant_data(edge_program)
+            if not self.skip_mutable_buffer:
+                logger.info(
+                    "Qnn partitioner will delegate torch mutable buffer with the same I/O address during the runtime, "
+                    "so if your model contains mutable buffer, "
+                    "then you can get the better performance with skip_mutable_buffer=False. "
+                    "If you encounter accuracy issue during the runtime, "
+                    "then please set `skip_mutable_buffer=True` and try again."
+                )
+                tag_mutated_buffer(edge_program)
         for node in edge_program.graph_module.graph.nodes:
             if hasattr(node, "meta"):
                 # pop certain keys in meta for not affecting the passes in compilation
diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py
index eca889a1610..e1e2ca6dff6 100644
--- a/backends/qualcomm/quantizer/annotators.py
+++ b/backends/qualcomm/quantizer/annotators.py
@@ -817,16 +817,32 @@ def annotate_index(node: Node, quantization_config: QuantizationConfig) -> None:
     [torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default]
 )
 def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None:
-    input = node.args[0]
+    # Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process.
     value = node.args[2]
 
     input_qspec_map = {}
-    input_qspec_map[input] = quantization_config.input_activation
-    input_qspec_map[value] = SharedQuantizationSpec((input, node))
+    input_qspec_map[value] = quantization_config.input_activation
 
     node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
         input_qspec_map=input_qspec_map,
-        output_qspec=SharedQuantizationSpec((input, node)),
+        output_qspec=SharedQuantizationSpec((value, node)),
+        _annotated=True,
+    )
+
+
+@register_annotator(
+    [torch.ops.aten.index_copy.default, torch.ops.aten.index_copy_.default]
+)
+def annotate_index_copy(node: Node, quantization_config: QuantizationConfig) -> None:
+    # Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process.
+    value = node.args[3]
+
+    input_qspec_map = {}
+    input_qspec_map[value] = quantization_config.input_activation
+
+    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
+        input_qspec_map=input_qspec_map,
+        output_qspec=SharedQuantizationSpec((value, node)),
         _annotated=True,
     )
 
diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp
index ab038404582..01bf13603d6 100644
--- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp
+++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp
@@ -129,33 +129,37 @@ Error QnnExecuTorchBackend::execute(
   std::vector<Qnn_Tensor_t> input_tensor_structs;
   std::vector<Qnn_Tensor_t> output_tensor_structs;
 
+  int args_index = 0;
   input_tensor_structs.reserve(input_tensors.size());
-  for (int i = 0; i < input_tensors.size(); ++i) {
-    if (qnn_manager->RegisterMem(
-            args[i]->toTensor().mutable_data_ptr(), input_tensors[i]) !=
-        Error::Ok) {
-      // update data ptr only should be fine
-      input_tensors[i]->FillDataBuffer(
-          args[i]->toTensor().const_data_ptr(), false /* copy_data */);
+  for (const auto& input_tensor : input_tensors) {
+    if (input_tensor->GetName().find("mutbuf_") == std::string::npos) {
+      if (qnn_manager->RegisterMem(
+              args[args_index]->toTensor().mutable_data_ptr(), input_tensor) !=
+          Error::Ok) {
+        // update data ptr only should be fine
+        input_tensor->FillDataBuffer(
+            args[args_index]->toTensor().const_data_ptr(),
+            false /* copy_data */);
+        // use the real input shape instead of nominal one to make sure
+        // dynamic shape is functional
+        auto dims = args[args_index]->toTensor().sizes();
+        input_tensor->SetDims(dims.data(), dims.size());
+      }
+      args_index++;
     }
-    // use the real input shape instead of nominal one to make sure
-    // dynamic shape is functional
-    auto dims = args[i]->toTensor().sizes();
-    input_tensors[i]->SetDims(dims.data(), dims.size());
-    input_tensor_structs.emplace_back(input_tensors[i]->CloneTensorStruct());
+    input_tensor_structs.emplace_back(input_tensor->CloneTensorStruct());
   }
 
-  int output_index = input_tensors.size();
   for (const auto& output_tensor : output_tensors) {
     // pos=0 limits the search to the prefix
-    if (output_tensor->GetName().rfind("output_", 0) == 0) {
-      void* mutable_data_ptr =
-          args[output_index]->toTensor().mutable_data_ptr();
+    if (output_tensor->GetName().rfind("output_", 0) == 0 &&
+        output_tensor->GetName().find("mutbuf_") == std::string::npos) {
+      void* mutable_data_ptr = args[args_index]->toTensor().mutable_data_ptr();
       if (qnn_manager->RegisterMem(mutable_data_ptr, output_tensor) !=
           Error::Ok) {
         output_tensor->FillDataBuffer(mutable_data_ptr, false /* copy_data */);
       }
-      output_index++;
+      args_index++;
     }
     output_tensor_structs.push_back(output_tensor->CloneTensorStruct());
   }
diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp
index 0f64e8b9cce..0dd0470a2b0 100644
--- a/backends/qualcomm/runtime/QnnManager.cpp
+++ b/backends/qualcomm/runtime/QnnManager.cpp
@@ -18,6 +18,7 @@
 #include <cstring>
 #include <fstream>
 #include <string>
+#include <unordered_map>
 
 namespace executorch {
 namespace backends {
@@ -35,6 +36,16 @@ bool CompareExportedInput(
   return numA < numB;
 }
 
+int ExtractMutableBufferNumber(const std::string& name) {
+  std::string prefix = "mutbuf_";
+  size_t startPos = name.find(prefix);
+  if (startPos != std::string::npos) {
+    startPos += prefix.length();
+    return std::stoi(name.substr(startPos));
+  }
+  return -1;
+}
+
 QnnManager::~QnnManager() {
   Destroy();
 }
@@ -363,9 +374,21 @@ Error QnnManager::AllocateTensor(const std::string& graph_name) {
   std::vector<Qnn_Tensor_t> output_tensors =
       backend_params_ptr_->qnn_context_ptr_->GetGraphOutputs(graph_name);
 
+  // Mapping memory address for the input and output of mutable buffer
+  std::unordered_map<int, const void*> mutable_buffer_id_to_memory_map;
+
   for (auto& tensor : input_tensors) {
     std::shared_ptr<TensorWrapper> tensor_wrapper = CreateTensorWrapper(tensor);
     tensor_wrapper->UpdateQnnTensorMeta(tensor);
+
+    int mutable_buffer_id =
+        ExtractMutableBufferNumber(tensor_wrapper->GetName());
+    if (mutable_buffer_id != -1) {
+      // Delegate maintains the memory for mutable buffer
+      tensor_wrapper->AllocateDataBuffer();
+      mutable_buffer_id_to_memory_map[mutable_buffer_id] =
+          tensor_wrapper->GetStaticTensorData();
+    }
     input_tensors_[graph_name].emplace_back(std::move(tensor_wrapper));
   }
   if (!options_->is_from_context_binary()) {
@@ -388,6 +411,16 @@ Error QnnManager::AllocateTensor(const std::string& graph_name) {
     if (IsTensorDump()) {
       tensor_wrapper->AllocateDataBuffer();
     }
+    int mutable_buffer_id =
+        ExtractMutableBufferNumber(tensor_wrapper->GetName());
+    if (mutable_buffer_id != -1 &&
+        mutable_buffer_id_to_memory_map.find(mutable_buffer_id) !=
+            mutable_buffer_id_to_memory_map.end()) {
+      // Fill the same memory for I/O of mutable buffer
+      tensor_wrapper->FillDataBuffer(
+          mutable_buffer_id_to_memory_map[mutable_buffer_id],
+          false /* copy_data */);
+    }
     output_tensors_[graph_name].emplace_back(std::move(tensor_wrapper));
   }
   return Error::Ok;
diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py
index 091c2d94cd0..8be05d46688 100644
--- a/backends/qualcomm/tests/models.py
+++ b/backends/qualcomm/tests/models.py
@@ -909,17 +909,35 @@ def forward(self, x):
         return self.dispatcher[self.axis](x)
 
 
+class IndexCopy(torch.nn.Module):
+    def __init__(self, skip_mutable_buffer=False):
+        super().__init__()
+        self.skip_mutable_buffer = skip_mutable_buffer
+        self.register_buffer(
+            "k_cache",
+            torch.zeros((1, 1024, 12, 64), dtype=torch.float32),
+            persistent=True,
+        )
+
+    def forward(self, input_pos, k_val):
+        k_out = self.k_cache
+        k_out.index_copy_(1, input_pos, k_val)
+        return k_out + 0
+
+
 class IndexPut(torch.nn.Module):
-    def __init__(self):
+    def __init__(self, skip_mutable_buffer=False):
         super().__init__()
+        self.skip_mutable_buffer = skip_mutable_buffer
         self.register_buffer(
             "k_cache",
             torch.zeros((1, 1024, 12, 64), dtype=torch.float32),
+            persistent=True,
         )
 
     def forward(self, input_pos, k_val):
         k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val)
-        return k_out
+        return k_out + 0
 
 
 class InstanceNorm2d(torch.nn.Module):
diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py
index 7666e36b985..747a6804957 100644
--- a/backends/qualcomm/tests/test_qnn_delegate.py
+++ b/backends/qualcomm/tests/test_qnn_delegate.py
@@ -618,13 +618,55 @@ def test_qnn_backend_index(self):
             with self.subTest(i=i):
                 self.lower_module_and_test_output(module, sample_input)
 
+    def test_qnn_backend_index_copy(self):
+        test_comb = [
+            {
+                QCOM_MODULE: IndexCopy(skip_mutable_buffer=False),  # noqa: F405
+                QCOM_SAMPLE_INPUTS: (
+                    torch.tensor([2], dtype=torch.int64),
+                    torch.randn([1, 1, 12, 64]),
+                ),
+            },
+            {
+                QCOM_MODULE: IndexCopy(skip_mutable_buffer=True),  # noqa: F405
+                QCOM_SAMPLE_INPUTS: (
+                    torch.tensor([2], dtype=torch.int64),
+                    torch.randn([1, 1, 12, 64]),
+                ),
+            },
+        ]
+        for i, test in enumerate(test_comb):
+            with self.subTest(i=i):
+                self.lower_module_and_test_output(
+                    test[QCOM_MODULE],
+                    test[QCOM_SAMPLE_INPUTS],
+                    skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer,
+                )
+
     def test_qnn_backend_index_put(self):
-        module = IndexPut()  # noqa: F405
-        sample_input = (
-            torch.tensor([2], dtype=torch.int32),
-            torch.randn([1, 1, 12, 64]),
-        )
-        self.lower_module_and_test_output(module, sample_input)
+        test_comb = [
+            {
+                QCOM_MODULE: IndexPut(skip_mutable_buffer=False),  # noqa: F405
+                QCOM_SAMPLE_INPUTS: (
+                    torch.tensor([2], dtype=torch.int32),
+                    torch.randn([1, 1, 12, 64]),
+                ),
+            },
+            {
+                QCOM_MODULE: IndexPut(skip_mutable_buffer=True),  # noqa: F405
+                QCOM_SAMPLE_INPUTS: (
+                    torch.tensor([2], dtype=torch.int32),
+                    torch.randn([1, 1, 12, 64]),
+                ),
+            },
+        ]
+        for i, test in enumerate(test_comb):
+            with self.subTest(i=i):
+                self.lower_module_and_test_output(
+                    test[QCOM_MODULE],
+                    test[QCOM_SAMPLE_INPUTS],
+                    skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer,
+                )
 
     def test_qnn_backend_instance_norm_2d(self):
         modules = [InstanceNorm2d(32), InstanceNorm2d(32, affine=False)]  # noqa: F405
@@ -1860,14 +1902,61 @@ def test_qnn_backend_index(self):
                 module = self.get_qdq_module(module, sample_input)
                 self.lower_module_and_test_output(module, sample_input)
 
+    def test_qnn_backend_index_copy(self):
+        test_comb = [
+            {
+                QCOM_MODULE: IndexCopy(skip_mutable_buffer=False),  # noqa: F405
+                QCOM_SAMPLE_INPUTS: (
+                    torch.tensor([2], dtype=torch.int64),
+                    torch.randn([1, 1, 12, 64]),
+                ),
+            },
+            {
+                QCOM_MODULE: IndexCopy(skip_mutable_buffer=True),  # noqa: F405
+                QCOM_SAMPLE_INPUTS: (
+                    torch.tensor([2], dtype=torch.int64),
+                    torch.randn([1, 1, 12, 64]),
+                ),
+            },
+        ]
+        for i, test in enumerate(test_comb):
+            with self.subTest(i=i):
+                module = self.get_qdq_module(
+                    test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
+                )
+                self.lower_module_and_test_output(
+                    module,
+                    test[QCOM_SAMPLE_INPUTS],
+                    skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer,
+                )
+
     def test_qnn_backend_index_put(self):
-        module = IndexPut()  # noqa: F405
-        sample_input = (
-            torch.tensor([2], dtype=torch.int32),
-            torch.randn([1, 1, 12, 64]),
-        )
-        module = self.get_qdq_module(module, sample_input)
-        self.lower_module_and_test_output(module, sample_input)
+        test_comb = [
+            {
+                QCOM_MODULE: IndexPut(skip_mutable_buffer=False),  # noqa: F405
+                QCOM_SAMPLE_INPUTS: (
+                    torch.tensor([2], dtype=torch.int32),
+                    torch.randn([1, 1, 12, 64]),
+                ),
+            },
+            {
+                QCOM_MODULE: IndexPut(skip_mutable_buffer=True),  # noqa: F405
+                QCOM_SAMPLE_INPUTS: (
+                    torch.tensor([2], dtype=torch.int32),
+                    torch.randn([1, 1, 12, 64]),
+                ),
+            },
+        ]
+        for i, test in enumerate(test_comb):
+            with self.subTest(i=i):
+                module = self.get_qdq_module(
+                    test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS]
+                )
+                self.lower_module_and_test_output(
+                    module,
+                    test[QCOM_SAMPLE_INPUTS],
+                    skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer,
+                )
 
     def test_qnn_backend_instance_norm_2d(self):
         modules = [InstanceNorm2d(32), InstanceNorm2d(32, affine=False)]  # noqa: F405
@@ -3030,7 +3119,17 @@ def test_qnn_backend_generate_optrace(self):
                 for _, (optrace, qhas) in binaries_trace.items():
                     with open(optrace, "r") as optrace_file:
                         optrace_data = json.load(optrace_file)
-                        for row in optrace_data:
+                        # {
+                        #  header:
+                        #    {
+                        #     'header_version': {'major': x, 'minor': y, 'patch': z},
+                        #     'version': {'major': x, 'minor': y, 'patch': z},
+                        #     'artifact_type': 'OP_TRACE'
+                        #    }
+                        #  traceEvents:
+                        #    {...}
+                        # }
+                        for row in optrace_data["traceEvents"]:
                             self.assertIn("pid", row)
                     with open(qhas, "r") as qhas_file:
                         qhas_data = json.load(qhas_file)
@@ -3726,7 +3825,17 @@ def test_qnn_backend_generate_optrace(self):
                 for _, (optrace, qhas) in binaries_trace.items():
                     with open(optrace, "r") as optrace_file:
                         optrace_data = json.load(optrace_file)
-                        for row in optrace_data:
+                        # {
+                        #  header:
+                        #    {
+                        #     'header_version': {'major': x, 'minor': y, 'patch': z},
+                        #     'version': {'major': x, 'minor': y, 'patch': z},
+                        #     'artifact_type': 'OP_TRACE'
+                        #    }
+                        #  traceEvents:
+                        #    {...}
+                        # }
+                        for row in optrace_data["traceEvents"]:
                             self.assertIn("pid", row)
                     with open(qhas, "r") as qhas_file:
                         qhas_data = json.load(qhas_file)
diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py
index 2968086d7a5..2e923b92250 100644
--- a/backends/qualcomm/tests/utils.py
+++ b/backends/qualcomm/tests/utils.py
@@ -462,6 +462,7 @@ def lower_module_and_test_output(
         passes_job: Optional[OrderedDict] = None,
         skip_node_id_set: set = None,
         skip_node_op_set: set = None,
+        skip_mutable_buffer: bool = False,
         dynamic_shapes: Dict = None,
     ):
         delegated_program = to_edge_transform_and_lower_to_qnn(
@@ -472,6 +473,7 @@ def lower_module_and_test_output(
             passes_job=passes_job,
             skip_node_id_set=skip_node_id_set,
             skip_node_op_set=skip_node_op_set,
+            skip_mutable_buffer=skip_mutable_buffer,
         )
 
         # this is needed for the ETRecord as lowering modifies the graph in-place
diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py
index 45f08c8f2c1..3471b0155bd 100644
--- a/backends/qualcomm/utils/utils.py
+++ b/backends/qualcomm/utils/utils.py
@@ -332,6 +332,7 @@ def to_edge_transform_and_lower_to_qnn(
     passes_job: Optional[Union[OrderedDict, Dict[str, OrderedDict]]] = None,
     skip_node_id_set: Optional[set] = None,
     skip_node_op_set: Optional[set] = None,
+    skip_mutable_buffer: bool = False,
 ) -> EdgeProgramManager:
     """
     Transforms and lowers a given PyTorch module to the QNN backend.
@@ -356,6 +357,8 @@ def to_edge_transform_and_lower_to_qnn(
             Set of node IDs to skip during partitioning.
         skip_node_op_set (Optional[set]):
             Set of node operations to skip during partitioning.
+        skip_mutable_buffer (Optional[set]):
+            Whether to skip delegating the mutable buffer in QNN backend.
 
     Returns:
         EdgeProgramManager:
@@ -407,6 +410,7 @@ def ensure_graph_specific_dict(value, graph_names):
                 compiler_specs[graph_name],
                 skip_node_id_set=skip_node_id_set,
                 skip_node_op_set=skip_node_op_set,
+                skip_mutable_buffer=skip_mutable_buffer,
             )
         ]
         for graph_name in graph_names
diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py
index 1f055d65822..f79a25b6077 100644
--- a/examples/models/llama/export_llama_lib.py
+++ b/examples/models/llama/export_llama_lib.py
@@ -1187,7 +1187,6 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
         calibration_seq_length=llm_config.quantization.calibration_seq_length,
         calibration_data=llm_config.quantization.calibration_data,
         tokenizer_path=llm_config.base.tokenizer_path,
-        use_legacy_export=llm_config.backend.qnn.enabled,
         save_exported_program=llm_config.export.export_only,
         verbose=llm_config.debug.verbose,
         metadata=_load_llama_model_metadata(
diff --git a/examples/models/llama/source_transformation/attention.py b/examples/models/llama/source_transformation/attention.py
index d5f065550d2..3c37cddee69 100644
--- a/examples/models/llama/source_transformation/attention.py
+++ b/examples/models/llama/source_transformation/attention.py
@@ -45,12 +45,10 @@ def __init__(
             self.register_buffer(
                 f"past_k_caches_{i}",
                 torch.zeros(cache_shape, dtype=dtype, device="cpu"),
-                persistent=False,
             )
             self.register_buffer(
                 f"past_v_caches_{i}",
                 torch.zeros(cache_shape, dtype=dtype, device="cpu"),
-                persistent=False,
             )
 
     def update(
diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py
index 1fb3d97a9c7..59823b533a3 100644
--- a/examples/models/llama/source_transformation/sdpa.py
+++ b/examples/models/llama/source_transformation/sdpa.py
@@ -493,12 +493,10 @@ def __init__(
         self.register_buffer(
             "past_k_caches",
             torch.zeros(cache_shape, dtype=dtype, device="cpu"),
-            persistent=False,
         )
         self.register_buffer(
             "past_v_caches",
             torch.zeros(cache_shape, dtype=dtype, device="cpu"),
-            persistent=False,
         )
 
     def update(
diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py
index 4128bfd8198..079d89848dd 100644
--- a/extension/llm/export/builder.py
+++ b/extension/llm/export/builder.py
@@ -10,11 +10,9 @@
 
 # pyre-unsafe
 
-import contextlib
 import logging
 from enum import Enum
 from typing import Any, Callable, Dict, List, Optional, Tuple
-from unittest.mock import patch
 
 import torch
 from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
@@ -96,7 +94,6 @@ def __init__(
         verbose: bool = False,
         metadata: Optional[dict] = None,
         dynamic_shapes: Optional[Any] = None,
-        use_legacy_export: bool = False,
         save_exported_program: bool = False,
     ):
         # Store necessary constructor arguments.
@@ -117,7 +114,6 @@ def __init__(
         self.verbose = verbose
         self.metadata = metadata
         self.dynamic_shapes = dynamic_shapes
-        self.use_legacy_export = use_legacy_export
         self.save_exported_program = save_exported_program
 
         # Note: treat this as the source of truth for the result of
@@ -229,39 +225,20 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
         # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
         # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
         with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
-            if self.use_legacy_export:
-                # TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
-                # See issue: https://github.com/pytorch/executorch/issues/7373
-
-                with patch.object(
-                    torch._utils_internal,
-                    "export_training_ir_rollout_check",
-                    return_value=False,
-                ):
-                    # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
-                    # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
-                    exported_module = torch.export.export(
-                        self.model if not module else module,
-                        self.example_inputs,
-                        self.example_kwarg_inputs,
-                        dynamic_shapes=dynamic_shape,
-                        strict=True,
-                    )
+            if module:
+                logging.info("Re-exporting with:")
             else:
-                if module:
-                    logging.info("Re-exporting with:")
-                else:
-                    logging.info("Exporting with:")
-                logging.info(f"inputs: {self.example_inputs}")
-                logging.info(f"kwargs: {self.example_kwarg_inputs}")
-                logging.info(f"dynamic shapes: {dynamic_shape}")
-                exported_module = export_for_training(
-                    self.model if not module else module,
-                    self.example_inputs,
-                    kwargs=self.example_kwarg_inputs,
-                    dynamic_shapes=dynamic_shape,
-                    strict=True,
-                )
+                logging.info("Exporting with:")
+            logging.info(f"inputs: {self.example_inputs}")
+            logging.info(f"kwargs: {self.example_kwarg_inputs}")
+            logging.info(f"dynamic shapes: {dynamic_shape}")
+            exported_module = export_for_training(
+                self.model if not module else module,
+                self.example_inputs,
+                kwargs=self.example_kwarg_inputs,
+                dynamic_shapes=dynamic_shape,
+                strict=True,
+            )
         return exported_module
 
     def export(self) -> "LLMEdgeManager":
@@ -446,24 +423,15 @@ def export_to_edge(self) -> "LLMEdgeManager":
                 # Run export() if it didn't run
                 self.export()
 
-            override_export_behaviour = contextlib.nullcontext()
-            if self.use_legacy_export:
-                override_export_behaviour = patch.object(
-                    torch._utils_internal,
-                    "export_training_ir_rollout_check",
-                    return_value=False,
-                )
-
-            with override_export_behaviour:
-                self.edge_manager = export_to_edge(
-                    self.pre_autograd_graph_module,  # pyre-fixme[6]
-                    self.example_inputs,
-                    example_kwarg_inputs=self.example_kwarg_inputs,
-                    dynamic_shapes=dynamic_shape,
-                    edge_constant_methods=self.metadata,
-                    edge_compile_config=edge_config,
-                    verbose=self.verbose,
-                )
+            self.edge_manager = export_to_edge(
+                self.pre_autograd_graph_module,  # pyre-fixme[6]
+                self.example_inputs,
+                example_kwarg_inputs=self.example_kwarg_inputs,
+                dynamic_shapes=dynamic_shape,
+                edge_constant_methods=self.metadata,
+                edge_compile_config=edge_config,
+                verbose=self.verbose,
+            )
         return self
 
     def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py
index 20604bbf635..e35358f56b7 100644
--- a/extension/llm/export/partitioner_lib.py
+++ b/extension/llm/export/partitioner_lib.py
@@ -216,4 +216,5 @@ def get_qnn_partitioner(
         ),
         skip_node_id_set={},
         skip_node_op_set=skip_node_op_set,
+        skip_mutable_buffer=False,
     )

From 19c5aa13cb232eca0b0d385b41c93c5ebe4cbad5 Mon Sep 17 00:00:00 2001
From: shewu-quic <quic_shewu@quicinc.com>
Date: Thu, 19 Jun 2025 12:14:53 +0800
Subject: [PATCH 2/3] Fixed the CI for meta's llama

---
 backends/qualcomm/_passes/__init__.py         |  2 +
 .../qualcomm/_passes/convert_bmm_to_matmul.py | 76 +++++++++++++++++++
 backends/qualcomm/_passes/qnn_pass_manager.py |  2 +
 backends/qualcomm/_passes/utils.py            |  3 +
 .../qualcomm/quantizer/custom_annotation.py   |  9 ++-
 examples/models/llama/export_llama_lib.py     |  2 +
 6 files changed, 90 insertions(+), 4 deletions(-)
 create mode 100644 backends/qualcomm/_passes/convert_bmm_to_matmul.py

diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py
index 2b743454bf9..01710aa8d80 100644
--- a/backends/qualcomm/_passes/__init__.py
+++ b/backends/qualcomm/_passes/__init__.py
@@ -8,6 +8,7 @@
 from .annotate_quant_attrs import AnnotateQuantAttrs
 from .annotate_stack import AnnotateStack
 from .annotate_unbind import AnnotateUnbind
+from .convert_bmm_to_matmul import ConvertBmmToMatmul
 from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
 from .convert_square_to_pow import ConvertSquareToPow
 from .decompose_any import DecomposeAny
@@ -44,6 +45,7 @@
     AnnotateQuantAttrs,
     AnnotateStack,
     AnnotateUnbind,
+    ConvertBmmToMatmul,
     ConvertConv1dToConv2d,
     ConvertSquareToPow,
     DecomposeAny,
diff --git a/backends/qualcomm/_passes/convert_bmm_to_matmul.py b/backends/qualcomm/_passes/convert_bmm_to_matmul.py
new file mode 100644
index 00000000000..84e1ff26aa1
--- /dev/null
+++ b/backends/qualcomm/_passes/convert_bmm_to_matmul.py
@@ -0,0 +1,76 @@
+# Copyright (c) Qualcomm Innovation Center, Inc.
+# All rights reserved
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+import operator
+from collections import Counter
+from typing import List
+
+import torch
+from executorch.exir.dialects._ops import ops as exir_ops
+from executorch.exir.pass_base import ExportPass, PassResult
+from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
+
+
+class ConvertBmmToMatmul(ExportPass):
+    """
+    Replace bmm to matmul, because bmm is eqaul to matmul in QNN.
+    Handle missing quantization tag for bmm op.
+    """
+
+    view_copy = exir_ops.edge.aten.view_copy.default
+    expand_copy = exir_ops.edge.aten.expand_copy.default
+    clone = exir_ops.edge.aten.clone.default
+    bmm = exir_ops.edge.aten.bmm.default
+    matmul = exir_ops.edge.aten.matmul.default
+    patterns = [
+        {expand_copy: 2, view_copy: 3, bmm: 1},
+        {expand_copy: 2, view_copy: 3, bmm: 1, clone: 1},
+        {bmm: 1},
+    ]
+
+    def __init__(self):
+        super(ConvertBmmToMatmul, self).__init__()
+
+    def _get_ordered_inputs(
+        self, inputs: List[torch.fx.Node], output: torch.fx.Node
+    ) -> List[torch.fx.Node]:
+        bmm_inputs = []
+        for arg in output.args:
+            while arg not in inputs:
+                arg = arg.args[0]
+            bmm_inputs.append(arg)
+        return bmm_inputs
+
+    def call(self, graph_module: torch.fx.GraphModule):
+        graph = graph_module.graph
+        partitions = get_source_partitions(
+            graph,
+            [operator.matmul, torch.matmul, torch.bmm, torch.ops.aten.matmul.default],
+        )
+        for _, src_partitions in partitions.items():
+            for src_partition in src_partitions:
+                op_cnt = Counter([n.target for n in src_partition.nodes])
+                if op_cnt not in self.patterns:
+                    raise AssertionError(
+                        "Found a new pattern needed be converted to linear op"
+                    )
+
+                inputs = src_partition.input_nodes
+                bmm_node = [n for n in src_partition.nodes if n.target == self.bmm][0]
+                output = src_partition.output_nodes[0]
+                # the order of src_partition.inputs is not guaranteed.
+                lhs, rhs = self._get_ordered_inputs(inputs, bmm_node)
+                with graph_module.graph.inserting_before(output):
+                    # replace bmm to matmul, because bmm is eqaul to matmul in qnn.
+                    matmul_node = graph.create_node(
+                        "call_function", self.matmul, (lhs, rhs)
+                    )
+                    matmul_node.meta = output.meta
+                    for user in output.users.copy():
+                        user.replace_input_with(output, matmul_node)
+
+        graph.eliminate_dead_code()
+        graph_module.recompile()
+        return PassResult(graph_module, True)
diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py
index 9d2726eaaed..8340fa6209e 100644
--- a/backends/qualcomm/_passes/qnn_pass_manager.py
+++ b/backends/qualcomm/_passes/qnn_pass_manager.py
@@ -13,6 +13,7 @@
     AnnotateQuantAttrs,
     AnnotateStack,
     AnnotateUnbind,
+    ConvertBmmToMatmul,
     ConvertConv1dToConv2d,
     ConvertSquareToPow,
     DecomposeAny,
@@ -79,6 +80,7 @@ def get_capture_program_passes():
         (AnnotateQuantAttrs, True),
         (AnnotateStack, True),
         (AnnotateUnbind, True),
+        (ConvertBmmToMatmul, False),
         (ConvertConv1dToConv2d, True),
         (DecomposeAny, True),
         (DecomposeColIm, True),
diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py
index b23efc8714f..ae11ba7b325 100755
--- a/backends/qualcomm/_passes/utils.py
+++ b/backends/qualcomm/_passes/utils.py
@@ -64,6 +64,7 @@ def get_passes_dependency_for_capture_program():
         AnnotateQuantAttrs,
         AnnotateStack,
         AnnotateUnbind,
+        ConvertBmmToMatmul,
         ConvertConv1dToConv2d,
         DecomposeAny,
         DecomposeColIm,
@@ -82,11 +83,13 @@ def get_passes_dependency_for_capture_program():
     return {
         AnnotateAdaptiveAvgPool1D: [RemoveRedundancy],
         AnnotateQuantAttrs: [
+            ConvertBmmToMatmul,
             RecomposePixelUnshuffle,
             RemoveRedundancy,
         ],
         AnnotateStack: [RemoveRedundancy],
         AnnotateUnbind: [RemoveRedundancy],
+        ConvertBmmToMatmul: [RecomposePixelUnshuffle],
         DecomposeAny: [RemoveRedundancy],
         DecomposeColIm: [FoldQDQ],
         DecomposeLinalgVectorNorm: [RemoveRedundancy],
diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py
index b5531cd492f..9e9370ef697 100644
--- a/backends/qualcomm/quantizer/custom_annotation.py
+++ b/backends/qualcomm/quantizer/custom_annotation.py
@@ -292,14 +292,15 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
         )
 
     def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None:
-        input = node.args[0]
+        # Avoid annotating the input node because mutable buffers will be folded during the convert_pt2e process.
         value = node.args[2]
+
         input_qspec_map = {}
-        input_qspec_map[input] = quantization_config.input_activation
-        input_qspec_map[value] = SharedQuantizationSpec((input, node))
+        input_qspec_map[value] = quantization_config.input_activation
+
         node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
             input_qspec_map=input_qspec_map,
-            output_qspec=SharedQuantizationSpec((input, node)),
+            output_qspec=SharedQuantizationSpec((value, node)),
             _annotated=True,
         )
 
diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py
index f79a25b6077..db2b781825f 100644
--- a/examples/models/llama/export_llama_lib.py
+++ b/examples/models/llama/export_llama_lib.py
@@ -914,6 +914,7 @@ def _to_edge_and_lower_llama(  # noqa: C901
         # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm._passes`
         from executorch.backends.qualcomm._passes import (
             AnnotateStack,
+            ConvertBmmToMatmul,
             FoldQDQ,
             RecomposeRmsNorm,
             TagQuantIO,
@@ -956,6 +957,7 @@ def _to_edge_and_lower_llama(  # noqa: C901
         passes_job = get_capture_program_passes()
         dep_table = get_passes_dependency_for_capture_program()
         passes_job[AnnotateStack][QCOM_PASS_ACTIVATE_KEY] = True
+        passes_job[ConvertBmmToMatmul][QCOM_PASS_ACTIVATE_KEY] = True
         passes_job[RecomposeRmsNorm][QCOM_PASS_ACTIVATE_KEY] = True
         passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True
         passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][

From 5fde193fa730e0a9fa7d20aa10c037108da4f1ae Mon Sep 17 00:00:00 2001
From: shewu-quic <quic_shewu@quicinc.com>
Date: Fri, 20 Jun 2025 10:11:52 +0800
Subject: [PATCH 3/3] revert change

---
 examples/models/llama/export_llama_lib.py |  1 +
 extension/llm/export/builder.py           | 76 ++++++++++++++++-------
 extension/llm/export/partitioner_lib.py   |  3 +-
 3 files changed, 57 insertions(+), 23 deletions(-)

diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py
index db2b781825f..9d5fcfdba25 100644
--- a/examples/models/llama/export_llama_lib.py
+++ b/examples/models/llama/export_llama_lib.py
@@ -1189,6 +1189,7 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager":
         calibration_seq_length=llm_config.quantization.calibration_seq_length,
         calibration_data=llm_config.quantization.calibration_data,
         tokenizer_path=llm_config.base.tokenizer_path,
+        use_legacy_export=llm_config.backend.qnn.enabled,
         save_exported_program=llm_config.export.export_only,
         verbose=llm_config.debug.verbose,
         metadata=_load_llama_model_metadata(
diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py
index 079d89848dd..4128bfd8198 100644
--- a/extension/llm/export/builder.py
+++ b/extension/llm/export/builder.py
@@ -10,9 +10,11 @@
 
 # pyre-unsafe
 
+import contextlib
 import logging
 from enum import Enum
 from typing import Any, Callable, Dict, List, Optional, Tuple
+from unittest.mock import patch
 
 import torch
 from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
@@ -94,6 +96,7 @@ def __init__(
         verbose: bool = False,
         metadata: Optional[dict] = None,
         dynamic_shapes: Optional[Any] = None,
+        use_legacy_export: bool = False,
         save_exported_program: bool = False,
     ):
         # Store necessary constructor arguments.
@@ -114,6 +117,7 @@ def __init__(
         self.verbose = verbose
         self.metadata = metadata
         self.dynamic_shapes = dynamic_shapes
+        self.use_legacy_export = use_legacy_export
         self.save_exported_program = save_exported_program
 
         # Note: treat this as the source of truth for the result of
@@ -225,20 +229,39 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
         # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
         # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
         with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
-            if module:
-                logging.info("Re-exporting with:")
+            if self.use_legacy_export:
+                # TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
+                # See issue: https://github.com/pytorch/executorch/issues/7373
+
+                with patch.object(
+                    torch._utils_internal,
+                    "export_training_ir_rollout_check",
+                    return_value=False,
+                ):
+                    # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
+                    # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
+                    exported_module = torch.export.export(
+                        self.model if not module else module,
+                        self.example_inputs,
+                        self.example_kwarg_inputs,
+                        dynamic_shapes=dynamic_shape,
+                        strict=True,
+                    )
             else:
-                logging.info("Exporting with:")
-            logging.info(f"inputs: {self.example_inputs}")
-            logging.info(f"kwargs: {self.example_kwarg_inputs}")
-            logging.info(f"dynamic shapes: {dynamic_shape}")
-            exported_module = export_for_training(
-                self.model if not module else module,
-                self.example_inputs,
-                kwargs=self.example_kwarg_inputs,
-                dynamic_shapes=dynamic_shape,
-                strict=True,
-            )
+                if module:
+                    logging.info("Re-exporting with:")
+                else:
+                    logging.info("Exporting with:")
+                logging.info(f"inputs: {self.example_inputs}")
+                logging.info(f"kwargs: {self.example_kwarg_inputs}")
+                logging.info(f"dynamic shapes: {dynamic_shape}")
+                exported_module = export_for_training(
+                    self.model if not module else module,
+                    self.example_inputs,
+                    kwargs=self.example_kwarg_inputs,
+                    dynamic_shapes=dynamic_shape,
+                    strict=True,
+                )
         return exported_module
 
     def export(self) -> "LLMEdgeManager":
@@ -423,15 +446,24 @@ def export_to_edge(self) -> "LLMEdgeManager":
                 # Run export() if it didn't run
                 self.export()
 
-            self.edge_manager = export_to_edge(
-                self.pre_autograd_graph_module,  # pyre-fixme[6]
-                self.example_inputs,
-                example_kwarg_inputs=self.example_kwarg_inputs,
-                dynamic_shapes=dynamic_shape,
-                edge_constant_methods=self.metadata,
-                edge_compile_config=edge_config,
-                verbose=self.verbose,
-            )
+            override_export_behaviour = contextlib.nullcontext()
+            if self.use_legacy_export:
+                override_export_behaviour = patch.object(
+                    torch._utils_internal,
+                    "export_training_ir_rollout_check",
+                    return_value=False,
+                )
+
+            with override_export_behaviour:
+                self.edge_manager = export_to_edge(
+                    self.pre_autograd_graph_module,  # pyre-fixme[6]
+                    self.example_inputs,
+                    example_kwarg_inputs=self.example_kwarg_inputs,
+                    dynamic_shapes=dynamic_shape,
+                    edge_constant_methods=self.metadata,
+                    edge_compile_config=edge_config,
+                    verbose=self.verbose,
+                )
         return self
 
     def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py
index e35358f56b7..7b093a7f1a3 100644
--- a/extension/llm/export/partitioner_lib.py
+++ b/extension/llm/export/partitioner_lib.py
@@ -216,5 +216,6 @@ def get_qnn_partitioner(
         ),
         skip_node_id_set={},
         skip_node_op_set=skip_node_op_set,
-        skip_mutable_buffer=False,
+        # TODO: if deprecated legacy export, skip_mutable_buffer can be set False
+        skip_mutable_buffer=True,
     )