77import itertools
88import operator
99import typing
10- from typing import final , Optional , Sequence , Type
10+ from typing import cast , final , Optional , Sequence , Type
1111
1212import torch
1313import torch .fx as fx
1414
15- from executorch .backends .arm ._passes .arm_pass_utils import get_first_fake_tensor
15+ from executorch .backends .arm ._passes .arm_pass_utils import (
16+ get_first_fake_tensor ,
17+ is_submodule_node ,
18+ )
1619from executorch .backends .arm ._passes .fuse_constant_ops_pass import ComputeConstantOpsAOT
1720from executorch .backends .arm ._passes .fuse_quantized_activation_pass import (
1821 FuseQuantizedActivationPass ,
3134 TOSA_PRO_INT_SupportList ,
3235)
3336from executorch .backends .arm .tosa import TosaSpecification
37+ from executorch .backends .arm .tosa .specification import Tosa_1_00
3438from executorch .exir import ExportedProgram
3539from executorch .exir .backend .utils import WhyNoPartitionReporter
3640from executorch .exir .dialects ._ops import ops as exir_ops
@@ -110,7 +114,9 @@ def tosa_support_factory(
110114 Additional checks can be supplied to avoid partitioning additional nodes.
111115 """
112116 # Postive checks: Add nodes to partitioning
113- positive_checks : list [OperatorSupportBase ] = []
117+ positive_checks : list [OperatorSupportBase ] = [
118+ CondSupported (exported_program , tosa_spec , reporter )
119+ ]
114120
115121 if tosa_spec .support_integer ():
116122 positive_checks .append (TOSAProINTSupportList ())
@@ -350,7 +356,8 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
350356 def is_node_supported (
351357 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
352358 ) -> bool :
353-
359+ if is_submodule_node (node ):
360+ return True
354361 vals = node .meta ["val" ]
355362 tensor_list = vals if isinstance (vals , (list , tuple )) else [vals ]
356363
@@ -390,7 +397,11 @@ def is_node_supported(
390397
391398 # Ops with int64 inputs are only partitioned if input nodes are constant and will be partitioned.
392399 # If it is not partitioned, the partition will get an int64 input and fail.
393- for input_node in node .all_input_nodes :
400+ for input_node in (
401+ input_node
402+ for input_node in node .all_input_nodes
403+ if input_node .op != "get_attr"
404+ ):
394405 tensor_in = get_first_fake_tensor (input_node )
395406 if tensor_in .dtype != torch .int64 :
396407 continue
@@ -426,8 +437,13 @@ def __init__(
426437 def is_node_supported (
427438 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
428439 ) -> bool :
429-
430- for input_node in node .all_input_nodes :
440+ if is_submodule_node (node ):
441+ return True
442+ for input_node in (
443+ input_node
444+ for input_node in node .all_input_nodes
445+ if input_node .op != "get_attr"
446+ ):
431447 tensor = get_first_fake_tensor (input_node )
432448 if tensor .dtype == torch .float64 :
433449 self .reporter .report_reject (
@@ -449,7 +465,13 @@ def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int):
449465 def is_node_supported (
450466 self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
451467 ) -> bool :
452- input_nodes = node .all_input_nodes
468+ if is_submodule_node (node ):
469+ return True
470+ input_nodes = (
471+ input_node
472+ for input_node in node .all_input_nodes
473+ if input_node .op != "get_attr"
474+ )
453475 # check if any input node has an unsupported rank
454476 for input_node in input_nodes :
455477 input_node_shape = get_first_fake_tensor (input_node ).shape
@@ -484,3 +506,112 @@ def is_node_supported(
484506 )
485507 return False
486508 return True
509+
510+
511+ class CondSupported (OperatorSupportBase ):
512+ """Checks whether the cond operator, and it's submodule args, should be partitioned."""
513+
514+ def __init__ (
515+ self ,
516+ exported_program : ExportedProgram ,
517+ tosa_spec : TosaSpecification ,
518+ reporter : WhyNoPartitionReporter ,
519+ ):
520+ self .exported_program = exported_program
521+ self .reporter = reporter
522+ self .tosa_spec = tosa_spec
523+ super ().__init__ ()
524+
525+ def _fully_partitioned (self , submodule : fx .GraphModule ) -> bool :
526+ partition_tag = None
527+ for submodule_node in submodule .graph .nodes :
528+ if submodule_node .op == "call_function" :
529+ # Input Q ops and output DQ ops will be de-tagged even if the submodule is fully supported.
530+ if (
531+ submodule_node .target in Q_OPS
532+ and list (submodule_node .all_input_nodes )[0 ].op == "placeholder"
533+ ):
534+ continue
535+ if (
536+ submodule_node .target in DQ_OPS
537+ and list (submodule_node .users )[0 ].op == "output"
538+ ):
539+ continue
540+ if "delegation_tag" not in submodule_node .meta :
541+ return False
542+ if partition_tag is None :
543+ partition_tag = submodule_node .meta ["delegation_tag" ]
544+ elif submodule_node .meta ["delegation_tag" ] != partition_tag :
545+ return False
546+ return True
547+
548+ def _cond_submodules_fully_partitioned (self , node : fx .Node ) -> bool :
549+ """Returns whether the submodule arguments to a cond node were fully partitioned.
550+ Updates "val" meta of the submodules if they are.
551+ """
552+ cond_submodules = (
553+ (
554+ self .exported_program .graph_module .get_submodule (
555+ str (cast (torch .fx .Node , submodule_node ).target )
556+ ),
557+ cast (torch .fx .Node , submodule_node ),
558+ )
559+ for submodule_node in node .args [1 :3 ]
560+ )
561+ for submodule , submodule_node in cond_submodules :
562+ submodule = cast (torch .fx .GraphModule , submodule )
563+
564+ if self ._fully_partitioned (submodule ):
565+ submodule_node .meta ["val" ] = submodule .graph .output_node ().meta ["val" ]
566+ else :
567+ return False
568+ return True
569+
570+ def is_node_supported ( # noqa: C901
571+ self , submodules : typing .Mapping [str , torch .nn .Module ], node : fx .Node
572+ ) -> bool :
573+ if is_submodule_node (node ):
574+ if not isinstance (self .tosa_spec , Tosa_1_00 ):
575+ self .reporter .report_reject (
576+ node , "Control flow extension not supported for TOSA version <1.0"
577+ )
578+ return False
579+ if not self .tosa_spec .support_extension ("cf" ):
580+ self .reporter .report_reject (
581+ node ,
582+ f"TOSA spec { self .tosa_spec } does not support control flow extension." ,
583+ )
584+ return False
585+ for user in node .users :
586+ if user .target != torch .ops .higher_order .cond :
587+ self .reporter .report_reject (
588+ node , f"Submodule had unsupported user { user } "
589+ )
590+ return False
591+ if not self ._cond_submodules_fully_partitioned (user ):
592+ self .reporter .report_reject (
593+ node , "One submodule was not fully partitioned"
594+ )
595+ return False
596+ return True
597+ if node .target == torch .ops .higher_order .cond :
598+ if not isinstance (self .tosa_spec , Tosa_1_00 ):
599+ self .reporter .report_reject (
600+ node , "Control flow extension not supported for TOSA version <1.0"
601+ )
602+ return False
603+ if not self .tosa_spec .support_extension ("cf" ):
604+ self .reporter .report_reject (
605+ node ,
606+ f"TOSA spec { self .tosa_spec } does not support control flow extension." ,
607+ )
608+ return False
609+
610+ if not self ._cond_submodules_fully_partitioned (node ):
611+ self .reporter .report_reject (
612+ node , "Submodule was not fully partitioned."
613+ )
614+ return False
615+ return True
616+
617+ return False
0 commit comments