|
| 1 | +from collections.abc import Iterator |
| 2 | +from typing import Any, assert_type |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import optype.numpy as onp |
| 6 | + |
| 7 | +from scipy.cluster.hierarchy import DisjointSet |
| 8 | + |
| 9 | +### |
| 10 | +# DisjointSet |
| 11 | + |
| 12 | +py_str_1d: list[str] |
| 13 | +py_int_1d: list[int] |
| 14 | + |
| 15 | +i32_1d: onp.Array1D[np.int32] |
| 16 | +i64_1d: onp.Array1D[np.int64] |
| 17 | +# DisjointSet(Iterable[T]) produces a DisjointSet[T] with universal set of type T. |
| 18 | +assert_type(DisjointSet(py_str_1d), DisjointSet[str]) |
| 19 | +assert_type(DisjointSet(py_int_1d), DisjointSet[int]) |
| 20 | +# NOTE: Directly using assert_type fails with numpy arrays for all numpy<=2.0. Instead, use assignment statements. |
| 21 | +_10: DisjointSet[np.int32] = DisjointSet(i32_1d) |
| 22 | +_11: DisjointSet[np.int64] = DisjointSet(i64_1d) |
| 23 | +# DisjointSet() produces a DisjointSet[Any] because T is unbound. |
| 24 | +assert_type(DisjointSet(), DisjointSet[Any]) |
| 25 | + |
| 26 | +disjoint_set_str: DisjointSet[str] |
| 27 | +disjoint_set_i64: DisjointSet[np.int64] |
| 28 | + |
| 29 | +# __iter__ produces an iterator over the universal set. |
| 30 | +assert_type(iter(disjoint_set_str), Iterator[str]) |
| 31 | +assert_type(iter(disjoint_set_i64), Iterator[np.int64]) |
| 32 | + |
| 33 | +# __len__ returns the length of the universal set |
| 34 | +assert_type(len(disjoint_set_str), int) |
| 35 | + |
| 36 | +# __contains__ accepts an element of the universal set and returns a boolean |
| 37 | +assert_type("a" in disjoint_set_str, bool) |
| 38 | +assert_type(np.int64(2) in disjoint_set_i64, bool) |
| 39 | + |
| 40 | +# __getitem__ returns an element of the universal set |
| 41 | +assert_type(disjoint_set_str["a"], str) |
| 42 | +disjoint_set_str[1] # type: ignore[index] # pyright: ignore[reportArgumentType] |
| 43 | +assert_type(disjoint_set_i64[np.int64(1)], np.int64) |
| 44 | + |
| 45 | +# add accepts an element of type T and adds it to the data structure (i.e. returns None) |
| 46 | +assert_type(disjoint_set_str.add("a"), None) |
| 47 | +disjoint_set_str.add(1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] |
| 48 | +assert_type(disjoint_set_i64.add(np.int64(1)), None) |
| 49 | + |
| 50 | +# merge accepts two elements of type T and returns a boolean indicating if they belonged to the same subset |
| 51 | +assert_type(disjoint_set_str.merge("a", "b"), bool) |
| 52 | +disjoint_set_str.merge(1, 2) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] |
| 53 | +assert_type(disjoint_set_i64.merge(np.int64(1), np.int64(2)), bool) |
| 54 | + |
| 55 | +# connected accepts two elements of type T and returns a boolean indicating if they belonged to the same subset |
| 56 | +assert_type(disjoint_set_str.connected("a", "b"), bool) |
| 57 | +disjoint_set_str.connected(1, 2) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] |
| 58 | +assert_type(disjoint_set_i64.connected(np.int64(1), np.int64(2)), bool) |
| 59 | + |
| 60 | +# subset accepts one element of type T and returns its containing subset. |
| 61 | +assert_type(disjoint_set_str.subset("a"), set[str]) |
| 62 | +disjoint_set_str.subset(1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] |
| 63 | +assert_type(disjoint_set_i64.subset(np.int64(1)), set[np.int64]) |
| 64 | + |
| 65 | +# subset_size accepts one element of type T and returns the *size* of its subset. |
| 66 | +assert_type(disjoint_set_str.subset_size("a"), int) |
| 67 | +disjoint_set_str.subset_size(1) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] |
| 68 | +assert_type(disjoint_set_i64.subset_size(np.int64(1)), int) |
| 69 | + |
| 70 | +# subsets returns a list of all subsets of type T |
| 71 | +assert_type(disjoint_set_str.subsets(), list[set[str]]) |
| 72 | +assert_type(disjoint_set_i64.subsets(), list[set[np.int64]]) |
0 commit comments