-
Notifications
You must be signed in to change notification settings - Fork 149
Numba AdvancedIndexing: Complete support for integer advanced indexing #1778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR completes support for integer advanced indexing in PyTensor's Numba backend by implementing a comprehensive dispatch mechanism that handles all forms of advanced indexing with basic and vector integer indices, including non-consecutive indices and mixed basic/advanced indexing patterns.
Key Changes:
- Replaces specialized multidimensional bool/int index raveling rewrites with a simpler
bool_idx_to_nonzerorewrite that converts boolean indices to integer indices vianonzero() - Implements a new codegen-based
vector_integer_advanced_indexingdispatcher that generates custom Numba functions for each indexing pattern, eliminating most object mode fallbacks - Updates tests to remove the
objmode_neededparameter and consolidates test parameters forignore_duplicatesbehavior
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
| tests/link/numba/test_subtensor.py | Removes objmode_needed parameter from tests and consolidates duplicate index test parameters, reflecting that most indexing operations no longer require object mode fallback |
| pytensor/tensor/rewriting/subtensor.py | Simplifies boolean indexing rewrite by converting to nonzero() instead of raveling, and removes complex multidimensional integer indexing rewrites that are now handled by the dispatcher |
| pytensor/link/numba/dispatch/subtensor.py | Implements comprehensive codegen-based dispatcher for advanced indexing that handles transpose operations, broadcasting, and index raveling at runtime without requiring object mode |
59b83a4 to
ebef3b1
Compare
…) advanced indexing When default `ignore_updates=True` for inc_subtensor, and boolean indices were rewritten during specialize
ebef3b1 to
61a9309
Compare
(and mixed basic indexing)
This covers the cases with default
ignore_updates=Falsefor inc_subtensor,Boolean indices are converted to integer indices during specialize rewriting.
This was motivated by #811 where it found the logic for when object fallback was needed was wrong.
Fixing it, revealed we had less cases covered by pure numba than thought, hence this PR.