Skip to content

Support Callable parameters in user-defined functions#1454

Open
shi-eric wants to merge 1 commit into
NVIDIA:mainfrom
shi-eric:shi-eric/callable-func-design
Open

Support Callable parameters in user-defined functions#1454
shi-eric wants to merge 1 commit into
NVIDIA:mainfrom
shi-eric:shi-eric/callable-func-design

Conversation

@shi-eric

@shi-eric shi-eric commented May 11, 2026

Copy link
Copy Markdown
Contributor

Description

Closes #1424.

This PR adds support for Callable-typed parameters in user-defined @wp.func functions. User-defined Warp functions and simple built-in Warp functions such as wp.sin() and wp.min() can be used as callable targets from kernels or other functions, including through positional arguments, keyword arguments, defaults, parameterized Callable[[...], ...] annotations, nested calls, and generic helpers that combine Callable with Any.

Callable targets are specialization inputs, not runtime values. User functions are specialized for the concrete callable target, callable parameters are omitted from generated runtime signatures, and callable targets participate in module hashing and dependency invalidation. Built-ins that require special dispatch or replay behavior, such as wp.printf(), and callable-specialized functions with custom grad or replay hooks are rejected explicitly.

Changes

  • Add Callable parameter recognition for typing.Callable, collections.abc.Callable, and parameterized Callable[[...], ...] annotations.
  • Specialize user-defined functions for concrete Callable targets passed through explicit arguments, keyword arguments, default arguments, local aliases, and nested calls.
  • Allow simple built-in Warp functions such as wp.sin(), wp.sqrt(), wp.add(), and wp.min() as Callable targets.
  • Include Callable targets in module hashes and dependency discovery so rebinding a target or unloading a provider module invalidates dependent generated code.
  • Reject unsupported built-in Callable targets and Callable-specialized functions with custom grad or replay hooks.
  • Document Callable parameters in the user guide, add the current limitations, update the changelog, and add focused regression coverage in warp/tests/test_func_callable.py.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • CHANGELOG.md is updated for any user-facing changes under the Unreleased section.

Validation summary

Callable parameter coverage lives in warp/tests/test_func_callable.py and is included in the default and debug unittest suites. The tests cover user-defined Callable targets, simple built-in Callable targets, defaults, keyword arguments, parameterized annotations, generic Any helpers, function-valued aliases, wp.static(...) targets, wp.grad() specialization, return-type validation, unsupported built-ins, and custom grad/replay rejection.

The regression coverage also verifies that Callable targets affect module hashes for explicit arguments, defaults, built-ins, and non-primary overloads. Dependency tests cover explicit arguments, default arguments, kernel-local aliases, overloaded functions, and Callable targets imported from Python modules so provider module unloads invalidate dependent consumer modules.

Validation performed:

  • Ran the dedicated TestFuncCallable suite on CPU and CUDA devices.
  • Ran pre-commit checks for the changed code, docs, changelog, and tests.
  • Built the documentation with build_docs.py.
  • Rebuilt Warp after rebasing onto current upstream/main because the native library was stale after upstream native API changes.
  • After the final docstring-only amend, reran pre-commit for warp/_src/codegen.py, warp/_src/context.py, and warp/tests/test_func_callable.py.

Bug fix

Without this PR, the call to apply(double_it, 3.0) fails during Warp overload resolution/codegen.

from typing import Callable

import warp as wp


@wp.func
def double_it(x: float):
    return x * 2.0


@wp.func
def apply(g: Callable, x: float):
    return g(x)


@wp.kernel
def k(out: wp.array[float]):
    out[0] = apply(double_it, 3.0)

New feature / enhancement

Callable markers from both typing and collections.abc are accepted. Parameterized forms are recognized as type-erased Callable markers, but their signatures are not validated in this first pass.

from collections.abc import Callable

import warp as wp


@wp.func
def triple_it(x: float):
    return x * 3.0


@wp.func
def apply(g: Callable[[float], float], x: float):
    return g(x)

@copy-pr-bot

copy-pr-bot Bot commented May 11, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented May 11, 2026

Copy link
Copy Markdown

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR implements support for Callable-typed parameters in user-defined @wp.func functions. Callable parameters are treated as compile-time function references that trigger per-call-site specialization; callable arguments/defaults are extracted and hashed for dependency tracking, bound into adjoint symbols, and omitted from emitted C++ signatures.

