-
Notifications
You must be signed in to change notification settings - Fork 52
feat: add API specification for returning the k
largest elements
#722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a2e33f9
30900eb
e5d3189
76873d8
07e62e9
efb985d
96461fc
c72d334
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,4 +25,5 @@ Objects in API | |
count_nonzero | ||
nonzero | ||
searchsorted | ||
top_k | ||
where |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,15 @@ | ||
__all__ = ["argmax", "argmin", "count_nonzero", "nonzero", "searchsorted", "where"] | ||
__all__ = [ | ||
"argmax", | ||
"argmin", | ||
"count_nonzero", | ||
"nonzero", | ||
"searchsorted", | ||
"top_k", | ||
"where", | ||
] | ||
|
||
|
||
from ._types import Optional, Tuple, Literal, Union, array | ||
from ._types import Optional, Literal, Tuple, Union, array | ||
|
||
|
||
def argmax(x: array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> array: | ||
|
@@ -168,6 +176,50 @@ def searchsorted( | |
""" | ||
|
||
|
||
def top_k( | ||
x: array, | ||
k: int, | ||
/, | ||
*, | ||
axis: Optional[int] = None, | ||
mode: Literal["largest", "smallest"] = "largest", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the value of using a string literal for this toggle? Are we anticipating other options? Why is string literals better than using an enum? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we're expecting other options. String literals is simply a standard way of doing this when only certain string values are valid.
Enums are awful for defining a public API. You'll need to make the enums themselves public so the users of your API can use them, meaning you will increase the size of your API for every argument with a fixed set of options you add. Keeping the API surface small and easy to understand is an important goal of this standard - it's mostly functions and a few constants and other objects. |
||
) -> Tuple[array, array]: | ||
""" | ||
Returns the values and indices of the ``k`` largest (or smallest) elements of an input array ``x`` along a specified dimension. | ||
|
||
Parameters | ||
---------- | ||
x: array | ||
input array. Should have a real-valued data type. | ||
k: int | ||
number of elements to find. Must be a positive integer value. | ||
axis: Optional[int] | ||
axis along which to search. If ``None``, the function must search the flattened array. Default: ``None``. | ||
mode: Literal['largest', 'smallest'] | ||
search mode. Must be one of the following modes: | ||
|
||
- ``'largest'``: return the ``k`` largest elements. | ||
- ``'smallest'``: return the ``k`` smallest elements. | ||
|
||
Default: ``'largest'``. | ||
|
||
Returns | ||
------- | ||
out: Tuple[array, array] | ||
a namedtuple ``(values, indices)`` whose | ||
|
||
- first element must have the field name ``values`` and must be an array containing the ``k`` largest (or smallest) elements of ``x``. The array must have the same data type as ``x``. If ``axis`` is ``None``, the array must be a one-dimensional array having shape ``(k,)``; otherwise, if ``axis`` is an integer value, the array must have the same rank (number of dimensions) and shape as ``x``, except for the axis specified by ``axis`` which must have size ``k``. | ||
- second element must have the field name ``indices`` and must be an array containing indices of ``x`` that result in ``values``. The array must have the same shape as ``values`` and must have the default array index data type. If ``axis`` is ``None``, ``indices`` must be the indices of a flattened ``x``. | ||
|
||
Notes | ||
----- | ||
|
||
- If ``k`` exceeds the number of elements in ``x`` or along the axis specified by ``axis``, behavior is left unspecified and thus implementation-dependent. Conforming implementations may choose, e.g., to raise an exception or return all elements. | ||
- The order of the returned values and indices is left unspecified and thus implementation-dependent. Conforming implementations may return sorted or unsorted values. | ||
- Conforming implementations may support complex numbers; however, inequality comparison of complex numbers is unspecified and thus implementation-dependent (see :ref:`complex-number-ordering`). | ||
""" | ||
|
||
|
||
def where(condition: array, x1: array, x2: array, /) -> array: | ||
""" | ||
Returns elements chosen from ``x1`` or ``x2`` depending on ``condition``. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally this look good. Other implementations I've seen pass a DeviceContext as well but I'm not sure if we want to tackle that as part of the initial implementation.