Conversation
|
From my side, this is ready to be merged. @mreineck or @Edenhofer can you have a look at it and merge it? |
|
After merging this, we should also make a new PyPi release. With the release of JAX 0.8.0 (which I believe is planned for November), the old custom call interface we have been using so far will be removed. |
src/_jaxbind.cc
Outdated
| PyObject* raw_ptr = reinterpret_cast<PyObject*>(func_id); | ||
| nb::handle hnd(raw_ptr); | ||
| nb::object func = nb::borrow<nb::object>(hnd); | ||
| func(py_out, py_in, py_kwargs); |
There was a problem hiding this comment.
I think that errors in the called functions need to be communicated to the caller via return codes with the new FFI.
So I think the two lines below probably need to be replaced with just
return func(py_out, py_in, py_kwargs);
Still, that doesn't catch the exceptions we may be throwing in the lines above. Perhaps we need to guard everything by try ... catch; I'll look into that.
|
Thank you very much, I like this a lot! It simplifies the C++ code substantially. Concerning the exception vs. return code issue, I'll try to get more information soon. |
|
I've pushed something; sorry, this includes a lot of white space changes ... |
|
Having better error handling is a good idea! However, I just tested it with the 01_linear_function demo by adding a The stacktrace of the original version is: XlaRuntimeError Traceback (most recent call last)
File ~/git/jaxbind/demos/01_linear_function.py:186
182 inp = inp + 1j * jax.random.uniform(subkey, shape=(10, 10), dtype=jnp.float64)
185 # apply the new primitive
--> 186 res = fftn_jax(inp)
188 # apply the new primitive and pass the keyword argument "workers=2" to the scipy fft
189 res2 = fftn_jax(inp, workers=2)
File ~/git/jaxbind/jaxbind/jaxbind.py:337, in _call(_func, *args, **kwargs)
333 def _call(*args, _func: FunctionType, **kwargs):
334 """helper function evaluating the JAX primitive for the function 'f' in
335 _func for given *args and **kwargs.
336 """
--> 337 return _prim.bind(*args, **kwargs, _func=_func)
File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:612, in Primitive.bind(self, *args, **params)
610 def bind(self, *args, **params):
611 args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 612 return self._true_bind(*args, **params)
File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:628, in Primitive._true_bind(self, *args, **params)
626 trace_ctx.set_trace(eval_trace)
627 try:
--> 628 return self.bind_with_trace(prev_trace, args, params)
629 finally:
630 trace_ctx.set_trace(prev_trace)
File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:638, in Primitive.bind_with_trace(self, trace, args, params)
635 with set_current_trace(trace):
636 return self.to_lojax(*args, **params) # type: ignore
--> 638 return trace.process_primitive(self, args, params)
File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:1162, in EvalTrace.process_primitive(self, primitive, args, params)
1160 args = map(full_lower, args)
1161 check_eval_args(args)
-> 1162 return primitive.impl(*args, **params)
File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/dispatch.py:90, in apply_primitive(prim, *args, **params)
88 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
89 try:
---> 90 outs = fun(*args)
91 finally:
92 lib.jax_jit.swap_thread_local_state_disable_jit(prev)
[... skipping hidden 5 frame]
File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1362, in ExecuteReplicated.__call__(self, *args)
1360 self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
1361 else:
-> 1362 results = self.xla_executable.execute_sharded(input_bufs)
1364 if dispatch.needs_check_special():
1365 out_arrays = results.disassemble_into_single_device_arrays()
XlaRuntimeError: UNKNOWN: XLA FFI call failed: Traceback (most recent call last):
File "/home/jakob/miniforge3/envs/nifty_rest/bin/ipython", line 8, in <module>
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/__init__.py", line 144, in start_ipython
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/traitlets/config/application.py", line 1074, in launch_instance
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/traitlets/config/application.py", line 118, in inner
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/terminal/ipapp.py", line 292, in initialize
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/core/shellapp.py", line 354, in init_code
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/core/shellapp.py", line 479, in _run_cmd_line_code
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/core/shellapp.py", line 404, in _exec_file
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 2906, in safe_execfile
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/IPython/utils/py3compat.py", line 56, in execfile
File "/home/jakob/git/jaxbind/demos/01_linear_function.py", line 186, in <module>
File "/home/jakob/git/jaxbind/jaxbind/jaxbind.py", line 337, in _call
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py", line 612, in bind
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py", line 628, in _true_bind
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py", line 638, in bind_with_trace
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py", line 1162, in process_primitive
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/pjit.py", line 270, in cache_miss
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/pjit.py", line 149, in _python_pjit_helper
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/pjit.py", line 1804, in _pjit_call_impl_python
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/profiler.py", line 364, in wrapper
File "/home/jakob/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1362, in __call__
File "/home/jakob/git/jaxbind/demos/01_linear_function.py", line 61, in fftn
RuntimeError: No active exception to reraiseand with the XlaRuntimeError Traceback (most recent call last)
File ~/git/jaxbind/demos/01_linear_function.py:186
182 inp = inp + 1j * jax.random.uniform(subkey, shape=(10, 10), dtype=jnp.float64)
185 # apply the new primitive
--> 186 res = fftn_jax(inp)
188 # apply the new primitive and pass the keyword argument "workers=2" to the scipy fft
189 res2 = fftn_jax(inp, workers=2)
File ~/git/jaxbind/jaxbind/jaxbind.py:337, in _call(_func, *args, **kwargs)
333 def _call(*args, _func: FunctionType, **kwargs):
334 """helper function evaluating the JAX primitive for the function 'f' in
335 _func for given *args and **kwargs.
336 """
--> 337 return _prim.bind(*args, **kwargs, _func=_func)
File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:612, in Primitive.bind(self, *args, **params)
610 def bind(self, *args, **params):
611 args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 612 return self._true_bind(*args, **params)
File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:628, in Primitive._true_bind(self, *args, **params)
626 trace_ctx.set_trace(eval_trace)
627 try:
--> 628 return self.bind_with_trace(prev_trace, args, params)
629 finally:
630 trace_ctx.set_trace(prev_trace)
File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:638, in Primitive.bind_with_trace(self, trace, args, params)
635 with set_current_trace(trace):
636 return self.to_lojax(*args, **params) # type: ignore
--> 638 return trace.process_primitive(self, args, params)
File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/core.py:1162, in EvalTrace.process_primitive(self, primitive, args, params)
1160 args = map(full_lower, args)
1161 check_eval_args(args)
-> 1162 return primitive.impl(*args, **params)
File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/dispatch.py:90, in apply_primitive(prim, *args, **params)
88 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
89 try:
---> 90 outs = fun(*args)
91 finally:
92 lib.jax_jit.swap_thread_local_state_disable_jit(prev)
[... skipping hidden 5 frame]
File ~/miniforge3/envs/nifty_rest/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1362, in ExecuteReplicated.__call__(self, *args)
1360 self._handle_token_bufs(result_token_bufs, sharded_runtime_token)
1361 else:
-> 1362 results = self.xla_executable.execute_sharded(input_bufs)
1364 if dispatch.needs_check_special():
1365 out_arrays = results.disassemble_into_single_device_arrays()
XlaRuntimeError: INTERNAL: Something happened; no idea whatThus in the original version one can see that the FFI call failed, and sees even the line of the python function with the raise. In the new version one only sees |
|
The easiest would be to just replace the |
|
If we can rely on Python exceptions being handled properly, then I'm absolutely in favor of the original approach! As you say, this preserves the error location and is much more convenient overall. I'm just not sure of this, since th FFI interface was (I think) not designed with the idea of calling back into Python from the compiled function. |
|
I'm not sure if we can rely on Python exceptions being handled properly. (Actually, I was surprised that the stacktrace looks so nice.) I'm unsure of the best way to proceed. |
|
I commented out the |
|
Since we rely on the new FFI now, do we have to adjust the minimum required JAX version in |
|
Ah, no, we already require |
|
Commenting out the I don't have anything else in mind that we have to do. Unless any of you has something else, we could merge this by tomorrow or so. |
|
@mreineck Thanks for your review and the tweaks! |
I will take another look at it next week, but it seems to be working.