Conversation
|
I can look at the Python side. However, first, I need to check the FFI documentation. |
|
I have been knee deep in the FFI docs for a few days to understand how to use it, so happy to pitch on the Python side. |
|
I don't have the capacity at the moment to contribute code to this. I have an insufficient understanding of JAX's FFI code and I don't see ad-hoc how the |
|
Adding @nschaeff, since he also indicated interest in this project. |
|
Personally, I can't tell whether adding GPU support will become easier if we switch to FFI; perhaps we have to find out by trying. |
|
Hello, What I do not see right now is how to pass gpu pointers corresponding to jax arrays to some custom function. All this does not apply to CPU, as Numpy arrays are more or less directly interchangeable with JAX arrays. Also, one needs to pass a cudaStream, which although not difficult to handle, needs to be taken care of. |
|
OK, I just read this: Which mentions that pydlpack may allow us to exchange data between libraries (jax, cupy, ...) without copy: Is this the way to go for jaxbind? |
That's correct. Supporting this has become possible by the switch to
No, the idea is to get the relevant pointers directly from JAX and use the data without copying. It would be very similar to what we have on the CPU, with the only difference that the pointers refer to memory on a device (we probably need to communicate device IDs in addition to what we currently pass through the interface). The underlying Python object would most likely be a Cupy array, and I think they can also be mapped onto JAX array structures in the same way we use for Numpy arrays. |
|
|
|
If only |
|
Apparently, JAX suports the See: |
|
The problem is that we intercept the JAX callback within C++ code, not Python, because that's the only supported way; Python packages won't help us. In C++, we get the arguments as |
|
Superseded by #46, closing. |
I have started adding FFI-related C++ code to the repo. At the moment this is just additional, uncalled code, but at least it compiles.
To proceed, we need someone to make the necessary changes on the Pyton side (I don't feel qualified for this...), and then we need to discuss how to pass the function pointer, number of input and result arrays, and the
kwargsover the interface. As far as I understand, this can perhaps be done in a much simpler way than the one we are currently using.@Edenhofer @roth-jakob @SepandKashani, any comment is much appreciated!