Skip to content

Commit 5a55382

Browse files
authored
Arm backend: Initial support of conditional operator (#15549)
- Add partition check to make sure that the submodules with the if/else codepaths are fully delegated. - Fix some partitioning issues with submodule nodes, since they point to a submodule rather than a tensor they dont have a fake tensor. - Add node visitor. - Add tests. Arm backend: Use output.name in node visitors As mentioned in #15381, TOSA tensors need unique naming, which gets tricky with submodules. It is handled in the TosaArg object, and therefore node visitors need to use output.name rather than node.name when creating new tensors. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai --------- Signed-off-by: Erik Lundell <[email protected]>
1 parent 0bd635e commit 5a55382

19 files changed

+470
-46
lines changed

backends/arm/_passes/arm_pass_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,25 @@
3131
from torch.export.graph_signature import InputKind
3232

3333

34+
def is_submodule_node(node: torch.fx.Node):
35+
if node.op not in ("get_attr", "placeholder"):
36+
return False
37+
try:
38+
node.graph.owning_module.get_submodule(node.target)
39+
except AttributeError:
40+
return False
41+
return True
42+
43+
3444
def is_get_attr_node(node: torch.fx.Node) -> bool:
3545
"""
36-
Returns true if the given node is a get attr node for a tensor of the model
46+
Returns true if the given node is a get attr node for a tensor of the model.
3747
"""
38-
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
48+
return (
49+
isinstance(node, torch.fx.Node)
50+
and node.op == "get_attr"
51+
and not is_submodule_node(node)
52+
)
3953

4054

4155
def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:

backends/arm/_passes/cast_int64_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def _to_int32(self, graph_module: torch.fx.GraphModule):
4141
for node in graph_module.graph.nodes:
4242
if len(node.users) == 0:
4343
continue
44+
if "val" not in node.meta:
45+
continue
4446
fake_tensor = node.meta["val"]
4547
if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
4648
continue

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ def remove_dim_order_kwargs(
299299

300300
def call(self, graph_module: torch.fx.GraphModule):
301301
for node in graph_module.graph.nodes:
302+
if "val" not in node.meta:
303+
continue
302304
node_data = get_first_fake_tensor(node).data
303305

304306
self.remove_dim_order_kwargs(graph_module, node)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 139 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
import itertools
88
import operator
99
import typing
10-
from typing import final, Optional, Sequence, Type
10+
from typing import cast, final, Optional, Sequence, Type
1111

1212
import torch
1313
import torch.fx as fx
1414

15-
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
15+
from executorch.backends.arm._passes.arm_pass_utils import (
16+
get_first_fake_tensor,
17+
is_submodule_node,
18+
)
1619
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
1720
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
1821
FuseQuantizedActivationPass,
@@ -31,6 +34,7 @@
3134
TOSA_PRO_INT_SupportList,
3235
)
3336
from executorch.backends.arm.tosa import TosaSpecification
37+
from executorch.backends.arm.tosa.specification import Tosa_1_00
3438
from executorch.exir import ExportedProgram
3539
from executorch.exir.backend.utils import WhyNoPartitionReporter
3640
from executorch.exir.dialects._ops import ops as exir_ops
@@ -110,7 +114,9 @@ def tosa_support_factory(
110114
Additional checks can be supplied to avoid partitioning additional nodes.
111115
"""
112116
# Postive checks: Add nodes to partitioning
113-
positive_checks: list[OperatorSupportBase] = []
117+
positive_checks: list[OperatorSupportBase] = [
118+
CondSupported(exported_program, tosa_spec, reporter)
119+
]
114120

115121
if tosa_spec.support_integer():
116122
positive_checks.append(TOSAProINTSupportList())
@@ -350,7 +356,8 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
350356
def is_node_supported(
351357
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
352358
) -> bool:
353-
359+
if is_submodule_node(node):
360+
return True
354361
vals = node.meta["val"]
355362
tensor_list = vals if isinstance(vals, (list, tuple)) else [vals]
356363

@@ -390,7 +397,11 @@ def is_node_supported(
390397

391398
# Ops with int64 inputs are only partitioned if input nodes are constant and will be partitioned.
392399
# If it is not partitioned, the partition will get an int64 input and fail.
393-
for input_node in node.all_input_nodes:
400+
for input_node in (
401+
input_node
402+
for input_node in node.all_input_nodes
403+
if input_node.op != "get_attr"
404+
):
394405
tensor_in = get_first_fake_tensor(input_node)
395406
if tensor_in.dtype != torch.int64:
396407
continue
@@ -426,8 +437,13 @@ def __init__(
426437
def is_node_supported(
427438
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
428439
) -> bool:
429-
430-
for input_node in node.all_input_nodes:
440+
if is_submodule_node(node):
441+
return True
442+
for input_node in (
443+
input_node
444+
for input_node in node.all_input_nodes
445+
if input_node.op != "get_attr"
446+
):
431447
tensor = get_first_fake_tensor(input_node)
432448
if tensor.dtype == torch.float64:
433449
self.reporter.report_reject(
@@ -449,7 +465,13 @@ def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int):
449465
def is_node_supported(
450466
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
451467
) -> bool:
452-
input_nodes = node.all_input_nodes
468+
if is_submodule_node(node):
469+
return True
470+
input_nodes = (
471+
input_node
472+
for input_node in node.all_input_nodes
473+
if input_node.op != "get_attr"
474+
)
453475
# check if any input node has an unsupported rank
454476
for input_node in input_nodes:
455477
input_node_shape = get_first_fake_tensor(input_node).shape
@@ -484,3 +506,112 @@ def is_node_supported(
484506
)
485507
return False
486508
return True
509+
510+
511+
class CondSupported(OperatorSupportBase):
512+
"""Checks whether the cond operator, and it's submodule args, should be partitioned."""
513+
514+
def __init__(
515+
self,
516+
exported_program: ExportedProgram,
517+
tosa_spec: TosaSpecification,
518+
reporter: WhyNoPartitionReporter,
519+
):
520+
self.exported_program = exported_program
521+
self.reporter = reporter
522+
self.tosa_spec = tosa_spec
523+
super().__init__()
524+
525+
def _fully_partitioned(self, submodule: fx.GraphModule) -> bool:
526+
partition_tag = None
527+
for submodule_node in submodule.graph.nodes:
528+
if submodule_node.op == "call_function":
529+
# Input Q ops and output DQ ops will be de-tagged even if the submodule is fully supported.
530+
if (
531+
submodule_node.target in Q_OPS
532+
and list(submodule_node.all_input_nodes)[0].op == "placeholder"
533+
):
534+
continue
535+
if (
536+
submodule_node.target in DQ_OPS
537+
and list(submodule_node.users)[0].op == "output"
538+
):
539+
continue
540+
if "delegation_tag" not in submodule_node.meta:
541+
return False
542+
if partition_tag is None:
543+
partition_tag = submodule_node.meta["delegation_tag"]
544+
elif submodule_node.meta["delegation_tag"] != partition_tag:
545+
return False
546+
return True
547+
548+
def _cond_submodules_fully_partitioned(self, node: fx.Node) -> bool:
549+
"""Returns whether the submodule arguments to a cond node were fully partitioned.
550+
Updates "val" meta of the submodules if they are.
551+
"""
552+
cond_submodules = (
553+
(
554+
self.exported_program.graph_module.get_submodule(
555+
str(cast(torch.fx.Node, submodule_node).target)
556+
),
557+
cast(torch.fx.Node, submodule_node),
558+
)
559+
for submodule_node in node.args[1:3]
560+
)
561+
for submodule, submodule_node in cond_submodules:
562+
submodule = cast(torch.fx.GraphModule, submodule)
563+
564+
if self._fully_partitioned(submodule):
565+
submodule_node.meta["val"] = submodule.graph.output_node().meta["val"]
566+
else:
567+
return False
568+
return True
569+
570+
def is_node_supported( # noqa: C901
571+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
572+
) -> bool:
573+
if is_submodule_node(node):
574+
if not isinstance(self.tosa_spec, Tosa_1_00):
575+
self.reporter.report_reject(
576+
node, "Control flow extension not supported for TOSA version <1.0"
577+
)
578+
return False
579+
if not self.tosa_spec.support_extension("cf"):
580+
self.reporter.report_reject(
581+
node,
582+
f"TOSA spec {self.tosa_spec} does not support control flow extension.",
583+
)
584+
return False
585+
for user in node.users:
586+
if user.target != torch.ops.higher_order.cond:
587+
self.reporter.report_reject(
588+
node, f"Submodule had unsupported user {user}"
589+
)
590+
return False
591+
if not self._cond_submodules_fully_partitioned(user):
592+
self.reporter.report_reject(
593+
node, "One submodule was not fully partitioned"
594+
)
595+
return False
596+
return True
597+
if node.target == torch.ops.higher_order.cond:
598+
if not isinstance(self.tosa_spec, Tosa_1_00):
599+
self.reporter.report_reject(
600+
node, "Control flow extension not supported for TOSA version <1.0"
601+
)
602+
return False
603+
if not self.tosa_spec.support_extension("cf"):
604+
self.reporter.report_reject(
605+
node,
606+
f"TOSA spec {self.tosa_spec} does not support control flow extension.",
607+
)
608+
return False
609+
610+
if not self._cond_submodules_fully_partitioned(node):
611+
self.reporter.report_reject(
612+
node, "Submodule was not fully partitioned."
613+
)
614+
return False
615+
return True
616+
617+
return False

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
op_cat,
1717
op_ceil,
1818
op_clamp,
19+
op_cond_if,
1920
op_constant_pad_nd,
2021
op_cos,
2122
op_eq,
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import Any, cast, List
8+
9+
import tosa_serializer as ts
10+
11+
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.operators.operator_validation_utils import (
16+
validate_num_inputs,
17+
validate_valid_dtype,
18+
)
19+
from executorch.backends.arm.tosa.mapping import TosaArg # type: ignore
20+
from executorch.backends.arm.tosa.specification import Tosa_1_00
21+
from torch.fx import Node
22+
23+
24+
@register_node_visitor
25+
class CondVisitor(NodeVisitor):
26+
target = "cond"
27+
28+
tosa_specs = NodeVisitor.tosa_specs
29+
30+
def define_node(
31+
self,
32+
node: Node,
33+
tosa_graph: Any,
34+
inputs: List[TosaArg],
35+
output: TosaArg,
36+
) -> None:
37+
38+
validate_num_inputs(self.target, inputs, 4)
39+
validate_valid_dtype(self.target, [inputs[0]], ts.DType.BOOL, self.tosa_spec)
40+
if not isinstance(self.tosa_spec, Tosa_1_00):
41+
raise ValueError("Trying to lower cond, but TOSA version is <1.0.")
42+
if not self.tosa_spec.support_extension("cf"):
43+
raise ValueError(
44+
f"Trying to lower cond, but TOSA specification {self.tosa_spec} does not support the cf extension."
45+
)
46+
47+
attr = ts.TosaSerializerAttribute()
48+
if_graph, else_graph = (cast(Node, arg).target for arg in node.args[1:3])
49+
attr.CondIfAttribute(if_graph, else_graph)
50+
51+
self._serialize_operator(
52+
node,
53+
tosa_graph,
54+
ts.Op.COND_IF,
55+
[
56+
inputs[0].name,
57+
*(subgraph_input.name for subgraph_input in inputs[-1].special),
58+
],
59+
[output.name],
60+
attr,
61+
)

backends/arm/operators/op_index_tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,14 @@ def define_node(
165165
# channels and thus the stride-shift.
166166
data = np.full(index_shape, int(values_strides[i] / C))
167167
mul_const = tosa_graph.addConst(index_shape, index_dtype, data)
168-
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_{i}_shift")
168+
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_{i}_shift")
169169
attr = ts.TosaSerializerAttribute()
170170
attr.MulAttribute()
171171
self._serialize_operator(
172172
node,
173173
tosa_graph,
174174
ts.Op.MUL,
175-
[index_name, mul_const.name, f"{node.name}_{i}_shift"],
175+
[index_name, mul_const.name, f"{output.name}_{i}_shift"],
176176
[stride_shifted_indices.name],
177177
attr,
178178
)
@@ -186,7 +186,7 @@ def define_node(
186186
stride_shifted_indices.name,
187187
gather_idx_shape,
188188
reshaped_idxs.name,
189-
shape_name_override=f"{node.name}_{i}_shape",
189+
shape_name_override=f"{output.name}_{i}_shape",
190190
)
191191

192192
# Guarantees that the accumulation tensor is properly
@@ -218,7 +218,7 @@ def define_node(
218218
values.name,
219219
gather_vals_shape,
220220
reshaped_input.name,
221-
shape_name_override=f"{node.name}_index_shape",
221+
shape_name_override=f"{output.name}_index_shape",
222222
)
223223

224224
gather_out_shape = (N, W, C)
@@ -244,5 +244,5 @@ def define_node(
244244
gather_out.name,
245245
list(output_shape),
246246
output.name,
247-
shape_name_override=f"{node.name}_output_shape",
247+
shape_name_override=f"{output.name}_output_shape",
248248
)

backends/arm/operators/op_mul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ def define_node(
4848
output.tosa_spec,
4949
)
5050

51-
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift")
51+
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_shift")
5252
attr = ts.TosaSerializerAttribute()
5353
attr.MulAttribute()
5454
self._serialize_operator(
5555
node,
5656
tosa_graph,
5757
ts.Op.MUL,
58-
[inputs[0].name, inputs[1].name, f"{node.name}_shift"],
58+
[inputs[0].name, inputs[1].name, f"{output.name}_shift"],
5959
[output.name],
6060
attr,
6161
)

backends/arm/operators/op_repeat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def define_node(
5656
(len(multiples),),
5757
ts.DType.SHAPE,
5858
list(tosa_shape(multiples, output.dim_order)),
59-
name=node.name + "_multiples",
59+
name=output.name + "_multiples",
6060
)
6161

6262
attr = ts.TosaSerializerAttribute()

0 commit comments

Comments
 (0)