Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions src/cascade/frontend/ast_visitors/simplify_returns.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Any
from klara.core.ssa_visitors import AstVisitor
from klara.core import nodes

def simplify_returns(node):
sr = SimplifyReturns.replace(node)
for parent, n, target in sr.inserts:

# Add the new assign nodes to the parent after the ast visit is complete
for parent, n, target in sr.new_nodes:
try:
i = parent.body.index(n)
parent.body.insert(i, target)
Expand All @@ -15,15 +18,19 @@ def simplify_returns(node):
raise e

class SimplifyReturns(AstVisitor):
"""Replace attributes with "self" into "state", and remove SSA versioning.
"""Put return statments in ANF form.

Examples:

e.g.:
self_0.balance_0 -> state.balance
`return x+3` -> `__ret_0 = x + 3; return __ret_0`
`return self.balance` -> `__ret_1 = self.balance; return __ret_1`
`return cat` -> `return cat`
"""

def __init__(self):
self.temps = 0
self.inserts = []
# (return_node parent block, modified return_node, new assign_node)
self.new_nodes: list[tuple[Any, nodes.Return, nodes.Assign]] = []

@classmethod
def replace(cls, node):
Expand All @@ -41,8 +48,7 @@ def replace_name(self, node: nodes.Return):
node.value.postinit(target.id)

assert hasattr(node.parent, "body"), type(node.parent)
print(f"replacing {node} in {node.parent} with {new_assign}")
self.inserts.append((node.parent, node, new_assign))
self.new_nodes.append((node.parent, node, new_assign))

def visit_return(self, node: nodes.Return):

Expand Down
2 changes: 1 addition & 1 deletion src/cascade/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
def setup_cfg(code: str, preprocess=True) -> tuple[Cfg, nodes.Module]:
as_tree = AstBuilder().string_build(code)
cfg = Cfg(as_tree)
cfg.convert_to_ssa()
if preprocess:
cfg.convert_to_ssa()
ReplaceSelfWithState.replace(as_tree)
simplify_returns(as_tree)
# TODO: do this in preprocessing
Expand Down
48 changes: 18 additions & 30 deletions tests/frontend/ast_visitors/test_simplify_returns.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,48 @@
from cascade.frontend.ast_visitors.simplify_returns import SimplifyReturns, simplify_returns
from cascade.frontend.generator.unparser import unparse
from cascade.preprocessing import setup_cfg
from klara.core import nodes
from klara.core.tree_rewriter import AstBuilder
from klara.core.cfg import Cfg

def setup_cfg_no_ssa(code: str) -> Cfg:
as_tree = AstBuilder().string_build(code)
cfg = Cfg(as_tree)
return cfg, as_tree

def test_simplify_return_state():
code = "return self.balance"
cfg, tree = setup_cfg_no_ssa(code)
for s in tree.get_statements():
print(repr(s))
sr = SimplifyReturns.replace(tree)
cfg, tree = setup_cfg(code, preprocess=False)

simplify_returns(tree)

for s in tree.get_statements():
print(repr(s))
new = [unparse(s) for s in tree.get_statements()]
assert new == ["__ret_0 = self.balance", "return __ret_0"]


def test_simplify_return_name():
code = "return cat"
cfg, tree = setup_cfg_no_ssa(code)
for s in tree.get_statements():
print(repr(s))
sr = SimplifyReturns.replace(tree)
cfg, tree = setup_cfg(code, preprocess=False)

simplify_returns(tree)

for s in tree.get_statements():
print(repr(s))
new = [unparse(s) for s in tree.get_statements()]
assert new == ["return cat"]

def test_simplify_return_binop():
code = """a = 1
return 4+1"""
cfg, tree = setup_cfg_no_ssa(code)
return 4 + 1"""
cfg, tree = setup_cfg(code, preprocess=False)


for s in tree.get_statements():
print(repr(s))
simplify_returns(tree)

for s in tree.get_statements():
print(repr(s))
new = [unparse(s) for s in tree.get_statements()]
assert new == ["a = 1", "__ret_0 = 4 + 1", "return __ret_0"]

def test_simplify_return_multiple():
code = """a = 1
if a == 1:
return 3 + 2
else:
return a"""
cfg, tree = setup_cfg_no_ssa(code)
cfg, tree = setup_cfg(code, preprocess=False)

for b in tree.get_statements():
print(repr(b))
simplify_returns(tree)

for b in tree.get_statements():
print(repr(b))
new = [unparse(s) for s in tree.get_statements()]
# if statements aren't returned in get_statements
assert new == ["a = 1", "__ret_0 = 3 + 2", "return __ret_0", "return a"]
Loading