Changes

Callable Parameter Support

Layer / File(s) Summary
Design Documentation
design/callable-func-parameters.md
Design doc added: semantics, recognition rules, specialization, hashing, rejection cases, and test plan.
Callable Type Detection
warp/_src/types.py
is_callable_annotation() added; get_type_code() uses it to emit "c" for typing/collections.abc callables.
Overload Instantiation
warp/_src/context.py
Function.get_overload() preserved callable-template params and computes instantiated overload signature keys from callable-aware annotations (handles defaults).
Overload Matching
warp/_src/codegen.py
func_match_args() updated to match callable-annotated params when bound argument is a Warp Function.
Specialization Helpers
warp/_src/codegen.py
Added get_callable_arg_values(), get_default_arg_value(), specialize_callable_func(); shallow-copy/cache specialized function variants with hashed native_func suffixes.
Adjoint & Call Integration
warp/_src/codegen.py
Adjoint.build() accepts callable_arg_values; adj.symbols bind callable params to concrete Function values. add_call() applies callable-aware defaults, triggers specialization, and omits callable params from runtime dispatch.
Codegen Signatures / Dispatch
warp/_src/codegen.py
Forward and reverse signature generation skip callable-annotated parameters so callables are not emitted in the C++ ABI; dispatch filtering removes callable args from C++ calls.
Dependency Tracking
warp/_src/context.py, warp/_src/codegen.py
Reference discovery resolves callable-typed arguments to concrete Function targets and records their modules for hashing and invalidation.
Tests
warp/tests/test_func.py
Adds helpers, kernels, and unit tests covering typing/collections.abc callables, parameterized/generic/default/nested usage, module hash invalidation, return-annotation checks, dependency propagation, and rejection cases (builtins/custom grad/replay).
Docs / Changelog
docs/user_guide/limitations.rst, CHANGELOG.md
Adds limitation note and changelog entry (GH-1424).

Sequence Diagram(s)

sequenceDiagram
    participant Kernel as Kernel (caller)
    participant Codegen as Codegen (add_call / func_match_args)
    participant Types as Types (is_callable_annotation)
    participant Specialize as Specialization (specialize_callable_func / get_callable_arg_values)
    participant Adjoint as Adjoint.build
    participant CppGen as C++ Generation

    Kernel->>Codegen: call apply(double_it, 3.0)
    Codegen->>Types: detect callable parameter via is_callable_annotation()
    Codegen->>Codegen: bind args, match overload (callable-aware)
    Codegen->>Specialize: extract callable args (get_callable_arg_values)
    Specialize->>Specialize: create/cache specialized clone (hash suffix)
    Codegen->>Adjoint: Adjoint.build(callable_arg_values)
    Adjoint->>Adjoint: populate adj.symbols with concrete Function values
    Codegen->>CppGen: emit call (skip callable params in forward/reverse signatures)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.54% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Support Callable parameters in user-defined functions' directly and clearly summarizes the main change in the PR, matching the primary objective of enabling Callable-typed parameters in @wp.func functions.
Linked Issues check ✅ Passed The PR comprehensively addresses all coding requirements from issue #1424: unified Callable annotation handling, overload resolution fixes, type code generation ('c' code), callable-aware specialization, module hashing/dependency tracking, and comprehensive test coverage with documented limitations.
Out of Scope Changes check ✅ Passed All changes are directly related to implementing Callable parameter support per GH-1424: type detection, overload matching, code generation, callable specialization, dependency tracking, tests, and documentation of limitations. No unrelated changes detected.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
warp/_src/codegen.py (1)

945-992: 💤 Low value

Consider documenting the specialization key stability for caching.

The specialization cache key at lines 952-954 uses a tuple of (name, callable_func) pairs, relying on Function object identity for cache hits. If the same logical function is recreated (e.g., module reload), cache misses will occur, generating duplicate specializations.

This is likely acceptable for correctness (each specialization is valid), but consider adding a brief comment explaining this design choice for future maintainers.

