Skip to content

Switch to JAX FFI#46

Merged
roth-jakob merged 8 commits intomainfrom
ffi_new
Sep 30, 2025
Merged

Switch to JAX FFI#46
roth-jakob merged 8 commits intomainfrom
ffi_new

Conversation

@roth-jakob
Copy link
Copy Markdown
Collaborator

I will take another look at it next week, but it seems to be working.

@roth-jakob
Copy link
Copy Markdown
Collaborator Author

From my side, this is ready to be merged. @mreineck or @Edenhofer can you have a look at it and merge it?

@roth-jakob roth-jakob mentioned this pull request Sep 28, 2025
@roth-jakob
Copy link
Copy Markdown
Collaborator Author

After merging this, we should also make a new PyPi release. With the release of JAX 0.8.0 (which I believe is planned for November), the old custom call interface we have been using so far will be removed.

src/_jaxbind.cc Outdated
PyObject* raw_ptr = reinterpret_cast<PyObject*>(func_id);
nb::handle hnd(raw_ptr);
nb::object func = nb::borrow<nb::object>(hnd);
func(py_out, py_in, py_kwargs);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that errors in the called functions need to be communicated to the caller via return codes with the new FFI.
So I think the two lines below probably need to be replaced with just
return func(py_out, py_in, py_kwargs);
Still, that doesn't catch the exceptions we may be throwing in the lines above. Perhaps we need to guard everything by try ... catch; I'll look into that.

@mreineck
Copy link
Copy Markdown
Collaborator

Thank you very much, I like this a lot! It simplifies the C++ code substantially.

Concerning the exception vs. return code issue, I'll try to get more information soon.

@mreineck
Copy link
Copy Markdown
Collaborator

I've pushed something; sorry, this includes a lot of white space changes ...
This is just a framework for more detailed error handling, but I think we can leave it for now. Please tell me what you think!

@roth-jakob
Copy link
Copy Markdown
Collaborator Author

Having better error handling is a good idea! However, I just tested it with the 01_linear_function demo by adding a raise in line 61, and the stacktrace seems to have become less understandable.

The stacktrace of the original version is:

XlaRuntimeError                           Traceback (most recent call last)
File ~/git/jaxbind/demos/01_linear_function.py:186
    182 inp = inp + 1j * jax.random.uniform(subkey, shape=(10, 10), dtype=jnp.float64)
    185 # apply the new primitive
--> 186 res = fftn_jax(inp)
    188 # apply the new primitive and pass the keyword argument "workers=2" to the scipy fft
    189 res2 = fftn_jax(inp, workers=2)

File ~/git/jaxbind/jaxbind/jaxbind.py:337, in _call(_func, *args, **kwargs)
    333 def _call(*args, _func: FunctionType, **kwargs):
    334     """helper function evaluating the JAX primitive for the function 'f' in
    335     _func for given *args and **kwargs.
    336     """
--> 337     return _prim.bind(*args, **kwargs, _func=_func)

File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:612, in Primitive.bind(self, *args, **params)
    610 def bind(self, *args, **params):
    611   args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 612   return self._true_bind(*args, **params)

File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:628, in Primitive._true_bind(self, *args, **params)
    626 trace_ctx.set_trace(eval_trace)
    627 try:
--> 628   return self.bind_with_trace(prev_trace, args, params)
    629 finally:
    630   trace_ctx.set_trace(prev_trace)

File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:638, in Primitive.bind_with_trace(self, trace, args, params)
    635   with set_current_trace(trace):
    636     return self.to_lojax(*args, **params)  # type: ignore
--> 638 return trace.process_primitive(self, args, params)

File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:1162, in EvalTrace.process_primitive(self, primitive, args, params)
   1160 args = map(full_lower, args)
   1161 check_eval_args(args)
-> 1162 return primitive.impl(*args, **params)

File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/dispatch.py:90, in apply_primitive(prim, *args, **params)
     88 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     89 try:
---> 90   outs = fun(*args)
     91 finally:
     92   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 5 frame]

File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1362, in ExecuteReplicated.__call__(self, *args)
   1360   self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1361 else:
-> 1362   results = self.xla_executable.execute_sharded(input_bufs)
   1364 if dispatch.needs_check_special():
   1365   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: UNKNOWN: XLA FFI call failed: Traceback (most recent call last):
  File "/home/jakob/miniforge3/envs/nifty_rest/bin/ipython", line 8, in <module>
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/__init__.py", line 144, in start_ipython
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/traitlets/config/application.py", line 1074, in launch_instance
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/traitlets/config/application.py", line 118, in inner
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/terminal/ipapp.py", line 292, in initialize
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/core/shellapp.py", line 354, in init_code
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/core/shellapp.py", line 479, in _run_cmd_line_code
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/core/shellapp.py", line 404, in _exec_file
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 2906, in safe_execfile
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/utils/py3compat.py", line 56, in execfile
  File "/home/jakob/git/jaxbind/demos/01_linear_function.py", line 186, in <module>
  File "/home/jakob/git/jaxbind/jaxbind/jaxbind.py", line 337, in _call
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py", line 612, in bind
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py", line 628, in _true_bind
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py", line 638, in bind_with_trace
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py", line 1162, in process_primitive
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/pjit.py", line 270, in cache_miss
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/pjit.py", line 149, in _python_pjit_helper
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/pjit.py", line 1804, in _pjit_call_impl_python
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/profiler.py", line 364, in wrapper
  File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1362, in __call__
  File "/home/jakob/git/jaxbind/demos/01_linear_function.py", line 61, in fftn
