From 703b72a9939ea339bc59a73dd197efcaceb089ff Mon Sep 17 00:00:00 2001 From: Prajwal Date: Wed, 4 Jun 2025 18:07:56 -0700 Subject: [PATCH 1/2] Fixed the get_var func --- pytensor/graph/basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 512f0ef3ab..baf6b4e381 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -2032,7 +2032,7 @@ def compare_nodes(nd_x, nd_y, common, different): def get_var_by_name( - graphs: Iterable[Variable], target_var_id: str, ids: str = "CHAR" + graphs: Iterable[Variable], target_var_id: str, include_inner_graphs: bool = False ) -> tuple[Variable, ...]: r"""Get variables in a graph using their names. @@ -2057,7 +2057,7 @@ def expand(r) -> list[Variable] | None: res = list(r.owner.inputs) - if isinstance(r.owner.op, HasInnerGraph): + if include_inner_graphs and isinstance(r.owner.op, HasInnerGraph): res.extend(r.owner.op.inner_outputs) return res From 62be8b82aa954b6ca90ae18f521968c645b4b41c Mon Sep 17 00:00:00 2001 From: Prajwal Date: Wed, 4 Jun 2025 18:38:41 -0700 Subject: [PATCH 2/2] Added a test case for get_var_by_name func fix --- tests/graph/test_basic.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index 84ffb365b5..4eb9ba735a 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -628,6 +628,31 @@ def test_get_var_by_name(): assert res == exp_res +def test_get_var_by_name_include_inner_graphs_flag(): + r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) + o1 = MyOp(r1, r2) + o1.name = "o1" + + # Inner graph + igo_in_1 = MyVariable(4) + igo_in_2 = MyVariable(5) + igo_out_1 = MyOp(igo_in_1, igo_in_2) + igo_out_1.name = "igo1" + + igo = MyInnerGraphOp([igo_in_1, igo_in_2], [igo_out_1]) + o2 = igo(r3, o1) + + res = get_var_by_name([o1, o2], "igo1", include_inner_graphs=False) + assert ( + res == () + ), "Should not return inner graph variable when include_inner_graphs is False" + + res = get_var_by_name([o1, o2], "igo1", include_inner_graphs=True) + assert any( + v.name == "igo1" for v in res + ), "Should return inner graph variable when include_inner_graphs is True" + + def test_clone_new_inputs(): """Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes."""