Also, the check at lines 946-950 correctly rejects functions with custom grad/replay—good defensive validation.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@warp/_src/codegen.py` around lines 945 - 992, The specialization cache key in
specialize_callable_func currently builds specialization_key from (name,
callable_arg_values[name]) and therefore relies on Function object identity for
cache hits; add a brief explanatory comment above the specialization_key
creation stating that this is an intentional design choice (it may produce cache
misses if the same logical function object is recreated, e.g., on module
reload), why that is acceptable for correctness, and note that using
callable_func.key/native_func could be an alternative for a stable key if
desired in the future; reference specialize_callable_func, specialization_key,
and func._callable_specializations in the comment so maintainers can find the
code easily.
CHANGELOG.md (1)

32-33: ⚡ Quick win

Use imperative “Add …” phrasing for this Unreleased entry.

Line 32 currently starts with “Support …”; the changelog convention here asks for imperative present tense.

Suggested wording
-- Support passing user-defined Warp functions to `Callable` parameters in `@wp.func` functions
+- Add support for passing user-defined Warp functions to `Callable` parameters in `@wp.func` functions

As per coding guidelines: CHANGELOG.md entries in Unreleased should use imperative present tense (“Add X”).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@CHANGELOG.md` around lines 32 - 33, Update the Unreleased changelog entry
that currently reads "Support passing user-defined Warp functions to `Callable`
parameters in `@wp.func` functions" to use imperative present-tense phrasing;
change it to something like "Add support for passing user-defined Warp functions
to `Callable` parameters in `@wp.func` functions" so the entry follows the
repository's changelog convention.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@design/callable-func-parameters.md`:
- Around line 41-43: The document's non-goal wording conflicts with the linked
GH-1424 objective about accepting built-in callables; update the text around the
callable targets section to reconcile them by either (A) removing or rephrasing
the line that excludes built-ins and explicitly stating that built-in Warp
functions like wp.sin and wp.add are supported as callable arguments, or (B) if
built-ins truly remain out of scope, change the GH-1424 reference/closure
language to reflect a narrower objective; ensure references to "wp.sin" and
"wp.add" and the GH-1424 issue ID are corrected so the scope described matches
the linked issue.
- Around line 15-18: The snippet imports Callable twice which shadows and
confuses the reader; update the examples so they don't clobber the same name by
either (a) using only one import (prefer collections.abc.Callable for modern
code) and removing the other, (b) aliasing one import (e.g., import Callable as
TypingCallable) to show the difference, or (c) present two separate, clearly
labeled snippets instead of both lines together—adjust the lines containing
"from typing import Callable" and "from collections.abc import Callable"
accordingly.

---

Nitpick comments:
In `@CHANGELOG.md`:
- Around line 32-33: Update the Unreleased changelog entry that currently reads
"Support passing user-defined Warp functions to `Callable` parameters in
`@wp.func` functions" to use imperative present-tense phrasing; change it to
something like "Add support for passing user-defined Warp functions to
`Callable` parameters in `@wp.func` functions" so the entry follows the
repository's changelog convention.

In `@warp/_src/codegen.py`:
- Around line 945-992: The specialization cache key in specialize_callable_func
currently builds specialization_key from (name, callable_arg_values[name]) and
therefore relies on Function object identity for cache hits; add a brief
explanatory comment above the specialization_key creation stating that this is
an intentional design choice (it may produce cache misses if the same logical
function object is recreated, e.g., on module reload), why that is acceptable
for correctness, and note that using callable_func.key/native_func could be an
alternative for a stable key if desired in the future; reference
specialize_callable_func, specialization_key, and func._callable_specializations
in the comment so maintainers can find the code easily.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: 9019ac3f-bc6a-4ec7-b04f-ab04997190aa

📥 Commits

Reviewing files that changed from the base of the PR and between bd91b99 and 230bf7f.

📒 Files selected for processing (6)
  • CHANGELOG.md
  • design/callable-func-parameters.md
  • warp/_src/codegen.py
  • warp/_src/context.py
  • warp/_src/types.py
  • warp/tests/test_func.py

Comment thread design/callable-func-parameters.md Outdated
Comment thread design/callable-func-parameters.md Outdated
@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch from 230bf7f to f6a270d Compare May 11, 2026 03:26
@shi-eric shi-eric marked this pull request as ready for review May 11, 2026 03:29
@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch 2 times, most recently from af680cc to 99a8d88 Compare May 11, 2026 03:40

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
warp/_src/codegen.py (1)

