From 0e6b40bae7d82ea03d238611d0f16824ac1ccc54 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 4 Sep 2025 19:25:51 +0200 Subject: [PATCH 01/11] Test more FusionOptimizer graphs --- pytensor/tensor/rewriting/elemwise.py | 12 ++++--- tests/tensor/rewriting/test_elemwise.py | 45 +++++++++++++++++++++---- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index e2d420f361..0eb2900729 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -569,8 +569,6 @@ def elemwise_to_scalar(inputs, outputs): return scalar_inputs, scalar_outputs def apply(self, fgraph): - nb_replacement = 0 - if fgraph.profile: validate_before = fgraph.profile.validate_time callbacks_before = fgraph.execute_callbacks_times.copy() @@ -925,6 +923,8 @@ def update_fuseable_mappings_after_fg_replace( starting_nodes=starting_nodes, ) + nb_fused = 0 + nb_replacement = 0 for inputs, outputs in find_next_fuseable_subgraph(fgraph): if (len(inputs) + len(outputs)) > max_operands: warn( @@ -943,11 +943,13 @@ def update_fuseable_mappings_after_fg_replace( if old_out.name: composite_out.name = old_out.name + starting_nodes = len(fgraph.apply_nodes) fgraph.replace_all_validate( list(zip(outputs, composite_outputs, strict=True)), reason=self.__class__.__name__, ) - nb_replacement += 1 + nb_fused += 1 + nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1 if fgraph.profile: validate_time = fgraph.profile.validate_time - validate_before @@ -965,7 +967,7 @@ def update_fuseable_mappings_after_fg_replace( return ( self, - 1, # nb_iter + nb_fused, nb_replacement, 0, # nb_inconsintency_replace validate_time, @@ -978,7 +980,7 @@ def update_fuseable_mappings_after_fg_replace( def print_profile(stream, prof, level=0): blanc = " " * level print(blanc, "FusionOptimizer", file=stream) - print(blanc, " nb_iter", prof[1], file=stream) + print(blanc, " nb_fused", prof[1], file=stream) print(blanc, " nb_replacement", prof[2], file=stream) print(blanc, " nb_inconsistency_replace", prof[3], file=stream) print(blanc, " validate_time", prof[4], file=stream) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index c23d0ac23a..2ace386376 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -273,7 +273,8 @@ def my_init(dtype="float64", num=0): fwx = fw + fx ftanx = tan(fx) - def large_fuseable_graph(self, n): + @staticmethod + def large_fuseable_graph(n): factors = [] sd = dscalar() means = dvector() @@ -296,6 +297,28 @@ def large_fuseable_graph(self, n): dlogp = [pytensor.grad(logp, v) for v in vars] return vars, dlogp + @staticmethod + def deep_small_kernels(n): + x = pt.matrix("x") + out = x + for _ in range(n): + out = pt.sin(out.T) + pt.cos(out) + + return [x], [out] + + @staticmethod + def test_diamond_graph(): + a = pt.matrix("a") + b = pt.exp(a) + c = pt.log(b) + d = pt.sin(c) + e = c + d + + fg = FunctionGraph([a], [e], clone=False) + _, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg) + assert nb_fused == 1 + assert nb_replacement == 4 + @pytest.mark.parametrize( "case", [ @@ -1347,16 +1370,26 @@ def test_eval_benchmark(self, benchmark): benchmark(func) @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") - def test_rewrite_benchmark(self, benchmark): - inps, outs = self.large_fuseable_graph(n=25) + @pytest.mark.parametrize( + "graph_fn, n, expected_n_repl", + [ + ("deep_small_kernels", 20, (20, 60)), + ("large_fuseable_graph", 25, (103, 876)), + ], + ) + def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark): + inps, outs = getattr(self, graph_fn)(n) fg = FunctionGraph(inps, outs) opt = FusionOptimizer() def rewrite_func(): - nb_replacement = opt.apply(fg.clone())[2] - return nb_replacement + fg_clone = fg.clone() + _, nb_fused, nb_replacement, *_ = opt.apply(fg_clone) + # fg_clone.dprint() + return nb_fused, nb_replacement - assert benchmark(rewrite_func) == 103 + assert rewrite_func() == expected_n_repl + benchmark.pedantic(rewrite_func, rounds=7, iterations=5) def test_no_warning_from_old_client(self): # There used to be a warning issued when creating fuseable mapping From b6cb3649dfc24999c57c46411bd8fd2243798af8 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Sat, 20 Sep 2025 10:19:40 +0200 Subject: [PATCH 02/11] Short-circuit `as_scalar` common cases faster --- pytensor/scalar/basic.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 339da84cd1..26d242d3f0 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -987,25 +987,28 @@ def constant(x, name=None, dtype=None) -> ScalarConstant: def as_scalar(x: Any, name: str | None = None) -> ScalarVariable: - from pytensor.tensor.basic import scalar_from_tensor - from pytensor.tensor.type import TensorType + if isinstance(x, ScalarVariable): + return x + + if isinstance(x, Variable): + from pytensor.tensor.basic import scalar_from_tensor + from pytensor.tensor.type import TensorType + + if isinstance(x.type, TensorType) and x.type.ndim == 0: + return scalar_from_tensor(x) + else: + raise TypeError(f"Cannot convert {x} to a scalar type") if isinstance(x, Apply): + # FIXME: Why do we support calling this with Apply? + # Also, if we do, why can't we support multiple outputs? if len(x.outputs) != 1: raise ValueError( "It is ambiguous which output of a multi-output" " Op has to be fetched.", x, ) - else: - x = x.outputs[0] - if isinstance(x, Variable): - if isinstance(x, ScalarVariable): - return x - elif isinstance(x.type, TensorType) and x.type.ndim == 0: - return scalar_from_tensor(x) - else: - raise TypeError(f"Cannot convert {x} to a scalar type") + return as_scalar(x.outputs[0]) return constant(x) From 244ef3a82569f5fad1578ce6bdff05b5100daa2d Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 19 Sep 2025 01:01:55 +0200 Subject: [PATCH 03/11] Speedup supports c_code Not using `__call__` avoids the test_value computation --- pytensor/scalar/basic.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 26d242d3f0..f12449cfc4 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1332,32 +1332,26 @@ def supports_c_code(self, inputs, outputs): the given Elemwise inputs, outputs. """ - try: - tmp_s_input = [] - # To keep the same aliasing between inputs - mapping = dict() - for ii in inputs: - if ii in mapping: - tmp_s_input.append(mapping[ii]) - else: - tmp = get_scalar_type(ii.dtype).make_variable() - tmp_s_input.append(tmp) - mapping[ii] = tmp_s_input[-1] - - with config.change_flags(compute_test_value="ignore"): - s_op = self(*tmp_s_input, return_list=True) + tmp_s_input = [] + # To keep the same aliasing between inputs + mapping = {} + for ii in inputs: + if ii in mapping: + tmp_s_input.append(mapping[ii]) + else: + tmp = mapping[ii] = get_scalar_type(ii.dtype).make_variable() + tmp_s_input.append(tmp) - # if the scalar_op don't have a c implementation, - # we skip its fusion to allow the fusion of the - # other ops. + try: self.c_code( - s_op[0].owner, + self.make_node(*tmp_s_input), "test_presence_of_c_code", + # FIXME: Shouldn't this be a unique name per unique variable? ["x" for x in inputs], ["z" for z in outputs], {"fail": "%(fail)s"}, ) - except (MethodNotDefined, NotImplementedError): + except (NotImplementedError, MethodNotDefined): return False return True From 7d399461e84db41d082bd2d634e6718d9cfd45ea Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 5 Sep 2025 13:33:36 +0200 Subject: [PATCH 04/11] Speedup FusionOptimizer.elemwise_to_scalar --- pytensor/scalar/basic.py | 8 ++-- pytensor/tensor/rewriting/elemwise.py | 55 +++++++++------------------ 2 files changed, 23 insertions(+), 40 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index f12449cfc4..cbf7b73542 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -779,9 +779,11 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType: This caches objects to save allocation and run time. """ - if dtype not in cache: - cache[dtype] = ScalarType(dtype=dtype) - return cache[dtype] + try: + return cache[dtype] + except KeyError: + cache[dtype] = res = ScalarType(dtype=dtype) + return res # Register C code for ViewOp on Scalars. diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 0eb2900729..1eb3d7c037 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -28,7 +28,7 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.unify import OpPattern -from pytensor.graph.traversal import ancestors +from pytensor.graph.traversal import ancestors, toposort from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( @@ -530,43 +530,24 @@ def add_requirements(self, fgraph): @staticmethod def elemwise_to_scalar(inputs, outputs): - replace_inputs = [(inp, inp.clone()) for inp in inputs] - outputs = clone_replace(outputs, replace=replace_inputs) - - inputs = [inp for _, inp in replace_inputs] - fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=False) - middle_inputs = [] - - scalar_inputs = [ - ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs - ] - middle_scalar_inputs = [] - - for node in fg.toposort(): - node_scalar_inputs = [] - for inp in node.inputs: - if inp in inputs: - node_scalar_inputs.append(scalar_inputs[inputs.index(inp)]) - elif inp in middle_inputs: - node_scalar_inputs.append( - middle_scalar_inputs[middle_inputs.index(inp)] + replacement = { + inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs + } + for node in toposort(outputs, blockers=inputs): + scalar_inputs = [replacement[inp] for inp in node.inputs] + replacement.update( + dict( + zip( + node.outputs, + node.op.scalar_op.make_node(*scalar_inputs).outputs, ) - else: - new_scalar_input = ps.get_scalar_type( - inp.type.dtype - ).make_variable() - node_scalar_inputs.append(new_scalar_input) - middle_scalar_inputs.append(new_scalar_input) - middle_inputs.append(inp) - - new_scalar_node = node.op.scalar_op.make_node(*node_scalar_inputs) - middle_scalar_inputs.append(new_scalar_node.outputs[0]) - middle_inputs.append(node.outputs[0]) - - scalar_outputs = [ - middle_scalar_inputs[middle_inputs.index(out)] for out in fg.outputs - ] - return scalar_inputs, scalar_outputs + ) + ) + + return ( + [replacement[inp] for inp in inputs], + [replacement[out] for out in outputs], + ) def apply(self, fgraph): if fgraph.profile: From fa3984fdf273aa80e7f00cf7f79e9b13efd809d3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 20 Sep 2025 10:05:06 +0200 Subject: [PATCH 05/11] Avoid double cloning of Composite Ops created by FusionOptimizer --- pytensor/scalar/basic.py | 19 ++++++++++++------- pytensor/tensor/rewriting/elemwise.py | 13 +++++++------ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index cbf7b73542..769a5dfeeb 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -13,7 +13,6 @@ import builtins import math from collections.abc import Callable -from copy import copy from itertools import chain from textwrap import dedent from typing import Any, TypeAlias @@ -4093,12 +4092,12 @@ def __init__(self, *args, **kwargs): self.prepare_node_called = set() super().__init__(*args, **kwargs) - def _cleanup_graph(self, inputs, outputs): + def _cleanup_graph(self, inputs, outputs, clone: builtins.bool = True): # TODO: We could convert to TensorVariable, optimize graph, # and then convert back to ScalarVariable. # This would introduce rewrites like `log(1 + x) -> log1p`. - fgraph = FunctionGraph(copy(inputs), copy(outputs)) + fgraph = FunctionGraph(inputs, outputs, clone=clone) # Validate node types for node in fgraph.apply_nodes: @@ -4281,7 +4280,9 @@ class Composite(ScalarInnerGraphOp): init_param: tuple[str, ...] = ("inputs", "outputs") - def __init__(self, inputs, outputs, name="Composite"): + def __init__( + self, inputs, outputs, name="Composite", clone_graph: builtins.bool = True + ): self.name = name self._name = None # We need to clone the graph as sometimes its nodes already @@ -4299,10 +4300,13 @@ def __init__(self, inputs, outputs, name="Composite"): if len(outputs) > 1 or not any( isinstance(var.owner.op, Composite) for var in outputs ): - # No inner Composite - inputs, outputs = clone(inputs, outputs) + if clone_graph: + inputs, outputs = clone(inputs, outputs) + else: # Inner Composite that we need to flatten + # FIXME: There could be a composite in the middle of the graph, why is this here? + # If anything it should be an optimization, but I suspect lower-level compilation can handle this anyway. assert len(outputs) == 1 # 1. Create a new graph from inputs up to the # Composite @@ -4321,7 +4325,8 @@ def __init__(self, inputs, outputs, name="Composite"): assert res[0] != inputs inputs, outputs = res[0], res2[1] - self.inputs, self.outputs = self._cleanup_graph(inputs, outputs) + # We already cloned the graph, or the user told us there was no need for it + self.inputs, self.outputs = self._cleanup_graph(inputs, outputs, clone=False) self.inputs_type = tuple(input.type for input in self.inputs) self.outputs_type = tuple(output.type for output in self.outputs) self.nin = len(inputs) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 1eb3d7c037..42f4b6fc67 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -915,12 +915,13 @@ def update_fuseable_mappings_after_fg_replace( break scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) - composite_outputs = Elemwise(ps.Composite(scalar_inputs, scalar_outputs))( - *inputs - ) - if not isinstance(composite_outputs, list): - composite_outputs = [composite_outputs] - for old_out, composite_out in zip(outputs, composite_outputs, strict=True): + composite_outputs = Elemwise( + # No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables + ps.Composite(scalar_inputs, scalar_outputs, clone_graph=False) + )(*inputs, return_list=True) + assert len(outputs) == len(composite_outputs) + for old_out, composite_out in zip(outputs, composite_outputs): + # Preserve any names on the original outputs if old_out.name: composite_out.name = old_out.name From e2d94e8be2a8f33d118778bf7e5c15e99aa549f1 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Sep 2025 18:06:06 +0200 Subject: [PATCH 06/11] Do not recompute toposort in every iteration of FusionOptimizer It's not really needed as we never expand on the new nodes --- pytensor/tensor/rewriting/elemwise.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 42f4b6fc67..689b47c28d 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -625,10 +625,10 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool: def find_fuseable_subgraph( *, - fg: FunctionGraph, visited_nodes: set[Apply], fuseable_clients: FUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING, + toposort_index: dict[Apply, int], ) -> tuple[list[Variable], list[Variable]]: KT = TypeVar("KT") VT = TypeVar("VT", list, set) @@ -648,8 +648,7 @@ def variables_depend_on( for a in ancestors(variables, blockers=stop_search_at) ) - toposort = fg.toposort() - for starting_node in toposort: + for starting_node in toposort_index: if starting_node in visited_nodes: continue @@ -791,7 +790,7 @@ def variables_depend_on( and inp.owner not in visited_nodes ) ), - key=lambda inp: toposort.index(inp.owner), + key=lambda inp: toposort_index[inp.owner], reverse=True, ): fuseable_nodes_to_visit.appendleft(inp.owner) @@ -803,7 +802,7 @@ def variables_depend_on( for node in fuseable_clients_temp.get(next_out, ()) if node not in visited_nodes ), - key=lambda node: toposort.index(node), + key=lambda node: toposort_index[node], ): fuseable_nodes_to_visit.append(next_node) @@ -877,20 +876,22 @@ def update_fuseable_mappings_after_fg_replace( # client (those that don't fit into 1)) fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) visited_nodes: set[Apply] = set() + toposort_index = {node: i for i, node in enumerate(fgraph.toposort())} while True: - starting_nodes = fg.apply_nodes.copy() try: subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( - fg=fg, visited_nodes=visited_nodes, fuseable_clients=fuseable_clients, unfuseable_clients=unfuseable_clients, + toposort_index=toposort_index, ) except ValueError: return else: # The caller is now expected to update fg in place, # by replacing the subgraph with a Composite Op + starting_nodes = fg.apply_nodes.copy() + yield subgraph_inputs, subgraph_outputs # This is where we avoid repeated work by using a stateful From adc74feadaa69a43ed9e4c74105f8d132902631f Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Sep 2025 23:49:50 +0200 Subject: [PATCH 07/11] Cleanup FusionOptimizer code --- pytensor/tensor/rewriting/elemwise.py | 173 ++++++++++++-------------- 1 file changed, 79 insertions(+), 94 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 689b47c28d..4eca867b4e 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -5,7 +5,7 @@ from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce -from typing import TypeVar +from typing import Literal from warnings import warn import pytensor.scalar.basic as ps @@ -555,8 +555,6 @@ def apply(self, fgraph): callbacks_before = fgraph.execute_callbacks_times.copy() callback_before = fgraph.execute_callbacks_time - max_operands = elemwise_max_operands_fct(None) - def find_next_fuseable_subgraph( fg: FunctionGraph, ) -> Generator[tuple[list[Variable], list[Variable]], None, None]: @@ -568,8 +566,7 @@ def find_next_fuseable_subgraph( This generator assumes that such subgraph is replaced by a single Elemwise Composite before being accessed again in the next iteration. """ - - FUSEABLE_MAPPING = defaultdict[Variable, list[Apply]] + FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] def initialize_fuseable_mappings( @@ -591,35 +588,31 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool: # to ensure the rewrite remains deterministic. # This is not a problem from unfuseable ones, as they can never # become part of the graph. - fuseable_clients: FUSEABLE_MAPPING = defaultdict(list) + fuseable_clients: FUSEABLE_MAPPING = defaultdict(set) unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set) for out, clients in fg.clients.items(): - # Old FunctionGraph nodes remain in the clients dictionary - # even after they are removed by rewrites - if not clients: - continue - out_maybe_fuseable = ( - out.owner + out.owner is not None and isinstance(out.owner.op, Elemwise) # and not isinstance(out.owner.op.scalar_op, ps.Composite) and len(out.owner.outputs) == 1 and elemwise_scalar_op_has_c_code(out.owner) ) - for client, _ in clients: - if ( - out_maybe_fuseable - and isinstance(client.op, Elemwise) - # and not isinstance(client.op.scalar_op, ps.Composite) - and len(client.outputs) == 1 - and out.type.broadcastable - == client.outputs[0].type.broadcastable - and elemwise_scalar_op_has_c_code(client) - ): - if client not in fuseable_clients[out]: - fuseable_clients[out].append(client) - else: - unfuseable_clients[out].add(client) + if out_maybe_fuseable: + out_bcast = out.type.broadcastable + for client, _ in clients: + if ( + isinstance(client.op, Elemwise) + # and not isinstance(client.op.scalar_op, ps.Composite) + and len(client.outputs) == 1 + and out_bcast == client.outputs[0].type.broadcastable + and elemwise_scalar_op_has_c_code(client) + ): + fuseable_clients[out].add(client) + else: + unfuseable_clients[out].add(client) + else: + unfuseable_clients[out] = {client for client, _ in clients} return fuseable_clients, unfuseable_clients @@ -630,16 +623,6 @@ def find_fuseable_subgraph( unfuseable_clients: UNFUSEABLE_MAPPING, toposort_index: dict[Apply, int], ) -> tuple[list[Variable], list[Variable]]: - KT = TypeVar("KT") - VT = TypeVar("VT", list, set) - - def shallow_clone_defaultdict( - d: defaultdict[KT, VT], - ) -> defaultdict[KT, VT]: - new_dict: defaultdict[KT, VT] = defaultdict(d.default_factory) - new_dict.update({k: v.copy() for k, v in d.items()}) - return new_dict - def variables_depend_on( variables, depend_on, stop_search_at=None ) -> bool: @@ -657,17 +640,19 @@ def variables_depend_on( visited_nodes.add(starting_node) continue - subgraph_inputs: list[Variable] = [] - subgraph_outputs: list[Variable] = [] + subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set + subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set unfuseable_clients_subgraph: set[Variable] = set() # Shallow cloning of maps so that they can be manipulated in place - fuseable_clients_temp = shallow_clone_defaultdict(fuseable_clients) - unfuseable_clients_clone = shallow_clone_defaultdict( - unfuseable_clients + fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set) + fuseable_clients_clone.update( + {k: v.copy() for k, v in fuseable_clients.items()} + ) + unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set) + unfuseable_clients_clone.update( + {k: v.copy() for k, v in unfuseable_clients.items()} ) - - fuseable_nodes_to_visit = deque([starting_node]) # We now try to expand as much as possible towards the potentially # fuseable clients and ancestors to detect the largest possible @@ -676,6 +661,7 @@ def variables_depend_on( # some inputs or clients may depend on other nodes of the same # subgraph via a path that cannot be included in the Composite # (unfuseable) + fuseable_nodes_to_visit = deque([starting_node]) while fuseable_nodes_to_visit: next_node = fuseable_nodes_to_visit.popleft() visited_nodes.add(next_node) @@ -684,15 +670,14 @@ def variables_depend_on( # If the output variable of next_node has no fuseable clients # or has unfuseable clients, then next_node must become an output # if it is to be fused. - must_become_output = ( - next_out not in fuseable_clients_temp - or next_out in unfuseable_clients_clone - ) + must_become_output = not fuseable_clients_clone.get( + next_out + ) or unfuseable_clients_clone.get(next_out) # We have backtracked to this node, and it may no longer be a viable output, # so we remove it and check again as if we had never seen this node - if must_become_output and next_out in subgraph_outputs: - subgraph_outputs.remove(next_out) + if must_become_output: + subgraph_outputs.pop(next_out, None) required_unfuseable_inputs = [ inp @@ -744,18 +729,19 @@ def variables_depend_on( if ( inp.owner in visited_nodes # next_node could have the same input repeated - and next_node in fuseable_clients_temp[inp] + and next_node in fuseable_clients_clone[inp] ): - fuseable_clients_temp[inp].remove(next_node) + fuseable_clients_clone[inp].remove(next_node) unfuseable_clients_clone[inp].add(next_node) # This input must become an output of the subgraph, # because it can't be merged with next_node. # We will revisit it to make sure this is safe. fuseable_nodes_to_visit.appendleft(inp.owner) - for client in fuseable_clients_temp[next_out]: + # need to convert to tuple not to change set size during iteration + for client in tuple(fuseable_clients_clone[next_out]): if client in visited_nodes: - fuseable_clients_temp[next_out].remove(client) + fuseable_clients_clone[next_out].remove(client) unfuseable_clients_clone[next_out].add(client) # next_out must become an input of the subgraph. # We will revisit any of its clients currently @@ -771,74 +757,72 @@ def variables_depend_on( # mappings as if it next_node was part of it. # Useless inputs will be removed by the useless Composite rewrite for inp in new_required_unfuseable_inputs: - if inp not in subgraph_inputs: - subgraph_inputs.append(inp) + subgraph_inputs[inp] = None if must_become_output: - subgraph_outputs.append(next_out) + subgraph_outputs[next_out] = None unfuseable_clients_subgraph.update( new_implied_unfuseable_clients ) # Expand through unvisited fuseable ancestors - for inp in sorted( - ( - inp - for inp in next_node.inputs - if ( - inp not in required_unfuseable_inputs - and inp.owner not in visited_nodes - ) - ), - key=lambda inp: toposort_index[inp.owner], - reverse=True, - ): - fuseable_nodes_to_visit.appendleft(inp.owner) + fuseable_nodes_to_visit.extendleft( + sorted( + ( + inp.owner + for inp in next_node.inputs + if ( + inp not in required_unfuseable_inputs + and inp.owner not in visited_nodes + ) + ), + key=toposort_index.get, # type: ignore[arg-type] + ) + ) # Expand through unvisited fuseable clients - for next_node in sorted( - ( - node - for node in fuseable_clients_temp.get(next_out, ()) - if node not in visited_nodes - ), - key=lambda node: toposort_index[node], - ): - fuseable_nodes_to_visit.append(next_node) + fuseable_nodes_to_visit.extend( + sorted( + ( + node + for node in fuseable_clients_clone.get(next_out, ()) + if node not in visited_nodes + ), + key=toposort_index.get, # type: ignore[arg-type] + ) + ) # Don't return if final subgraph is just the original Elemwise if len(subgraph_outputs) == 1 and set( - subgraph_outputs[0].owner.inputs + next(iter(subgraph_outputs)).owner.inputs ) == set(subgraph_inputs): # Update global fuseable mappings # No input was actually fuseable for inp in starting_node.inputs: - if starting_node in fuseable_clients.get(inp, ()): - fuseable_clients[inp].remove(starting_node) - unfuseable_clients[inp].add(starting_node) + fuseable_clients[inp].discard(starting_node) + unfuseable_clients[inp].add(starting_node) # No client was actually fuseable unfuseable_clients[starting_out].update( fuseable_clients.pop(starting_out, ()) ) continue - return subgraph_inputs, subgraph_outputs + return list(subgraph_inputs), list(subgraph_outputs) raise ValueError def update_fuseable_mappings_after_fg_replace( *, - fg: FunctionGraph, visited_nodes: set[Apply], fuseable_clients: FUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING, starting_nodes: set[Apply], + updated_nodes: set[Apply], ) -> None: # Find new composite node and dropped intermediate nodes # by comparing the current fg.apply nodes with the cached # original nodes - next_nodes = fg.apply_nodes - (new_composite_node,) = next_nodes - starting_nodes - dropped_nodes = starting_nodes - next_nodes + (new_composite_node,) = updated_nodes - starting_nodes + dropped_nodes = starting_nodes - updated_nodes # Remove intermediate Composite nodes from mappings for dropped_node in dropped_nodes: @@ -850,11 +834,11 @@ def update_fuseable_mappings_after_fg_replace( # Update fuseable information for subgraph inputs for inp in subgraph_inputs: if inp in fuseable_clients: - new_fuseable_clients = [ + new_fuseable_clients = { client for client in fuseable_clients[inp] if client not in dropped_nodes - ] + } if new_fuseable_clients: fuseable_clients[inp] = new_fuseable_clients else: @@ -898,13 +882,15 @@ def update_fuseable_mappings_after_fg_replace( # generator. For large models (as in `TestFusion.test_big_fusion`) # this can provide huge speedups update_fuseable_mappings_after_fg_replace( - fg=fg, visited_nodes=visited_nodes, fuseable_clients=fuseable_clients, unfuseable_clients=unfuseable_clients, starting_nodes=starting_nodes, + updated_nodes=fg.apply_nodes, ) + max_operands = elemwise_max_operands_fct(None) + reason = self.__class__.__name__ nb_fused = 0 nb_replacement = 0 for inputs, outputs in find_next_fuseable_subgraph(fgraph): @@ -923,13 +909,12 @@ def update_fuseable_mappings_after_fg_replace( assert len(outputs) == len(composite_outputs) for old_out, composite_out in zip(outputs, composite_outputs): # Preserve any names on the original outputs - if old_out.name: - composite_out.name = old_out.name + if old_name := old_out.name: + composite_out.name = old_name starting_nodes = len(fgraph.apply_nodes) fgraph.replace_all_validate( - list(zip(outputs, composite_outputs, strict=True)), - reason=self.__class__.__name__, + tuple(zip(outputs, composite_outputs)), reason=reason ) nb_fused += 1 nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1 From f768c3381996ab5ffb0038cf425bf8fa4cb79d0b Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Sep 2025 23:49:50 +0200 Subject: [PATCH 08/11] Copy on write in FusionOptimizer --- pytensor/tensor/rewriting/elemwise.py | 82 ++++++++++++++++++++------- 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 4eca867b4e..aa89fb2e56 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -2,6 +2,7 @@ import itertools import operator import sys +import typing from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce @@ -522,6 +523,43 @@ def elemwise_max_operands_fct(node) -> int: return 1024 +class CopyOnWriteDictOfSets: + __slots__ = ("d", "d_copy") + + def __init__(self, d: dict[typing.Any, set]): + self.d = d + self.d_copy: dict[typing.Any, set] = {} + + def __getitem__(self, key): + try: + return self.d_copy[key] + except KeyError: + return self.d[key] + + def get(self, key, default=frozenset()): + try: + return self.d_copy[key] + except KeyError: + try: + return self.d[key] + except KeyError: + return default + + def remove_from_key(self, key, value): + try: + self.d_copy[key].remove(value) + except KeyError: + self.d_copy[key] = copied_value = self.d[key].copy() + copied_value.remove(value) + + def add_to_key(self, key, value): + try: + self.d_copy[key].add(value) + except KeyError: + self.d_copy[key] = copied_value = self.d[key].copy() + copied_value.add(value) + + class FusionOptimizer(GraphRewriter): """Graph optimizer that fuses consecutive Elemwise operations.""" @@ -644,15 +682,10 @@ def variables_depend_on( subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set unfuseable_clients_subgraph: set[Variable] = set() - # Shallow cloning of maps so that they can be manipulated in place - fuseable_clients_clone: FUSEABLE_MAPPING = defaultdict(set) - fuseable_clients_clone.update( - {k: v.copy() for k, v in fuseable_clients.items()} - ) - unfuseable_clients_clone: UNFUSEABLE_MAPPING = defaultdict(set) - unfuseable_clients_clone.update( - {k: v.copy() for k, v in unfuseable_clients.items()} - ) + # If we need to manipulate the maps in place, we'll do a shallow copy later + # For now we query on the original ones + fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients) + unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients) # We now try to expand as much as possible towards the potentially # fuseable clients and ancestors to detect the largest possible @@ -682,7 +715,7 @@ def variables_depend_on( required_unfuseable_inputs = [ inp for inp in next_node.inputs - if next_node in unfuseable_clients_clone.get(inp, ()) + if next_node in unfuseable_clients_clone.get(inp) ] new_required_unfuseable_inputs = [ inp @@ -705,7 +738,7 @@ def variables_depend_on( if not must_backtrack: implied_unfuseable_clients = { c - for client in unfuseable_clients_clone.get(next_out, ()) + for client in unfuseable_clients_clone.get(next_out) if not isinstance(client.op, Output) for c in client.outputs } @@ -726,13 +759,15 @@ def variables_depend_on( if must_backtrack: for inp in next_node.inputs: - if ( - inp.owner in visited_nodes - # next_node could have the same input repeated - and next_node in fuseable_clients_clone[inp] - ): - fuseable_clients_clone[inp].remove(next_node) - unfuseable_clients_clone[inp].add(next_node) + if inp.owner in visited_nodes: + if next_node not in fuseable_clients_clone[inp]: + # This can happen when next node has repeated inputs + continue + fuseable_clients_clone.remove_from_key( + inp, next_node + ) + unfuseable_clients_clone.add_to_key(inp, next_node) + # This input must become an output of the subgraph, # because it can't be merged with next_node. # We will revisit it to make sure this is safe. @@ -741,8 +776,13 @@ def variables_depend_on( # need to convert to tuple not to change set size during iteration for client in tuple(fuseable_clients_clone[next_out]): if client in visited_nodes: - fuseable_clients_clone[next_out].remove(client) - unfuseable_clients_clone[next_out].add(client) + fuseable_clients_clone.remove_from_key( + next_out, client + ) + unfuseable_clients_clone.add_to_key( + next_out, client + ) + # next_out must become an input of the subgraph. # We will revisit any of its clients currently # in the subgraph to make sure this is safe. @@ -785,7 +825,7 @@ def variables_depend_on( sorted( ( node - for node in fuseable_clients_clone.get(next_out, ()) + for node in fuseable_clients_clone.get(next_out) if node not in visited_nodes ), key=toposort_index.get, # type: ignore[arg-type] From 9b9034357bfde45526bd27957d969c046a829c35 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 12 Sep 2025 18:57:02 +0200 Subject: [PATCH 09/11] Use bitset to check ancestors more efficiently --- pytensor/tensor/rewriting/elemwise.py | 139 +++++++++++++------------- tests/test_printing.py | 14 +-- 2 files changed, 77 insertions(+), 76 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index aa89fb2e56..a1fdad51bc 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -6,6 +6,7 @@ from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce +from operator import or_ from typing import Literal from warnings import warn @@ -29,7 +30,7 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.unify import OpPattern -from pytensor.graph.traversal import ancestors, toposort +from pytensor.graph.traversal import toposort from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( @@ -659,16 +660,9 @@ def find_fuseable_subgraph( visited_nodes: set[Apply], fuseable_clients: FUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING, + ancestors_bitset: dict[Apply, int], toposort_index: dict[Apply, int], ) -> tuple[list[Variable], list[Variable]]: - def variables_depend_on( - variables, depend_on, stop_search_at=None - ) -> bool: - return any( - a in depend_on - for a in ancestors(variables, blockers=stop_search_at) - ) - for starting_node in toposort_index: if starting_node in visited_nodes: continue @@ -680,7 +674,8 @@ def variables_depend_on( subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set - unfuseable_clients_subgraph: set[Variable] = set() + subgraph_inputs_ancestors_bitset = 0 + unfuseable_clients_subgraph_bitset = 0 # If we need to manipulate the maps in place, we'll do a shallow copy later # For now we query on the original ones @@ -712,50 +707,32 @@ def variables_depend_on( if must_become_output: subgraph_outputs.pop(next_out, None) - required_unfuseable_inputs = [ - inp - for inp in next_node.inputs - if next_node in unfuseable_clients_clone.get(inp) - ] - new_required_unfuseable_inputs = [ - inp - for inp in required_unfuseable_inputs - if inp not in subgraph_inputs - ] - - must_backtrack = False - if new_required_unfuseable_inputs and subgraph_outputs: - # We need to check that any new inputs required by this node - # do not depend on other outputs of the current subgraph, - # via an unfuseable path. - if variables_depend_on( - [next_out], - depend_on=unfuseable_clients_subgraph, - stop_search_at=subgraph_outputs, - ): - must_backtrack = True + # We need to check that any inputs required by this node + # do not depend on other outputs of the current subgraph, + # via an unfuseable path. + must_backtrack = ( + ancestors_bitset[next_node] + & unfuseable_clients_subgraph_bitset + ) if not must_backtrack: - implied_unfuseable_clients = { - c - for client in unfuseable_clients_clone.get(next_out) - if not isinstance(client.op, Output) - for c in client.outputs - } - - new_implied_unfuseable_clients = ( - implied_unfuseable_clients - unfuseable_clients_subgraph + implied_unfuseable_clients_bitset = reduce( + or_, + ( + 1 << toposort_index[client] + for client in unfuseable_clients_clone.get(next_out) + if not isinstance(client.op, Output) + ), + 0, ) - if new_implied_unfuseable_clients and subgraph_inputs: - # We need to check that any inputs of the current subgraph - # do not depend on other clients of this node, - # via an unfuseable path. - if variables_depend_on( - subgraph_inputs, - depend_on=new_implied_unfuseable_clients, - ): - must_backtrack = True + # We need to check that any inputs of the current subgraph + # do not depend on other clients of this node, + # via an unfuseable path. + must_backtrack = ( + subgraph_inputs_ancestors_bitset + & implied_unfuseable_clients_bitset + ) if must_backtrack: for inp in next_node.inputs: @@ -796,29 +773,24 @@ def variables_depend_on( # immediate dependency problems. Update subgraph # mappings as if it next_node was part of it. # Useless inputs will be removed by the useless Composite rewrite - for inp in new_required_unfuseable_inputs: - subgraph_inputs[inp] = None - if must_become_output: subgraph_outputs[next_out] = None - unfuseable_clients_subgraph.update( - new_implied_unfuseable_clients + unfuseable_clients_subgraph_bitset |= ( + implied_unfuseable_clients_bitset ) - # Expand through unvisited fuseable ancestors - fuseable_nodes_to_visit.extendleft( - sorted( - ( - inp.owner - for inp in next_node.inputs - if ( - inp not in required_unfuseable_inputs - and inp.owner not in visited_nodes - ) - ), - key=toposort_index.get, # type: ignore[arg-type] - ) - ) + for inp in sorted( + next_node.inputs, + key=lambda x: toposort_index.get(x.owner, -1), + ): + if next_node in unfuseable_clients_clone.get(inp, ()): + # input must become an input of the subgraph since it's unfuseable with new node + subgraph_inputs_ancestors_bitset |= ( + ancestors_bitset.get(inp.owner, 0) + ) + subgraph_inputs[inp] = None + elif inp.owner not in visited_nodes: + fuseable_nodes_to_visit.appendleft(inp.owner) # Expand through unvisited fuseable clients fuseable_nodes_to_visit.extend( @@ -855,6 +827,8 @@ def update_fuseable_mappings_after_fg_replace( visited_nodes: set[Apply], fuseable_clients: FUSEABLE_MAPPING, unfuseable_clients: UNFUSEABLE_MAPPING, + toposort_index: dict[Apply, int], + ancestors_bitset: dict[Apply, int], starting_nodes: set[Apply], updated_nodes: set[Apply], ) -> None: @@ -865,11 +839,25 @@ def update_fuseable_mappings_after_fg_replace( dropped_nodes = starting_nodes - updated_nodes # Remove intermediate Composite nodes from mappings + # And compute the ancestors bitset of the new composite node + # As well as the new toposort index for the new node + new_node_ancestor_bitset = 0 + new_node_toposort_index = len(toposort_index) for dropped_node in dropped_nodes: (dropped_out,) = dropped_node.outputs fuseable_clients.pop(dropped_out, None) unfuseable_clients.pop(dropped_out, None) visited_nodes.remove(dropped_node) + # The new composite ancestor bitset is the union + # of the ancestors of all the dropped nodes + new_node_ancestor_bitset |= ancestors_bitset[dropped_node] + # The new composite node can have the same order as the latest node that was absorbed into it + new_node_toposort_index = max( + new_node_toposort_index, toposort_index[dropped_node] + ) + + ancestors_bitset[new_composite_node] = new_node_ancestor_bitset + toposort_index[new_composite_node] = new_node_toposort_index # Update fuseable information for subgraph inputs for inp in subgraph_inputs: @@ -901,12 +889,23 @@ def update_fuseable_mappings_after_fg_replace( fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) visited_nodes: set[Apply] = set() toposort_index = {node: i for i, node in enumerate(fgraph.toposort())} + # Create a bitset for each node of all its ancestors + # This allows to quickly check if a variable depends on a set + ancestors_bitset: dict[Apply, int] = {} + for node, index in toposort_index.items(): + node_ancestor_bitset = 1 << index + for inp in node.inputs: + if (inp_node := inp.owner) is not None: + node_ancestor_bitset |= ancestors_bitset[inp_node] + ancestors_bitset[node] = node_ancestor_bitset + while True: try: subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( visited_nodes=visited_nodes, fuseable_clients=fuseable_clients, unfuseable_clients=unfuseable_clients, + ancestors_bitset=ancestors_bitset, toposort_index=toposort_index, ) except ValueError: @@ -925,6 +924,8 @@ def update_fuseable_mappings_after_fg_replace( visited_nodes=visited_nodes, fuseable_clients=fuseable_clients, unfuseable_clients=unfuseable_clients, + toposort_index=toposort_index, + ancestors_bitset=ancestors_bitset, starting_nodes=starting_nodes, updated_nodes=fg.apply_nodes, ) diff --git a/tests/test_printing.py b/tests/test_printing.py index 95c3c938cf..dbad8c063b 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -301,7 +301,8 @@ def test_debugprint(): Gemv_op_name = "CGemv" if pytensor.config.blas__ldflags else "Gemv" exp_res = dedent( r""" - Composite{(i2 + (i0 - i1))} 4 + Composite{(i0 + (i1 - i2))} 4 + ├─ A ├─ ExpandDims{axis=0} v={0: [0]} 3 """ f" │ └─ {Gemv_op_name}{{inplace}} d={{0: [0]}} 2" @@ -313,17 +314,16 @@ def test_debugprint(): │ ├─ B │ ├─ │ └─ 0.0 - ├─ D - └─ A + └─ D Inner graphs: - Composite{(i2 + (i0 - i1))} + Composite{(i0 + (i1 - i2))} ← add 'o0' - ├─ i2 - └─ sub ├─ i0 - └─ i1 + └─ sub + ├─ i1 + └─ i2 """ ).lstrip() From e2980981c17aa9457f44591666e7689a30840d69 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 18 Sep 2025 09:36:13 +0200 Subject: [PATCH 10/11] Avoid backtracking in FusionOptimizer The change in number of fused kernels has to do with the order of iteration, and could be replicated in the old approach by iterating in topological order. It was an accident that it happen to visit in an order where it connected two branches, instead of keeping them separate. The underlying limitation already existed and is described in https://github.com/pymc-devs/pytensor/issues/249 --- pytensor/tensor/rewriting/elemwise.py | 633 ++++++++++-------------- tests/tensor/rewriting/test_elemwise.py | 22 +- 2 files changed, 292 insertions(+), 363 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index a1fdad51bc..ff6b6c70e3 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -2,12 +2,10 @@ import itertools import operator import sys -import typing -from collections import defaultdict, deque from collections.abc import Generator, Sequence from functools import cache, reduce +from heapq import heapify, heappop, heappush from operator import or_ -from typing import Literal from warnings import warn import pytensor.scalar.basic as ps @@ -524,43 +522,6 @@ def elemwise_max_operands_fct(node) -> int: return 1024 -class CopyOnWriteDictOfSets: - __slots__ = ("d", "d_copy") - - def __init__(self, d: dict[typing.Any, set]): - self.d = d - self.d_copy: dict[typing.Any, set] = {} - - def __getitem__(self, key): - try: - return self.d_copy[key] - except KeyError: - return self.d[key] - - def get(self, key, default=frozenset()): - try: - return self.d_copy[key] - except KeyError: - try: - return self.d[key] - except KeyError: - return default - - def remove_from_key(self, key, value): - try: - self.d_copy[key].remove(value) - except KeyError: - self.d_copy[key] = copied_value = self.d[key].copy() - copied_value.remove(value) - - def add_to_key(self, key, value): - try: - self.d_copy[key].add(value) - except KeyError: - self.d_copy[key] = copied_value = self.d[key].copy() - copied_value.add(value) - - class FusionOptimizer(GraphRewriter): """Graph optimizer that fuses consecutive Elemwise operations.""" @@ -594,353 +555,300 @@ def apply(self, fgraph): callbacks_before = fgraph.execute_callbacks_times.copy() callback_before = fgraph.execute_callbacks_time - def find_next_fuseable_subgraph( + def find_fuseable_subgraphs( fg: FunctionGraph, - ) -> Generator[tuple[list[Variable], list[Variable]], None, None]: - """Find all subgraphs in a FunctionGraph that can be fused together - - Yields - ------- - List of inputs and outputs that determine subgraphs which can be fused. - This generator assumes that such subgraph is replaced by a single - Elemwise Composite before being accessed again in the next iteration. + ) -> Generator[tuple[tuple[Variable], tuple[Variable]], None, None]: + """Find subgraphs of Elemwise nodes that can be fused together. + + In general, there is no single solution. We try to find large subgraphs eagerly + + Any two consecutive Elemwise nodes that have the same broadcasting pattern, + and a C-implementation (historical accident that should be revisited), are potentially fuseable. + + However, not all collections of fuseable pairs make a valid fused subgraph. + A valid fused subgraph must be "convex", meaning that no two nodes in the subgraph + are connected via a path that goes outside the subgraph, either because they + are connected via unfuseable nodes, or nodes that have been claimed by another fused subgraph. + + For example the subgraph add(sin(exp(x)), sum(exp(x)) cannot be fused together, + because the sum node breaks the convexity of the subgraph {exp, sin, add}. + However, we can fuse {exp, sin}, and perhaps fuse add with something else. + + This function yields subgraph in reverse topological order so they can be safely replaced one at a time """ - FUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] - UNFUSEABLE_MAPPING = defaultdict[Variable, set[Apply]] - - def initialize_fuseable_mappings( - *, fg: FunctionGraph - ) -> tuple[FUSEABLE_MAPPING, UNFUSEABLE_MAPPING]: - @cache - def elemwise_scalar_op_has_c_code(node: Apply) -> bool: - # TODO: This should not play a role in non-c backends! - if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): - return True - else: - if config.optimizer_verbose: - warn( - f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." - ) - return False - - # Fuseable nodes have to be accessed in a deterministic manner - # to ensure the rewrite remains deterministic. - # This is not a problem from unfuseable ones, as they can never - # become part of the graph. - fuseable_clients: FUSEABLE_MAPPING = defaultdict(set) - unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set) - for out, clients in fg.clients.items(): - out_maybe_fuseable = ( - out.owner is not None - and isinstance(out.owner.op, Elemwise) - # and not isinstance(out.owner.op.scalar_op, ps.Composite) - and len(out.owner.outputs) == 1 - and elemwise_scalar_op_has_c_code(out.owner) + + @cache + def elemwise_scalar_op_has_c_code( + node: Apply, optimizer_verbose=config.optimizer_verbose + ) -> bool: + # TODO: This should not play a role in non-c backends! + if node.op.scalar_op.supports_c_code(node.inputs, node.outputs): + return True + elif optimizer_verbose: + warn( + f"Loop fusion interrupted because {node.op.scalar_op} does not provide a C implementation." ) - if out_maybe_fuseable: - out_bcast = out.type.broadcastable - for client, _ in clients: - if ( - isinstance(client.op, Elemwise) - # and not isinstance(client.op.scalar_op, ps.Composite) - and len(client.outputs) == 1 - and out_bcast == client.outputs[0].type.broadcastable - and elemwise_scalar_op_has_c_code(client) - ): - fuseable_clients[out].add(client) - else: - unfuseable_clients[out].add(client) - else: - unfuseable_clients[out] = {client for client, _ in clients} - - return fuseable_clients, unfuseable_clients - - def find_fuseable_subgraph( - *, - visited_nodes: set[Apply], - fuseable_clients: FUSEABLE_MAPPING, - unfuseable_clients: UNFUSEABLE_MAPPING, - ancestors_bitset: dict[Apply, int], - toposort_index: dict[Apply, int], - ) -> tuple[list[Variable], list[Variable]]: - for starting_node in toposort_index: - if starting_node in visited_nodes: - continue + return False + + # Create a map from node to a set of fuseable client (successor) nodes + # A node and a client are fuseable if they are both single output Elemwise + # (with C-implementation) and have the same output broadcastable pattern + # Nodes that have no fuseable clients are not included + fuseable_clients: dict[Apply, set[Apply]] = {} + # We also create a set with candidate nodes from which to start a subgraph expansion + # These are Single output Elemwise nodes (with C-implementation) that may or not + # have fuseable ancestors/clients at the start. + candidate_starting_nodes = set() + fg_clients = fg.clients + for out, clients_and_indices in fg_clients.items(): + out_node = out.owner + + if not ( + out_node is not None + and len(out_node.outputs) == 1 + and isinstance(out_node.op, Elemwise) + and elemwise_scalar_op_has_c_code(out_node) + ): + continue + + candidate_starting_nodes.add(out_node) + out_bcast = out.type.broadcastable + out_fuseable_clients = { + client + for client, _ in clients_and_indices + if ( + len(client.outputs) == 1 + and isinstance(client.op, Elemwise) + and out_bcast == client.outputs[0].type.broadcastable + and elemwise_scalar_op_has_c_code(client) + ) + } + if out_fuseable_clients: + fuseable_clients[out_node] = out_fuseable_clients + + if not candidate_starting_nodes: + return None + + # To enable fast dependency queries, we create a bitset of ancestors for each node. + # Each node is first represented by a bit flag of it's position in the toposort + # This can be achieved with python integers, via 1 << toposort_idx (equivalent to slower 2 ** toposort_idx) + # The ancestors bitsets of each node are obtained by bitwise OR of the ancestor bitsets + # of each of the nodes' inputs, and the bit flag of the node itself. + # + # Example: With three variables {a, b, c} owned by nodes {A, B, C}, where a is an input of b, and b an input of c, + # the nodes bit flags would be {A: 0b001, B: 0b010, C: 0b100} (integers {A: 1, B: 2, C: 4}) + # and the ancestors bitset would be {A: 0b001, B: 0b011, C: 0b111} (integers {A: 1, B: 3, C: 7}) + # + # This allows us to quickly ask if one or more variables are ancestors of a node by a simple bitwise AND + # For example, to ask if A is an ancestor of C we can do `ancestors_bitset[C] & node_bitset[A] != 0` + # We can also easily handle multiple nodes at once, for example to ask if A or B are ancestors of C we can do + # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0` + nodes_bitflags = {node: 1 << i for i, node in enumerate(fgraph.toposort())} + # Root variables have `None` as owner, which we can handle with a bitset of 0 + ancestors_bitset = {None: 0} + for node, node_bitflag in nodes_bitflags.items(): + # The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag + ancestors_bitset[node] = reduce( + or_, + (ancestors_bitset[inp.owner] for inp in node.inputs), + node_bitflag, + ) + # Handle root and leaf nodes gracefully + # We do it after the ancestors_bitset are built to simplify the previous loop. + # Root variables have `None` as owner, which we can handle with a bitflag of 0 + nodes_bitflags[None] = 0 + # Nothing ever depends on the special Output nodes, so just use a new bit for all of them + out_bitflag = 1 << len(nodes_bitflags) + for out in fg.outputs: + for client, _ in fg_clients[out]: + if isinstance(client.op, Output): + nodes_bitflags[client] = out_bitflag + + # Start main loop to find collection of fuseable subgraphs + # We store the collection in `sorted_subgraphs`, in reverse topological order + sorted_subgraphs: list[ + tuple[int, tuple[tuple[Variable], tuple[Variable]]] + ] = [] + # Keep a bitset of nodes that have been claimed by subgraphs + all_subgraphs_bitset = 0 + # Start exploring in reverse topological order from candidate sink nodes + # Sink nodes, are nodes that don't have any potential fuseable clients + for starting_node, starting_bitflag in reversed(nodes_bitflags.items()): + if ( + starting_bitflag & all_subgraphs_bitset + or starting_node not in candidate_starting_nodes + or starting_node in fuseable_clients + ): + continue - starting_out = starting_node.outputs[0] - if not fuseable_clients.get(starting_out): - visited_nodes.add(starting_node) + # We use an ordered queue to control the direction in which we expand the subgraph + # For simplicity, we always want to visit ancestors before clients + # For ancestors, we want to visit the later nodes first (those that have more dependencies) + # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies) + # To achieve this we use the bitflag as the sorting key (which encodes the topological order) + # and negate it for ancestors. + fuseables_nodes_queue = [(-starting_bitflag, starting_node)] + heapify(fuseables_nodes_queue) + + # We keep 3 bitsets during the exploration of a new subgraph: + # - the nodes that are part of the subgraph + # - the unfuseable ancestors of the subgraph (i.e., ancestors that are not fuseable with a node in the subgraph) + # - the unfuseable clients of the subgraph (i.e., clients that are not fuseable with a node in the subgraph) + # Whenever we visit a candidate node, we check if the subgraph's unfuseable ancestors depend on it, + # or if it depends on one of the subgraphs' unfuseable client, in which case we can't add it. + # If we can add it, we then add its unfuseable ancestors/clients to the respective bitsets + # and add its fuseable ancestors/clients to the queue to explore later. + # To work correctly, we must visit candidate subgraph nodes in the order described by the queue above. + # Otherwise, we would need to perform more complex dependency checks in every iteration and/or backtrack. + subgraph_nodes = [] + subgraph_bitset = 0 + unfuseable_ancestors_bitset = 0 + unfuseable_clients_bitset = 0 + + while fuseables_nodes_queue: + node_bitflag, node = heappop(fuseables_nodes_queue) + is_ancestor = node_bitflag < 0 + if is_ancestor: + node_bitflag = -node_bitflag + + if node_bitflag & subgraph_bitset: + # Already part of the subgraph continue - subgraph_inputs: dict[Variable, Literal[None]] = {} # ordered set - subgraph_outputs: dict[Variable, Literal[None]] = {} # ordered set - subgraph_inputs_ancestors_bitset = 0 - unfuseable_clients_subgraph_bitset = 0 - - # If we need to manipulate the maps in place, we'll do a shallow copy later - # For now we query on the original ones - fuseable_clients_clone = CopyOnWriteDictOfSets(fuseable_clients) - unfuseable_clients_clone = CopyOnWriteDictOfSets(unfuseable_clients) - - # We now try to expand as much as possible towards the potentially - # fuseable clients and ancestors to detect the largest possible - # subgraph that can be Composed together into a single `Op`. The - # largest issue to watch out is for cyclical dependencies, where - # some inputs or clients may depend on other nodes of the same - # subgraph via a path that cannot be included in the Composite - # (unfuseable) - fuseable_nodes_to_visit = deque([starting_node]) - while fuseable_nodes_to_visit: - next_node = fuseable_nodes_to_visit.popleft() - visited_nodes.add(next_node) - next_out = next_node.outputs[0] - - # If the output variable of next_node has no fuseable clients - # or has unfuseable clients, then next_node must become an output - # if it is to be fused. - must_become_output = not fuseable_clients_clone.get( - next_out - ) or unfuseable_clients_clone.get(next_out) - - # We have backtracked to this node, and it may no longer be a viable output, - # so we remove it and check again as if we had never seen this node - if must_become_output: - subgraph_outputs.pop(next_out, None) - - # We need to check that any inputs required by this node - # do not depend on other outputs of the current subgraph, - # via an unfuseable path. - must_backtrack = ( - ancestors_bitset[next_node] - & unfuseable_clients_subgraph_bitset - ) - - if not must_backtrack: - implied_unfuseable_clients_bitset = reduce( - or_, - ( - 1 << toposort_index[client] - for client in unfuseable_clients_clone.get(next_out) - if not isinstance(client.op, Output) - ), - 0, - ) + if is_ancestor: + if node_bitflag & unfuseable_ancestors_bitset: + # An unfuseable ancestor of the subgraph depends on this node, can't fuse + continue + elif ancestors_bitset[node] & unfuseable_clients_bitset: + # This node depends on an unfuseable client of the subgraph, can't fuse + continue - # We need to check that any inputs of the current subgraph - # do not depend on other clients of this node, - # via an unfuseable path. - must_backtrack = ( - subgraph_inputs_ancestors_bitset - & implied_unfuseable_clients_bitset + # Add node to subgraph + subgraph_nodes.append(node) + subgraph_bitset |= node_bitflag + + # Expand through ancestors and client nodes + # A node can either be: + # - already part of the subgraph (skip) + # - fuseable (add to queue) + # - unfuseable (add to respective unfuseable bitset) + for inp in node.inputs: + ancestor_node = inp.owner + ancestor_bitflag = nodes_bitflags[ancestor_node] + if ancestor_bitflag & subgraph_bitset: + continue + if node in fuseable_clients.get(ancestor_node, ()): + heappush( + fuseables_nodes_queue, + (-ancestor_bitflag, ancestor_node), ) - - if must_backtrack: - for inp in next_node.inputs: - if inp.owner in visited_nodes: - if next_node not in fuseable_clients_clone[inp]: - # This can happen when next node has repeated inputs - continue - fuseable_clients_clone.remove_from_key( - inp, next_node - ) - unfuseable_clients_clone.add_to_key(inp, next_node) - - # This input must become an output of the subgraph, - # because it can't be merged with next_node. - # We will revisit it to make sure this is safe. - fuseable_nodes_to_visit.appendleft(inp.owner) - - # need to convert to tuple not to change set size during iteration - for client in tuple(fuseable_clients_clone[next_out]): - if client in visited_nodes: - fuseable_clients_clone.remove_from_key( - next_out, client - ) - unfuseable_clients_clone.add_to_key( - next_out, client - ) - - # next_out must become an input of the subgraph. - # We will revisit any of its clients currently - # in the subgraph to make sure this is safe. - fuseable_nodes_to_visit.appendleft(client) - - # Revisit node at a later time - visited_nodes.remove(next_node) + else: + # If the node is not in the ancestor's fuseable clients set, it's not fuseable with it, + # nor with any of the ancestor's ancestors + unfuseable_ancestors_bitset |= ancestors_bitset[ + ancestor_node + ] + + next_fuseable_clients = fuseable_clients.get(node, ()) + for client, _ in fg_clients[node.outputs[0]]: + client_bitflag = nodes_bitflags[client] + if client_bitflag & subgraph_bitset: continue + if client in next_fuseable_clients: + heappush(fuseables_nodes_queue, (client_bitflag, client)) + else: + # If a client is not in the node's fuseable clients set, it's nto fuseable with it, + # nor any of its clients. But we don't need to keep track of those as any downstream + # client we may consider later will also depend on this unfuseable client and be rejected + unfuseable_clients_bitset |= client_bitflag - # Adding next_node to subgraph does not result in any - # immediate dependency problems. Update subgraph - # mappings as if it next_node was part of it. - # Useless inputs will be removed by the useless Composite rewrite - if must_become_output: - subgraph_outputs[next_out] = None - unfuseable_clients_subgraph_bitset |= ( - implied_unfuseable_clients_bitset - ) + # Finished exploring this subgraph + all_subgraphs_bitset |= subgraph_bitset - for inp in sorted( - next_node.inputs, - key=lambda x: toposort_index.get(x.owner, -1), - ): - if next_node in unfuseable_clients_clone.get(inp, ()): - # input must become an input of the subgraph since it's unfuseable with new node - subgraph_inputs_ancestors_bitset |= ( - ancestors_bitset.get(inp.owner, 0) - ) - subgraph_inputs[inp] = None - elif inp.owner not in visited_nodes: - fuseable_nodes_to_visit.appendleft(inp.owner) - - # Expand through unvisited fuseable clients - fuseable_nodes_to_visit.extend( - sorted( - ( - node - for node in fuseable_clients_clone.get(next_out) - if node not in visited_nodes - ), - key=toposort_index.get, # type: ignore[arg-type] - ) - ) - - # Don't return if final subgraph is just the original Elemwise - if len(subgraph_outputs) == 1 and set( - next(iter(subgraph_outputs)).owner.inputs - ) == set(subgraph_inputs): - # Update global fuseable mappings - # No input was actually fuseable - for inp in starting_node.inputs: - fuseable_clients[inp].discard(starting_node) - unfuseable_clients[inp].add(starting_node) - # No client was actually fuseable - unfuseable_clients[starting_out].update( - fuseable_clients.pop(starting_out, ()) - ) - continue + if subgraph_bitset == starting_bitflag: + # We ended were we started, no fusion possible + continue - return list(subgraph_inputs), list(subgraph_outputs) - raise ValueError - - def update_fuseable_mappings_after_fg_replace( - *, - visited_nodes: set[Apply], - fuseable_clients: FUSEABLE_MAPPING, - unfuseable_clients: UNFUSEABLE_MAPPING, - toposort_index: dict[Apply, int], - ancestors_bitset: dict[Apply, int], - starting_nodes: set[Apply], - updated_nodes: set[Apply], - ) -> None: - # Find new composite node and dropped intermediate nodes - # by comparing the current fg.apply nodes with the cached - # original nodes - (new_composite_node,) = updated_nodes - starting_nodes - dropped_nodes = starting_nodes - updated_nodes - - # Remove intermediate Composite nodes from mappings - # And compute the ancestors bitset of the new composite node - # As well as the new toposort index for the new node - new_node_ancestor_bitset = 0 - new_node_toposort_index = len(toposort_index) - for dropped_node in dropped_nodes: - (dropped_out,) = dropped_node.outputs - fuseable_clients.pop(dropped_out, None) - unfuseable_clients.pop(dropped_out, None) - visited_nodes.remove(dropped_node) - # The new composite ancestor bitset is the union - # of the ancestors of all the dropped nodes - new_node_ancestor_bitset |= ancestors_bitset[dropped_node] - # The new composite node can have the same order as the latest node that was absorbed into it - new_node_toposort_index = max( - new_node_toposort_index, toposort_index[dropped_node] + # Find out the actual inputs/outputs variables of the subgraph + not_subgraph_bitset = ~subgraph_bitset + # Inputs are variables whose nodes are not part of the subgraph (including root variables without nodes) + # Use a dict to deduplicate while preserving order + subgraph_inputs = tuple( + dict.fromkeys( + inp + for node in subgraph_nodes + for inp in node.inputs + if (inp_node := inp.owner) is None + or nodes_bitflags[inp_node] & not_subgraph_bitset ) + ) + # Outputs are variables with client nodes that are not part of the subgraph (including special fgraph output nodes) + # Outputs are unique, no need to deduplicate + subgraph_outputs = tuple( + node.outputs[0] + for node in subgraph_nodes + if any( + nodes_bitflags[client] & not_subgraph_bitset + for client, _ in fg_clients[node.outputs[0]] + ) + ) - ancestors_bitset[new_composite_node] = new_node_ancestor_bitset - toposort_index[new_composite_node] = new_node_toposort_index - - # Update fuseable information for subgraph inputs + # Update fuseable clients mapping for subgraph inputs and outputs + # Inputs cannot be fused with nodes in the subgraph for inp in subgraph_inputs: - if inp in fuseable_clients: - new_fuseable_clients = { - client - for client in fuseable_clients[inp] - if client not in dropped_nodes - } - if new_fuseable_clients: - fuseable_clients[inp] = new_fuseable_clients - else: - fuseable_clients.pop(inp) - unfuseable_clients[inp] = ( - unfuseable_clients[inp] - dropped_nodes - ) | {new_composite_node} - - # Update fuseable information for subgraph outputs - for out in new_composite_node.outputs: - unfuseable_clients[out] = {client for client, _ in fg.clients[out]} - - visited_nodes.add(new_composite_node) - return - - # We start by creating two maps, 1) from each node to each potentially - # fuseable client (both nodes must be single output Elemwise with same - # broadcast type) and 2) from each node to each certainly unfuseable - # client (those that don't fit into 1)) - fuseable_clients, unfuseable_clients = initialize_fuseable_mappings(fg=fg) - visited_nodes: set[Apply] = set() - toposort_index = {node: i for i, node in enumerate(fgraph.toposort())} - # Create a bitset for each node of all its ancestors - # This allows to quickly check if a variable depends on a set - ancestors_bitset: dict[Apply, int] = {} - for node, index in toposort_index.items(): - node_ancestor_bitset = 1 << index - for inp in node.inputs: - if (inp_node := inp.owner) is not None: - node_ancestor_bitset |= ancestors_bitset[inp_node] - ancestors_bitset[node] = node_ancestor_bitset - - while True: - try: - subgraph_inputs, subgraph_outputs = find_fuseable_subgraph( - visited_nodes=visited_nodes, - fuseable_clients=fuseable_clients, - unfuseable_clients=unfuseable_clients, - ancestors_bitset=ancestors_bitset, - toposort_index=toposort_index, + if (inp_node := inp.owner) is not None and ( + inp_fuseable_clients := fuseable_clients.get(inp_node) + ): + inp_fuseable_clients.difference_update(subgraph_nodes) + # If there are no fuseable_clients left for this input delete it's entry + if not inp_fuseable_clients: + del fuseable_clients[inp_node] + # Outputs cannot be fused with anything else + for out in subgraph_outputs: + fuseable_clients.pop(out.owner, None) + + # Add new subgraph to sorted_subgraphs + # Because we start from sink nodes in reverse topological order, most times new subgraphs + # don't depend on previous subgraphs, so we can just append them at the end. + if not (unfuseable_ancestors_bitset & all_subgraphs_bitset): + # That's the case here + # None of the unfuseable_ancestors (i.e, the ancestors) are present in the previous collected subgraphs + sorted_subgraphs.append( + (subgraph_bitset, (subgraph_inputs, subgraph_outputs)) ) - except ValueError: - return else: - # The caller is now expected to update fg in place, - # by replacing the subgraph with a Composite Op - starting_nodes = fg.apply_nodes.copy() - - yield subgraph_inputs, subgraph_outputs - - # This is where we avoid repeated work by using a stateful - # generator. For large models (as in `TestFusion.test_big_fusion`) - # this can provide huge speedups - update_fuseable_mappings_after_fg_replace( - visited_nodes=visited_nodes, - fuseable_clients=fuseable_clients, - unfuseable_clients=unfuseable_clients, - toposort_index=toposort_index, - ancestors_bitset=ancestors_bitset, - starting_nodes=starting_nodes, - updated_nodes=fg.apply_nodes, + # But not here, so we need to find the right position for insertion. + # We iterate through the previous subgraphs in topological order (reverse of the stored order). + # We exclude cumulatively exclude each subgraph_bitset and perform the same dependency check again. + # The (index + 1) of the firs iteration where the check passes is the correct insertion position. + remaining_subgraphs_bitset = all_subgraphs_bitset + for index, (other_subgraph_bitset, _) in enumerate( + reversed(sorted_subgraphs) + ): + # Exclude subgraph bitset + remaining_subgraphs_bitset &= ~other_subgraph_bitset + if not ( + unfuseable_ancestors_bitset & remaining_subgraphs_bitset + ): + break # bingo + sorted_subgraphs.insert( + -(index + 1), + (subgraph_bitset, (subgraph_inputs, subgraph_outputs)), ) + # yield from sorted_subgraphs, discarding the subgraph_bitset + yield from (io for _, io in sorted_subgraphs) + max_operands = elemwise_max_operands_fct(None) reason = self.__class__.__name__ nb_fused = 0 nb_replacement = 0 - for inputs, outputs in find_next_fuseable_subgraph(fgraph): + for inputs, outputs in find_fuseable_subgraphs(fgraph): if (len(inputs) + len(outputs)) > max_operands: warn( - "Loop fusion failed because the resulting node would exceed " - "the kernel argument limit." + "Loop fusion failed because the resulting node would exceed the kernel argument limit." ) - break + continue scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) composite_outputs = Elemwise( @@ -955,7 +863,8 @@ def update_fuseable_mappings_after_fg_replace( starting_nodes = len(fgraph.apply_nodes) fgraph.replace_all_validate( - tuple(zip(outputs, composite_outputs)), reason=reason + tuple(zip(outputs, composite_outputs)), + reason=reason, ) nb_fused += 1 nb_replacement += (starting_nodes - len(fgraph.apply_nodes)) + 1 diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 2ace386376..523effb1d1 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -319,6 +319,26 @@ def test_diamond_graph(): assert nb_fused == 1 assert nb_replacement == 4 + def test_expansion_order(self): + # This test is designed to fail if we don't use the right expansion order in the current implementation + # It may be considered irrelevant if the algorithm changes and this is no longer a concern. + # In that case the test can be tweaked or removed + a = pt.vector("a") + b = pt.exp(a) + # Unique creates an unfuesable path between b and d/e + c = pt.unique(b) + d = pt.log(c) + # The critical aspect of the current implementation, is that we must visit d before c, + # so we learn about the unfuseable path by the time we visit c + e1 = b + d + e2 = d + b # test both orders + + fg = FunctionGraph([a], [e1, e2], clone=False) + _, nb_fused, nb_replacement, *_ = FusionOptimizer().apply(fg) + fg.dprint() + assert nb_fused == 1 + assert nb_replacement == 3 + @pytest.mark.parametrize( "case", [ @@ -1374,7 +1394,7 @@ def test_eval_benchmark(self, benchmark): "graph_fn, n, expected_n_repl", [ ("deep_small_kernels", 20, (20, 60)), - ("large_fuseable_graph", 25, (103, 876)), + ("large_fuseable_graph", 25, (128, 876)), ], ) def test_rewrite_benchmark(self, graph_fn, n, expected_n_repl, benchmark): From e5e58b2fab5aaae07561ad6df2166b4f18d8b776 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 29 Sep 2025 18:33:11 +0200 Subject: [PATCH 11/11] Use more direct imports in rewriting/elemwise.py --- pytensor/tensor/rewriting/elemwise.py | 69 +++++++++++++++------------ 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index ff6b6c70e3..dfcdfdd471 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -8,16 +8,15 @@ from operator import or_ from warnings import warn -import pytensor.scalar.basic as ps -from pytensor import clone_replace, compile from pytensor.compile.function.types import Supervisor -from pytensor.compile.mode import get_target_language +from pytensor.compile.mode import get_target_language, optdb from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates from pytensor.graph.features import ReplaceValidate from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.op import Op +from pytensor.graph.replace import clone_replace from pytensor.graph.rewriting.basic import ( GraphRewriter, copy_stack_trace, @@ -30,11 +29,21 @@ from pytensor.graph.rewriting.unify import OpPattern from pytensor.graph.traversal import toposort from pytensor.graph.utils import InconsistencyError, MethodNotDefined -from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop -from pytensor.tensor.basic import ( - MakeVector, - constant, +from pytensor.scalar import ( + Add, + Composite, + Mul, + ScalarOp, + get_scalar_type, + transfer_type, + upcast_out, + upgrade_to_float, ) +from pytensor.scalar import cast as scalar_cast +from pytensor.scalar import constant as scalar_constant +from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop +from pytensor.tensor.basic import MakeVector +from pytensor.tensor.basic import constant as tensor_constant from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import add, exp, mul from pytensor.tensor.rewriting.basic import ( @@ -280,7 +289,7 @@ def create_inplace_node(self, node, inplace_pattern): inplace_pattern = {i: o for i, [o] in inplace_pattern.items()} if hasattr(scalar_op, "make_new_inplace"): new_scalar_op = scalar_op.make_new_inplace( - ps.transfer_type( + transfer_type( *[ inplace_pattern.get(i, o.dtype) for i, o in enumerate(node.outputs) @@ -289,14 +298,14 @@ def create_inplace_node(self, node, inplace_pattern): ) else: new_scalar_op = type(scalar_op)( - ps.transfer_type( + transfer_type( *[inplace_pattern.get(i, None) for i in range(len(node.outputs))] ) ) return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs) -compile.optdb.register( +optdb.register( "inplace_elemwise", InplaceElemwiseOptimizer(), "inplace_elemwise_opt", # for historic reason @@ -428,10 +437,8 @@ def local_useless_dimshuffle_makevector(fgraph, node): @register_canonicalize @node_rewriter( [ - elemwise_of( - OpPattern(ps.ScalarOp, output_types_preference=ps.upgrade_to_float) - ), - elemwise_of(OpPattern(ps.ScalarOp, output_types_preference=ps.upcast_out)), + elemwise_of(OpPattern(ScalarOp, output_types_preference=upgrade_to_float)), + elemwise_of(OpPattern(ScalarOp, output_types_preference=upcast_out)), ] ) def local_upcast_elemwise_constant_inputs(fgraph, node): @@ -452,7 +459,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): changed = False for i, inp in enumerate(node.inputs): if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant): - new_inputs[i] = constant(inp.data.astype(output_dtype)) + new_inputs[i] = tensor_constant(inp.data.astype(output_dtype)) changed = True if not changed: @@ -531,7 +538,7 @@ def add_requirements(self, fgraph): @staticmethod def elemwise_to_scalar(inputs, outputs): replacement = { - inp: ps.get_scalar_type(inp.type.dtype).make_variable() for inp in inputs + inp: get_scalar_type(inp.type.dtype).make_variable() for inp in inputs } for node in toposort(outputs, blockers=inputs): scalar_inputs = [replacement[inp] for inp in node.inputs] @@ -853,7 +860,7 @@ def elemwise_scalar_op_has_c_code( scalar_inputs, scalar_outputs = self.elemwise_to_scalar(inputs, outputs) composite_outputs = Elemwise( # No need to clone Composite graph, because `self.elemwise_to_scalar` creates fresh variables - ps.Composite(scalar_inputs, scalar_outputs, clone_graph=False) + Composite(scalar_inputs, scalar_outputs, clone_graph=False) )(*inputs, return_list=True) assert len(outputs) == len(composite_outputs) for old_out, composite_out in zip(outputs, composite_outputs): @@ -913,7 +920,7 @@ def print_profile(stream, prof, level=0): @register_canonicalize @register_specialize -@node_rewriter([elemwise_of(ps.Composite)]) +@node_rewriter([elemwise_of(Composite)]) def local_useless_composite_outputs(fgraph, node): """Remove inputs and outputs of Composite Ops that are not used anywhere.""" comp = node.op.scalar_op @@ -934,7 +941,7 @@ def local_useless_composite_outputs(fgraph, node): node.outputs ): used_inputs = [node.inputs[i] for i in used_inputs_idxs] - c = ps.Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) + c = Composite(inputs=used_inner_inputs, outputs=used_inner_outputs) e = Elemwise(scalar_op=c)(*used_inputs, return_list=True) return dict(zip([node.outputs[i] for i in used_outputs_idxs], e, strict=True)) @@ -948,7 +955,7 @@ def local_careduce_fusion(fgraph, node): # FIXME: This check is needed because of the faulty logic in the FIXME below! # Right now, rewrite only works for `Sum`/`Prod` - if not isinstance(car_scalar_op, ps.Add | ps.Mul): + if not isinstance(car_scalar_op, Add | Mul): return None elm_node = car_input.owner @@ -992,19 +999,19 @@ def local_careduce_fusion(fgraph, node): car_acc_dtype = node.op.acc_dtype scalar_elm_inputs = [ - ps.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs + get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs ] elm_output = elm_scalar_op(*scalar_elm_inputs) # This input represents the previous value in the `CAReduce` binary reduction - carried_car_input = ps.get_scalar_type(car_acc_dtype).make_variable() + carried_car_input = get_scalar_type(car_acc_dtype).make_variable() scalar_fused_output = car_scalar_op(carried_car_input, elm_output) if scalar_fused_output.type.dtype != car_acc_dtype: - scalar_fused_output = ps.cast(scalar_fused_output, car_acc_dtype) + scalar_fused_output = scalar_cast(scalar_fused_output, car_acc_dtype) - fused_scalar_op = ps.Composite( + fused_scalar_op = Composite( inputs=[carried_car_input, *scalar_elm_inputs], outputs=[scalar_fused_output] ) @@ -1025,7 +1032,7 @@ def local_careduce_fusion(fgraph, node): return [new_car_op(*elm_inputs)] -@node_rewriter([elemwise_of(ps.Composite)]) +@node_rewriter([elemwise_of(Composite)]) def local_inline_composite_constants(fgraph, node): """Inline scalar constants in Composite graphs.""" composite_op = node.op.scalar_op @@ -1041,7 +1048,7 @@ def local_inline_composite_constants(fgraph, node): and "complex" not in outer_inp.type.dtype ): if outer_inp.unique_value is not None: - inner_replacements[inner_inp] = ps.constant( + inner_replacements[inner_inp] = scalar_constant( outer_inp.unique_value, dtype=inner_inp.dtype ) continue @@ -1054,7 +1061,7 @@ def local_inline_composite_constants(fgraph, node): new_inner_outs = clone_replace( composite_op.fgraph.outputs, replace=inner_replacements ) - new_composite_op = ps.Composite(new_inner_inputs, new_inner_outs) + new_composite_op = Composite(new_inner_inputs, new_inner_outs) new_outputs = Elemwise(new_composite_op).make_node(*new_outer_inputs).outputs # Some of the inlined constants were broadcasting the output shape @@ -1095,7 +1102,7 @@ def constant_fold_branches_of_add_mul(fgraph, node): if other_inps: python_op = operator.mul if node.op == mul else operator.add folded_inputs = [reference_inp, *other_inps] - new_inp = constant( + new_inp = tensor_constant( reduce(python_op, (const.data for const in folded_inputs)) ) new_constants = [ @@ -1119,7 +1126,7 @@ def constant_fold_branches_of_add_mul(fgraph, node): add_mul_fusion_seqopt = SequenceDB() -compile.optdb.register( +optdb.register( "add_mul_fusion", add_mul_fusion_seqopt, "fast_run", @@ -1140,7 +1147,7 @@ def constant_fold_branches_of_add_mul(fgraph, node): # Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites) fuse_seqopt = SequenceDB() -compile.optdb.register( +optdb.register( "elemwise_fusion", fuse_seqopt, "fast_run", @@ -1271,7 +1278,7 @@ def split_2f1grad_loop(fgraph, node): return replacements -compile.optdb["py_only"].register( +optdb["py_only"].register( "split_2f1grad_loop", split_2f1grad_loop, "fast_compile",