diff --git a/py/dml/codegen.py b/py/dml/codegen.py index fbc25cf3a..1e1e98073 100644 --- a/py/dml/codegen.py +++ b/py/dml/codegen.py @@ -3770,6 +3770,14 @@ def codegen_method_func(func): inline_scope.add(ExpressionSymbol(name, inlined_arg, method.site)) inp = [(n, t) for (n, t) in func.inp if isinstance(t, DMLType)] + if indices: + within_bounds = ' && '.join([ + f'_idx{i} < {dimsize}' + for (i, dimsize) in enumerate(method.dimsizes)]) + validate_indices = f'ASSERT({within_bounds});' + else: + validate_indices = None + with ErrorContext(method): location = Location(method, indices) if func.memoized: @@ -3781,7 +3789,8 @@ def codegen_method_func(func): method.site, inp, func.outp, func.throws, func.independent, memoization, method.astcode, method.default_method.default_sym(indices), - location, inline_scope, method.rbrace_site) + location, inline_scope, method.rbrace_site, + validate_indices) return code def codegen_return(site, outp, throws, retvals): @@ -3814,7 +3823,7 @@ def codegen_return(site, outp, throws, retvals): return mkCompound(site, stmts) def codegen_method(site, inp, outp, throws, independent, memoization, ast, - default, location, fnscope, rbrace_site): + default, location, fnscope, rbrace_site, validate_indices=None): with (crep.DeviceInstanceContext() if not independent else contextlib.nullcontext()): for (arg, etype) in inp: @@ -3867,6 +3876,8 @@ def prelude(): code.append(mkAssignStatement(site, param, init)) else: code = [] + if validate_indices: + code.append(mkInline(site, validate_indices)) with fail_handler, exit_handler: code.append(codegen_statement(ast, location, fnscope)) @@ -3882,6 +3893,8 @@ def prelude(): [subs] = ast.args with fail_handler, exit_handler: body = prelude() + if validate_indices: + body.append(mkInline(site, validate_indices)) body.extend(codegen_statements(subs, location, fnscope)) code = mkCompound(site, body) if code.control_flow().fallthrough: diff --git a/py/dml/ctree.py b/py/dml/ctree.py index 5d9fad091..aee79a6da 100644 --- a/py/dml/ctree.py +++ b/py/dml/ctree.py @@ -4601,6 +4601,8 @@ def value(self): def mkInlinedParam(site, expr, name, type): if not defined(expr): raise ICE(site, 'undefined parameter') + if isinstance(expr, InlinedParam): + expr = expr.expr if isinstance(expr, IntegerConstant): value = expr.value type = realtype(type) diff --git a/test/1.4/expressions/T_inlined_param.dml b/test/1.4/expressions/T_inlined_param.dml index 09422cf9d..e22577643 100644 --- a/test/1.4/expressions/T_inlined_param.dml +++ b/test/1.4/expressions/T_inlined_param.dml @@ -19,6 +19,10 @@ inline method m(inline x, void *y) { } } +inline method m2(inline x, void *y) { + m(x, y); +} + method init() { - m(NULL, NULL); + m2(NULL, NULL); }