Skip to content

Add JAX FFI Host support#1446

Open
loney7 wants to merge 3 commits into
NVIDIA:mainfrom
loney7:loney7/ffi-host-support
Open

Add JAX FFI Host support#1446
loney7 wants to merge 3 commits into
NVIDIA:mainfrom
loney7:loney7/ffi-host-support

Conversation

@loney7

@loney7 loney7 commented May 8, 2026

Copy link
Copy Markdown

Description

This PR adds support for running JAX FFI callbacks on the CPU (Host) in addition to CUDA.

Changes:

  • Registered FFI targets for both "CUDA" and "Host" platforms in register_ffi_callback.
  • Handled the "Host" platform case in ffi_callback by using the CPU device and bypassing CUDA-specific features like streams and graphs.
  • Updated FfiCallable to reconstruct arguments and execute the function on the CPU when running on the Host platform.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

Test plan

You can verify these changes by running the JAX interop tests which include FFI tests. Ensure they pass on both CPU and GPU (if available).

uv run warp/tests/interop/test_jax.py


<!-- This is an auto-generated comment: release notes by coderabbit.ai -->
## Summary by CodeRabbit

* **New Features**
  * Separate CUDA and CPU execution paths for FFI callbacks with device-aware execution scoping.
  * CUDA graph compatibility enabled only for CUDA execution.
  * CPU/host path can take a direct Python execution route for faster host calls.
  * Separate CUDA and Host callback registrations for JAX interop.

* **Tests**
  * Added CPU/host FFI tests covering kernels and callables (add, sincos, in/out args, scale).
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

@copy-pr-bot

copy-pr-bot Bot commented May 8, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented May 8, 2026

Copy link
Copy Markdown

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This pull request extends JAX experimental FFI callbacks to support both CUDA and Host (CPU) execution paths. Callbacks now accept a platform parameter, register separate FFI targets for each platform, and conditionally dispatch to platform-appropriate device and kernel launch logic. CUDA-specific traits and graph capture modes are guarded by platform checks.

Changes

CUDA and Host FFI Execution

Layer / File(s) Summary
Callback Protocol & Platform Parameter
warp/_src/jax_experimental/ffi.py
FfiKernel.ffi_callback(), FfiCallable.ffi_callback(), and register_ffi_callback() callbacks gain platform="CUDA" parameter. CUDA graph compatibility traits are now enabled only when platform=="CUDA".
Dual-Platform FFI Registration
warp/_src/jax_experimental/ffi.py
FfiKernel and FfiCallable register separate FFI capsules for both CUDA and Host platforms, each passing the appropriate platform argument. register_ffi_callback() stores capsules under distinct registry keys (_cuda and _host suffixes).
FfiKernel Platform-Conditional Launch
warp/_src/jax_experimental/ffi.py
FfiKernel execution branches by platform: CUDA selects CUDA device and stream, calls wp_cuda_launch_kernel; Host uses CPU device with no stream, calls wp_cpu_launch_kernel.
FfiCallable Host Execution Short-Circuit
warp/_src/jax_experimental/ffi.py
When platform=="Host", FfiCallable reconstructs Warp arrays on the CPU device from the call frame and directly invokes the wrapped Python function, bypassing all CUDA graph logic.
FfiCallable CUDA Execution & Graph Capture
warp/_src/jax_experimental/ffi.py
CUDA graph modes are guarded by device.is_cuda. Execution scopes conditionally use ScopedStream (when stream present) or ScopedDevice, restricting graph capture and replay to CUDA devices only.
CPU Host Tests
warp/tests/interop/test_jax.py
New CPU-only jax_kernel and jax_callable tests added and registered in TestJax with devices=None.

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 5.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add JAX FFI Host support' directly and clearly summarizes the main change: adding Host/CPU platform support to JAX FFI callbacks alongside existing CUDA support.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@greptile-apps

greptile-apps Bot commented May 8, 2026

Copy link
Copy Markdown

Greptile Summary

This PR extends JAX FFI support to the CPU (Host) platform by registering dual CUDA/Host callbacks in FfiKernel, FfiCallable, and register_ffi_callback, routing Host-platform invocations to a new CPU execution path that bypasses all CUDA-specific machinery (streams, graphs, ExecutionContext.stream).

  • FfiKernel Host path: Builds a ctypes ArgsStruct from arg_refs at call-time and launches via wp_cpu_launch_kernel, mirroring the existing context.py invoke pattern.
  • FfiCallable Host path: Adds _reconstruct_args to wrap FFI buffer pointers as wp.array views on the CPU device and calls the Python function directly under wp.ScopedDevice(cpu).
  • ExecutionContext: Guarded get_stream_from_callframe behind platform == \"CUDA\" so the CUDA-only XLA FFI stream API is never invoked on Host.
  • Tests: Six new CPU-only tests cover FfiKernel (add, sincos, in-out, scale-vec) and FfiCallable (scale-constant, in-out) on the JAX Host backend.