RuntimeError: No active exception to reraise

and with the try catch I get the following stacktrace:

XlaRuntimeError                           Traceback (most recent call last)
File ~/git/jaxbind/demos/01_linear_function.py:186
    182 inp = inp + 1j * jax.random.uniform(subkey, shape=(10, 10), dtype=jnp.float64)
    185 # apply the new primitive
--> 186 res = fftn_jax(inp)
    188 # apply the new primitive and pass the keyword argument "workers=2" to the scipy fft
    189 res2 = fftn_jax(inp, workers=2)

File ~/git/jaxbind/jaxbind/jaxbind.py:337, in _call(_func, *args, **kwargs)
    333 def _call(*args, _func: FunctionType, **kwargs):
    334     """helper function evaluating the JAX primitive for the function 'f' in
    335     _func for given *args and **kwargs.
    336     """
--> 337     return _prim.bind(*args, **kwargs, _func=_func)

File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:612, in Primitive.bind(self, *args, **params)
    610 def bind(self, *args, **params):
    611   args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 612   return self._true_bind(*args, **params)

File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:628, in Primitive._true_bind(self, *args, **params)
    626 trace_ctx.set_trace(eval_trace)
    627 try:
--> 628   return self.bind_with_trace(prev_trace, args, params)
    629 finally:
    630   trace_ctx.set_trace(prev_trace)

File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:638, in Primitive.bind_with_trace(self, trace, args, params)
    635   with set_current_trace(trace):
    636     return self.to_lojax(*args, **params)  # type: ignore
--> 638 return trace.process_primitive(self, args, params)

File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:1162, in EvalTrace.process_primitive(self, primitive, args, params)
   1160 args = map(full_lower, args)
   1161 check_eval_args(args)
-> 1162 return primitive.impl(*args, **params)

File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/dispatch.py:90, in apply_primitive(prim, *args, **params)
     88 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     89 try:
---> 90   outs = fun(*args)
     91 finally:
     92   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 5 frame]

File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1362, in ExecuteReplicated.__call__(self, *args)
   1360   self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
   1361 else:
-> 1362   results = self.xla_executable.execute_sharded(input_bufs)
   1364 if dispatch.needs_check_special():
   1365   out_arrays = results.disassemble_into_single_device_arrays()

XlaRuntimeError: INTERNAL: Something happened; no idea what

Thus in the original version one can see that the FFI call failed, and sees even the line of the python function with the raise. In the new version one only sees XLARuntimeError:INTERNAL giving no information that it was the FFI call which caused the error. Do you have any idea why it is like that, or how we could improve it?

@roth-jakob
Copy link
Copy Markdown
Collaborator Author

The easiest would be to just replace the Something happened; no idea what with something like Runtime error in JAXbind call. But maybe there is a way to also get the stacktrace of the python function as it is the case in the original version.

@mreineck
Copy link
Copy Markdown
Collaborator

If we can rely on Python exceptions being handled properly, then I'm absolutely in favor of the original approach! As you say, this preserves the error location and is much more convenient overall. I'm just not sure of this, since th FFI interface was (I think) not designed with the idea of calling back into Python from the compiled function.
But as your test case shows: it works nicely, at least at the moment,so it is probably best to revert my change for now. I can do that in the afternoon.

@roth-jakob
Copy link
Copy Markdown
Collaborator Author

I'm not sure if we can rely on Python exceptions being handled properly. (Actually, I was surprised that the stacktrace looks so nice.)

I'm unsure of the best way to proceed.

@mreineck
Copy link
Copy Markdown
Collaborator

I commented out the try/catch, and I think we can proceed with things as they are, at least for the moment. Giving the user the clearest possible feedback is valuable.

@mreineck
Copy link
Copy Markdown
Collaborator

Since we rely on the new FFI now, do we have to adjust the minimum required JAX version in pyproject.toml?

@mreineck
Copy link
Copy Markdown
Collaborator

Ah, no, we already require jax >= 0.5 ... strange.

@roth-jakob
Copy link
Copy Markdown
Collaborator Author

Commenting out the try/catch sounds like a good idea for now. Yes, jax >= 0.5 already includes the FFI interface.

I don't have anything else in mind that we have to do. Unless any of you has something else, we could merge this by tomorrow or so.

@roth-jakob
Copy link
Copy Markdown
Collaborator Author

@mreineck Thanks for your review and the tweaks!

@roth-jakob roth-jakob merged commit 19a38b3 into main Sep 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants