Add JAX FFI Host support#1446
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis pull request extends JAX experimental FFI callbacks to support both CUDA and Host (CPU) execution paths. Callbacks now accept a ChangesCUDA and Host FFI Execution
🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Greptile SummaryThis PR extends JAX FFI support to the CPU (Host) platform by registering dual CUDA/Host callbacks in
Confidence Score: 4/5The PR is safe to merge; the three previously flagged bugs (NameError on ffi_capsule_host, wrong wp_cpu_launch_kernel argument order, and unconditional CUDA stream fetch in ExecutionContext) are all correctly resolved in this revision. The CPU kernel launch path is new and non-trivial: it builds a ctypes ArgsStruct from FFI buffer pointers, calls wp_cpu_launch_kernel, and uses a separate _reconstruct_args helper for FfiCallable. The structure closely mirrors the existing context.py reference implementation and the CUDA path, and six new tests cover the key variants. The one remaining gap is that ArgsStruct type construction is not cached in FfiKernel's Host path the way context.py does it, which is a performance concern on repeated calls but does not affect correctness. No new correctness bugs were identified. warp/_src/jax/ffi.py — specifically the new CPU kernel launch block (FfiKernel.ffi_callback else-branch) where ArgsStruct is rebuilt on every invocation. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[JAX FFI call] --> B{platform?}
B -->|CUDA| C[FfiKernel / FfiCallable CUDA path]
B -->|Host| D[FfiKernel / FfiCallable Host path]
C --> C1[get_stream_from_callframe]
C1 --> C2[wp_cuda_launch_kernel or func + stream]
D --> D1[wp.get_device cpu]
D1 --> D2{class?}
D2 -->|FfiKernel| D3[Build ArgsStruct from arg_refs]
D3 --> D4[wp_cpu_launch_kernel with hooks.forward and args_struct]
D2 -->|FfiCallable| D5[_reconstruct_args: wp.array from buffer ptrs]
D5 --> D6[ScopedDevice cpu: func with arg_list]
subgraph ExecutionContext
E{platform?}
E -->|CUDA| F[stream = get_stream_from_callframe]
E -->|Host| G[stream = None]
end
Reviews (11): Last reviewed commit: "Merge branch 'main' into loney7/ffi-host..." | Re-trigger Greptile |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@warp/_src/jax_experimental/ffi.py`:
- Around line 1772-1777: In register_ffi_callback, the Host ffi capsule is
constructed from the wrong variable (ffi_capsule_host) causing a NameError;
change the construction to use the host ccall address value by calling
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then register that capsule
with jax.ffi.register_ffi_target(name, ffi_capsule_host, platform="Host") so the
Host path mirrors the CUDA path (refer to ffi_ccall_address_host,
ffi_capsule_host, register_ffi_target).
- Around line 629-632: The code assigns ffi_ccall_address_host then creates
ffi_capsule_host but mistakenly uses ffi_capsule_host.value (self-referential
NameError); change the capsule creation to use the previously computed address
value (ffi_ccall_address_host.value) so the lines around callback_func_host,
ffi_ccall_address_host, ffi_capsule_host and the
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") call
use ffi_ccall_address_host.value when constructing the pycapsule for the Host
callback that wraps FFI_CCALLFUNC and calls self.ffi_callback.
- Around line 226-229: The Host FFI registration references an undefined
variable: replace the erroneous creation of ffi_capsule_host (currently using
ffi_capsule_host.value) with a capsule built from the c_void_p address you just
made; specifically, after creating callback_func_host and
ffi_ccall_address_host, set ffi_capsule_host =
jax.ffi.pycapsule(ffi_ccall_address_host.value) and then call
jax.ffi.register_ffi_target(self.name, ffi_capsule_host, platform="Host") so the
capsule is created from the correct address (symbols: callback_func_host,
FFI_CCALLFUNC, ffi_ccall_address_host, ffi_capsule_host,
jax.ffi.register_ffi_target, self.name, self.ffi_callback).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: d2c93d0f-cbef-4891-b361-773ad8140f6c
📒 Files selected for processing (1)
warp/_src/jax_experimental/ffi.py
|
pre-commit.ci autofix |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@warp/tests/interop/test_jax.py`:
- Around line 2422-2450: The host-only test registrations for TestJax (the
add_function_test calls registering test_ffi_jax_kernel_host_add,
test_ffi_jax_kernel_host_sincos, test_ffi_jax_kernel_host_in_out,
test_ffi_jax_kernel_host_scale_vec_constant,
test_ffi_jax_callable_host_scale_constant, and
test_ffi_jax_callable_host_in_out) are currently inside the
jax_compatible_cuda_devices conditional; move these specific add_function_test
calls out of that CUDA-only if block so they are always registered on CPU-only
setups, keeping the existing device=None argument and leaving CUDA/GPU-specific
registrations inside the original conditional.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yml
Review profile: CHILL
Plan: Enterprise
Run ID: fcfb2e10-e509-41b9-b675-c8960056b6be
📒 Files selected for processing (1)
warp/tests/interop/test_jax.py
|
pre-commit.ci autofix |
|
Please add a |
|
Please squash this PR down to a single coherent commit before merge. |
|
Please rebase onto current |
nvlukasz
left a comment
There was a problem hiding this comment.
Thank you for the contribution. Please address outstanding comments or let us know if you are unable to do so.
@nvlukasz , please allow till the end of this week to address the outstanding comments. I apologise for the delay. |
766a819 to
cc5aa99
Compare
|
pre-commit.ci autofix |
cc5aa99 to
c06e8d4
Compare
|
pre-commit.ci autofix |
420290b to
38f4e20
Compare
|
pre-commit.ci autofix |
|
Hi @nvlukasz wondering if you could take another pass :) |
Register FFI targets for both CUDA and Host platforms in FfiKernel, FfiCallable, and register_ffi_callback. The Host platform selects the CPU device and bypasses CUDA-specific features (streams, graph capture). FfiKernel Host path builds a ctypes ArgsStruct and calls wp_cpu_launch_kernel with the correct 5-argument ABI. FfiCallable Host path uses a factored-out _reconstruct_args helper and returns early with wp.ScopedDevice, avoiding all graph-capture logic. ExecutionContext now accepts an optional platform parameter and skips get_stream_from_callframe on Host, preventing a crash from the CUDA-only XLA_FFI_Stream_Get call. Add six CPU-focused integration tests registered outside the CUDA gate. Signed-off-by: Ankit Jain <kitsrish@google.com>
df57595 to
f752ab9
Compare
|
pre-commit.ci autofix |
|
pre-commit.ci autofix |
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
Description
This PR adds support for running JAX FFI callbacks on the CPU (Host) in addition to CUDA.
Changes:
"CUDA"and"Host"platforms inregister_ffi_callback."Host"platform case inffi_callbackby using the CPU device and bypassing CUDA-specific features like streams and graphs.FfiCallableto reconstruct arguments and execute the function on the CPU when running on the Host platform.Checklist
Test plan
You can verify these changes by running the JAX interop tests which include FFI tests. Ensure they pass on both CPU and GPU (if available).