forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 2
add pull_request event #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
skye
wants to merge
26
commits into
main
Choose a base branch
from
tpu_presubmit
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This ensures all existing JAX buffer types have a `delete` method that can be used to free device buffer allocation eagerly. User code sometimes have lingering python refs due to cyclic deps and other reasons, yet users may know for sure that certain arrays will no longer be used after a certain point. Calling `foo_array.delete()` for DeviceArray/ShardedDeviceArray/GlobalDeviceArray/Array allows users to force free the device side allocation to minimize device memory usage. PiperOrigin-RevId: 482892157
…e in CUDA 11.1 PiperOrigin-RevId: 482897448
PiperOrigin-RevId: 482902569
PiperOrigin-RevId: 482903592
PiperOrigin-RevId: 482905407
…ther than trivial computation. PiperOrigin-RevId: 482919649
PiperOrigin-RevId: 482945880
The shape function of DotGeneralOp can't be integrated into MHLO yet: the shape function only predicts return shape but not able to predict element type. However, the current python binding infra will generate the constructor __init__() without the `return` as the first arg, which assumes the shape function can provide a fully inferred type (including an accurate element type). This leads to "inferred type does not match actual result type" errors in JAX. This needs a future solution. This CL is the corresponding change with openxla/stablehlo#269 Related Python __init__() interface changes (used by JAX): batch_norm_grad: not used by JAX batch_norm_inference: not used by JAX batch_norm_training: not used by JAX case: no change* dot_general: open new b/253644255 to track the issue if: no change* map: no change* reduce: no change* reduce_window: no change* sort: no change* triangular_solve: updated in `linalg.py` while: no change* no change*: the signature of __init()__ for the op is not changed because of existence of regions https://github.com/llvm/llvm-project/blob/main/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp#L577 PiperOrigin-RevId: 482951512
fix some shape and type issues import into namespace imports into non-_src library working logpdf test cleanup working tests for cdf and sf after fixing select relax need for x to be in (a, b) ensure behavior with invalid input matches scipy remove enforcing valid parameters in tests added truncnorm to docs whoops alphabetical fix linter error fix circular import issue
No functional changes intended. PiperOrigin-RevId: 483413031
PiperOrigin-RevId: 483425197
PiperOrigin-RevId: 483425722
overriden -> overridden
PiperOrigin-RevId: 483471278
PiperOrigin-RevId: 483490497
PiperOrigin-RevId: 483493703
skye
pushed a commit
that referenced
this pull request
Apr 15, 2025
When run under an optimized build and Python 3.13.2t, I saw the following high probability crash in lax_control_flow_test: ``` Stack trace of thread 3526917: #0 0x00007f0898c4bf91 dump_frame (libpython3.13t.so.1.0 + 0x24bf91) #1 0x00007f0898c4b73f dump_traceback (libpython3.13t.so.1.0 + 0x24b73f) #2 0x00007f0898c4b86f _Py_DumpTracebackThreads (libpython3.13t.so.1.0 + 0x24b86f) #3 0x00007f0898cd4fe0 faulthandler_dump_traceback (libpython3.13t.so.1.0 + 0x2d4fe0) #4 0x00007f0898cd4f44 faulthandler_fatal_error (libpython3.13t.so.1.0 + 0x2d4f44) #5 0x00007f0898849e20 __restore_rt (libc.so.6 + 0x3fe20) #6 0x00007f07eb80e493 _ZNSt8__detail16_Hashtable_allocISaINS_10_Hash_nodeISt4pairIKN3jax15WeakrefLRUCache15WeakrefCacheKeyENS4_17WeakrefCacheValueEELb1EEEEE18_M_deallocate_nodeEPS9_ (libjax_common.so + 0x2c0e493) #7 0x00007f07eb80e13e _ZN3jax15WeakrefLRUCache5ClearEv (libjax_common.so + 0x2c0e13e) #8 0x00007f07eb812e37 _ZZN8nanobind6detail11func_createILb0ELb1EZNS_16cpp_function_defIN3jax15WeakrefLRUCacheEvS4_JEJNS_5scopeENS_4nameENS_9is_methodENS_9lock_selfEEEEvMT1_FT0_DpT2_EDpRKT3_EUlPS4_E_vJSJ_EJLm0EEJS5_S6_S7_S8_EEEP> jax-ml#9 0x00007f07eb7fff70 _ZN8nanobind6detailL25nb_func_vectorcall_simpleEP7_objectPKS2_mS2_ (libjax_common.so + 0x2bfff70) jax-ml#10 0x00007f0898dbbdee _PyObject_VectorcallTstate (libpython3.13t.so.1.0 + 0x3bbdee) jax-ml#11 0x00007f0898d1d4db _PyEval_EvalFrame (libpython3.13t.so.1.0 + 0x31d4db) jax-ml#12 0x00007f0898d1ee78 _PyObject_VectorcallTstate (libpython3.13t.so.1.0 + 0x31ee78) jax-ml#13 0x00007f0898dc0054 _PyVectorcall_Call (libpython3.13t.so.1.0 + 0x3c0054) jax-ml#14 0x00007f0898d1d4db _PyEval_EvalFrame (libpython3.13t.so.1.0 + 0x31d4db) jax-ml#15 0x00007f0898d1e02c _PyObject_VectorcallDictTstate (libpython3.13t.so.1.0 + 0x31e02c) jax-ml#16 0x00007f0898ed8e35 slot_tp_call (libpython3.13t.so.1.0 + 0x4d8e35) jax-ml#17 0x00007f0898dbc312 _PyObject_MakeTpCall (libpython3.13t.so.1.0 + 0x3bc312) jax-ml#18 0x00007f0898d1d4db _PyEval_EvalFrame (libpython3.13t.so.1.0 + 0x31d4db) jax-ml#19 0x00007f0898d1ef54 _PyObject_VectorcallTstate (libpython3.13t.so.1.0 + 0x31ef54) jax-ml#20 0x00007f0899094c1f thread_run (libpython3.13t.so.1.0 + 0x694c1f) jax-ml#21 0x00007f0898fa0c58 pythread_wrapper (libpython3.13t.so.1.0 + 0x5a0c58) jax-ml#22 0x00007f089889c103 start_thread (libc.so.6 + 0x92103) jax-ml#23 0x00007f089891a7b8 __clone3 (libc.so.6 + 0x1107b8) ``` It appears that this is due to freeing Python objects during unordered_map::clear(), which may release the enclosing critical section (`nb::lock_self()` on the method). Fix this by deferring destruction of the both the keys and the values to after the map's destruction.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.