|
| 1 | +from typing import assert_type |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from scipy.stats import ( |
| 6 | + dirichlet, |
| 7 | + invwishart, |
| 8 | + matrix_normal, |
| 9 | + multinomial, |
| 10 | + multivariate_hypergeom, |
| 11 | + multivariate_normal, |
| 12 | + multivariate_t, |
| 13 | + normal_inverse_gamma, |
| 14 | + ortho_group, |
| 15 | + random_correlation, |
| 16 | + random_table, |
| 17 | + special_ortho_group, |
| 18 | + uniform_direction, |
| 19 | + unitary_group, |
| 20 | + vonmises_fisher, |
| 21 | + wishart, |
| 22 | +) |
| 23 | + |
| 24 | +### |
| 25 | +# rvs dtype checks |
| 26 | + |
| 27 | +# the pyright ignore comments are needed due to https://github.com/microsoft/pyright/issues/11127 |
| 28 | + |
| 29 | +assert_type(multivariate_normal.rvs().dtype, np.dtype[np.float64]) |
| 30 | +assert_type(multivariate_normal().rvs().dtype, np.dtype[np.float64]) |
| 31 | + |
| 32 | +assert_type(matrix_normal.rvs().dtype, np.dtype[np.float64]) |
| 33 | +assert_type(matrix_normal().rvs().dtype, np.dtype[np.float64]) |
| 34 | + |
| 35 | +assert_type(dirichlet.rvs([1, 2]).dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType] |
| 36 | +assert_type(dirichlet([1, 2]).rvs().dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType] |
| 37 | + |
| 38 | +assert_type(wishart.rvs(1, 1).dtype, np.dtype[np.float64]) |
| 39 | +assert_type(wishart().rvs().dtype, np.dtype[np.float64]) |
| 40 | + |
| 41 | +assert_type(invwishart.rvs(1, 1).dtype, np.dtype[np.float64]) |
| 42 | +assert_type(invwishart().rvs().dtype, np.dtype[np.float64]) |
| 43 | + |
| 44 | +assert_type(multinomial.rvs([1], [0.5]).dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType] |
| 45 | +assert_type(multinomial([1], [0.5]).rvs().dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType] |
| 46 | + |
| 47 | +assert_type(ortho_group.rvs(3).dtype, np.dtype[np.float64]) |
| 48 | +assert_type(ortho_group().rvs(3).dtype, np.dtype[np.float64]) |
| 49 | + |
| 50 | +assert_type(special_ortho_group.rvs(3).dtype, np.dtype[np.float64]) |
| 51 | +assert_type(special_ortho_group().rvs(3).dtype, np.dtype[np.float64]) |
| 52 | + |
| 53 | +assert_type(unitary_group.rvs(3).dtype, np.dtype[np.complex128]) |
| 54 | +assert_type(unitary_group().rvs(3).dtype, np.dtype[np.complex128]) |
| 55 | + |
| 56 | +assert_type(uniform_direction.rvs(3).dtype, np.dtype[np.float64]) |
| 57 | +assert_type(uniform_direction(1).rvs(3).dtype, np.dtype[np.float64]) |
| 58 | + |
| 59 | +assert_type(random_correlation.rvs([1, 1]).dtype, np.dtype[np.float64]) |
| 60 | +assert_type(random_correlation([1, 1]).rvs().dtype, np.dtype[np.float64]) |
| 61 | + |
| 62 | +assert_type(multivariate_t.rvs().dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType] |
| 63 | +assert_type(multivariate_t().rvs().dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType] |
| 64 | + |
| 65 | +assert_type(multivariate_hypergeom.rvs([1], 1).dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType] |
| 66 | +assert_type(multivariate_hypergeom([1], 1).rvs().dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType] |
| 67 | + |
| 68 | +assert_type(random_table.rvs([1, 2], [2, 1]).dtype, np.dtype[np.float64]) |
| 69 | +assert_type(random_table([1, 2], [2, 1]).rvs().dtype, np.dtype[np.float64]) |
| 70 | + |
| 71 | +# `dirichlet_multinomial` has no `rvs` method |
| 72 | + |
| 73 | +assert_type(vonmises_fisher.rvs([0.8, 0.6]).dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType] |
| 74 | +assert_type(vonmises_fisher([0.8, 0.6]).rvs().dtype, np.dtype[np.float64]) # pyright: ignore[reportUnknownMemberType] |
| 75 | + |
| 76 | +assert_type(normal_inverse_gamma.rvs()[0].dtype, np.dtype[np.float64]) |
| 77 | +assert_type(normal_inverse_gamma().rvs()[0].dtype, np.dtype[np.float64]) |
0 commit comments