Confidence Score: 4/5

The PR is safe to merge; the three previously flagged bugs (NameError on ffi_capsule_host, wrong wp_cpu_launch_kernel argument order, and unconditional CUDA stream fetch in ExecutionContext) are all correctly resolved in this revision.

The CPU kernel launch path is new and non-trivial: it builds a ctypes ArgsStruct from FFI buffer pointers, calls wp_cpu_launch_kernel, and uses a separate _reconstruct_args helper for FfiCallable. The structure closely mirrors the existing context.py reference implementation and the CUDA path, and six new tests cover the key variants. The one remaining gap is that ArgsStruct type construction is not cached in FfiKernel's Host path the way context.py does it, which is a performance concern on repeated calls but does not affect correctness. No new correctness bugs were identified.

warp/_src/jax/ffi.py — specifically the new CPU kernel launch block (FfiKernel.ffi_callback else-branch) where ArgsStruct is rebuilt on every invocation.

Important Files Changed

Filename Overview
warp/_src/jax/ffi.py Core FFI implementation: adds Host callbacks for FfiKernel, FfiCallable, and register_ffi_callback; introduces new CPU launch path with ArgsStruct construction and _reconstruct_args helper; ArgsStruct is rebuilt on every CPU kernel invocation without caching, unlike the context.py reference implementation.
warp/_src/jax/xla_ffi.py ExecutionContext now correctly guards get_stream_from_callframe behind platform == "CUDA", fixing the previously flagged unconditional CUDA-only API call on Host.
warp/tests/interop/test_jax.py Adds six new Host-platform FFI tests covering kernel and callable variants; tests are correctly gated on JAX >= 0.5.0 and registered with devices=None to avoid CUDA dependency.
CHANGELOG.md Changelog entry added for the Host platform FFI support; references a GH issue number consistent with this PR.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[JAX FFI call] --> B{platform?}
    B -->|CUDA| C[FfiKernel / FfiCallable CUDA path]
    B -->|Host| D[FfiKernel / FfiCallable Host path]

    C --> C1[get_stream_from_callframe]
    C1 --> C2[wp_cuda_launch_kernel or func + stream]

    D --> D1[wp.get_device cpu]
    D1 --> D2{class?}
    D2 -->|FfiKernel| D3[Build ArgsStruct from arg_refs]
    D3 --> D4[wp_cpu_launch_kernel with hooks.forward and args_struct]
    D2 -->|FfiCallable| D5[_reconstruct_args: wp.array from buffer ptrs]
    D5 --> D6[ScopedDevice cpu: func with arg_list]

    subgraph ExecutionContext
        E{platform?}
        E -->|CUDA| F[stream = get_stream_from_callframe]
        E -->|Host| G[stream = None]
    end
Loading

Reviews (11): Last reviewed commit: "Merge branch 'main' into loney7/ffi-host..." | Re-trigger Greptile

Comment thread warp/_src/jax/ffi.py
Comment thread warp/_src/jax/ffi.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1772-1777: In register_ffi_callback, the Host ffi capsule is
constructed from the wrong variable (ffi_capsule_host) causing a NameError;
change the construction to use the host ccall address value by calling
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then register that capsule
with jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") so the
Host path mirrors the CUDA path (refer to ffi_ccall_address_host,
ffi_capsule_host, register_ffi_target).
- Around line 629-632: The code assigns ffi_ccall_address_host then creates
ffi_capsule_host but mistakenly uses ffi_capsule_host.value (self-referential
NameError); change the capsule creation to use the previously computed address
value (ffi_ccall_address_host.value) so the lines around callback_func_host,
ffi_ccall_address_host, ffi_capsule_host and the
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") call
use ffi_ccall_address_host.value when constructing the pycapsule for the Host
callback that wraps FFI_CCALLFUNC and calls self.ffi_callback.
- Around line 226-229: The Host FFI registration references an undefined
variable: replace the erroneous creation of ffi_capsule_host (currently using
ffi_capsule_host.value) with a capsule built from the c_void_p address you just
made; specifically, after creating callback_func_host and
ffi_ccall_address_host, set ffi_capsule_host =
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then call
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") so the
capsule is created from the correct address (symbols: callback_func_host,
FFI_CCALLFUNC, ffi_ccall_address_host, ffi_capsule_host,
jax.ffi.register_ffi_target, self.name, self.ffi_callback).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: d2c93d0f-cbef-4891-b361-773ad8140f6c

📥 Commits

Reviewing files that changed from the base of the PR and between 54327e3 and 09569ba.

📒 Files selected for processing (1)
  • warp/_src/jax_experimental/ffi.py

