Skip to content

Commit 7433b5e

Browse files
authored
🐛 cluster.hierarchy: use generic return type for DisjointSet.__getitem__ (#992)
1 parent bfa4343 commit 7433b5e

File tree

2 files changed

+81
-5
lines changed

2 files changed

+81
-5
lines changed

scipy-stubs/_lib/_disjoint_set.pyi

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1-
from collections.abc import Iterator
2-
from typing import Generic
1+
from collections.abc import Iterable, Iterator
2+
from typing import Any, Generic
33
from typing_extensions import TypeVar
44

5+
import numpy as np
56
import optype as op
67

7-
_T = TypeVar("_T", bound=op.CanHash, default=object)
8+
# Only the existence of `__hash__` is required. However, in numpy < 2.1 the
9+
# `__hash__` method is missing from numpy stubs on scalar values. Allowing
10+
# `np.generic` fixes this for older numpy versions.
11+
_T = TypeVar("_T", bound=op.CanHash | np.generic, default=Any)
812

913
class DisjointSet(Generic[_T]):
1014
n_subsets: int
1115

12-
def __init__(self, /, elements: _T | None = None) -> None: ...
16+
def __init__(self, /, elements: Iterable[_T] | None = None) -> None: ...
1317
def __iter__(self, /) -> Iterator[_T]: ...
1418
def __len__(self, /) -> int: ...
1519
def __contains__(self, x: object, /) -> bool: ...
16-
def __getitem__(self, x: _T, /) -> int: ...
20+
def __getitem__(self, x: _T, /) -> _T: ...
1721
def add(self, /, x: _T) -> None: ...
1822
def merge(self, /, x: _T, y: _T) -> bool: ...
1923
def connected(self, /, x: _T, y: _T) -> bool: ...

tests/cluster/test_hierarchy.pyi

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from collections.abc import Iterator
2+
from typing import Any, assert_type
3+
4+
import numpy as np
5+
import optype.numpy as onp
6+
7+
from scipy.cluster.hierarchy import DisjointSet
8+
9+
###
10+
# DisjointSet
11+
12+
py_str_1d: list[str]
13+
py_int_1d: list[int]
14+
15+
i32_1d: onp.Array1D[np.int32]
16+
i64_1d: onp.Array1D[np.int64]
17+
# DisjointSet(Iterable[T]) produces a DisjointSet[T] with universal set of type T.
18+
assert_type(DisjointSet(py_str_1d), DisjointSet[str])
19+
assert_type(DisjointSet(py_int_1d), DisjointSet[int])
20+
# NOTE: Directly using assert_type fails with numpy arrays for all numpy<=2.0. Instead, use assignment statements.
21+
_10: DisjointSet[np.int32] = DisjointSet(i32_1d)
22+
_11: DisjointSet[np.int64] = DisjointSet(i64_1d)
23+
# DisjointSet() produces a DisjointSet[Any] because T is unbound.
24+
assert_type(DisjointSet(), DisjointSet[Any])
25+
26+
disjoint_set_str: DisjointSet[str]
27+
disjoint_set_i64: DisjointSet[np.int64]
28+
29+
# __iter__ produces an iterator over the universal set.
30+
assert_type(iter(disjoint_set_str), Iterator[str])
31+
assert_type(iter(disjoint_set_i64), Iterator[np.int64])
32+
33+
# __len__ returns the length of the universal set
34+
assert_type(len(disjoint_set_str), int)
35+
36+
# __contains__ accepts an element of the universal set and returns a boolean
37+
assert_type("a" in disjoint_set_str, bool)
38+
assert_type(np.int64(2) in disjoint_set_i64, bool)
39+
40+
# __getitem__ returns an element of the universal set
41+
assert_type(disjoint_set_str["a"], str)
42+
disjoint_set_str[1] # type: ignore[index] # pyright: ignore[reportArgumentType]
43+
assert_type(disjoint_set_i64[np.int64(1)], np.int64)
44+
45+
# add accepts an element of type T and adds it to the data structure (i.e. returns None)
46+
assert_type(disjoint_set_str.add("a"), None)
47+
disjoint_set_str.add(1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
48+
assert_type(disjoint_set_i64.add(np.int64(1)), None)
49+
50+
# merge accepts two elements of type T and returns a boolean indicating if they belonged to the same subset
51+
assert_type(disjoint_set_str.merge("a", "b"), bool)
52+
disjoint_set_str.merge(1, 2) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
53+
assert_type(disjoint_set_i64.merge(np.int64(1), np.int64(2)), bool)
54+
55+
# connected accepts two elements of type T and returns a boolean indicating if they belonged to the same subset
56+
assert_type(disjoint_set_str.connected("a", "b"), bool)
57+
disjoint_set_str.connected(1, 2) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
58+
assert_type(disjoint_set_i64.connected(np.int64(1), np.int64(2)), bool)
59+
60+
# subset accepts one element of type T and returns its containing subset.
61+
assert_type(disjoint_set_str.subset("a"), set[str])
62+
disjoint_set_str.subset(1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
63+
assert_type(disjoint_set_i64.subset(np.int64(1)), set[np.int64])
64+
65+
# subset_size accepts one element of type T and returns the *size* of its subset.
66+
assert_type(disjoint_set_str.subset_size("a"), int)
67+
disjoint_set_str.subset_size(1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
68+
assert_type(disjoint_set_i64.subset_size(np.int64(1)), int)
69+
70+
# subsets returns a list of all subsets of type T
71+
assert_type(disjoint_set_str.subsets(), list[set[str]])
72+
assert_type(disjoint_set_i64.subsets(), list[set[np.int64]])

0 commit comments

Comments
 (0)