Skip to content

Commit 0978d11

Browse files
committed
Add a second inline_ofg_expansion in xtensor for Ops that wrap OpFromGraph once lowered
1 parent e3e4afe commit 0978d11

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

pytensor/xtensor/rewriting/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pytensor.compile import optdb
2-
from pytensor.graph.rewriting.basic import NodeRewriter
2+
from pytensor.graph.rewriting.basic import NodeRewriter, in2out
33
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
4+
from pytensor.tensor.rewriting.ofg import inline_ofg_expansion
45

56

67
lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)
@@ -14,6 +15,15 @@
1415
position=0.1,
1516
)
1617

18+
# Register OFG inline again after lowering xtensor
19+
optdb.register(
20+
"inline_ofg_expansion_xtensor",
21+
in2out(inline_ofg_expansion),
22+
"fast_run",
23+
"fast_compile",
24+
position=0.11,
25+
)
26+
1727

1828
def register_lower_xtensor(
1929
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs

0 commit comments

Comments
 (0)