Comment thread warp/_src/jax/ffi.py
Comment thread warp/_src/jax/ffi.py
Comment thread warp/_src/jax/ffi.py
Comment thread warp/_src/jax/ffi.py
@loney7

loney7 commented May 8, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@warp/tests/interop/test_jax.py`:
- Around line 2422-2450: The host-only test registrations for TestJax (the
add_function_test calls registering test_ffi_jax_kernel_host_add,
test_ffi_jax_kernel_host_sincos, test_ffi_jax_kernel_host_in_out,
test_ffi_jax_kernel_host_scale_vec_constant,
test_ffi_jax_callable_host_scale_constant, and
test_ffi_jax_callable_host_in_out) are currently inside the
jax_compatible_cuda_devices conditional; move these specific add_function_test
calls out of that CUDA-only if block so they are always registered on CPU-only
setups, keeping the existing device=None argument and leaving CUDA/GPU-specific
registrations inside the original conditional.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yml

Review profile: CHILL

Plan: Enterprise

Run ID: fcfb2e10-e509-41b9-b675-c8960056b6be

📥 Commits

Reviewing files that changed from the base of the PR and between d39169c and f6048f4.

📒 Files selected for processing (1)
  • warp/tests/interop/test_jax.py

Comment thread warp/tests/interop/test_jax.py Outdated
@loney7

loney7 commented May 9, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@shi-eric shi-eric requested a review from nvlukasz May 12, 2026 16:45
Comment thread warp/tests/interop/test_jax.py Outdated
Comment thread warp/_src/jax/ffi.py
@shi-eric

Copy link
Copy Markdown
Contributor

Please add a CHANGELOG.md entry under Unreleased for this JAX Host FFI behavior change.

@shi-eric

Copy link
Copy Markdown
Contributor

Please squash this PR down to a single coherent commit before merge.

@shi-eric

Copy link
Copy Markdown
Contributor

Please rebase onto current main and move the fix/tests to the promoted JAX code paths. Commit 604a8961df6d40ea64ff1e740b23581e4c72c96f promoted the JAX code from jax_experimental to jax after this PR was opened, so the final diff should target the current locations.

@nvlukasz nvlukasz left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution. Please address outstanding comments or let us know if you are unable to do so.

Comment thread warp/_src/jax/ffi.py
Comment thread warp/_src/jax_experimental/ffi.py Outdated
Comment thread warp/_src/jax_experimental/ffi.py Outdated
@loney7

loney7 commented May 19, 2026

Copy link
Copy Markdown
Author

Thank you for the contribution. Please address outstanding comments or let us know if you are unable to do so.

@nvlukasz , please allow till the end of this week to address the outstanding comments. I apologise for the delay.

@loney7 loney7 force-pushed the loney7/ffi-host-support branch 2 times, most recently from 766a819 to cc5aa99 Compare June 9, 2026 05:46
@loney7

loney7 commented Jun 9, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@loney7 loney7 force-pushed the loney7/ffi-host-support branch from cc5aa99 to c06e8d4 Compare June 9, 2026 05:48
@loney7

loney7 commented Jun 9, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@loney7 loney7 marked this pull request as ready for review June 9, 2026 07:27
@loney7 loney7 force-pushed the loney7/ffi-host-support branch from 420290b to 38f4e20 Compare June 9, 2026 07:37
@loney7

loney7 commented Jun 9, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@btaba

btaba commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Hi @nvlukasz wondering if you could take another pass :)

Ankit Jain and others added 2 commits June 9, 2026 20:11
Register FFI targets for both CUDA and Host platforms in FfiKernel,
FfiCallable, and register_ffi_callback.  The Host platform selects the
CPU device and bypasses CUDA-specific features (streams, graph capture).

FfiKernel Host path builds a ctypes ArgsStruct and calls
wp_cpu_launch_kernel with the correct 5-argument ABI.

FfiCallable Host path uses a factored-out _reconstruct_args helper and
returns early with wp.ScopedDevice, avoiding all graph-capture logic.

ExecutionContext now accepts an optional platform parameter and skips
get_stream_from_callframe on Host, preventing a crash from the
CUDA-only XLA_FFI_Stream_Get call.

Add six CPU-focused integration tests registered outside the CUDA gate.

Signed-off-by: Ankit Jain <kitsrish@google.com>
@loney7 loney7 force-pushed the loney7/ffi-host-support branch from df57595 to f752ab9 Compare June 9, 2026 20:17
@loney7

loney7 commented Jun 9, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@loney7

loney7 commented Jun 9, 2026

Copy link
Copy Markdown
Author

pre-commit.ci autofix

@greptile-apps

greptile-apps Bot commented Jun 9, 2026

Copy link
Copy Markdown

Want your agent to iterate on Greptile's feedback? Try greploops.

@btaba

btaba commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

friendly ping to @nvlukasz or @shi-eric (sorry for the trouble)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants