diff --git a/src/cascade/frontend/ast_visitors/simplify_returns.py b/src/cascade/frontend/ast_visitors/simplify_returns.py index 443d1b9..d4205d7 100644 --- a/src/cascade/frontend/ast_visitors/simplify_returns.py +++ b/src/cascade/frontend/ast_visitors/simplify_returns.py @@ -1,9 +1,12 @@ +from typing import Any from klara.core.ssa_visitors import AstVisitor from klara.core import nodes def simplify_returns(node): sr = SimplifyReturns.replace(node) - for parent, n, target in sr.inserts: + + # Add the new assign nodes to the parent after the ast visit is complete + for parent, n, target in sr.new_nodes: try: i = parent.body.index(n) parent.body.insert(i, target) @@ -15,15 +18,19 @@ def simplify_returns(node): raise e class SimplifyReturns(AstVisitor): - """Replace attributes with "self" into "state", and remove SSA versioning. + """Put return statments in ANF form. + + Examples: - e.g.: - self_0.balance_0 -> state.balance + `return x+3` -> `__ret_0 = x + 3; return __ret_0` + `return self.balance` -> `__ret_1 = self.balance; return __ret_1` + `return cat` -> `return cat` """ def __init__(self): self.temps = 0 - self.inserts = [] + # (return_node parent block, modified return_node, new assign_node) + self.new_nodes: list[tuple[Any, nodes.Return, nodes.Assign]] = [] @classmethod def replace(cls, node): @@ -41,8 +48,7 @@ def replace_name(self, node: nodes.Return): node.value.postinit(target.id) assert hasattr(node.parent, "body"), type(node.parent) - print(f"replacing {node} in {node.parent} with {new_assign}") - self.inserts.append((node.parent, node, new_assign)) + self.new_nodes.append((node.parent, node, new_assign)) def visit_return(self, node: nodes.Return): diff --git a/src/cascade/preprocessing.py b/src/cascade/preprocessing.py index 8ef0449..35aede9 100644 --- a/src/cascade/preprocessing.py +++ b/src/cascade/preprocessing.py @@ -7,8 +7,8 @@ def setup_cfg(code: str, preprocess=True) -> tuple[Cfg, nodes.Module]: as_tree = AstBuilder().string_build(code) cfg = Cfg(as_tree) - cfg.convert_to_ssa() if preprocess: + cfg.convert_to_ssa() ReplaceSelfWithState.replace(as_tree) simplify_returns(as_tree) # TODO: do this in preprocessing diff --git a/tests/frontend/ast_visitors/test_simplify_returns.py b/tests/frontend/ast_visitors/test_simplify_returns.py index 02199b7..89712d1 100644 --- a/tests/frontend/ast_visitors/test_simplify_returns.py +++ b/tests/frontend/ast_visitors/test_simplify_returns.py @@ -1,48 +1,37 @@ from cascade.frontend.ast_visitors.simplify_returns import SimplifyReturns, simplify_returns from cascade.frontend.generator.unparser import unparse from cascade.preprocessing import setup_cfg -from klara.core import nodes -from klara.core.tree_rewriter import AstBuilder -from klara.core.cfg import Cfg -def setup_cfg_no_ssa(code: str) -> Cfg: - as_tree = AstBuilder().string_build(code) - cfg = Cfg(as_tree) - return cfg, as_tree def test_simplify_return_state(): code = "return self.balance" - cfg, tree = setup_cfg_no_ssa(code) - for s in tree.get_statements(): - print(repr(s)) - sr = SimplifyReturns.replace(tree) + cfg, tree = setup_cfg(code, preprocess=False) + simplify_returns(tree) - for s in tree.get_statements(): - print(repr(s)) + new = [unparse(s) for s in tree.get_statements()] + assert new == ["__ret_0 = self.balance", "return __ret_0"] + def test_simplify_return_name(): code = "return cat" - cfg, tree = setup_cfg_no_ssa(code) - for s in tree.get_statements(): - print(repr(s)) - sr = SimplifyReturns.replace(tree) + cfg, tree = setup_cfg(code, preprocess=False) + simplify_returns(tree) - for s in tree.get_statements(): - print(repr(s)) + new = [unparse(s) for s in tree.get_statements()] + assert new == ["return cat"] def test_simplify_return_binop(): code = """a = 1 -return 4+1""" - cfg, tree = setup_cfg_no_ssa(code) +return 4 + 1""" + cfg, tree = setup_cfg(code, preprocess=False) + - for s in tree.get_statements(): - print(repr(s)) simplify_returns(tree) - for s in tree.get_statements(): - print(repr(s)) + new = [unparse(s) for s in tree.get_statements()] + assert new == ["a = 1", "__ret_0 = 4 + 1", "return __ret_0"] def test_simplify_return_multiple(): code = """a = 1 @@ -50,11 +39,10 @@ def test_simplify_return_multiple(): return 3 + 2 else: return a""" - cfg, tree = setup_cfg_no_ssa(code) + cfg, tree = setup_cfg(code, preprocess=False) - for b in tree.get_statements(): - print(repr(b)) simplify_returns(tree) - for b in tree.get_statements(): - print(repr(b)) \ No newline at end of file + new = [unparse(s) for s in tree.get_statements()] + # if statements aren't returned in get_statements + assert new == ["a = 1", "__ret_0 = 3 + 2", "return __ret_0", "return a"] \ No newline at end of file