4513-4539: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Track callable-parameter specializations in get_references() too.

This only finds callable targets when the AST argument itself resolves statically. For specialized higher-order functions, forwarded names like inner(f, x) and direct calls like f(x) are backed by adj.callable_arg_values, not by a global/static lookup, so the concrete callable never gets added to functions. That leaves module hashing/dependency invalidation stale when the passed function changes.

Suggested direction
 def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp._src.context.Function, Any]]:
     """Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""

     local_variables = set()  # Track local variables appearing on the LHS so we know when variables are shadowed

     constants: dict[str, Any] = {}
     types: dict[Struct | type, Any] = {}
     functions: dict[warp._src.context.Function, Any] = {}
+    callable_arg_values = getattr(adj, "callable_arg_values", None) or {}

     for node in ast.walk(adj.tree):
         ...
         elif isinstance(node, ast.Call):
             func, _ = adj.resolve_static_expression(node.func, eval_types=False)
+            if func is None and isinstance(node.func, ast.Name):
+                func = callable_arg_values.get(node.func.id)
+
+            if isinstance(func, warp._src.context.Function) and not func.is_builtin():
+                functions[func] = None
+
                 try:
                     bound_args = func.signature.bind(*node.args, **{kw.arg: kw.value for kw in node.keywords})
                 except TypeError:
                     bound_arg_nodes = {}
                 else:
                     ...
                 for arg_name, arg_node in bound_arg_nodes.items():
                     if not warp._src.types.is_callable_annotation(func.input_types.get(arg_name)):
                         continue

                     if isinstance(arg_node, warp._src.context.Function):
                         callable_func = arg_node
+                    elif isinstance(arg_node, ast.Name):
+                        callable_func = callable_arg_values.get(arg_node.id)
                     else:
                         callable_func, _ = adj.resolve_static_expression(arg_node, eval_types=False)

                     if isinstance(callable_func, warp._src.context.Function) and not callable_func.is_builtin():
                         functions[callable_func] = None
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@warp/_src/codegen.py` around lines 4513 - 4539, get_references() currently
only discovers callable arguments when the AST arg resolves statically via
adj.resolve_static_expression, so specialized higher-order calls backed by
adj.callable_arg_values (e.g., when func is forwarded into inner/f and invoked)
are missed; update the Call-handling branch in get_references() (the block that
binds arguments with func.signature.bind, apply_defaults, and iterates
bound_arg_nodes) to also check adj.callable_arg_values for the bound argument
node (and for raw arg names where resolution failed) and add any concrete
warp._src.context.Function entries there into the functions dict (same guard:
isinstance(..., Function) and not is_builtin()). Ensure you reference
adj.callable_arg_values lookup when bound_arg_nodes contains nodes that didn't
resolve statically or that are names/attributes representing forwarded callables
so the concrete callable specializations are included.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@design/callable-func-parameters.md`:
- Around line 43-46: The PR currently conflicts with the design non-goal that
excludes built-in Warp callables (e.g., wp.sin, wp.add) from being accepted as
callable parameters for user-defined `@wp.func` per callable-func-parameters.md;
either update the PR metadata so it does not close GH-1424 (change "Closes
`#1424`" to "Addresses/Related to/Partial implementation of `#1424`" in the PR
description/commit message) or implement full support for built-in callables
(update the `@wp.func` callable-argument handling, specialization hashing, and
dependency tracking code paths that validate callables so built-ins are
recognized and included) and then remove/update the non-goal text and document
the addition.

In `@warp/_src/codegen.py`:
- Around line 972-975: The specialized callable currently clears its value_func
(specialized_func.value_func = None) which breaks downstream use where
add_call() invokes func.value_func(...); restore the original value_func on the
specialized copy instead of nulling it (e.g., assign specialized_func.value_func
= func.value_func or simply remove the line that sets it to None) so the
specialized_func keeps the callable-return resolver used later by add_call().

In `@warp/_src/context.py`:
- Around line 491-506: The overload_annotations map is being populated with the
template annotation (template_type) for callable parameters, causing all
Callable args to collapse to the same specialization; instead, when
warp._src.types.is_callable_annotation(template_type) and the runtime type
(arg_type or default_type) is a Function, store the concrete Function object
into overload_annotations (use arg_type for positional args and default_type for
defaults) so specializations keyed later (lines ~514-516) use the actual
Function instance rather than the generic annotation; update the branches that
currently assign template_type to assign arg_type or default_type accordingly
and keep using get_arg_type/strip_reference and is_callable_annotation checks as
present.

---

Outside diff comments:
In `@warp/_src/codegen.py`:
- Around line 4513-4539: get_references() currently only discovers callable
arguments when the AST arg resolves statically via
adj.resolve_static_expression, so specialized higher-order calls backed by
adj.callable_arg_values (e.g., when func is forwarded into inner/f and invoked)
are missed; update the Call-handling branch in get_references() (the block that
binds arguments with func.signature.bind, apply_defaults, and iterates
bound_arg_nodes) to also check adj.callable_arg_values for the bound argument
node (and for raw arg names where resolution failed) and add any concrete
warp._src.context.Function entries there into the functions dict (same guard:
isinstance(..., Function) and not is_builtin()). Ensure you reference
adj.callable_arg_values lookup when bound_arg_nodes contains nodes that didn't
resolve statically or that are names/attributes representing forwarded callables
so the concrete callable specializations are included.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: 075e71fa-9c64-4710-9b9b-afd2eee2f986

📥 Commits

Reviewing files that changed from the base of the PR and between af680cc and 99a8d88.

📒 Files selected for processing (7)
  • CHANGELOG.md
  • design/callable-func-parameters.md
  • docs/user_guide/limitations.rst
  • warp/_src/codegen.py
  • warp/_src/context.py
  • warp/_src/types.py
  • warp/tests/test_func.py
✅ Files skipped from review due to trivial changes (2)
  • docs/user_guide/limitations.rst
  • CHANGELOG.md
🚧 Files skipped from review as they are similar to previous changes (2)
  • warp/_src/types.py
  • warp/tests/test_func.py

Comment thread design/callable-func-parameters.md Outdated
Comment thread warp/_src/codegen.py
Comment thread warp/_src/context.py
@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch 2 times, most recently from 5414529 to b574078 Compare May 11, 2026 03:50

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
warp/_src/context.py (1)

493-506: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Specialize callable overload keys by the concrete Function.

Line 496 and Line 506 still write the annotation back into overload_annotations. Since Lines 516-518 key the instantiated overload from that map, every Callable argument collapses to the same specialization and a later call can reuse the wrong overload/body/hash.

Proposed fix
                 overload_annotations = {}
                 for name, arg_type, template_type in zip(arg_names, arg_types, template_types, strict=False):
                     if warp._src.types.is_callable_annotation(template_type) and isinstance(arg_type, Function):
-                        overload_annotations[name] = template_type
+                        overload_annotations[name] = arg_type
                     else:
                         overload_annotations[name] = arg_type
@@
                         template_type = f.input_types[k]
                         default_type = warp._src.codegen.strip_reference(warp._src.codegen.get_arg_type(d))
                         if warp._src.types.is_callable_annotation(template_type) and isinstance(default_type, Function):
-                            overload_annotations[k] = template_type
+                            overload_annotations[k] = default_type
                         else:
                             overload_annotations[k] = default_type

Also applies to: 516-518

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@warp/_src/context.py` around lines 493 - 506, The overload_annotations map
currently stores the generic callable annotation (template_type) when a
parameter is declared as Callable, causing all callable params to collapse to
the same specialization; instead, when
warp._src.types.is_callable_annotation(template_type) and the runtime value is a
concrete Function, assign the concrete Function type (arg_type for parameters,
default_type for defaults) into overload_annotations[name] so each callable
parameter is specialized by its actual Function; update both the zip loop that
handles arg_types (overload_annotations[name] = arg_type) and the defaults loop
that handles f.defaults (overload_annotations[k] = default_type) and ensure the
later instantiation that keys overloads uses those concrete Function entries.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Duplicate comments:
In `@warp/_src/context.py`:
- Around line 493-506: The overload_annotations map currently stores the generic
callable annotation (template_type) when a parameter is declared as Callable,
causing all callable params to collapse to the same specialization; instead,
when warp._src.types.is_callable_annotation(template_type) and the runtime value
is a concrete Function, assign the concrete Function type (arg_type for
parameters, default_type for defaults) into overload_annotations[name] so each
callable parameter is specialized by its actual Function; update both the zip
loop that handles arg_types (overload_annotations[name] = arg_type) and the
defaults loop that handles f.defaults (overload_annotations[k] = default_type)
and ensure the later instantiation that keys overloads uses those concrete
Function entries.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: 57311a82-690c-4a18-9a42-229586f4bef2

📥 Commits

Reviewing files that changed from the base of the PR and between 99a8d88 and 5414529.

📒 Files selected for processing (7)
  • CHANGELOG.md
  • design/callable-func-parameters.md
  • docs/user_guide/limitations.rst
  • warp/_src/codegen.py
  • warp/_src/context.py
  • warp/_src/types.py
  • warp/tests/test_func.py
✅ Files skipped from review due to trivial changes (3)
  • docs/user_guide/limitations.rst
  • CHANGELOG.md
  • design/callable-func-parameters.md
🚧 Files skipped from review as they are similar to previous changes (3)
  • warp/_src/types.py
  • warp/_src/codegen.py
  • warp/tests/test_func.py

@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch from b574078 to 72ab41a Compare May 11, 2026 03:55
@greptile-apps

greptile-apps Bot commented May 11, 2026

Copy link
Copy Markdown

Greptile Summary

This PR introduces compile-time specialization for Callable-typed parameters in user-defined @wp.func functions. Callable arguments are type-erased at the C++ boundary: each unique combination of callable targets produces a cloned function with a distinct native name, so overload resolution and code generation work without adding runtime dispatch overhead.

  • specialize_callable_func creates and caches a shallow-copy clone of the receiving function, injecting concrete Function objects into the clone's symbol table and stripping callable params from the generated C++ signature. Built-ins are accepted if every overload passes is_regular_builtin_callable_target.
  • iter_call_callable_arg_targets is the shared helper that both Adjoint.get_references() and Module._find_references() use to discover callable targets at each call site, ensuring module-hash soundness and dependency-graph invalidation when callable targets change or reload.
  • A comprehensive new test suite (test_func_callable.py) covers annotation variants, positional/keyword/default args, nested calls, local rebinding, built-in targets, module-hash sensitivity, and explicit rejection of unsupported built-ins and functions with custom grad/replay hooks.

Confidence Score: 4/5

The feature is well-contained behind compile-time specialization; callable targets never reach the C++ ABI, so the risk of breaking existing non-callable code paths is low.

The implementation is architecturally sound and comprehensively tested. The primary caution is the volume of new codegen machinery interacting with existing complex codegen paths. The verify_autograd_array_access path performs a redundant second resolve_call that, while harmless today, could silently diverge if resolve_call acquires stateful side effects in a future change.

warp/_src/codegen.py — the emit_call method (specifically the verify_autograd_array_access branch) and the new specialization helpers warrant a close look on any future change to resolve_call or add_call.

Important Files Changed

Filename Overview
warp/_src/codegen.py Core of the feature: adds ~330 lines of new helpers plus integration into Adjoint.build, add_call (via new resolve_call), get_references, and codegen_func. Logic is sound; verify_autograd_array_access path calls resolve_call a second time redundantly.
warp/_src/context.py Adds hash_builtin_function for identity-stable hashing of built-in callable targets, wires iter_call_callable_arg_targets into _find_references for module-dependency tracking, and corrects generic overload instantiation to preserve Callable annotations.
warp/_src/types.py Adds is_callable_annotation() helper and replaces the incorrect isinstance(arg_type, Callable) check in get_type_code with the new predicate; handles both bare and parameterized Callable annotations from typing and collections.abc.
warp/tests/test_func_callable.py Comprehensive 695-line test suite covering annotation variants, positional/keyword/default args, nested calls, local rebinding, built-in targets, module-hash sensitivity, cross-module dependency invalidation, grad calls, and explicit rejection paths.
warp/tests/aux_test_callable_double.py Small auxiliary module used to verify cross-module callable dependency invalidation.
warp/tests/aux_test_callable_triple.py Small auxiliary module used to verify cross-module callable dependency invalidation.
docs/user_guide/basics.rst Adds documentation for the new Callable parameter feature including annotation forms, built-in targets, defaults, and known limitations.
warp/tests/unittest_suites.py Registers TestFuncCallable in default and debug suites.

Sequence Diagram

sequenceDiagram
    participant K as Kernel / Caller
    participant AC as Adjoint.add_call
    participant RC as Adjoint.resolve_call
    participant GC as get_callable_arg_values
    participant SC as specialize_callable_func
    participant BLD as Adjoint.build (clone)

    K->>AC: "add_call(func, args=[callable_target, …])"
    AC->>RC: resolve_call(func, args, kwargs, type_args)
    RC->>RC: func.signature.bind(args) → bound_args
    RC->>RC: apply defaults
    RC->>GC: get_callable_arg_values(func, bound_args)
    GC-->>RC: "{param_name: callable_target}"
    RC->>SC: specialize_callable_func(func, callable_arg_values)
    SC->>SC: lookup / create specialization cache entry
    SC->>SC: shallowcopy(func), drop _callable_specializations
    SC->>SC: new Adjoint(source) + callable_arg_values on clone adj
    SC-->>RC: specialized_clone
    RC-->>AC: (specialized_clone, bound_args)
    AC->>AC: filter Callable params from func_args
    AC->>BLD: builder.build_function(specialized_clone)
    BLD->>BLD: "symbols[callable_param] = concrete Function"
    BLD->>BLD: emit code: g(x) → resolved_function(x)
    BLD-->>AC: built clone
Loading

Reviews (11): Last reviewed commit: "Support Callable function parameters (GH..." | Re-trigger Greptile

Comment thread warp/_src/codegen.py
Comment thread warp/_src/codegen.py Outdated
@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch 3 times, most recently from ce230dc to 2055cb9 Compare May 14, 2026 07:25
@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch 3 times, most recently from c203741 to 9f8b8f2 Compare June 7, 2026 20:35
@thomasbbrunner

Copy link
Copy Markdown
Contributor

We'd be super interested in having something like this!

@shi-eric

shi-eric commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

We'd be super interested in having something like this!

Hey @thomasbbrunner, do you have any extra requirements for what the feature needs to be able to do?

@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch from 9f8b8f2 to 083fab1 Compare June 9, 2026 18:08
@thomasbbrunner

thomasbbrunner commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

@shi-eric I don't think there's anything that wouldn't be covered by the current examples in the PR's description.

Maybe a small difference of our use-case would be that the functions themselves (e.g., double_it and triple_it in your example) would be located in different modules. Unsure if this matters. To illustrate:

# in ops/double.py

@wp.func
def double_it(x: float):
    return x * 2.0

# in ops/triple.py

@wp.func
def triple_it(x: float):
    return x * 3.0

# in main.py
from ops import double, triple

@wp.func
def apply(g: Callable, x: float):
    return g(x)

@wp.kernel
def k(cond: wp.array[bool], out: wp.array[float]):
    if cond[0]:
        out[0] = apply(double.double_it, 3.0)
    else:
        out[0] = apply(triple.triple_it, 3.0)

On another note, I'm curious about the motivation for using the generic Callable type instead of wp.Function. As in, not all callables could be passed as an argument, only Warp functions, no? To illustrate, this is what I'd have expected:

@wp.func
def apply(g: wp.Function[[float], float], x: float):
    return g(x)

@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch 2 times, most recently from 83d27f3 to 500cff5 Compare June 15, 2026 07:05
Callable parameters let user functions accept another Warp
function as a specialization argument. Before this, helpers had to
hard-code the callee, which made higher-order patterns unavailable and
left specialization-sensitive behavior untested.

Support these parameters in codegen, module hashing, dependency
tracking, documentation, and autograd diagnostics. Keep the coverage in
a dedicated test module so future regressions around callable
specialization are easier to isolate.

Signed-off-by: Eric Shi <ershi@nvidia.com>
@shi-eric shi-eric force-pushed the shi-eric/callable-func-design branch from 500cff5 to 9b56a73 Compare June 15, 2026 07:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support Callable-typed parameters in user-defined @wp.func

2 participants