Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Row indexing a dataset with numpy integers #7423

Open
DavidRConnell opened this issue Feb 25, 2025 · 1 comment
Open

Row indexing a dataset with numpy integers #7423

DavidRConnell opened this issue Feb 25, 2025 · 1 comment
Labels
enhancement New feature or request

Comments

@DavidRConnell
Copy link

Feature request

Allow indexing datasets with a scalar numpy integer type.

Motivation

Indexing a dataset with a scalar numpy.int* object raises a TypeError. This is due to the test in datasets/formatting/formatting.py:key_to_query_type

def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str:
    if isinstance(key, int):
        return "row"
    elif isinstance(key, str):
        return "column"
    elif isinstance(key, (slice, range, Iterable)):
        return "batch"
    _raise_bad_key_type(key)

In the row case, it checks if key is an int, which returns false when key is integer like but not a builtin python integer type. This is counterintuitive because a numpy array of np.int64s can be used for the batch case.

For example:

import numpy as np

import datasets

dataset = datasets.Dataset.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})

# Regular indexing
dataset[0]
dataset[:2]

# Indexing with numpy data types (expect same results)
idx = np.asarray([0, 1])
dataset[idx]  # Succeeds when using an array of np.int64 values
dataset[idx[0]]  # Fails with TypeError when using scalar np.int64

For the user, this can be solved by wrapping idx[0] in int but the test could also be changed in key_to_query_type to accept a less strict definition of int.

+import numbers
+
def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str:
+   if isinstance(key, numbers.Integral):
-   if isinstance(key, int):
        return "row"
    elif isinstance(key, str):
        return "column"
    elif isinstance(key, (slice, range, Iterable)):
        return "batch"
    _raise_bad_key_type(key)

Looking at how others do it, pandas has an is_integer definition that it checks which uses is_integer_object defined in pandas/_libs/utils.pxd:

cdef inline bint is_integer_object(object obj) noexcept:
    """
    Cython equivalent of

    `isinstance(val, (int, np.integer)) and not isinstance(val, (bool, np.timedelta64))`

    Parameters
    ----------
    val : object

    Returns
    -------
    is_integer : bool

    Notes
    -----
    This counts np.timedelta64 objects as integers.
    """
    return (not PyBool_Check(obj) and isinstance(obj, (int, cnp.integer))
            and not is_timedelta64_object(obj))

This would be less flexible as it explicitly checks for numpy integer, but worth noting that they had the need to ensure the key is not a bool.

Your contribution

I can submit a pull request with the above changes after checking that indexing succeeds with the numpy integer type. Or if there is a different integer check that would be preferred I could add that.

If there is a reason not to want this behavior that is fine too.

@DavidRConnell DavidRConnell added the enhancement New feature or request label Feb 25, 2025
@lhoestq
Copy link
Member

lhoestq commented Mar 3, 2025

Would be cool to be consistent when it comes to indexing with numpy objects, if we do accept numpy arrays we should indeed accept numpy integers. Your idea sounds reasonable, I'd also be in favor of adding a simple test as well

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants