Skip to content

Commit 7d3fa21

Browse files
authored
🐛stats.unitary_group.rvs: fix return dtype (#993)
2 parents bb6b6bd + 42ad29c commit 7d3fa21

File tree

2 files changed

+86
-6
lines changed

2 files changed

+86
-6
lines changed

scipy-stubs/stats/_multivariate.pyi

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ __all__ = [
3131
]
3232

3333
_ScalarT = TypeVar("_ScalarT", bound=np.generic, default=np.float64)
34+
_ScalarT_co = TypeVar("_ScalarT_co", bound=np.generic, default=np.float64, covariant=True)
3435

3536
# TODO(@jorenham): rename as {}T_co
3637
_RVG_co = TypeVar("_RVG_co", bound=multi_rv_generic, default=multi_rv_generic, covariant=True)
@@ -363,24 +364,26 @@ class multinomial_frozen(multi_rv_frozen[multinomial_gen]):
363364
def rvs(self, /, size: onp.AtLeast1D | int = 1, random_state: onp.random.ToRNG | None = None) -> _Array2ND: ...
364365

365366
@type_check_only
366-
class _group_rv_gen_mixin(Generic[_RVF_co]):
367+
class _group_rv_gen_mixin(Generic[_RVF_co, _ScalarT_co]):
367368
def __call__(self, /, dim: onp.ToJustInt | None = None, seed: onp.random.ToRNG | None = None) -> _RVF_co: ...
368369
def rvs(
369370
self, /, dim: onp.ToJustInt, size: onp.ToJustInt | None = 1, random_state: onp.random.ToRNG | None = None
370-
) -> onp.Array3D[np.float64]: ...
371+
) -> onp.Array3D[_ScalarT_co]: ...
371372

372373
@type_check_only
373-
class _group_rv_frozen_mixin:
374+
class _group_rv_frozen_mixin(Generic[_ScalarT_co]):
374375
dim: onp.ToJustInt
375376
def __init__(self, /, dim: onp.ToJustInt | None = None, seed: onp.random.ToRNG | None = None) -> None: ...
376-
def rvs(self, /, size: onp.ToJustInt | None = 1, random_state: onp.random.ToRNG | None = None) -> onp.Array3D[np.float64]: ...
377+
def rvs(
378+
self, /, size: onp.ToJustInt | None = 1, random_state: onp.random.ToRNG | None = None
379+
) -> onp.Array3D[_ScalarT_co]: ...
377380

378381
class special_ortho_group_gen(_group_rv_gen_mixin[special_ortho_group_frozen], multi_rv_generic): ...
379382
class special_ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[special_ortho_group_gen]): ...
380383
class ortho_group_gen(_group_rv_gen_mixin[ortho_group_frozen], multi_rv_generic): ...
381384
class ortho_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[ortho_group_gen]): ...
382-
class unitary_group_gen(_group_rv_gen_mixin[unitary_group_frozen], multi_rv_generic): ...
383-
class unitary_group_frozen(_group_rv_frozen_mixin, multi_rv_frozen[unitary_group_gen]): ...
385+
class unitary_group_gen(_group_rv_gen_mixin[unitary_group_frozen, np.complex128], multi_rv_generic): ...
386+
class unitary_group_frozen(_group_rv_frozen_mixin[np.complex128], multi_rv_frozen[unitary_group_gen]): ...
384387

385388
#
386389
class uniform_direction_gen(multi_rv_generic):

tests/stats/test_multivariate.pyi

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from typing import assert_type
2+
3+
import numpy as np
4+
5+
from scipy.stats import (
6+
dirichlet,
7+
invwishart,
8+
matrix_normal,
9+
multinomial,
10+
multivariate_hypergeom,
11+
multivariate_normal,
12+
multivariate_t,
13+
normal_inverse_gamma,
14+
ortho_group,
15+
random_correlation,
16+
random_table,
17+
special_ortho_group,
18+
uniform_direction,
19+
unitary_group,
20+
vonmises_fisher,
21+
wishart,
22+
)
23+
24+
###
25+
# rvs dtype checks
26+
27+
# the pyright ignore comments are needed due to https://github.com/microsoft/pyright/issues/11127
28+
29+
assert_type(multivariate_normal.rvs().dtype, np.dtype[np.float64])
30+
assert_type(multivariate_normal().rvs().dtype, np.dtype[np.float64])
31+
32+
assert_type(matrix_normal.rvs().dtype, np.dtype[np.float64])
33+
assert_type(matrix_normal().rvs().dtype, np.dtype[np.float64])
34+
35+
assert_type(dirichlet.rvs([1, 2]).dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType]
36+
assert_type(dirichlet([1, 2]).rvs().dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType]
37+
38+
assert_type(wishart.rvs(1, 1).dtype, np.dtype[np.float64])
39+
assert_type(wishart().rvs().dtype, np.dtype[np.float64])
40+
41+
assert_type(invwishart.rvs(1, 1).dtype, np.dtype[np.float64])
42+
assert_type(invwishart().rvs().dtype, np.dtype[np.float64])
43+
44+
assert_type(multinomial.rvs([1], [0.5]).dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType]
45+
assert_type(multinomial([1], [0.5]).rvs().dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType]
46+
47+
assert_type(ortho_group.rvs(3).dtype, np.dtype[np.float64])
48+
assert_type(ortho_group().rvs(3).dtype, np.dtype[np.float64])
49+
50+
assert_type(special_ortho_group.rvs(3).dtype, np.dtype[np.float64])
51+
assert_type(special_ortho_group().rvs(3).dtype, np.dtype[np.float64])
52+
53+
assert_type(unitary_group.rvs(3).dtype, np.dtype[np.complex128])
54+
assert_type(unitary_group().rvs(3).dtype, np.dtype[np.complex128])
55+
56+
assert_type(uniform_direction.rvs(3).dtype, np.dtype[np.float64])
57+
assert_type(uniform_direction(1).rvs(3).dtype, np.dtype[np.float64])
58+
59+
assert_type(random_correlation.rvs([1, 1]).dtype, np.dtype[np.float64])
60+
assert_type(random_correlation([1, 1]).rvs().dtype, np.dtype[np.float64])
61+
62+
assert_type(multivariate_t.rvs().dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType]
63+
assert_type(multivariate_t().rvs().dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType]
64+
65+
assert_type(multivariate_hypergeom.rvs([1], 1).dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType]
66+
assert_type(multivariate_hypergeom([1], 1).rvs().dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType]
67+
68+
assert_type(random_table.rvs([1, 2], [2, 1]).dtype, np.dtype[np.float64])
69+
assert_type(random_table([1, 2], [2, 1]).rvs().dtype, np.dtype[np.float64])
70+
71+
# `dirichlet_multinomial` has no `rvs` method
72+
73+
assert_type(vonmises_fisher.rvs([0.8, 0.6]).dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType]
74+
assert_type(vonmises_fisher([0.8, 0.6]).rvs().dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType]
75+
76+
assert_type(normal_inverse_gamma.rvs()[0].dtype, np.dtype[np.float64])
77+
assert_type(normal_inverse_gamma().rvs()[0].dtype, np.dtype[np.float64])

0 commit comments

Comments
 (0)