Skip to content

WIP: FFI support#43

Closed
mreineck wants to merge 1 commit intomainfrom
ffi
Closed

WIP: FFI support#43
mreineck wants to merge 1 commit intomainfrom
ffi

Conversation

@mreineck
Copy link
Copy Markdown
Collaborator

@mreineck mreineck commented Apr 7, 2025

I have started adding FFI-related C++ code to the repo. At the moment this is just additional, uncalled code, but at least it compiles.
To proceed, we need someone to make the necessary changes on the Pyton side (I don't feel qualified for this...), and then we need to discuss how to pass the function pointer, number of input and result arrays, and the kwargs over the interface. As far as I understand, this can perhaps be done in a much simpler way than the one we are currently using.

@Edenhofer @roth-jakob @SepandKashani, any comment is much appreciated!

@roth-jakob
Copy link
Copy Markdown
Collaborator

I can look at the Python side. However, first, I need to check the FFI documentation.

@SepandKashani
Copy link
Copy Markdown

I have been knee deep in the FFI docs for a few days to understand how to use it, so happy to pitch on the Python side.
I can't commit to any updates this week since I'm traveling, but I'll have some input next week.

@Edenhofer
Copy link
Copy Markdown
Contributor

I don't have the capacity at the moment to contribute code to this. I have an insufficient understanding of JAX's FFI code and I don't see ad-hoc how the kwargs and pointer parsing will be much simpler. I'm happy to test any changes though if that would be helpful!

@mreineck
Copy link
Copy Markdown
Collaborator Author

Adding @nschaeff, since he also indicated interest in this project.

@mreineck
Copy link
Copy Markdown
Collaborator Author

Personally, I can't tell whether adding GPU support will become easier if we switch to FFI; perhaps we have to find out by trying.
In any case, I'm pretty sure that all the strictly necessary prerequisites for GPU support are in place already now, which is mainly nanobind. So we can also try to build directly on what we have now. Opinions?

@nschaeff
Copy link
Copy Markdown

nschaeff commented Apr 16, 2025

Hello,

What I do not see right now is how to pass gpu pointers corresponding to jax arrays to some custom function.
Or do we need to copy data first?
Same thing for the output data, how is it referenced or copied by jax? What underlying python object carries it?

All this does not apply to CPU, as Numpy arrays are more or less directly interchangeable with JAX arrays.

Also, one needs to pass a cudaStream, which although not difficult to handle, needs to be taken care of.

@nschaeff
Copy link
Copy Markdown

OK, I just read this:
jax-ml/jax#1100 (comment)

Which mentions that pydlpack may allow us to exchange data between libraries (jax, cupy, ...) without copy:
https://github.com/pearu/pydlpack

Is this the way to go for jaxbind?

@mreineck
Copy link
Copy Markdown
Collaborator Author

mreineck commented Apr 16, 2025

What I do not see right now is how to pass gpu pointers corresponding to jax arrays to some custom function.

That's correct. Supporting this has become possible by the switch to nanobind, but the support is not implemented yet.

Or do we need to copy data first?

No, the idea is to get the relevant pointers directly from JAX and use the data without copying. It would be very similar to what we have on the CPU, with the only difference that the pointers refer to memory on a device (we probably need to communicate device IDs in addition to what we currently pass through the interface).

The underlying Python object would most likely be a Cupy array, and I think they can also be mapped onto JAX array structures in the same way we use for Numpy arrays.

@mreineck
Copy link
Copy Markdown
Collaborator Author

mreineck commented Apr 16, 2025

pydlpack sounds like a very promising package, since nanobind internally also works with dlpack-conforming array structures.

@mreineck
Copy link
Copy Markdown
Collaborator Author

If only jax provided a custom call interface that can call Python functions directly ... then all this low-level acrobatics could go away.

@nschaeff
Copy link
Copy Markdown

nschaeff commented Apr 16, 2025

Apparently, JAX suports the __cuda_array_interface__, just as cupy and pytorch, numba, and probably others.
This would then be a good common ground, and maybe simpler to deal with than pydlpack ?

See:
https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html
and
jax-ml/jax#16440 (this is a fixed issue, but it tells us that this is actually supported in jax)

@mreineck
Copy link
Copy Markdown
Collaborator Author

mreineck commented Apr 16, 2025

The problem is that we intercept the JAX callback within C++ code, not Python, because that's the only supported way; Python packages won't help us. In C++, we get the arguments as void ** and void * (if we are using the old custom call interface), or as xla::ffi::AnyBuffer objects if we use the FFI. Starting from there, I don't think there is any convenient mechanism to turn that into numpy, cupy or torch arrays, we have to do it manually.

@roth-jakob
Copy link
Copy Markdown
Collaborator

Building on this PR, I have completed the switch to the FFI interface in #46. #46 still only supports CPU custom call functions. We can work on adding GPU support in a separate PR.

@mreineck
Copy link
Copy Markdown
Collaborator Author

Superseded by #46, closing.

@mreineck mreineck closed this Sep 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants