Skip to content

Commit f023eaa

Browse files
authored
Merge pull request #173 from bcaller/recursion
Recursive function calls shouldn't raise RecursionError
2 parents c7b244d + 093f506 commit f023eaa

File tree

7 files changed

+62
-2
lines changed

7 files changed

+62
-2
lines changed

examples/vulnerable_code/recursive.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from flask import Flask, request
2+
3+
app = Flask(__name__)
4+
5+
6+
def recur_without_any_propagation(x):
7+
if len(x) < 20:
8+
return recur_without_any_propagation("a" * 24)
9+
return "Done"
10+
11+
12+
def recur_no_propagation_false_positive(x):
13+
if len(x) < 20:
14+
return recur_no_propagation_false_positive(x + "!")
15+
return "Done"
16+
17+
18+
def recur_with_propagation(x):
19+
if len(x) < 20:
20+
return recur_with_propagation(x + "!")
21+
return x
22+
23+
24+
@app.route('/recursive')
25+
def route():
26+
param = request.args.get('param', 'not set')
27+
repeated_completely_untainted = recur_without_any_propagation(param)
28+
app.db.execute(repeated_completely_untainted)
29+
repeated_untainted = recur_no_propagation_false_positive(param)
30+
app.db.execute(repeated_untainted)
31+
repeated_tainted = recur_with_propagation(param)
32+
app.db.execute(repeated_tainted)

pyt/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ def main(command_line_args=sys.argv[1:]): # noqa: C901
125125
)
126126

127127
initialize_constraint_table(cfg_list)
128+
log.info("Analysing")
128129
analyse(cfg_list)
130+
log.info("Finding vulnerabilities")
129131
vulnerabilities = find_vulnerabilities(
130132
cfg_list,
131133
args.blackbox_mapping_file,

pyt/cfg/expr_visitor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import logging
23

34
from .alias_helper import handle_aliases_in_calls
45
from ..core.ast_helper import (
@@ -30,6 +31,8 @@
3031
from .stmt_visitor import StmtVisitor
3132
from .stmt_visitor_helper import CALL_IDENTIFIER
3233

34+
log = logging.getLogger(__name__)
35+
3336

3437
class ExprVisitor(StmtVisitor):
3538
def __init__(
@@ -52,6 +55,7 @@ def __init__(
5255
self.undecided = False
5356
self.function_names = list()
5457
self.function_return_stack = list()
58+
self.function_definition_stack = list() # used to avoid recursion
5559
self.module_definitions_stack = list()
5660
self.prev_nodes_to_avoid = list()
5761
self.last_control_flow_nodes = list()
@@ -543,6 +547,7 @@ def process_function(self, call_node, definition):
543547
first_node
544548
)
545549
self.function_return_stack.pop()
550+
self.function_definition_stack.pop()
546551

547552
return self.nodes[-1]
548553

@@ -560,11 +565,15 @@ def visit_Call(self, node):
560565
last_attribute = _id.rpartition('.')[-1]
561566

562567
if definition:
568+
if definition in self.function_definition_stack:
569+
log.debug("Recursion encountered in function %s", _id)
570+
return self.add_blackbox_or_builtin_call(node, blackbox=True)
563571
if isinstance(definition.node, ast.ClassDef):
564572
self.add_blackbox_or_builtin_call(node, blackbox=False)
565573
elif isinstance(definition.node, ast.FunctionDef):
566574
self.undecided = False
567575
self.function_return_stack.append(_id)
576+
self.function_definition_stack.append(definition)
568577
return self.process_function(node, definition)
569578
else:
570579
raise Exception('Definition was neither FunctionDef or ' +

pyt/web_frameworks/framework_adaptor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""A generic framework adaptor that leaves route criteria to the caller."""
22

33
import ast
4+
import logging
45

56
from ..cfg import make_cfg
67
from ..core.ast_helper import Arguments
@@ -10,6 +11,8 @@
1011
TaintedNode
1112
)
1213

14+
log = logging.getLogger(__name__)
15+
1316

1417
class FrameworkAdaptor():
1518
"""An engine that uses the template pattern to find all
@@ -31,6 +34,7 @@ def __init__(
3134

3235
def get_func_cfg_with_tainted_args(self, definition):
3336
"""Build a function cfg and return it, with all arguments tainted."""
37+
log.debug("Getting CFG for %s", definition.name)
3438
func_cfg = make_cfg(
3539
definition.node,
3640
self.project_modules,

tests/cfg/cfg_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .cfg_base_test_case import CFGBaseTestCase
44

55
from pyt.core.node_types import (
6+
BBorBInode,
67
EntryOrExitNode,
78
Node
89
)
@@ -1389,6 +1390,13 @@ def test_call_on_call(self):
13891390
path = 'examples/example_inputs/call_on_call.py'
13901391
self.cfg_create_from_file(path)
13911392

1393+
def test_recursive_function(self):
1394+
path = 'examples/example_inputs/recursive.py'
1395+
self.cfg_create_from_file(path)
1396+
recursive_call = self.cfg.nodes[7]
1397+
assert recursive_call.label == '~call_3 = ret_rec(wat)'
1398+
assert isinstance(recursive_call, BBorBInode) # Not RestoreNode
1399+
13921400

13931401
class CFGCallWithAttributeTest(CFGBaseTestCase):
13941402
def setUp(self):

tests/main_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ def test_targets_with_recursive(self):
108108
excluded_files = ""
109109

110110
included_files = discover_files(targets, excluded_files, True)
111-
self.assertEqual(len(included_files), 31)
111+
self.assertEqual(len(included_files), 32)
112112

113113
def test_targets_with_recursive_and_excluded(self):
114114
targets = ["examples/vulnerable_code/"]
115115
excluded_files = "inter_command_injection.py"
116116

117117
included_files = discover_files(targets, excluded_files, True)
118-
self.assertEqual(len(included_files), 30)
118+
self.assertEqual(len(included_files), 31)

tests/vulnerabilities/vulnerabilities_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,11 @@ def assert_vulnerable(fixture):
465465
assert_vulnerable('result = repr(str("%s" % TAINT.lower().upper()))')
466466
assert_vulnerable('result = repr(str("{}".format(TAINT.lower())))')
467467

468+
def test_recursion(self):
469+
# Really this file only has one vulnerability, but for now it's safer to keep the false positive.
470+
vulnerabilities = self.run_analysis('examples/vulnerable_code/recursive.py')
471+
self.assert_length(vulnerabilities, expected_length=2)
472+
468473

469474
class EngineDjangoTest(VulnerabilitiesBaseTestCase):
470475
def run_analysis(self, path):

0 commit comments

Comments
 (0)