diff --git a/scipy-stubs/_lib/_disjoint_set.pyi b/scipy-stubs/_lib/_disjoint_set.pyi index c71e0179..5aeeda93 100644 --- a/scipy-stubs/_lib/_disjoint_set.pyi +++ b/scipy-stubs/_lib/_disjoint_set.pyi @@ -1,19 +1,23 @@ -from collections.abc import Iterator -from typing import Generic +from collections.abc import Iterable, Iterator +from typing import Any, Generic from typing_extensions import TypeVar +import numpy as np import optype as op -_T = TypeVar("_T", bound=op.CanHash, default=object) +# Only the existence of `__hash__` is required. However, in numpy < 2.1 the +# `__hash__` method is missing from numpy stubs on scalar values. Allowing +# `np.generic` fixes this for older numpy versions. +_T = TypeVar("_T", bound=op.CanHash | np.generic, default=Any) class DisjointSet(Generic[_T]): n_subsets: int - def __init__(self, /, elements: _T | None = None) -> None: ... + def __init__(self, /, elements: Iterable[_T] | None = None) -> None: ... def __iter__(self, /) -> Iterator[_T]: ... def __len__(self, /) -> int: ... def __contains__(self, x: object, /) -> bool: ... - def __getitem__(self, x: _T, /) -> int: ... + def __getitem__(self, x: _T, /) -> _T: ... def add(self, /, x: _T) -> None: ... def merge(self, /, x: _T, y: _T) -> bool: ... def connected(self, /, x: _T, y: _T) -> bool: ... diff --git a/tests/cluster/test_hierarchy.pyi b/tests/cluster/test_hierarchy.pyi new file mode 100644 index 00000000..fd7cb462 --- /dev/null +++ b/tests/cluster/test_hierarchy.pyi @@ -0,0 +1,72 @@ +from collections.abc import Iterator +from typing import Any, assert_type + +import numpy as np +import optype.numpy as onp + +from scipy.cluster.hierarchy import DisjointSet + +### +# DisjointSet + +py_str_1d: list[str] +py_int_1d: list[int] + +i32_1d: onp.Array1D[np.int32] +i64_1d: onp.Array1D[np.int64] +# DisjointSet(Iterable[T]) produces a DisjointSet[T] with universal set of type T. +assert_type(DisjointSet(py_str_1d), DisjointSet[str]) +assert_type(DisjointSet(py_int_1d), DisjointSet[int]) +# NOTE: Directly using assert_type fails with numpy arrays for all numpy<=2.0. Instead, use assignment statements. +_10: DisjointSet[np.int32] = DisjointSet(i32_1d) +_11: DisjointSet[np.int64] = DisjointSet(i64_1d) +# DisjointSet() produces a DisjointSet[Any] because T is unbound. +assert_type(DisjointSet(), DisjointSet[Any]) + +disjoint_set_str: DisjointSet[str] +disjoint_set_i64: DisjointSet[np.int64] + +# __iter__ produces an iterator over the universal set. +assert_type(iter(disjoint_set_str), Iterator[str]) +assert_type(iter(disjoint_set_i64), Iterator[np.int64]) + +# __len__ returns the length of the universal set +assert_type(len(disjoint_set_str), int) + +# __contains__ accepts an element of the universal set and returns a boolean +assert_type("a" in disjoint_set_str, bool) +assert_type(np.int64(2) in disjoint_set_i64, bool) + +# __getitem__ returns an element of the universal set +assert_type(disjoint_set_str["a"], str) +disjoint_set_str[1] # type: ignore[index] # pyright: ignore[reportArgumentType] +assert_type(disjoint_set_i64[np.int64(1)], np.int64) + +# add accepts an element of type T and adds it to the data structure (i.e. returns None) +assert_type(disjoint_set_str.add("a"), None) +disjoint_set_str.add(1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] +assert_type(disjoint_set_i64.add(np.int64(1)), None) + +# merge accepts two elements of type T and returns a boolean indicating if they belonged to the same subset +assert_type(disjoint_set_str.merge("a", "b"), bool) +disjoint_set_str.merge(1, 2) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] +assert_type(disjoint_set_i64.merge(np.int64(1), np.int64(2)), bool) + +# connected accepts two elements of type T and returns a boolean indicating if they belonged to the same subset +assert_type(disjoint_set_str.connected("a", "b"), bool) +disjoint_set_str.connected(1, 2) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] +assert_type(disjoint_set_i64.connected(np.int64(1), np.int64(2)), bool) + +# subset accepts one element of type T and returns its containing subset. +assert_type(disjoint_set_str.subset("a"), set[str]) +disjoint_set_str.subset(1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] +assert_type(disjoint_set_i64.subset(np.int64(1)), set[np.int64]) + +# subset_size accepts one element of type T and returns the *size* of its subset. +assert_type(disjoint_set_str.subset_size("a"), int) +disjoint_set_str.subset_size(1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] +assert_type(disjoint_set_i64.subset_size(np.int64(1)), int) + +# subsets returns a list of all subsets of type T +assert_type(disjoint_set_str.subsets(), list[set[str]]) +assert_type(disjoint_set_i64.subsets(), list[set[np.int64]])