Support Callable parameters in user-defined functions#1454
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR implements support for ChangesCallable Parameter Support
Sequence Diagram(s)sequenceDiagram
participant Kernel as Kernel (caller)
participant Codegen as Codegen (add_call / func_match_args)
participant Types as Types (is_callable_annotation)
participant Specialize as Specialization (specialize_callable_func / get_callable_arg_values)
participant Adjoint as Adjoint.build
participant CppGen as C++ Generation
Kernel->>Codegen: call apply(double_it, 3.0)
Codegen->>Types: detect callable parameter via is_callable_annotation()
Codegen->>Codegen: bind args, match overload (callable-aware)
Codegen->>Specialize: extract callable args (get_callable_arg_values)
Specialize->>Specialize: create/cache specialized clone (hash suffix)
Codegen->>Adjoint: Adjoint.build(callable_arg_values)
Adjoint->>Adjoint: populate adj.symbols with concrete Function values
Codegen->>CppGen: emit call (skip callable params in forward/reverse signatures)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
warp/_src/codegen.py (1)
945-992: 💤 Low valueConsider documenting the specialization key stability for caching.
The specialization cache key at lines 952-954 uses a tuple of
(name, callable_func)pairs, relying onFunctionobject identity for cache hits. If the same logical function is recreated (e.g., module reload), cache misses will occur, generating duplicate specializations.This is likely acceptable for correctness (each specialization is valid), but consider adding a brief comment explaining this design choice for future maintainers.
Also, the check at lines 946-950 correctly rejects functions with custom grad/replay—good defensive validation.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@warp/_src/codegen.py` around lines 945 - 992, The specialization cache key in specialize_callable_func currently builds specialization_key from (name, callable_arg_values[name]) and therefore relies on Function object identity for cache hits; add a brief explanatory comment above the specialization_key creation stating that this is an intentional design choice (it may produce cache misses if the same logical function object is recreated, e.g., on module reload), why that is acceptable for correctness, and note that using callable_func.key/native_func could be an alternative for a stable key if desired in the future; reference specialize_callable_func, specialization_key, and func._callable_specializations in the comment so maintainers can find the code easily.CHANGELOG.md (1)
32-33: ⚡ Quick winUse imperative “Add …” phrasing for this Unreleased entry.
Line 32 currently starts with “Support …”; the changelog convention here asks for imperative present tense.
Suggested wording
-- Support passing user-defined Warp functions to `Callable` parameters in `@wp.func` functions +- Add support for passing user-defined Warp functions to `Callable` parameters in `@wp.func` functionsAs per coding guidelines:
CHANGELOG.mdentries inUnreleasedshould use imperative present tense (“Add X”).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@CHANGELOG.md` around lines 32 - 33, Update the Unreleased changelog entry that currently reads "Support passing user-defined Warp functions to `Callable` parameters in `@wp.func` functions" to use imperative present-tense phrasing; change it to something like "Add support for passing user-defined Warp functions to `Callable` parameters in `@wp.func` functions" so the entry follows the repository's changelog convention.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@design/callable-func-parameters.md`:
- Around line 41-43: The document's non-goal wording conflicts with the linked
GH-1424 objective about accepting built-in callables; update the text around the
callable targets section to reconcile them by either (A) removing or rephrasing
the line that excludes built-ins and explicitly stating that built-in Warp
functions like wp.sin and wp.add are supported as callable arguments, or (B) if
built-ins truly remain out of scope, change the GH-1424 reference/closure
language to reflect a narrower objective; ensure references to "wp.sin" and
"wp.add" and the GH-1424 issue ID are corrected so the scope described matches
the linked issue.
- Around line 15-18: The snippet imports Callable twice which shadows and
confuses the reader; update the examples so they don't clobber the same name by
either (a) using only one import (prefer collections.abc.Callable for modern
code) and removing the other, (b) aliasing one import (e.g., import Callable as
TypingCallable) to show the difference, or (c) present two separate, clearly
labeled snippets instead of both lines together—adjust the lines containing
"from typing import Callable" and "from collections.abc import Callable"
accordingly.
---
Nitpick comments:
In `@CHANGELOG.md`:
- Around line 32-33: Update the Unreleased changelog entry that currently reads
"Support passing user-defined Warp functions to `Callable` parameters in
`@wp.func` functions" to use imperative present-tense phrasing; change it to
something like "Add support for passing user-defined Warp functions to
`Callable` parameters in `@wp.func` functions" so the entry follows the
repository's changelog convention.
In `@warp/_src/codegen.py`:
- Around line 945-992: The specialization cache key in specialize_callable_func
currently builds specialization_key from (name, callable_arg_values[name]) and
therefore relies on Function object identity for cache hits; add a brief
explanatory comment above the specialization_key creation stating that this is
an intentional design choice (it may produce cache misses if the same logical
function object is recreated, e.g., on module reload), why that is acceptable
for correctness, and note that using callable_func.key/native_func could be an
alternative for a stable key if desired in the future; reference
specialize_callable_func, specialization_key, and func._callable_specializations
in the comment so maintainers can find the code easily.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: 9019ac3f-bc6a-4ec7-b04f-ab04997190aa
📒 Files selected for processing (6)
CHANGELOG.mddesign/callable-func-parameters.mdwarp/_src/codegen.pywarp/_src/context.pywarp/_src/types.pywarp/tests/test_func.py
230bf7f to
f6a270d
Compare
af680cc to
99a8d88
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
warp/_src/codegen.py (1)
4513-4539:⚠️ Potential issue | 🟠 Major | ⚡ Quick winTrack callable-parameter specializations in
get_references()too.This only finds callable targets when the AST argument itself resolves statically. For specialized higher-order functions, forwarded names like
inner(f, x)and direct calls likef(x)are backed byadj.callable_arg_values, not by a global/static lookup, so the concrete callable never gets added tofunctions. That leaves module hashing/dependency invalidation stale when the passed function changes.Suggested direction
def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp._src.context.Function, Any]]: """Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions.""" local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed constants: dict[str, Any] = {} types: dict[Struct | type, Any] = {} functions: dict[warp._src.context.Function, Any] = {} + callable_arg_values = getattr(adj, "callable_arg_values", None) or {} for node in ast.walk(adj.tree): ... elif isinstance(node, ast.Call): func, _ = adj.resolve_static_expression(node.func, eval_types=False) + if func is None and isinstance(node.func, ast.Name): + func = callable_arg_values.get(node.func.id) + + if isinstance(func, warp._src.context.Function) and not func.is_builtin(): + functions[func] = None + try: bound_args = func.signature.bind(*node.args, **{kw.arg: kw.value for kw in node.keywords}) except TypeError: bound_arg_nodes = {} else: ... for arg_name, arg_node in bound_arg_nodes.items(): if not warp._src.types.is_callable_annotation(func.input_types.get(arg_name)): continue if isinstance(arg_node, warp._src.context.Function): callable_func = arg_node + elif isinstance(arg_node, ast.Name): + callable_func = callable_arg_values.get(arg_node.id) else: callable_func, _ = adj.resolve_static_expression(arg_node, eval_types=False) if isinstance(callable_func, warp._src.context.Function) and not callable_func.is_builtin(): functions[callable_func] = None🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@warp/_src/codegen.py` around lines 4513 - 4539, get_references() currently only discovers callable arguments when the AST arg resolves statically via adj.resolve_static_expression, so specialized higher-order calls backed by adj.callable_arg_values (e.g., when func is forwarded into inner/f and invoked) are missed; update the Call-handling branch in get_references() (the block that binds arguments with func.signature.bind, apply_defaults, and iterates bound_arg_nodes) to also check adj.callable_arg_values for the bound argument node (and for raw arg names where resolution failed) and add any concrete warp._src.context.Function entries there into the functions dict (same guard: isinstance(..., Function) and not is_builtin()). Ensure you reference adj.callable_arg_values lookup when bound_arg_nodes contains nodes that didn't resolve statically or that are names/attributes representing forwarded callables so the concrete callable specializations are included.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@design/callable-func-parameters.md`:
- Around line 43-46: The PR currently conflicts with the design non-goal that
excludes built-in Warp callables (e.g., wp.sin, wp.add) from being accepted as
callable parameters for user-defined `@wp.func` per callable-func-parameters.md;
either update the PR metadata so it does not close GH-1424 (change "Closes
`#1424`" to "Addresses/Related to/Partial implementation of `#1424`" in the PR
description/commit message) or implement full support for built-in callables
(update the `@wp.func` callable-argument handling, specialization hashing, and
dependency tracking code paths that validate callables so built-ins are
recognized and included) and then remove/update the non-goal text and document
the addition.
In `@warp/_src/codegen.py`:
- Around line 972-975: The specialized callable currently clears its value_func
(specialized_func.value_func = None) which breaks downstream use where
add_call() invokes func.value_func(...); restore the original value_func on the
specialized copy instead of nulling it (e.g., assign specialized_func.value_func
= func.value_func or simply remove the line that sets it to None) so the
specialized_func keeps the callable-return resolver used later by add_call().
In `@warp/_src/context.py`:
- Around line 491-506: The overload_annotations map is being populated with the
template annotation (template_type) for callable parameters, causing all
Callable args to collapse to the same specialization; instead, when
warp._src.types.is_callable_annotation(template_type) and the runtime type
(arg_type or default_type) is a Function, store the concrete Function object
into overload_annotations (use arg_type for positional args and default_type for
defaults) so specializations keyed later (lines ~514-516) use the actual
Function instance rather than the generic annotation; update the branches that
currently assign template_type to assign arg_type or default_type accordingly
and keep using get_arg_type/strip_reference and is_callable_annotation checks as
present.
---
Outside diff comments:
In `@warp/_src/codegen.py`:
- Around line 4513-4539: get_references() currently only discovers callable
arguments when the AST arg resolves statically via
adj.resolve_static_expression, so specialized higher-order calls backed by
adj.callable_arg_values (e.g., when func is forwarded into inner/f and invoked)
are missed; update the Call-handling branch in get_references() (the block that
binds arguments with func.signature.bind, apply_defaults, and iterates
bound_arg_nodes) to also check adj.callable_arg_values for the bound argument
node (and for raw arg names where resolution failed) and add any concrete
warp._src.context.Function entries there into the functions dict (same guard:
isinstance(..., Function) and not is_builtin()). Ensure you reference
adj.callable_arg_values lookup when bound_arg_nodes contains nodes that didn't
resolve statically or that are names/attributes representing forwarded callables
so the concrete callable specializations are included.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: 075e71fa-9c64-4710-9b9b-afd2eee2f986
📒 Files selected for processing (7)
CHANGELOG.mddesign/callable-func-parameters.mddocs/user_guide/limitations.rstwarp/_src/codegen.pywarp/_src/context.pywarp/_src/types.pywarp/tests/test_func.py
✅ Files skipped from review due to trivial changes (2)
- docs/user_guide/limitations.rst
- CHANGELOG.md
🚧 Files skipped from review as they are similar to previous changes (2)
- warp/_src/types.py
- warp/tests/test_func.py
5414529 to
b574078
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
warp/_src/context.py (1)
493-506:⚠️ Potential issue | 🟠 Major | ⚡ Quick winSpecialize callable overload keys by the concrete
Function.Line 496 and Line 506 still write the annotation back into
overload_annotations. Since Lines 516-518 key the instantiated overload from that map, everyCallableargument collapses to the same specialization and a later call can reuse the wrong overload/body/hash.Proposed fix
overload_annotations = {} for name, arg_type, template_type in zip(arg_names, arg_types, template_types, strict=False): if warp._src.types.is_callable_annotation(template_type) and isinstance(arg_type, Function): - overload_annotations[name] = template_type + overload_annotations[name] = arg_type else: overload_annotations[name] = arg_type @@ template_type = f.input_types[k] default_type = warp._src.codegen.strip_reference(warp._src.codegen.get_arg_type(d)) if warp._src.types.is_callable_annotation(template_type) and isinstance(default_type, Function): - overload_annotations[k] = template_type + overload_annotations[k] = default_type else: overload_annotations[k] = default_typeAlso applies to: 516-518
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@warp/_src/context.py` around lines 493 - 506, The overload_annotations map currently stores the generic callable annotation (template_type) when a parameter is declared as Callable, causing all callable params to collapse to the same specialization; instead, when warp._src.types.is_callable_annotation(template_type) and the runtime value is a concrete Function, assign the concrete Function type (arg_type for parameters, default_type for defaults) into overload_annotations[name] so each callable parameter is specialized by its actual Function; update both the zip loop that handles arg_types (overload_annotations[name] = arg_type) and the defaults loop that handles f.defaults (overload_annotations[k] = default_type) and ensure the later instantiation that keys overloads uses those concrete Function entries.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@warp/_src/context.py`:
- Around line 493-506: The overload_annotations map currently stores the generic
callable annotation (template_type) when a parameter is declared as Callable,
causing all callable params to collapse to the same specialization; instead,
when warp._src.types.is_callable_annotation(template_type) and the runtime value
is a concrete Function, assign the concrete Function type (arg_type for
parameters, default_type for defaults) into overload_annotations[name] so each
callable parameter is specialized by its actual Function; update both the zip
loop that handles arg_types (overload_annotations[name] = arg_type) and the
defaults loop that handles f.defaults (overload_annotations[k] = default_type)
and ensure the later instantiation that keys overloads uses those concrete
Function entries.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: 57311a82-690c-4a18-9a42-229586f4bef2
📒 Files selected for processing (7)
CHANGELOG.mddesign/callable-func-parameters.mddocs/user_guide/limitations.rstwarp/_src/codegen.pywarp/_src/context.pywarp/_src/types.pywarp/tests/test_func.py
✅ Files skipped from review due to trivial changes (3)
- docs/user_guide/limitations.rst
- CHANGELOG.md
- design/callable-func-parameters.md
🚧 Files skipped from review as they are similar to previous changes (3)
- warp/_src/types.py
- warp/_src/codegen.py
- warp/tests/test_func.py
b574078 to
72ab41a
Compare
Greptile SummaryThis PR introduces compile-time specialization for
Confidence Score: 4/5The feature is well-contained behind compile-time specialization; callable targets never reach the C++ ABI, so the risk of breaking existing non-callable code paths is low. The implementation is architecturally sound and comprehensively tested. The primary caution is the volume of new codegen machinery interacting with existing complex codegen paths. The warp/_src/codegen.py — the Important Files Changed
Sequence DiagramsequenceDiagram
participant K as Kernel / Caller
participant AC as Adjoint.add_call
participant RC as Adjoint.resolve_call
participant GC as get_callable_arg_values
participant SC as specialize_callable_func
participant BLD as Adjoint.build (clone)
K->>AC: "add_call(func, args=[callable_target, …])"
AC->>RC: resolve_call(func, args, kwargs, type_args)
RC->>RC: func.signature.bind(args) → bound_args
RC->>RC: apply defaults
RC->>GC: get_callable_arg_values(func, bound_args)
GC-->>RC: "{param_name: callable_target}"
RC->>SC: specialize_callable_func(func, callable_arg_values)
SC->>SC: lookup / create specialization cache entry
SC->>SC: shallowcopy(func), drop _callable_specializations
SC->>SC: new Adjoint(source) + callable_arg_values on clone adj
SC-->>RC: specialized_clone
RC-->>AC: (specialized_clone, bound_args)
AC->>AC: filter Callable params from func_args
AC->>BLD: builder.build_function(specialized_clone)
BLD->>BLD: "symbols[callable_param] = concrete Function"
BLD->>BLD: emit code: g(x) → resolved_function(x)
BLD-->>AC: built clone
Reviews (11): Last reviewed commit: "Support Callable function parameters (GH..." | Re-trigger Greptile |
ce230dc to
2055cb9
Compare
c203741 to
9f8b8f2
Compare
|
We'd be super interested in having something like this! |
Hey @thomasbbrunner, do you have any extra requirements for what the feature needs to be able to do? |
9f8b8f2 to
083fab1
Compare
|
@shi-eric I don't think there's anything that wouldn't be covered by the current examples in the PR's description. Maybe a small difference of our use-case would be that the functions themselves (e.g., # in ops/double.py
@wp.func
def double_it(x: float):
return x * 2.0
# in ops/triple.py
@wp.func
def triple_it(x: float):
return x * 3.0
# in main.py
from ops import double, triple
@wp.func
def apply(g: Callable, x: float):
return g(x)
@wp.kernel
def k(cond: wp.array[bool], out: wp.array[float]):
if cond[0]:
out[0] = apply(double.double_it, 3.0)
else:
out[0] = apply(triple.triple_it, 3.0)On another note, I'm curious about the motivation for using the generic @wp.func
def apply(g: wp.Function[[float], float], x: float):
return g(x) |
83d27f3 to
500cff5
Compare
Callable parameters let user functions accept another Warp function as a specialization argument. Before this, helpers had to hard-code the callee, which made higher-order patterns unavailable and left specialization-sensitive behavior untested. Support these parameters in codegen, module hashing, dependency tracking, documentation, and autograd diagnostics. Keep the coverage in a dedicated test module so future regressions around callable specialization are easier to isolate. Signed-off-by: Eric Shi <ershi@nvidia.com>
500cff5 to
9b56a73
Compare
Description
Closes #1424.
This PR adds support for
Callable-typed parameters in user-defined@wp.funcfunctions. User-defined Warp functions and simple built-in Warp functions such aswp.sin()andwp.min()can be used as callable targets from kernels or other functions, including through positional arguments, keyword arguments, defaults, parameterizedCallable[[...], ...]annotations, nested calls, and generic helpers that combineCallablewithAny.Callable targets are specialization inputs, not runtime values. User functions are specialized for the concrete callable target, callable parameters are omitted from generated runtime signatures, and callable targets participate in module hashing and dependency invalidation. Built-ins that require special dispatch or replay behavior, such as
wp.printf(), and callable-specialized functions with custom grad or replay hooks are rejected explicitly.Changes
typing.Callable,collections.abc.Callable, and parameterizedCallable[[...], ...]annotations.wp.sin(),wp.sqrt(),wp.add(), andwp.min()as Callable targets.warp/tests/test_func_callable.py.Checklist
Unreleasedsection.Validation summary
Callable parameter coverage lives in
warp/tests/test_func_callable.pyand is included in the default and debug unittest suites. The tests cover user-defined Callable targets, simple built-in Callable targets, defaults, keyword arguments, parameterized annotations, genericAnyhelpers, function-valued aliases,wp.static(...)targets,wp.grad()specialization, return-type validation, unsupported built-ins, and custom grad/replay rejection.The regression coverage also verifies that Callable targets affect module hashes for explicit arguments, defaults, built-ins, and non-primary overloads. Dependency tests cover explicit arguments, default arguments, kernel-local aliases, overloaded functions, and Callable targets imported from Python modules so provider module unloads invalidate dependent consumer modules.
Validation performed:
TestFuncCallablesuite on CPU and CUDA devices.build_docs.py.upstream/mainbecause the native library was stale after upstream native API changes.warp/_src/codegen.py,warp/_src/context.py, andwarp/tests/test_func_callable.py.Bug fix
Without this PR, the call to
apply(double_it, 3.0)fails during Warp overload resolution/codegen.New feature / enhancement
Callable markers from both
typingandcollections.abcare accepted. Parameterized forms are recognized as type-erased Callable markers, but their signatures are not validated in this first pass.