From 9b56a73d47c07d7bc980cdea2125c0640e9c1655 Mon Sep 17 00:00:00 2001 From: Eric Shi Date: Tue, 9 Jun 2026 18:06:59 +0000 Subject: [PATCH] Support Callable function parameters (GH-1424) Callable parameters let user functions accept another Warp function as a specialization argument. Before this, helpers had to hard-code the callee, which made higher-order patterns unavailable and left specialization-sensitive behavior untested. Support these parameters in codegen, module hashing, dependency tracking, documentation, and autograd diagnostics. Keep the coverage in a dedicated test module so future regressions around callable specialization are easier to isolate. Signed-off-by: Eric Shi --- CHANGELOG.md | 3 + docs/user_guide/basics.rst | 63 +++ docs/user_guide/limitations.rst | 9 + warp/_src/codegen.py | 434 +++++++++++++-- warp/_src/context.py | 66 ++- warp/_src/types.py | 11 +- warp/tests/aux_test_callable_double.py | 11 + warp/tests/aux_test_callable_triple.py | 11 + warp/tests/test_func_callable.py | 695 +++++++++++++++++++++++++ warp/tests/unittest_suites.py | 4 + 10 files changed, 1265 insertions(+), 42 deletions(-) create mode 100644 warp/tests/aux_test_callable_double.py create mode 100644 warp/tests/aux_test_callable_triple.py create mode 100644 warp/tests/test_func_callable.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8330375ff0..5e1d3c70a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ ### Added +- Add support for `Callable`-typed parameters in user-defined `@wp.func` functions, allowing user-defined Warp functions + and simple built-in Warp functions such as `wp.sin()` and `wp.min()` to be used as callable targets from kernels or + other functions, including through defaults ([GH-1424](https://github.com/NVIDIA/warp/issues/1424)). - Add mipmap (texture level-of-detail) support to `wp.Texture1D`, `wp.Texture2D`, and `wp.Texture3D` via the new `num_mip_levels` and `mip_filter_mode` constructor parameters, and allow `wp.texture_sample()` to accept an optional trailing `lod` argument for controlling sampled detail level ([GH-1409](https://github.com/NVIDIA/warp/issues/1409)). diff --git a/docs/user_guide/basics.rst b/docs/user_guide/basics.rst index b1051f7346..9bc3243e03 100644 --- a/docs/user_guide/basics.rst +++ b/docs/user_guide/basics.rst @@ -335,6 +335,69 @@ User functions may also be overloaded by defining multiple function signatures w def custom(x: wp.vec3): return x + wp.vec3(1.0, 0.0, 0.0) +.. _callable-parameters: + +Callable Parameters +^^^^^^^^^^^^^^^^^^^ + +User functions can accept another user-defined Warp function or simple built-in +Warp function by annotating the parameter as ``Callable`` from +``collections.abc`` or ``typing``. +Callable targets are specialization inputs, not runtime values: the target must +be known during code generation, and Warp generates a separate specialization +for each target function. The callable target is chosen where the user function +is called and can be invoked directly inside the user function body: + +.. code-block:: python + + from collections.abc import Callable + + import warp as wp + + @wp.func + def square(x: float): + return x * x + + + @wp.func + def cube(x: float): + return x * x * x + + + @wp.func + def apply(f: Callable, x: float): + return f(x) + + + @wp.kernel + def apply_kernel( + values: wp.array[float], + square_out: wp.array[float], + cube_out: wp.array[float], + ): + i = wp.tid() + square_out[i] = apply(square, values[i]) + cube_out[i] = apply(cube, values[i]) + +Parameterized callable annotations such as ``Callable[[float], float]`` are +accepted, but Warp currently treats them the same as bare ``Callable``. The +argument and return types in the annotation are not validated against the target +function signature. The callable target is checked only through the actual calls +made in the function body during code generation. + +Callable parameters may also use defaults and keyword arguments: + +.. code-block:: python + + @wp.func + def apply_default(f: Callable = square, x: float = 0.0): + return f(x) + +Pass only user-defined :func:`@wp.func ` functions or simple built-in +functions such as ``wp.sin``, ``wp.cos``, ``wp.sqrt``, ``wp.add``, and ``wp.min`` +as callable targets. See :doc:`limitations` for unsupported callable targets and +other restrictions. + Tiles may also be passed to user functions. The function signature tile argument should include dtype and shape parameters to match the tile type intended to be used in the function. For example: diff --git a/docs/user_guide/limitations.rst b/docs/user_guide/limitations.rst index 383a312007..61446356ed 100644 --- a/docs/user_guide/limitations.rst +++ b/docs/user_guide/limitations.rst @@ -35,6 +35,15 @@ Kernels and User Functions (e.g., ``wp.float64(wp.PI)`` or ``wp.int64(large_value)``). * Python ``IntFlag`` values behave like raw integers in Warp kernels: bitwise negation (``~``) produces the integer negation, not a masked combination of flags as in standard Python ``IntFlag`` behavior. +* :ref:`Callable parameters ` in user functions only support direct inline calls with + specialization-time targets: user-defined :func:`@wp.func ` functions and simple built-in Warp functions + such as ``wp.sin``, ``wp.cos``, ``wp.sqrt``, ``wp.add``, and ``wp.min``. + Arbitrary Python callables and built-in Warp functions that require special code-generation behavior, dispatch logic, + or side effects, such as ``wp.printf``, are not supported. + Rebinding a function-valued local to a different function or to a non-function value is not supported. + User functions with ``Callable`` parameters also cannot define custom gradient or replay functions. + Callable argument and return types in annotations such as ``Callable[[float], float]`` are accepted but are not + validated against the target function signature. A limitation of Warp is that each dimension of the grid used to launch a kernel must be representable as a 32-bit signed integer. Therefore, no single dimension of a grid should exceed :math:`2^{31}-1`. diff --git a/warp/_src/codegen.py b/warp/_src/codegen.py index 2a525c576d..f4446cb235 100644 --- a/warp/_src/codegen.py +++ b/warp/_src/codegen.py @@ -20,6 +20,7 @@ import types from collections import deque from collections.abc import Callable, Mapping, Sequence +from copy import copy as shallowcopy from typing import Any, ClassVar, get_args, get_origin import warp.config @@ -940,8 +941,11 @@ def func_match_args(func, arg_types, kwarg_types): if func_arg_type is Any: continue - # handle function refs as a special case - if func_arg_type is Callable and isinstance(bound_arg_type, warp._src.context.Function): + # Callable parameters are type-erased during overload matching; the + # concrete target is bound later during call specialization. + if warp._src.types.is_callable_annotation(func_arg_type) and isinstance( + bound_arg_type, warp._src.context.Function + ): continue bound_arg_type_stripped = strip_reference(bound_arg_type) @@ -963,6 +967,332 @@ def func_match_args(func, arg_types, kwarg_types): return True +def is_regular_builtin_callable_target(func): + """Return whether ``func`` is a simple built-in ``Callable`` target. + + Callable parameters can specialize on built-ins that lower directly through + the normal built-in function path. Built-ins that need variadic handling, + dispatch callbacks, replay suppression, or LTO dispatch are excluded because + they need special codegen behavior that cannot be represented by replacing a + callable parameter with a direct function call. + + Args: + func: Function object to check. + + Returns: + ``True`` if every overload of ``func`` is supported as a Callable target. + """ + + if not isinstance(func, warp._src.context.Function) or not func.is_builtin(): + return False + + overloads = getattr(func, "overloads", None) or (func,) + for overload in overloads: + if ( + not overload.is_builtin() + or overload.variadic + or overload.dispatch_func is not None + or overload.lto_dispatch_func is not None + or overload.skip_replay + ): + return False + + return True + + +def get_callable_arg_values(func, bound_args): + """Return concrete function targets for ``Callable`` parameters. + + ``bound_args`` already includes defaults. A non-empty result means the call + needs a specialized clone where callable parameter names resolve directly to + function objects during codegen instead of runtime variables. + + Args: + func: Function being called. + bound_args: Bound argument values for the call, including defaults. + + Returns: + A mapping from Callable parameter names to concrete Warp functions, or + ``None`` when the call has no concrete Callable targets. + """ + + if func.is_builtin(): + return None + + callable_arg_values = {} + + for name, value in bound_args.items(): + if not warp._src.types.is_callable_annotation(func.input_types.get(name)): + continue + + if not isinstance(value, warp._src.context.Function): + continue + + if value.is_builtin() and not is_regular_builtin_callable_target(value): + raise WarpCodegenError( + "Callable parameters support user-defined Warp functions and simple built-in Warp functions " + "such as wp.sin() and wp.min(), " + f"but parameter '{name}' of '{func.key}' received unsupported built-in function '{value.key}'." + ) + + callable_arg_values[name] = value + + if callable_arg_values: + return callable_arg_values + + return None + + +def get_default_arg_value(func, name, value): + """Return the codegen value for a default argument. + + Callable defaults are specialization inputs, not runtime constants, so they + stay as raw function objects. Other defaults are represented as constant + variables and emitted through the regular default-argument path. + + Args: + func: Function that owns the default argument. + name: Parameter name for ``value``. + value: Python default value from the function signature. + + Returns: + A Warp function for Callable defaults, otherwise a constant ``Var``. + """ + + if warp._src.types.is_callable_annotation(func.input_types.get(name)) and isinstance( + value, warp._src.context.Function + ): + # Callable defaults need the same specialization path as explicit + # callable arguments. + return value + + return Var(None, type=type(value), constant=value) + + +def bind_call_arg_nodes(func, call_node): + """Bind a call AST to ``func`` and return AST/default arguments by name.""" + + try: + bound_args = func.signature.bind(*call_node.args, **{kw.arg: kw.value for kw in call_node.keywords}) + except TypeError: + return {} + + default_args = {k: v for k, v in func.defaults.items() if k not in bound_args.arguments and v is not None} + apply_defaults(bound_args, default_args) + return bound_args.arguments + + +def resolve_callable_arg_target(adj, arg_node, callable_arg_values=None): + """Resolve a callable argument node or default to a concrete Warp function. + + Args: + adj: Adjoint whose symbols and globals should be used for resolution. + arg_node: AST node or default value bound to a Callable parameter. + callable_arg_values: Specialized Callable targets already bound in the + caller, keyed by parameter name. + + Returns: + The resolved Warp function, or the unresolved object when resolution + does not produce a function. + """ + + if isinstance(arg_node, warp._src.context.Function): + return arg_node + + if callable_arg_values and isinstance(arg_node, ast.Name): + callable_func = callable_arg_values.get(arg_node.id) + if callable_func is not None: + return callable_func + + callable_func, _ = adj.resolve_static_expression(arg_node, eval_types=False) + return callable_func + + +_UNRESOLVED_CALL_ARG = object() + + +def resolve_call_arg_type(adj, arg_node, callable_arg_values=None): + """Best-effort static type resolution for call arguments during reference scans.""" + + if isinstance(arg_node, ast.Name): + if callable_arg_values: + callable_func = callable_arg_values.get(arg_node.id) + if callable_func is not None: + return get_arg_type(callable_func) + + symbol = adj.symbols.get(arg_node.id) + if symbol is not None: + return get_arg_type(symbol) + + obj = adj.resolve_external_reference(arg_node.id) + if obj is not None: + return get_arg_type(obj) + + return _UNRESOLVED_CALL_ARG + + if isinstance(arg_node, ast.Attribute): + obj, _ = adj.resolve_static_expression(arg_node, eval_types=False) + if obj is not None: + return get_arg_type(obj) + + return _UNRESOLVED_CALL_ARG + + if isinstance(arg_node, ast.Constant): + return get_arg_type(arg_node.value) + + try: + return get_arg_type(ast.literal_eval(arg_node)) + except (TypeError, ValueError): + return _UNRESOLVED_CALL_ARG + + +def resolve_call_func_overload(adj, func, call_node, callable_arg_values=None): + """Resolve ``func`` to the overload selected by static call arguments. + + Callable target discovery needs the same overload that normal call + resolution will choose. If any argument cannot be resolved statically, this + returns the original function so callers can fall back to conservative + behavior. + + Args: + adj: Adjoint whose symbols and globals should be used for resolution. + func: Function object referenced by the call. + call_node: AST call node whose arguments select the overload. + callable_arg_values: Specialized Callable targets already bound in the + caller, keyed by parameter name. + + Returns: + The selected overload when it can be resolved, otherwise ``func``. + """ + + if not isinstance(func, warp._src.context.Function) or func.is_builtin(): + return func + + arg_types = [] + for arg_node in call_node.args: + if isinstance(arg_node, ast.Starred): + return func + + arg_type = resolve_call_arg_type(adj, arg_node, callable_arg_values) + if arg_type is _UNRESOLVED_CALL_ARG: + return func + + arg_types.append(arg_type) + + kwarg_types = {} + for kw_node in call_node.keywords: + if kw_node.arg is None: + return func + + arg_type = resolve_call_arg_type(adj, kw_node.value, callable_arg_values) + if arg_type is _UNRESOLVED_CALL_ARG: + return func + + kwarg_types[kw_node.arg] = arg_type + + overload = func.get_overload(tuple(arg_types), kwarg_types) + return overload or func + + +def iter_call_callable_arg_targets(adj, func, call_node, callable_arg_values=None): + """Yield Warp function targets passed to ``Callable`` parameters. + + Args: + adj: Adjoint whose symbols and globals should be used for resolution. + func: Function object referenced by the call. + call_node: AST call node whose arguments may include Callable targets. + callable_arg_values: Specialized Callable targets already bound in the + caller, keyed by parameter name. + + Yields: + Concrete Warp functions supplied to Callable parameters by explicit + arguments or defaults. + """ + + if not isinstance(func, warp._src.context.Function) or func.is_builtin(): + return + + func = resolve_call_func_overload(adj, func, call_node, callable_arg_values) + bound_arg_nodes = bind_call_arg_nodes(func, call_node) + + for arg_name, arg_node in bound_arg_nodes.items(): + if not warp._src.types.is_callable_annotation(func.input_types.get(arg_name)): + continue + + callable_func = resolve_callable_arg_target(adj, arg_node, callable_arg_values) + if isinstance(callable_func, warp._src.context.Function): + yield callable_func + + +def specialize_callable_func(func, callable_arg_values): + """Clone ``func`` for a concrete set of Callable parameter targets. + + Callable targets affect generated code but are omitted from the native C++ + signature, so each target set needs a cached specialization with a distinct + native function name. The clone keeps the original Python source and arg + types while storing the concrete Callable targets on the new adjoint. + + Args: + func: User-defined function to specialize. + callable_arg_values: Mapping from Callable parameter names to concrete + Warp functions. + + Returns: + A cached specialized clone of ``func`` for ``callable_arg_values``. + """ + + if func.custom_grad_func is not None or func.custom_replay_func is not None: + raise WarpCodegenError( + "Callable parameters are not supported on functions with custom gradients or replay functions: " + f"'{func.key}'" + ) + + specialization_key = tuple( + (name, callable_arg_values[name]) for name in func.input_types if name in callable_arg_values + ) + + specializations = getattr(func, "_callable_specializations", None) + if specializations is None: + specializations = {} + func._callable_specializations = specializations + + specialized_func = specializations.get(specialization_key) + if specialized_func is not None: + return specialized_func + + # The callable targets are inlined by name while being omitted from the C++ + # function parameters, so each target set needs a distinct native name. + suffix_hash = hashlib.sha256() + suffix_hash.update(bytes(func.native_func, "utf-8")) + for name, callable_func in specialization_key: + suffix_hash.update(bytes(name, "utf-8")) + suffix_hash.update(bytes(callable_func.key, "utf-8")) + suffix_hash.update(bytes(callable_func.native_func, "utf-8")) + + specialized_func = shallowcopy(func) + # Specialization clones should not share the parent specialization cache. + specialized_func.__dict__.pop("_callable_specializations", None) + specialized_func.native_func = f"{func.native_func}_callable_{suffix_hash.hexdigest()[:12]}" + specialized_func.value_func = None + specialized_func.adj = Adjoint( + func.func, + overload_annotations=func.adj.arg_types, + is_user_function=func.adj.is_user_function, + skip_forward_codegen=func.adj.skip_forward_codegen, + skip_reverse_codegen=func.adj.skip_reverse_codegen, + custom_reverse_mode=func.adj.custom_reverse_mode, + custom_reverse_num_input_args=func.adj.custom_reverse_num_input_args, + transformers=func.adj.transformers, + source=func.adj.source, + ) + specialized_func.adj.callable_arg_values = dict(callable_arg_values) + specialized_func.adj.used_by_backward_kernel = func.adj.used_by_backward_kernel + specialized_func.adj.force_adjoint_codegen = func.adj.force_adjoint_codegen + + specializations[specialization_key] = specialized_func + return specialized_func + + def get_arg_type(arg: Var | Any) -> type: arg = strip_reference(arg) @@ -1341,7 +1671,7 @@ def _try_extract_function_source(code: types.CodeType) -> tuple[str, int] | None # generate function ssa form and adjoint @synchronized(_codegen_lock) - def build(adj, builder, default_builder_options=None): + def build(adj, builder, default_builder_options=None, callable_arg_values=None): # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build for arg in adj.args: arg.is_read = False @@ -1350,6 +1680,9 @@ def build(adj, builder, default_builder_options=None): if adj.skip_build: return + if callable_arg_values is None: + callable_arg_values = getattr(adj, "callable_arg_values", None) + adj.builder = builder if default_builder_options is None: @@ -1386,9 +1719,13 @@ def build(adj, builder, default_builder_options=None): # tracks how much additional shared memory is required by any dependent function calls adj.max_required_extra_shared_memory = 0 - # update symbol map for each argument + # Callable-specialized functions replace selected argument Vars with + # Function objects so calls like `op(x)` resolve statically. for a in adj.args: - adj.symbols[a.label] = a + if callable_arg_values is not None and a.label in callable_arg_values: + adj.symbols[a.label] = callable_arg_values[a.label] + else: + adj.symbols[a.label] = a # recursively evaluate function body try: @@ -1746,7 +2083,7 @@ def resolve_func(adj, func, arg_types, kwarg_types, min_outputs): f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_type_reprs)}]" ) - def add_call(adj, func, args, kwargs, type_args, min_outputs=None): + def resolve_call(adj, func, args, kwargs, type_args=None, min_outputs=None): # Extract the types and values passed as arguments to the function call. arg_types = tuple(get_arg_type(x) for x in args) kwarg_types = {k: get_arg_type(v) for k, v in kwargs.items()} @@ -1758,6 +2095,9 @@ def add_call(adj, func, args, kwargs, type_args, min_outputs=None): # in order to process them as Python does it. bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs) + if type_args is None: + type_args = {} + # Type args are the "compile time" argument values we get from codegen. # For example, when calling `wp.vec3f(...)` from within a kernel, # this translates in fact to calling the `vector()` built-in augmented @@ -1787,13 +2127,21 @@ def add_call(adj, func, args, kwargs, type_args, min_outputs=None): if func.defaults: default_vars = { - k: Var(None, type=type(v), constant=v) + k: get_default_arg_value(func, k, v) for k, v in func.defaults.items() if k not in bound_args.arguments and v is not None } apply_defaults(bound_args, default_vars) bound_args = bound_args.arguments + callable_arg_values = get_callable_arg_values(func, bound_args) + if callable_arg_values is not None: + func = specialize_callable_func(func, callable_arg_values) + + return func, bound_args + + def add_call(adj, func, args, kwargs, type_args, min_outputs=None): + func, bound_args = adj.resolve_call(func, args, kwargs, type_args, min_outputs) # Constant precision preservation: when calling a 64-bit scalar type # constructor with a single compile-time constant argument, emit @@ -1832,6 +2180,8 @@ def add_call(adj, func, args, kwargs, type_args, min_outputs=None): # we need to ensure its adjoint is also being generated. if adj.used_by_backward_kernel: func.adj.used_by_backward_kernel = True + if adj.force_adjoint_codegen: + func.adj.force_adjoint_codegen = True if adj.builder is None: func.build(None) @@ -1898,7 +2248,12 @@ def add_call(adj, func, args, kwargs, type_args, min_outputs=None): elif func.dispatch_func is not None: func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args) else: - func_args = tuple(bound_args.values()) + func_args = tuple( + value + for name, value in bound_args.items() + if func.is_builtin() or not warp._src.types.is_callable_annotation(func.input_types.get(name)) + ) + # Callable parameters are specialization inputs, not C++ arguments. template_args = () func_args = tuple(adj.register_var(x) for x in func_args) @@ -1916,6 +2271,8 @@ def add_call(adj, func, args, kwargs, type_args, min_outputs=None): if isinstance(func_arg_var, warp._src.context.Function) and not func_arg_var.is_builtin(): if adj.used_by_backward_kernel: func_arg_var.adj.used_by_backward_kernel = True + if adj.force_adjoint_codegen: + func_arg_var.adj.force_adjoint_codegen = True adj.builder.build_function(func_arg_var) @@ -2005,10 +2362,7 @@ def add_grad_call(adj, func, args, kwargs): This gradient call is forward-only and does NOT participate in automatic differentiation. """ - # Resolve the function overload based on argument types - arg_types = tuple(get_arg_type(x) for x in args) - kwarg_types = {k: get_arg_type(v) for k, v in kwargs.items()} - func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs=None) + func, bound_args = adj.resolve_call(func, args, kwargs) if not func.is_differentiable: raise WarpCodegenError(f"Cannot compute gradient of non-differentiable function '{func.key}'") @@ -2038,20 +2392,6 @@ def add_grad_call(adj, func, args, kwargs): elif func not in adj.builder.functions: adj.builder.build_function(func) - # Get function's input types - input_types = func.input_types - - # Bind arguments to function signature - bound_args = func.signature.bind(*args, **kwargs) - if func.defaults: - default_vars = { - k: Var(None, type=type(v), constant=v) - for k, v in func.defaults.items() - if k not in bound_args.arguments and v is not None - } - apply_defaults(bound_args, default_vars) - bound_args = bound_args.arguments - # Get return type bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()} bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()} @@ -2063,6 +2403,12 @@ def add_grad_call(adj, func, args, kwargs): if return_type is None: raise WarpCodegenError(f"Cannot compute gradient of void function '{func.key}'") + # Callable parameters are specialization inputs, not native function + # arguments, and their adjoints are not representable. + input_types = { + name: typ for name, typ in func.input_types.items() if not warp._src.types.is_callable_annotation(typ) + } + # Load input arguments into variables fwd_args_loaded = [adj.load(bound_args[name]) for name in input_types.keys()] @@ -3239,22 +3585,25 @@ def emit_Call(adj, node): out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs) if adj.builder_options.get("verify_autograd_array_access", False): - # Extract the types and values passed as arguments to the function call. - arg_types = tuple(get_arg_type(x) for x in args) - kwarg_types = {k: get_arg_type(v) for k, v in kwargs.items()} - - # Resolve the exact function signature among any existing overload. - resolved_func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs) + resolved_func, resolved_bound_args = adj.resolve_call(func, args, kwargs, type_args, min_outputs) # update arg read/write states according to what happens to that arg in the called function if hasattr(resolved_func, "adj"): - for i, arg in enumerate(args): - if resolved_func.adj.args[i].is_write: + resolved_args_by_name = {arg.label: arg for arg in resolved_func.adj.args} + for name, arg in resolved_bound_args.items(): + if warp._src.types.is_callable_annotation(resolved_func.input_types.get(name)): + continue + + resolved_arg = resolved_args_by_name.get(name) + if resolved_arg is None or not isinstance(arg, Var): + continue + + if resolved_arg.is_write: kernel_name = adj.fun_name filename = adj.filename lineno = adj.lineno + adj.fun_lineno arg.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno) - if resolved_func.adj.args[i].is_read: + if resolved_arg.is_read: arg.mark_read() return out @@ -4746,6 +5095,7 @@ def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp._src. types: dict[Struct | type, Any] = {} functions: dict[warp._src.context.Function, Any] = {} max_dim = 0 # thread-grid dimension, inferred from wp.tid() unpack arity + callable_arg_values = getattr(adj, "callable_arg_values", None) or {} # Shared single traversal (see reference_nodes); resolved here at hash time. for node in adj.reference_nodes(): @@ -4762,9 +5112,19 @@ def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp._src. elif isinstance(node, ast.Call): func, _ = adj.resolve_static_expression(node.func, eval_types=False) + if func is None and isinstance(node.func, ast.Name): + func = callable_arg_values.get(node.func.id) + if isinstance(func, warp._src.context.Function) and not func.is_builtin(): # calling user-defined function functions[func] = None + + # Callable targets are passed as values, so they must be + # added explicitly to the function reference set. Built-in + # targets are hash inputs too, but they are filtered out by + # module dependency discovery because they have no module. + for callable_func in iter_call_callable_arg_targets(adj, func, node, callable_arg_values): + functions[callable_func] = None elif isinstance(func, Struct): # calling struct constructor types[func] = None @@ -5393,6 +5753,8 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None, forward_only # forward args for i, arg in enumerate(adj.args): + if warp._src.types.is_callable_annotation(arg.type): + continue if is_tile(arg.type) or is_tile_stack(arg.type): tname = f"tile_{arg.label}" template_params.append(tname) @@ -5409,6 +5771,8 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None, forward_only # reverse args for i, arg in enumerate(adj.args): + if warp._src.types.is_callable_annotation(arg.type): + continue if adj.custom_reverse_mode and i >= adj.custom_reverse_num_input_args: break # indexed array gradients are regular arrays diff --git a/warp/_src/context.py b/warp/_src/context.py index 03238ac077..043909bb6a 100644 --- a/warp/_src/context.py +++ b/warp/_src/context.py @@ -484,6 +484,11 @@ def get_overload(self, arg_types: list[type], kwarg_types: Mapping[str, type]) - args_matched = True for i in range(len(arg_types)): + # Callable annotations stay type-erased here; specialization + # handles the concrete function target. + if warp._src.types.is_callable_annotation(template_types[i]) and isinstance(arg_types[i], Function): + continue + if not warp._src.types.type_matches_template(arg_types[i], template_types[i]): args_matched = False break @@ -492,11 +497,22 @@ def get_overload(self, arg_types: list[type], kwarg_types: Mapping[str, type]) - # instantiate this function with the specified argument types arg_names = f.input_types.keys() - overload_annotations = dict(zip(arg_names, arg_types, strict=False)) + overload_annotations = {} + for name, arg_type, template_type in zip(arg_names, arg_types, template_types, strict=False): + if warp._src.types.is_callable_annotation(template_type) and isinstance(arg_type, Function): + overload_annotations[name] = template_type + else: + overload_annotations[name] = arg_type + # add defaults for k, d in f.defaults.items(): if k not in overload_annotations: - overload_annotations[k] = warp._src.codegen.strip_reference(warp._src.codegen.get_arg_type(d)) + template_type = f.input_types[k] + default_type = warp._src.codegen.strip_reference(warp._src.codegen.get_arg_type(d)) + if warp._src.types.is_callable_annotation(template_type) and isinstance(default_type, Function): + overload_annotations[k] = template_type + else: + overload_annotations[k] = default_type ovl = shallowcopy(f) ovl.adj = warp._src.codegen.Adjoint(f.func, overload_annotations, source=f.adj.source) @@ -504,7 +520,9 @@ def get_overload(self, arg_types: list[type], kwarg_types: Mapping[str, type]) - ovl.value_func = None ovl.generic_parent = f - sig = warp._src.types.get_signature(arg_types, func_name=self.key) + sig = warp._src.types.get_signature( + list(overload_annotations.values()), func_name=self.key, arg_names=list(overload_annotations.keys()) + ) self.user_overloads[sig] = ovl return ovl @@ -2345,6 +2363,30 @@ def hash_function(self, func: Function) -> bytes: return h + @staticmethod + def hash_builtin_function(func: Function) -> bytes: + """Hash the identity of a built-in function used as a specialization input. + + Built-in Callable targets do not add module dependency edges, but they + still change generated code. Including their stable identity in the + module hash prevents a cached module compiled for one built-in target + from being reused after the target changes. + + Args: + func: Built-in function to hash. + + Returns: + Digest bytes representing the built-in function identity. + """ + + ch = hashlib.sha256() + + ch.update(bytes("builtin", "utf-8")) + ch.update(bytes(func.key, "utf-8")) + ch.update(bytes(func.native_func, "utf-8")) + + return ch.digest() + def hash_adjoint(self, adj: warp._src.codegen.Adjoint) -> bytes: # NOTE: We don't cache adjoint hashes, because adjoints are always unique. # Even instances of generic kernels and functions have unique adjoints with @@ -2395,7 +2437,10 @@ def hash_adjoint(self, adj: warp._src.codegen.Adjoint) -> bytes: # hash referenced functions for f in functions.keys(): if f not in self.functions_in_progress: - ch.update(self.hash_function(f)) + if f.is_builtin(): + ch.update(self.hash_builtin_function(f)) + else: + ch.update(self.hash_function(f)) return ch.digest() @@ -3007,6 +3052,8 @@ def add_ref(ref): self.references.add(ref) ref.dependents.add(self) + callable_arg_values = getattr(adj, "callable_arg_values", None) or {} + # scan for function calls and kernel-local function bindings. ``reference_nodes`` shares # a single AST traversal with Adjoint.get_references; it also yields Name/Attribute # nodes, which this dependency scan ignores. @@ -3015,11 +3062,22 @@ def add_ref(ref): try: # try to resolve the function func, _ = adj.resolve_static_expression(node.func, eval_types=False) + if func is None and isinstance(node.func, ast.Name): + func = callable_arg_values.get(node.func.id) # if this is a user-defined function, add a module reference if isinstance(func, warp._src.context.Function) and func.module is not None: add_ref(func.module) + if isinstance(func, warp._src.context.Function) and not func.is_builtin(): + # Callable targets can come from arguments or defaults; + # either way their modules must invalidate this module. + for callable_func in warp._src.codegen.iter_call_callable_arg_targets( + adj, func, node, callable_arg_values + ): + if not callable_func.is_builtin() and callable_func.module is not None: + add_ref(callable_func.module) + except Exception: # Lookups may fail for builtins, but that's ok. # Lookups may also fail for functions in this module that haven't been imported yet, diff --git a/warp/_src/types.py b/warp/_src/types.py index 74aff6176c..765ad00d82 100644 --- a/warp/_src/types.py +++ b/warp/_src/types.py @@ -7186,11 +7186,19 @@ def infer_argument_types(args: list[Any], template_types, arg_names: list[str] | } +def is_callable_annotation(annotation) -> bool: + """Return whether an annotation denotes a type-erased callable.""" + + return annotation is Callable or get_origin(annotation) is Callable + + def get_type_code(arg_type) -> str: if arg_type is Any: # special case for generics # note: since Python 3.11 Any is a type, so we check for it first return "?" + elif is_callable_annotation(arg_type): + return "c" elif ( sys.version_info < (3, 11) and hasattr(types, "GenericAlias") @@ -7270,9 +7278,6 @@ def get_type_code(arg_type) -> str: elif arg_type == Int: # generic int return "i?" - elif isinstance(arg_type, Callable): - # TODO: elaborate on Callable type? - return "c" elif arg_type is Ellipsis: return "?" else: diff --git a/warp/tests/aux_test_callable_double.py b/warp/tests/aux_test_callable_double.py new file mode 100644 index 0000000000..154622ec29 --- /dev/null +++ b/warp/tests/aux_test_callable_double.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Callable target fixture imported as a Python module.""" + +import warp as wp + + +@wp.func +def callable_external_module_double_it(x: float): + return x * 2.0 diff --git a/warp/tests/aux_test_callable_triple.py b/warp/tests/aux_test_callable_triple.py new file mode 100644 index 0000000000..6befcf2d79 --- /dev/null +++ b/warp/tests/aux_test_callable_triple.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Callable target fixture imported as a Python module.""" + +import warp as wp + + +@wp.func +def callable_external_module_triple_it(x: float): + return x * 3.0 diff --git a/warp/tests/test_func_callable.py b/warp/tests/test_func_callable.py new file mode 100644 index 0000000000..1aaaa99b5d --- /dev/null +++ b/warp/tests/test_func_callable.py @@ -0,0 +1,695 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for ``Callable`` parameters in user-defined Warp functions. + +These tests live outside ``test_func.py`` because ``Callable`` parameter support +needs a dedicated set of helper functions, kernels, module dependency checks, +hashing checks, and rejection-path coverage. Keep future ``Callable`` +parameter tests in this module so ``test_func.py`` remains focused on general +``@wp.func`` behavior. +""" + +import unittest +from collections.abc import Callable as CollectionsCallable +from typing import Any +from typing import Callable as TypingCallable # noqa: UP035 + +import numpy as np + +import warp as wp +import warp.tests.aux_test_callable_double as callable_double_module +import warp.tests.aux_test_callable_triple as callable_triple_module +from warp.tests.unittest_utils import * + + +# User-facing helpers used by the runtime tests below. Keep these examples close +# to patterns a Warp user would write in kernels or user functions. +@wp.func +def callable_double_it(x: float): + return x * 2.0 + + +@wp.func +def callable_triple_it(x: float): + return x * 3.0 + + +@wp.func +def callable_apply_typing(g: TypingCallable, x: float): + return g(x) + + +@wp.func +def callable_apply_collections(g: CollectionsCallable, x: float): + return g(x) + + +@wp.func +def callable_apply_typing_parameterized(g: TypingCallable[[float], float], x: float): + return g(x) + + +@wp.func +def callable_apply_collections_parameterized(g: CollectionsCallable[[float], float], x: float): + return g(x) + + +@wp.func +def callable_apply_generic(g: TypingCallable, x: Any): + return g(x) + + +@wp.func +def callable_apply_default(g: TypingCallable = callable_double_it, x: float = 3.0): + return g(x) + + +@wp.func +def callable_apply_builtin_default(g: TypingCallable = wp.sin, x: float = 0.5): + return g(x) + + +@wp.func +def callable_apply_nested(x: float): + return callable_apply_typing(callable_double_it, x) + + +@wp.func +def callable_apply_binary(g: TypingCallable, x: float, y: float): + return g(x, y) + + +@wp.func +def callable_apply_bound_parameter(g: TypingCallable, x: float): + f = g + return f(x) + + +@wp.func +def callable_square_it(x: float): + return x * x + + +@wp.func +def callable_apply_for_grad(g: TypingCallable, x: float): + return g(x) + + +@wp.func +def callable_apply_builtin_for_grad(g: TypingCallable, x: float): + return g(x) + + +@wp.kernel +def callable_func_parameter_kernel(out: wp.array[float]): + out[0] = callable_apply_typing(callable_double_it, 3.0) + out[1] = callable_apply_typing(callable_triple_it, 4.0) + out[2] = callable_apply_collections(callable_double_it, 5.0) + out[3] = callable_apply_collections(callable_triple_it, 6.0) + out[4] = callable_apply_nested(7.0) + out[5] = callable_apply_default() + out[6] = callable_apply_typing(g=callable_triple_it, x=8.0) + out[7] = callable_apply_typing_parameterized(callable_double_it, 9.0) + out[8] = callable_apply_collections_parameterized(callable_triple_it, 10.0) + out[9] = callable_apply_generic(callable_double_it, 11.0) + + +def test_callable_func_parameter(test, device): + """Verify Callable parameters accept common user-facing annotation forms.""" + out = wp.empty(10, dtype=float, device=device) + + wp.launch(callable_func_parameter_kernel, dim=1, outputs=[out], device=device) + + assert_np_equal( + out.numpy(), + np.array([6.0, 12.0, 10.0, 18.0, 14.0, 6.0, 24.0, 18.0, 30.0, 22.0], dtype=np.float32), + ) + + +# ``wp.static()`` needs a global container so the kernel can resolve a function +# target statically while exercising the local-binding composition path. +CALLABLE_STATIC_TARGETS = {"double": callable_double_it} + + +@wp.kernel +def callable_func_parameter_local_binding_kernel(out: wp.array[float]): + f = callable_double_it + out[0] = callable_apply_typing(f, 3.0) + + out[1] = callable_apply_bound_parameter(callable_triple_it, 4.0) + + static_f = wp.static(CALLABLE_STATIC_TARGETS["double"]) + out[2] = callable_apply_typing(static_f, 5.0) + + +def test_callable_func_parameter_local_binding(test, device): + """Verify Callable parameters compose with function-valued local aliases.""" + out = wp.empty(3, dtype=float, device=device) + + wp.launch(callable_func_parameter_local_binding_kernel, dim=1, outputs=[out], device=device) + + assert_np_equal(out.numpy(), np.array([6.0, 12.0, 10.0], dtype=np.float32)) + + +@wp.kernel +def callable_func_parameter_external_module_kernel(cond: wp.array(dtype=bool), out: wp.array(dtype=float)): + i = wp.tid() + if cond[i]: + out[i] = callable_apply_typing(callable_double_module.callable_external_module_double_it, 3.0) + else: + out[i] = callable_apply_typing(callable_triple_module.callable_external_module_triple_it, 3.0) + + +def test_callable_func_parameter_external_module(test, device): + """Verify Callable parameters accept targets imported from Python modules.""" + cond = wp.array([True, False], dtype=bool, device=device) + out = wp.empty(2, dtype=float, device=device) + + wp.launch(callable_func_parameter_external_module_kernel, dim=2, inputs=[cond], outputs=[out], device=device) + + assert_np_equal(out.numpy(), np.array([6.0, 9.0], dtype=np.float32)) + + +@wp.kernel +def callable_builtin_func_parameter_kernel(out: wp.array(dtype=float)): + out[0] = callable_apply_typing(wp.sin, 0.5) + out[1] = callable_apply_typing(wp.sqrt, 9.0) + out[2] = callable_apply_binary(wp.add, 2.0, 3.0) + out[3] = callable_apply_binary(wp.min, 7.0, 4.0) + out[4] = callable_apply_builtin_default() + out[5] = callable_apply_typing(g=wp.cos, x=0.0) + + +def test_callable_builtin_func_parameter(test, device): + """Verify Callable parameters accept simple built-in Warp functions.""" + out = wp.empty(6, dtype=float, device=device) + + wp.launch(callable_builtin_func_parameter_kernel, dim=1, outputs=[out], device=device) + + assert_np_equal( + out.numpy(), + np.array([np.sin(0.5), 3.0, 5.0, 4.0, np.sin(0.5), 1.0], dtype=np.float32), + ) + + +# This global is intentionally mutated by +# ``test_callable_argument_target_affects_module_hash``. The kernel reads it as +# a Callable argument so the module hash must change when the target changes. +CALLABLE_TARGET = callable_double_it + + +@wp.kernel +def callable_global_target_kernel(out: wp.array[float]): + out[0] = callable_apply_typing(CALLABLE_TARGET, 3.0) + + +CALLABLE_BUILTIN_TARGET = wp.sin + + +@wp.kernel +def callable_builtin_global_target_kernel(out: wp.array[float]): + out[0] = callable_apply_typing(CALLABLE_BUILTIN_TARGET, 0.5) + + +@wp.kernel +def callable_default_target_kernel(out: wp.array[float]): + out[0] = callable_apply_default() + + +@wp.kernel +def callable_builtin_default_target_kernel(out: wp.array[float]): + out[0] = callable_apply_builtin_default() + + +CALLABLE_OVERLOAD_TARGET = callable_double_it + + +@wp.func +def callable_apply_overloaded(x: int): + return float(x) + + +@wp.func +def callable_apply_overloaded(g: TypingCallable, x: float): + return g(x) + + +@wp.kernel +def callable_overloaded_global_target_kernel(out: wp.array[float]): + out[0] = callable_apply_overloaded(CALLABLE_OVERLOAD_TARGET, 3.0) + + +@wp.func +def callable_apply_overloaded_default(x: int): + return float(x) + + +@wp.func +def callable_apply_overloaded_default(x: float, g: TypingCallable = callable_double_it): + return g(x) + + +@wp.kernel +def callable_overloaded_default_target_kernel(out: wp.array[float]): + out[0] = callable_apply_overloaded_default(3.0) + + +@wp.kernel(enable_backward=False, module="unique") +def callable_grad_kernel(out: wp.array[float]): + out[0] = wp.grad(callable_apply_for_grad)(callable_square_it, 3.0) + + +@wp.kernel(enable_backward=False, module="unique") +def callable_builtin_grad_kernel(out: wp.array[float]): + out[0] = wp.grad(callable_apply_builtin_for_grad)(wp.sin, 0.5) + + +@wp.func +def callable_read_array(arr: wp.array(dtype=float), i: int): + return arr[i] + + +@wp.func +def callable_apply_array(g: TypingCallable, arr: wp.array(dtype=float), i: int): + return g(arr, i) + + +@wp.kernel(module="unique") +def callable_array_read_kernel(arr: wp.array(dtype=float), out: wp.array(dtype=float)): + out[0] = callable_apply_array(callable_read_array, arr, 0) + + +# These explicit modules are part of the behavior under test. Callable targets +# from provider modules must be registered as dependencies of consumer modules so +# provider unloads invalidate stale consumer kernels. +CALLABLE_DEPENDENCY_EXPLICIT_PROVIDER_MODULE = wp.Module("callable_dependency_explicit_provider") +CALLABLE_DEPENDENCY_EXPLICIT_CONSUMER_MODULE = wp.Module("callable_dependency_explicit_consumer") +CALLABLE_DEPENDENCY_DEFAULT_PROVIDER_MODULE = wp.Module("callable_dependency_default_provider") +CALLABLE_DEPENDENCY_DEFAULT_CONSUMER_MODULE = wp.Module("callable_dependency_default_consumer") +CALLABLE_DEPENDENCY_LOCAL_PROVIDER_MODULE = wp.Module("callable_dependency_local_provider") +CALLABLE_DEPENDENCY_LOCAL_CONSUMER_MODULE = wp.Module("callable_dependency_local_consumer") +CALLABLE_DEPENDENCY_OVERLOAD_PROVIDER_MODULE = wp.Module("callable_dependency_overload_provider") +CALLABLE_DEPENDENCY_OVERLOAD_CONSUMER_MODULE = wp.Module("callable_dependency_overload_consumer") +CALLABLE_DEPENDENCY_EXTERNAL_CONSUMER_MODULE = wp.Module("callable_dependency_external_consumer") + + +@wp.func(module=CALLABLE_DEPENDENCY_EXPLICIT_PROVIDER_MODULE) +def callable_dependency_explicit_target(x: float): + return x + 1.0 + + +@wp.func(module=CALLABLE_DEPENDENCY_EXPLICIT_CONSUMER_MODULE) +def callable_dependency_apply_explicit(g: TypingCallable, x: float): + return g(x) + + +@wp.kernel(module=CALLABLE_DEPENDENCY_EXPLICIT_CONSUMER_MODULE) +def callable_dependency_explicit_kernel(out: wp.array[float]): + out[0] = callable_dependency_apply_explicit(callable_dependency_explicit_target, 2.0) + + +@wp.func(module=CALLABLE_DEPENDENCY_DEFAULT_PROVIDER_MODULE) +def callable_dependency_default_target(x: float): + return x + 1.0 + + +@wp.func(module=CALLABLE_DEPENDENCY_DEFAULT_CONSUMER_MODULE) +def callable_dependency_apply_default(g: TypingCallable = callable_dependency_default_target, x: float = 2.0): + return g(x) + + +@wp.kernel(module=CALLABLE_DEPENDENCY_DEFAULT_CONSUMER_MODULE) +def callable_dependency_default_kernel(out: wp.array[float]): + out[0] = callable_dependency_apply_default() + + +@wp.func(module=CALLABLE_DEPENDENCY_LOCAL_PROVIDER_MODULE) +def callable_dependency_local_target(x: float): + return x + 1.0 + + +@wp.func(module=CALLABLE_DEPENDENCY_LOCAL_CONSUMER_MODULE) +def callable_dependency_apply_local(g: TypingCallable, x: float): + return g(x) + + +@wp.kernel(module=CALLABLE_DEPENDENCY_LOCAL_CONSUMER_MODULE) +def callable_dependency_local_kernel(out: wp.array[float]): + f = callable_dependency_local_target + out[0] = callable_dependency_apply_local(f, 2.0) + + +@wp.func(module=CALLABLE_DEPENDENCY_OVERLOAD_PROVIDER_MODULE) +def callable_dependency_overload_target(x: float): + return x + 1.0 + + +@wp.func(module=CALLABLE_DEPENDENCY_OVERLOAD_CONSUMER_MODULE) +def callable_dependency_apply_overload(x: int): + return float(x) + + +@wp.func(module=CALLABLE_DEPENDENCY_OVERLOAD_CONSUMER_MODULE) +def callable_dependency_apply_overload(g: TypingCallable, x: float): + return g(x) + + +@wp.kernel(module=CALLABLE_DEPENDENCY_OVERLOAD_CONSUMER_MODULE) +def callable_dependency_overload_kernel(out: wp.array[float]): + out[0] = callable_dependency_apply_overload(callable_dependency_overload_target, 2.0) + + +@wp.kernel(module=CALLABLE_DEPENDENCY_EXTERNAL_CONSUMER_MODULE) +def callable_dependency_external_module_kernel(cond: wp.array(dtype=bool), out: wp.array(dtype=float)): + i = wp.tid() + if cond[i]: + out[i] = callable_apply_typing(callable_double_module.callable_external_module_double_it, 3.0) + else: + out[i] = callable_apply_typing(callable_triple_module.callable_external_module_triple_it, 3.0) + + +# These rejection fixtures live at module scope because custom grad and replay +# hooks are registered against a concrete ``@wp.func`` object. +@wp.func +def callable_custom_grad_unsupported(g: TypingCallable, x: float): + return x + + +@wp.func_grad(callable_custom_grad_unsupported) +def adj_callable_custom_grad_unsupported(g: TypingCallable, x: float, adj_ret: float): + wp.adjoint[x] += adj_ret + + +@wp.func +def callable_custom_replay_unsupported(g: TypingCallable, x: float): + return x + + +@wp.func_replay(callable_custom_replay_unsupported) +def replay_callable_custom_replay_unsupported(g: TypingCallable, x: float): + return x + + +class TestFuncCallable(unittest.TestCase): + def test_callable_argument_target_affects_module_hash(self): + """Verify explicit Callable targets participate in module hashes.""" + global CALLABLE_TARGET + + original_target = CALLABLE_TARGET + try: + CALLABLE_TARGET = callable_double_it + double_hash = callable_global_target_kernel.module.hash_module() + + CALLABLE_TARGET = callable_triple_it + triple_hash = callable_global_target_kernel.module.hash_module() + finally: + CALLABLE_TARGET = original_target + + self.assertNotEqual(double_hash, triple_hash) + + def test_callable_default_target_affects_module_hash(self): + """Verify default Callable targets participate in module hashes.""" + original_defaults = callable_apply_default.defaults.copy() + try: + callable_apply_default.defaults["g"] = callable_double_it + double_hash = callable_default_target_kernel.module.hash_module() + + callable_apply_default.defaults["g"] = callable_triple_it + triple_hash = callable_default_target_kernel.module.hash_module() + finally: + callable_apply_default.defaults = original_defaults + + self.assertNotEqual(double_hash, triple_hash) + + def test_callable_builtin_argument_target_affects_module_hash(self): + """Verify changing a global built-in Callable target changes the module hash.""" + global CALLABLE_BUILTIN_TARGET + + original_target = CALLABLE_BUILTIN_TARGET + try: + CALLABLE_BUILTIN_TARGET = wp.sin + sin_hash = callable_builtin_global_target_kernel.module.hash_module() + + CALLABLE_BUILTIN_TARGET = wp.cos + cos_hash = callable_builtin_global_target_kernel.module.hash_module() + finally: + CALLABLE_BUILTIN_TARGET = original_target + + self.assertNotEqual(sin_hash, cos_hash) + + def test_callable_builtin_default_target_affects_module_hash(self): + """Verify changing a default built-in Callable target changes the module hash.""" + original_defaults = callable_apply_builtin_default.defaults.copy() + try: + callable_apply_builtin_default.defaults["g"] = wp.sin + sin_hash = callable_builtin_default_target_kernel.module.hash_module() + + callable_apply_builtin_default.defaults["g"] = wp.cos + cos_hash = callable_builtin_default_target_kernel.module.hash_module() + finally: + callable_apply_builtin_default.defaults = original_defaults + + self.assertNotEqual(sin_hash, cos_hash) + + def test_callable_overload_target_affects_module_hash(self): + """Verify changing a non-primary Callable overload target changes the module hash.""" + global CALLABLE_OVERLOAD_TARGET + + original_target = CALLABLE_OVERLOAD_TARGET + try: + CALLABLE_OVERLOAD_TARGET = callable_double_it + double_hash = callable_overloaded_global_target_kernel.module.hash_module() + + CALLABLE_OVERLOAD_TARGET = callable_triple_it + triple_hash = callable_overloaded_global_target_kernel.module.hash_module() + finally: + CALLABLE_OVERLOAD_TARGET = original_target + + self.assertNotEqual(double_hash, triple_hash) + + def test_callable_overload_default_target_affects_module_hash(self): + """Verify changing a non-primary Callable overload default changes the module hash.""" + callable_overload = callable_apply_overloaded_default.get_overload([float], {}) + original_defaults = callable_overload.defaults.copy() + try: + callable_overload.defaults["g"] = callable_double_it + double_hash = callable_overloaded_default_target_kernel.module.hash_module() + + callable_overload.defaults["g"] = callable_triple_it + triple_hash = callable_overloaded_default_target_kernel.module.hash_module() + finally: + callable_overload.defaults = original_defaults + + self.assertNotEqual(double_hash, triple_hash) + + def test_callable_grad_call(self): + """Verify wp.grad() specializes functions with Callable targets.""" + + out = wp.empty(1, dtype=float, device="cpu") + + wp.launch(callable_grad_kernel, dim=1, outputs=[out], device="cpu") + + assert_np_equal(out.numpy(), np.array([6.0], dtype=np.float32)) + + def test_callable_builtin_grad_call(self): + """Verify wp.grad() specializes functions with built-in Callable targets.""" + + out = wp.empty(1, dtype=float, device="cpu") + + wp.launch(callable_builtin_grad_kernel, dim=1, outputs=[out], device="cpu") + + assert_np_equal(out.numpy(), np.array([np.cos(0.5)], dtype=np.float32)) + + def test_callable_target_array_read_tracks_access(self): + """Verify Callable target array reads propagate to tape access tracking.""" + + original = wp.config.verify_autograd_array_access + wp.config.verify_autograd_array_access = True + try: + arr = wp.array([2.0], dtype=float, device="cpu") + out = wp.empty(1, dtype=float, device="cpu") + + with wp.Tape(): + wp.launch(callable_array_read_kernel, dim=1, inputs=[arr], outputs=[out], device="cpu") + + self.assertTrue(arr._is_read) + finally: + wp.config.verify_autograd_array_access = original + + def test_callable_wrong_return_annotation_reports_error(self): + """Verify Callable calls report annotated return type errors.""" + + @wp.func + def callable_wrong_return_annotation(g: TypingCallable, x: float) -> int: + return g(x) + + @wp.kernel(module="unique") + def callable_wrong_return_annotation_kernel(out: wp.array[float]): + out[0] = float(callable_wrong_return_annotation(callable_double_it, 2.0)) + + out = wp.empty(1, dtype=float, device="cpu") + + with self.assertRaisesRegex( + wp.WarpCodegenError, + r"The function `callable_wrong_return_annotation` has its return type " + r"annotated as `int` but the code returns a value of type `float32`.", + ): + wp.launch(callable_wrong_return_annotation_kernel, dim=1, outputs=[out], device="cpu") + + def test_callable_argument_target_updates_module_dependents(self): + """Verify Callable targets register provider modules as dependencies. + + Explicit arguments, default arguments, and kernel-local aliases exercise + the paths where callable targets can otherwise be missed during module + reference discovery. + """ + + def unload_recursive(module, visited): + module.unload() + visited.add(module) + for dependent in module.dependents: + if dependent not in visited: + unload_recursive(dependent, visited) + + cases = ( + ( + "explicit", + CALLABLE_DEPENDENCY_EXPLICIT_PROVIDER_MODULE, + CALLABLE_DEPENDENCY_EXPLICIT_CONSUMER_MODULE, + callable_dependency_explicit_kernel, + ), + ( + "default", + CALLABLE_DEPENDENCY_DEFAULT_PROVIDER_MODULE, + CALLABLE_DEPENDENCY_DEFAULT_CONSUMER_MODULE, + callable_dependency_default_kernel, + ), + ( + "local", + CALLABLE_DEPENDENCY_LOCAL_PROVIDER_MODULE, + CALLABLE_DEPENDENCY_LOCAL_CONSUMER_MODULE, + callable_dependency_local_kernel, + ), + ( + "overload", + CALLABLE_DEPENDENCY_OVERLOAD_PROVIDER_MODULE, + CALLABLE_DEPENDENCY_OVERLOAD_CONSUMER_MODULE, + callable_dependency_overload_kernel, + ), + ) + + for name, provider_module, consumer_module, kernel in cases: + with self.subTest(name=name): + out = wp.empty(1, dtype=float, device="cpu") + wp.launch(kernel, dim=1, outputs=[out], device="cpu") + + assert_np_equal(out.numpy(), np.array([3.0], dtype=np.float32)) + self.assertIn(provider_module, consumer_module.references) + self.assertIn(consumer_module, provider_module.dependents) + self.assertTrue(consumer_module.hashers) + + unload_recursive(provider_module, visited=set()) + + self.assertFalse(consumer_module.hashers) + + def test_callable_external_module_targets_update_dependents(self): + """Verify module-qualified Callable targets register provider modules.""" + + def unload_recursive(module, visited): + module.unload() + visited.add(module) + for dependent in module.dependents: + if dependent not in visited: + unload_recursive(dependent, visited) + + cond = wp.array([True, False], dtype=bool, device="cpu") + out = wp.empty(2, dtype=float, device="cpu") + + wp.launch(callable_dependency_external_module_kernel, dim=2, inputs=[cond], outputs=[out], device="cpu") + + assert_np_equal(out.numpy(), np.array([6.0, 9.0], dtype=np.float32)) + + consumer_module = CALLABLE_DEPENDENCY_EXTERNAL_CONSUMER_MODULE + provider_modules = ( + callable_double_module.callable_external_module_double_it.module, + callable_triple_module.callable_external_module_triple_it.module, + ) + + for provider_module in provider_modules: + self.assertIn(provider_module, consumer_module.references) + self.assertIn(consumer_module, provider_module.dependents) + + self.assertTrue(consumer_module.hashers) + + unload_recursive(provider_modules[0], visited=set()) + + self.assertFalse(consumer_module.hashers) + + def test_callable_custom_grad_rejected(self): + """Verify Callable-specialized functions reject custom grad and replay hooks.""" + + @wp.kernel(module="unique") + def custom_grad_rejection_kernel(out: wp.array[float]): + out[0] = callable_custom_grad_unsupported(callable_double_it, 2.0) + + @wp.kernel(module="unique") + def custom_replay_rejection_kernel(out: wp.array[float]): + out[0] = callable_custom_replay_unsupported(callable_double_it, 2.0) + + for kernel in (custom_grad_rejection_kernel, custom_replay_rejection_kernel): + with self.subTest(kernel=kernel.key): + out = wp.empty(1, dtype=float, device="cpu") + + with self.assertRaisesRegex( + wp.WarpCodegenError, + "Callable parameters.*custom gradients or replay functions", + ): + wp.launch(kernel, dim=1, outputs=[out], device="cpu") + + def test_callable_non_regular_builtin_target_rejected(self): + """Verify Callable parameters reject built-ins that need special dispatch.""" + + @wp.kernel(module="unique") + def callable_non_regular_builtin_target_kernel(out: wp.array[float]): + out[0] = callable_apply_typing(wp.printf, 0.5) + + out = wp.empty(1, dtype=float, device="cpu") + + with self.assertRaisesRegex( + wp.WarpCodegenError, + "unsupported built-in function 'printf'", + ): + wp.launch(callable_non_regular_builtin_target_kernel, dim=1, outputs=[out], device="cpu") + + +devices = get_test_devices() + +add_function_test( + TestFuncCallable, + func=test_callable_func_parameter, + name="test_callable_func_parameter", + devices=devices, +) +add_function_test( + TestFuncCallable, + func=test_callable_func_parameter_local_binding, + name="test_callable_func_parameter_local_binding", + devices=devices, +) +add_function_test( + TestFuncCallable, + func=test_callable_func_parameter_external_module, + name="test_callable_func_parameter_external_module", + devices=devices, +) +add_function_test( + TestFuncCallable, + func=test_callable_builtin_func_parameter, + name="test_callable_builtin_func_parameter", + devices=devices, +) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/warp/tests/unittest_suites.py b/warp/tests/unittest_suites.py index e17aaaa937..150836eae0 100644 --- a/warp/tests/unittest_suites.py +++ b/warp/tests/unittest_suites.py @@ -179,6 +179,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader) from warp.tests.test_fastcall import TestFastcall, TestFastcallAvailable from warp.tests.test_fp16 import TestFp16 from warp.tests.test_func import TestFunc + from warp.tests.test_func_callable import TestFuncCallable from warp.tests.test_future_annotations import TestFutureAnnotations from warp.tests.test_generics import TestGenerics from warp.tests.test_grad import TestGrad @@ -321,6 +322,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader) TestFemShape, TestFp16, TestFunc, + TestFuncCallable, TestFutureAnnotations, TestGenerics, TestGrad, @@ -474,6 +476,7 @@ def debug_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader): from warp.tests.test_fast_math import TestFastMath from warp.tests.test_fp16 import TestFp16 from warp.tests.test_func import TestFunc + from warp.tests.test_func_callable import TestFuncCallable from warp.tests.test_generics import TestGenerics from warp.tests.test_grad import TestGrad from warp.tests.test_grad_customs import TestGradCustoms @@ -511,6 +514,7 @@ def debug_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader): TestEnum, TestFastMath, TestFunc, + TestFuncCallable, TestGenerics, TestMath, TestModuleHashing,