diff --git a/backends/arm/operators/op_index_select.py b/backends/arm/operators/op_index_select.py index a4f541e65d9..ba2aa03c7ff 100644 --- a/backends/arm/operators/op_index_select.py +++ b/backends/arm/operators/op_index_select.py @@ -19,7 +19,7 @@ ) from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.utils import build_reshape_tosa_1_0 +from executorch.backends.arm.tosa.utils import build_reshape_tosa from torch.fx import Node @@ -67,7 +67,7 @@ def define_node( weights_new_shape, weights.dtype, ) - build_reshape_tosa_1_0( + build_reshape_tosa( tosa_graph, weights.name, weights_new_shape, weights_reshaped.name ) @@ -89,7 +89,7 @@ def define_node( indices_new_shape, indices.dtype, ) - build_reshape_tosa_1_0( + build_reshape_tosa( tosa_graph, indices.name, indices_new_shape, indices_reshaped.name ) @@ -106,6 +106,4 @@ def define_node( if len(weights.shape) == 2: output_real_shape = [output.shape[0], output.shape[1]] - build_reshape_tosa_1_0( - tosa_graph, output_name, output_real_shape, output.name - ) + build_reshape_tosa(tosa_graph, output_name, output_real_shape, output.name) diff --git a/backends/arm/operators/op_index_tensor.py b/backends/arm/operators/op_index_tensor.py index 710b5f8e1d8..cd0809df95b 100644 --- a/backends/arm/operators/op_index_tensor.py +++ b/backends/arm/operators/op_index_tensor.py @@ -180,7 +180,7 @@ def define_node( gather_idx_shape, index_dtype, ) - tutils.build_reshape_tosa_1_0( + tutils.build_reshape_tosa( tosa_graph, stride_shifted_indices.name, gather_idx_shape, @@ -212,7 +212,7 @@ def define_node( gather_vals_shape = [N, K, C] reshaped_input = tosa_graph.addIntermediate(gather_vals_shape, values.dtype) - tutils.build_reshape_tosa_1_0( + tutils.build_reshape_tosa( tosa_graph, values.name, gather_vals_shape, @@ -238,7 +238,7 @@ def define_node( output_shape = tutils.tosa_shape(output.shape, output.dim_order) - tutils.build_reshape_tosa_1_0( + tutils.build_reshape_tosa( tosa_graph, gather_out.name, list(output_shape), diff --git a/backends/arm/tosa/utils.py b/backends/arm/tosa/utils.py index 60ed0376697..df77153e29f 100644 --- a/backends/arm/tosa/utils.py +++ b/backends/arm/tosa/utils.py @@ -110,7 +110,7 @@ def broadcast_tensors( tens_dtype, ) - build_reshape_tosa_1_0(tosa_fb, node.name, new_shape, reshaped.name) + build_reshape_tosa(tosa_fb, node.name, new_shape, reshaped.name) tiled = tosa_fb.addIntermediate(common_shape, tens_dtype) multipliers = [ @@ -137,7 +137,7 @@ def broadcast_tensors( return broadcast_tensors -def build_reshape_tosa_1_0( +def build_reshape_tosa( tosa_graph, input_name, new_shape, output_name, shape_name_override="" ): """Insert a TOSA reshape operator using the v1.0 semantics.