From f75dbf1d511cd2984d0d021e9f6404ce8ecc550a Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 3 Jul 2025 01:14:04 +0800 Subject: [PATCH 1/4] Refactor QR --- pytensor/tensor/nlinalg.py | 171 ---------------- pytensor/tensor/slinalg.py | 380 ++++++++++++++++++++++++++++++++++- tests/tensor/test_nlinalg.py | 98 --------- tests/tensor/test_slinalg.py | 104 ++++++++++ 4 files changed, 481 insertions(+), 272 deletions(-) diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 8fff2a2f59..74c985e1e6 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -5,15 +5,12 @@ import numpy as np -import pytensor.tensor as pt from pytensor import scalar as ps from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op -from pytensor.ifelse import ifelse from pytensor.npy_2_compat import normalize_axis_tuple -from pytensor.raise_op import Assert from pytensor.tensor import TensorLike from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm @@ -468,173 +465,6 @@ def eigh(a, UPLO="L"): return Eigh(UPLO)(a) -class QRFull(Op): - """ - Full QR Decomposition. - - Computes the QR decomposition of a matrix. - Factor the matrix a as qr, where q is orthonormal - and r is upper-triangular. - - """ - - __props__ = ("mode",) - - def __init__(self, mode): - self.mode = mode - - def make_node(self, x): - x = as_tensor_variable(x) - - assert x.ndim == 2, "The input of qr function should be a matrix." - - in_dtype = x.type.numpy_dtype - out_dtype = np.dtype(f"f{in_dtype.itemsize}") - - q = matrix(dtype=out_dtype) - - if self.mode != "raw": - r = matrix(dtype=out_dtype) - else: - r = vector(dtype=out_dtype) - - if self.mode != "r": - q = matrix(dtype=out_dtype) - outputs = [q, r] - else: - outputs = [r] - - return Apply(self, [x], outputs) - - def perform(self, node, inputs, outputs): - (x,) = inputs - assert x.ndim == 2, "The input of qr function should be a matrix." - res = np.linalg.qr(x, self.mode) - if self.mode != "r": - outputs[0][0], outputs[1][0] = res - else: - outputs[0][0] = res - - def L_op(self, inputs, outputs, output_grads): - """ - Reverse-mode gradient of the QR function. - - References - ---------- - .. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/ - .. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2 - """ - - from pytensor.tensor.slinalg import solve_triangular - - (A,) = (cast(ptb.TensorVariable, x) for x in inputs) - m, n = A.shape - - def _H(x: ptb.TensorVariable): - return x.conj().mT - - def _copyltu(x: ptb.TensorVariable): - return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1)) - - if self.mode == "raw": - raise NotImplementedError("Gradient of qr not implemented for mode=raw") - - elif self.mode == "r": - # We need all the components of the QR to compute the gradient of A even if we only - # use the upper triangular component in the cost function. - Q, R = qr(A, mode="reduced") - dQ = Q.zeros_like() - dR = cast(ptb.TensorVariable, output_grads[0]) - - else: - Q, R = (cast(ptb.TensorVariable, x) for x in outputs) - if self.mode == "complete": - qr_assert_op = Assert( - "Gradient of qr not implemented for m x n matrices with m > n and mode=complete" - ) - R = qr_assert_op(R, ptm.le(m, n)) - - new_output_grads = [] - is_disconnected = [ - isinstance(x.type, DisconnectedType) for x in output_grads - ] - if all(is_disconnected): - # This should never be reached by Pytensor - return [DisconnectedType()()] # pragma: no cover - - for disconnected, output_grad, output in zip( - is_disconnected, output_grads, [Q, R], strict=True - ): - if disconnected: - new_output_grads.append(output.zeros_like()) - else: - new_output_grads.append(output_grad) - - (dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads) - - # gradient expression when m >= n - M = R @ _H(dR) - _H(dQ) @ Q - K = dQ + Q @ _copyltu(M) - A_bar_m_ge_n = _H(solve_triangular(R, _H(K))) - - # gradient expression when m < n - Y = A[:, m:] - U = R[:, :m] - dU, dV = dR[:, :m], dR[:, m:] - dQ_Yt_dV = dQ + Y @ _H(dV) - M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q - X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M)))) - Y_bar = Q @ dV - A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1) - - return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)] - - -def qr(a, mode="reduced"): - """ - Computes the QR decomposition of a matrix. - Factor the matrix a as qr, where q - is orthonormal and r is upper-triangular. - - Parameters - ---------- - a : array_like, shape (M, N) - Matrix to be factored. - - mode : {'reduced', 'complete', 'r', 'raw'}, optional - If K = min(M, N), then - - 'reduced' - returns q, r with dimensions (M, K), (K, N) - - 'complete' - returns q, r with dimensions (M, M), (M, N) - - 'r' - returns r only with dimensions (K, N) - - 'raw' - returns h, tau with dimensions (N, M), (K,) - - Note that array h returned in 'raw' mode is - transposed for calling Fortran. - - Default mode is 'reduced' - - Returns - ------- - q : matrix of float or complex, optional - A matrix with orthonormal columns. When mode = 'complete' the - result is an orthogonal/unitary matrix depending on whether or - not a is real/complex. The determinant may be either +/- 1 in - that case. - r : matrix of float or complex, optional - The upper-triangular matrix. - - """ - return QRFull(mode)(a) - - class SVD(Op): """ Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V @@ -1291,7 +1121,6 @@ def kron(a, b): "det", "eig", "eigh", - "qr", "svd", "lstsq", "matrix_power", diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 946abbb0d6..68d056fdc0 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -7,16 +7,19 @@ import numpy as np import scipy.linalg as scipy_linalg from numpy.exceptions import ComplexWarning +from scipy.linalg import get_lapack_funcs import pytensor -import pytensor.tensor as pt +from pytensor import ifelse +from pytensor import tensor as pt from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op -from pytensor.tensor import TensorLike, as_tensor_variable +from pytensor.raise_op import Assert +from pytensor.tensor import TensorLike from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm -from pytensor.tensor.basic import diagonal +from pytensor.tensor.basic import as_tensor_variable, diagonal from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.nlinalg import kron, matrix_dot from pytensor.tensor.shape import reshape @@ -1714,6 +1717,376 @@ def block_diag(*matrices: TensorVariable): return _block_diagonal_matrix(*matrices) +class QR(Op): + """ + QR Decomposition + """ + + __props__ = ( + "overwrite_a", + "mode", + "pivoting", + "check_finite", + ) + + def __init__( + self, + mode: Literal["full", "r", "economic", "raw"] = "full", + overwrite_a: bool = False, + pivoting: bool = False, + check_finite: bool = False, + ): + self.mode = mode + self.overwrite_a = overwrite_a + self.pivoting = pivoting + self.check_finite = check_finite + + self.destroy_map = {} + + if overwrite_a: + self.destroy_map = {0: [0]} + + match self.mode: + case "economic": + self.gufunc_signature = "(m,n)->(m,k),(k,n)" + case "full": + self.gufunc_signature = "(m,n)->(m,m),(m,n)" + case "r": + self.gufunc_signature = "(m,n)->(m,n)" + case "raw": + self.gufunc_signature = "(m,n)->(n,m),(k),(m,n)" + case _: + raise ValueError( + f"Invalid mode '{mode}'. Supported modes are 'full', 'economic', 'r', and 'raw'." + ) + + if pivoting: + self.gufunc_signature += ",(n)" + + def make_node(self, x): + x = as_tensor_variable(x) + + assert x.ndim == 2, "The input of qr function should be a matrix." + + # Preserve static shape information if possible + M, N = x.type.shape + if M is not None and N is not None: + K = min(M, N) + else: + K = None + + in_dtype = x.type.numpy_dtype + out_dtype = np.dtype(f"f{in_dtype.itemsize}") + + match self.mode: + case "full": + outputs = [ + tensor(shape=(M, M), dtype=out_dtype), + tensor(shape=(M, N), dtype=out_dtype), + ] + case "economic": + outputs = [ + tensor(shape=(M, K), dtype=out_dtype), + tensor(shape=(K, N), dtype=out_dtype), + ] + case "r": + outputs = [ + tensor(shape=(M, N), dtype=out_dtype), + ] + case "raw": + outputs = [ + tensor(shape=(M, M), dtype=out_dtype), + tensor(shape=(K,), dtype=out_dtype), + tensor(shape=(M, N), dtype=out_dtype), + ] + case _: + raise NotImplementedError + + if self.pivoting: + outputs = [*outputs, tensor(shape=(N,), dtype="int32")] + + return Apply(self, [x], outputs) + + def infer_shape(self, fgraph, node, shapes): + (x_shape,) = shapes + + M, N = x_shape + K = ptm.minimum(M, N) + + Q_shape = None + R_shape = None + tau_shape = None + P_shape = None + + match self.mode: + case "full": + Q_shape = (M, M) + R_shape = (M, N) + case "economic": + Q_shape = (M, K) + R_shape = (K, N) + case "r": + R_shape = (M, N) + case "raw": + Q_shape = (M, M) # Actually this is H in this case + tau_shape = (K,) + R_shape = (M, N) + + if self.pivoting: + P_shape = (N,) + + return [ + shape + for shape in (Q_shape, tau_shape, R_shape, P_shape) + if shape is not None + ] + + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + if not allowed_inplace_inputs: + return self + new_props = self._props_dict() # type: ignore + new_props["overwrite_a"] = True + return type(self)(**new_props) + + def _call_and_get_lwork(self, fn, *args, lwork, **kwargs): + if lwork in [-1, None]: + *_, work, info = fn(*args, lwork=-1, **kwargs) + lwork = work.item() + + return fn(*args, lwork=lwork, **kwargs) + + def perform(self, node, inputs, outputs): + (x,) = inputs + M, N = x.shape + + if self.pivoting: + (geqp3,) = get_lapack_funcs(("geqp3",), (x,)) + qr, jpvt, tau, *work_info = self._call_and_get_lwork( + geqp3, x, lwork=-1, overwrite_a=self.overwrite_a + ) + jpvt -= 1 # geqp3 returns a 1-based index array, so subtract 1 + else: + (geqrf,) = get_lapack_funcs(("geqrf",), (x,)) + qr, tau, *work_info = self._call_and_get_lwork( + geqrf, x, lwork=-1, overwrite_a=self.overwrite_a + ) + + if self.mode not in ["economic", "raw"] or M < N: + R = np.triu(qr) + else: + R = np.triu(qr[:N, :]) + + if self.mode == "r" and self.pivoting: + outputs[0][0] = R + outputs[1][0] = jpvt + return + + elif self.mode == "r": + outputs[0][0] = R + return + + elif self.mode == "raw" and self.pivoting: + outputs[0][0] = qr + outputs[1][0] = tau + outputs[2][0] = R + outputs[3][0] = jpvt + return + + elif self.mode == "raw": + outputs[0][0] = qr + outputs[1][0] = tau + outputs[2][0] = R + return + + (gor_un_gqr,) = get_lapack_funcs(("orgqr",), (qr,)) + + if M < N: + Q, work, info = self._call_and_get_lwork( + gor_un_gqr, qr[:, :M], tau, lwork=-1, overwrite_a=1 + ) + elif self.mode == "economic": + Q, work, info = self._call_and_get_lwork( + gor_un_gqr, qr, tau, lwork=-1, overwrite_a=1 + ) + else: + t = qr.dtype.char + qqr = np.empty((M, M), dtype=t) + qqr[:, :N] = qr + + # Always overwite qqr -- it's a meaningless intermediate value + Q, work, info = self._call_and_get_lwork( + gor_un_gqr, qqr, tau, lwork=-1, overwrite_a=1 + ) + + outputs[0][0] = Q + outputs[1][0] = R + + if self.pivoting: + outputs[2][0] = jpvt + + def L_op(self, inputs, outputs, output_grads): + """ + Reverse-mode gradient of the QR function. + + References + ---------- + .. [1] Jinguo Liu. "Linear Algebra Autodiff (complex valued)", blog post https://giggleliu.github.io/posts/2019-04-02-einsumbp/ + .. [2] Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang. "Differentiable Programming Tensor Networks", arXiv:1903.09650v2 + """ + + from pytensor.tensor.slinalg import solve_triangular + + (A,) = (cast(ptb.TensorVariable, x) for x in inputs) + m, n = A.shape + + # Check if we have static shape info, if so we can get a better graph (avoiding the ifelse Op in the output) + M_static, N_static = A.type.shape + shapes_unknown = M_static is None or N_static is None + + def _H(x: ptb.TensorVariable): + return x.conj().mT + + def _copyltu(x: ptb.TensorVariable): + return ptb.tril(x, k=0) + _H(ptb.tril(x, k=-1)) + + if self.mode == "raw": + raise NotImplementedError("Gradient of qr not implemented for mode=raw") + + elif self.mode == "r": + k = pt.minimum(m, n) + + # We need all the components of the QR to compute the gradient of A even if we only + # use the upper triangular component in the cost function. + props_dict = self._props_dict() + props_dict["mode"] = "economic" + props_dict["pivoting"] = False + + qr_op = type(self)(**props_dict) + + Q, R = qr_op(A) + dQ = Q.zeros_like() + + # Unlike numpy.linalg.qr, scipy.linalg.qr returns the full (m,n) matrix when mode='r', *not* the (k,n) + # matrix that is computed by mode='economic'. The gradient assumes that dR is of shape (k,n), so we need to + # slice it to the first k rows. Note that if m <= n, then k = m, so this is safe in all cases. + dR = cast(ptb.TensorVariable, output_grads[0][:k, :]) + + else: + Q, R = (cast(ptb.TensorVariable, x) for x in outputs) + if self.mode == "full": + qr_assert_op = Assert( + "Gradient of qr not implemented for m x n matrices with m > n and mode=full" + ) + R = qr_assert_op(R, ptm.le(m, n)) + + new_output_grads = [] + is_disconnected = [ + isinstance(x.type, DisconnectedType) for x in output_grads + ] + if all(is_disconnected): + # This should never be reached by Pytensor + return [DisconnectedType()()] # pragma: no cover + + for disconnected, output_grad, output in zip( + is_disconnected, output_grads, [Q, R], strict=True + ): + if disconnected: + new_output_grads.append(output.zeros_like()) + else: + new_output_grads.append(output_grad) + + (dQ, dR) = (cast(ptb.TensorVariable, x) for x in new_output_grads) + + if shapes_unknown or M_static >= N_static: + # gradient expression when m >= n + M = R @ _H(dR) - _H(dQ) @ Q + K = dQ + Q @ _copyltu(M) + A_bar_m_ge_n = _H(solve_triangular(R, _H(K))) + + if not shapes_unknown: + return [A_bar_m_ge_n] + + # We have to trigger both branches if shapes_unknown is True, so this is purposefully not an elif branch + if shapes_unknown or M_static < N_static: + # gradient expression when m < n + Y = A[:, m:] + U = R[:, :m] + dU, dV = dR[:, :m], dR[:, m:] + dQ_Yt_dV = dQ + Y @ _H(dV) + M = U @ _H(dU) - _H(dQ_Yt_dV) @ Q + X_bar = _H(solve_triangular(U, _H(dQ_Yt_dV + Q @ _copyltu(M)))) + Y_bar = Q @ dV + A_bar_m_lt_n = pt.concatenate([X_bar, Y_bar], axis=1) + + if not shapes_unknown: + return [A_bar_m_lt_n] + + return [ifelse(ptm.ge(m, n), A_bar_m_ge_n, A_bar_m_lt_n)] + + +def qr( + A: TensorLike, + mode: Literal["full", "r", "economic", "raw", "complete", "reduced"] = "full", + overwrite_a: bool = False, + pivoting: bool = False, + lwork: int | None = None, +): + """ + QR Decomposition of input matrix `a`. + + The QR decomposition of a matrix `A` is a factorization of the form :math`A = QR`, where `Q` is an orthogonal + matrix (:math:`Q Q^T = I`) and `R` is an upper triangular matrix. + + This decomposition is useful in various numerical methods, including solving linear systems and least squares + problems. + + Parameters + ---------- + A: TensorLike + Input matrix of shape (M, N) to be decomposed. + + mode: str, one of "full", "economic", "r", or "raw" + How the QR decomposition is computed and returned. Choosing the mode can avoid unnecessary computations, + depending on which of the return matrices are needed. Given input matrix with shape Choices are: + + - "full" (or "complete"): returns `Q` and `R` with dimensions `(M, M)` and `(M, N)`. + - "economic" (or "reduced"): returns `Q` and `R` with dimensions `(M, K)` and `(K, N)`, + where `K = min(M, N)`. + - "r": returns only `R` with dimensions `(K, N)`. + - "raw": returns `H` and `tau` with dimensions `(N, M)` and `(K,)`, where `H` is the matrix of + Householder reflections, and tau is the vector of Householder coefficients. + + pivoting: bool, default False + If True, also return a vector of rank-revealing permutations `P` such that `A[:, P] = QR`. + + overwrite_a: bool, ignored + Ignored. Included only for consistency with the function signature of `scipy.linalg.qr`. Pytensor will always + automatically overwrite the input matrix `A` if it is safe to do sol. + + lwork: int, ignored + Ignored. Included only for consistency with the function signature of `scipy.linalg.qr`. Pytensor will + automatically determine the optimal workspace size for the QR decomposition. + + Returns + ------- + Q or H: TensorVariable, optional + A matrix with orthonormal columns. When mode = 'complete', it is the result is an orthogonal/unitary matrix + depending on whether a is real/complex. The determinant may be either +/- 1 in that case. If + mode = 'raw', it is the matrix of Householder reflections. If mode = 'r', Q is not returned. + + R or tau : TensorVariable, optional + Upper-triangular matrix. If mode = 'raw', it is the vector of Householder coefficients. + + """ + # backwards compatibility from the numpy API + if mode == "complete": + mode = "full" + elif mode == "reduced": + mode = "economic" + + return Blockwise(QR(mode=mode, pivoting=pivoting, overwrite_a=False))(A) + + __all__ = [ "cholesky", "solve", @@ -1728,4 +2101,5 @@ def block_diag(*matrices: TensorVariable): "lu", "lu_factor", "lu_solve", + "qr", ] diff --git a/tests/tensor/test_nlinalg.py b/tests/tensor/test_nlinalg.py index c8ae3ac4cb..cc4ed99a93 100644 --- a/tests/tensor/test_nlinalg.py +++ b/tests/tensor/test_nlinalg.py @@ -1,7 +1,6 @@ from functools import partial import numpy as np -import numpy.linalg import pytest from numpy.testing import assert_array_almost_equal @@ -25,7 +24,6 @@ matrix_power, norm, pinv, - qr, slogdet, svd, tensorinv, @@ -122,102 +120,6 @@ def test_matrix_dot(): assert _allclose(numpy_sol, pytensor_sol) -def test_qr_modes(): - rng = np.random.default_rng(utt.fetch_seed()) - - A = matrix("A", dtype=config.floatX) - a = rng.random((4, 4)).astype(config.floatX) - - f = function([A], qr(A)) - t_qr = f(a) - n_qr = np.linalg.qr(a) - assert _allclose(n_qr, t_qr) - - for mode in ["reduced", "r", "raw"]: - f = function([A], qr(A, mode)) - t_qr = f(a) - n_qr = np.linalg.qr(a, mode) - if isinstance(n_qr, list | tuple): - assert _allclose(n_qr[0], t_qr[0]) - assert _allclose(n_qr[1], t_qr[1]) - else: - assert _allclose(n_qr, t_qr) - - try: - n_qr = np.linalg.qr(a, "complete") - f = function([A], qr(A, "complete")) - t_qr = f(a) - assert _allclose(n_qr, t_qr) - except TypeError as e: - assert "name 'complete' is not defined" in str(e) - - -@pytest.mark.parametrize( - "shape, gradient_test_case, mode", - ( - [(s, c, "reduced") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]] - + [(s, c, "complete") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]] - + [(s, 0, "r") for s in [(3, 3), (6, 3), (3, 6)]] - + [((3, 3), 0, "raw")] - ), - ids=( - [ - f"shape={s}, gradient_test_case={c}, mode=reduced" - for s in [(3, 3), (6, 3), (3, 6)] - for c in ["Q", "R", "both"] - ] - + [ - f"shape={s}, gradient_test_case={c}, mode=complete" - for s in [(3, 3), (6, 3), (3, 6)] - for c in ["Q", "R", "both"] - ] - + [f"shape={s}, gradient_test_case=R, mode=r" for s in [(3, 3), (6, 3), (3, 6)]] - + ["shape=(3, 3), gradient_test_case=Q, mode=raw"] - ), -) -@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"]) -def test_qr_grad(shape, gradient_test_case, mode, is_complex): - rng = np.random.default_rng(utt.fetch_seed()) - - def _test_fn(x, case=2, mode="reduced"): - if case == 0: - return qr(x, mode=mode)[0].sum() - elif case == 1: - return qr(x, mode=mode)[1].sum() - elif case == 2: - Q, R = qr(x, mode=mode) - return Q.sum() + R.sum() - - if is_complex: - pytest.xfail("Complex inputs currently not supported by verify_grad") - - m, n = shape - a = rng.standard_normal(shape).astype(config.floatX) - if is_complex: - a += 1j * rng.standard_normal(shape).astype(config.floatX) - - if mode == "raw": - with pytest.raises(NotImplementedError): - utt.verify_grad( - partial(_test_fn, case=gradient_test_case, mode=mode), - [a], - rng=np.random, - ) - - elif mode == "complete" and m > n: - with pytest.raises(AssertionError): - utt.verify_grad( - partial(_test_fn, case=gradient_test_case, mode=mode), - [a], - rng=np.random, - ) - - else: - utt.verify_grad( - partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random - ) - - class TestSvd(utt.InferShapeTester): op_class = SVD diff --git a/tests/tensor/test_slinalg.py b/tests/tensor/test_slinalg.py index 8b48c33a3c..a82307a612 100644 --- a/tests/tensor/test_slinalg.py +++ b/tests/tensor/test_slinalg.py @@ -1,10 +1,12 @@ import functools import itertools +from functools import partial from typing import Literal import numpy as np import pytest import scipy +from scipy import linalg as scipy_linalg from pytensor import function, grad from pytensor import tensor as pt @@ -26,6 +28,7 @@ lu_factor, lu_solve, pivot_to_permutation, + qr, solve, solve_continuous_lyapunov, solve_discrete_are, @@ -1088,3 +1091,104 @@ def test_block_diagonal_blockwise(): B = np.random.normal(size=(1, batch_size, 4, 4)).astype(config.floatX) result = block_diag(A, B).eval() assert result.shape == (10, batch_size, 6, 6) + + +@pytest.mark.parametrize( + "mode, names", + [ + ("economic", ["Q", "R"]), + ("full", ["Q", "R"]), + ("r", ["R"]), + ("raw", ["H", "tau", "R"]), + ], +) +@pytest.mark.parametrize("pivoting", [True, False]) +def test_qr_modes(mode, names, pivoting): + rng = np.random.default_rng(utt.fetch_seed()) + A_val = rng.random((4, 4)).astype(config.floatX) + + if pivoting: + names = [*names, "pivots"] + + A = tensor("A", dtype=config.floatX, shape=(None, None)) + + f = function([A], qr(A, mode=mode, pivoting=pivoting)) + + outputs_pt = f(A_val) + outputs_sp = scipy_linalg.qr(A_val, mode=mode, pivoting=pivoting) + + if mode == "raw": + # The first output of scipy's qr is a tuple when mode is raw; flatten it for easier iteration + outputs_sp = (*outputs_sp[0], *outputs_sp[1:]) + elif mode == "r" and not pivoting: + # Here there's only one output from the pytensor function; wrap it in a list for iteration + outputs_pt = [outputs_pt] + + for out_pt, out_sp, name in zip(outputs_pt, outputs_sp, names): + np.testing.assert_allclose(out_pt, out_sp, err_msg=f"{name} disagrees") + + +@pytest.mark.parametrize( + "shape, gradient_test_case, mode", + ( + [(s, c, "economic") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]] + + [(s, c, "full") for s in [(3, 3), (6, 3), (3, 6)] for c in [0, 1, 2]] + + [(s, 0, "r") for s in [(3, 3), (6, 3), (3, 6)]] + + [((3, 3), 0, "raw")] + ), + ids=( + [ + f"shape={s}, gradient_test_case={c}, mode=economic" + for s in [(3, 3), (6, 3), (3, 6)] + for c in ["Q", "R", "both"] + ] + + [ + f"shape={s}, gradient_test_case={c}, mode=full" + for s in [(3, 3), (6, 3), (3, 6)] + for c in ["Q", "R", "both"] + ] + + [f"shape={s}, gradient_test_case=R, mode=r" for s in [(3, 3), (6, 3), (3, 6)]] + + ["shape=(3, 3), gradient_test_case=Q, mode=raw"] + ), +) +@pytest.mark.parametrize("is_complex", [True, False], ids=["complex", "real"]) +def test_qr_grad(shape, gradient_test_case, mode, is_complex): + rng = np.random.default_rng(utt.fetch_seed()) + + def _test_fn(x, case=2, mode="reduced"): + if case == 0: + return qr(x, mode=mode)[0].sum() + elif case == 1: + return qr(x, mode=mode)[1].sum() + elif case == 2: + Q, R = qr(x, mode=mode) + return Q.sum() + R.sum() + + if is_complex: + pytest.xfail("Complex inputs currently not supported by verify_grad") + + m, n = shape + a = rng.standard_normal(shape).astype(config.floatX) + if is_complex: + a += 1j * rng.standard_normal(shape).astype(config.floatX) + + if mode == "raw": + with pytest.raises(NotImplementedError): + utt.verify_grad( + partial(_test_fn, case=gradient_test_case, mode=mode), + [a], + rng=np.random, + ) + + elif mode == "full" and m > n: + with pytest.raises(AssertionError): + utt.verify_grad( + partial(_test_fn, case=gradient_test_case, mode=mode), + [a], + rng=np.random, + ) + + else: + utt.verify_grad( + partial(_test_fn, case=gradient_test_case, mode=mode), [a], rng=np.random + ) From f48fe8ce039cee25763f2ad5dc0d5718b3bff11c Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 3 Jul 2025 08:42:37 +0800 Subject: [PATCH 2/4] Update JAX QR dispatch --- pytensor/link/jax/dispatch/nlinalg.py | 11 ----------- pytensor/link/jax/dispatch/slinalg.py | 11 +++++++++++ tests/link/jax/test_nlinalg.py | 6 ------ tests/link/jax/test_slinalg.py | 12 ++++++++++++ 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index 8b6fc62f2a..38690c7c03 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -9,7 +9,6 @@ KroneckerProduct, MatrixInverse, MatrixPinv, - QRFull, SLogDet, ) @@ -67,16 +66,6 @@ def matrix_inverse(x): return matrix_inverse -@jax_funcify.register(QRFull) -def jax_funcify_QRFull(op, **kwargs): - mode = op.mode - - def qr_full(x, mode=mode): - return jnp.linalg.qr(x, mode=mode) - - return qr_full - - @jax_funcify.register(MatrixPinv) def jax_funcify_Pinv(op, **kwargs): def pinv(x): diff --git a/pytensor/link/jax/dispatch/slinalg.py b/pytensor/link/jax/dispatch/slinalg.py index 4448e14f99..38803f11b5 100644 --- a/pytensor/link/jax/dispatch/slinalg.py +++ b/pytensor/link/jax/dispatch/slinalg.py @@ -5,6 +5,7 @@ from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.tensor.slinalg import ( LU, + QR, BlockDiagonal, Cholesky, CholeskySolve, @@ -168,3 +169,13 @@ def cho_solve(c, b): ) return cho_solve + + +@jax_funcify.register(QR) +def jax_funcify_QR(op, **kwargs): + mode = op.mode + + def qr(x, mode=mode): + return jax.scipy.linalg.qr(x, mode=mode) + + return qr diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 866d99ce71..18c1f36919 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -29,12 +29,6 @@ def assert_fn(x, y): outs = pt_nlinalg.eigh(x) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) - outs = pt_nlinalg.qr(x, mode="full") - compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) - - outs = pt_nlinalg.qr(x, mode="reduced") - compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) - outs = pt_nlinalg.svd(x) compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py index 513ee2fa49..36354764d9 100644 --- a/tests/link/jax/test_slinalg.py +++ b/tests/link/jax/test_slinalg.py @@ -103,6 +103,18 @@ def test_jax_basic(): ], ) + def assert_fn(x, y): + np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3) + + M = rng.normal(size=(3, 3)) + X = M.dot(M.T) + + outs = pt_slinalg.qr(x, mode="full") + compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) + + outs = pt_slinalg.qr(x, mode="economic") + compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn) + def test_jax_solve(): rng = np.random.default_rng(utt.fetch_seed()) From ac2a8208ba3cd760adfcf1b78d2b4239cf9a3dc7 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 3 Jul 2025 08:43:10 +0800 Subject: [PATCH 3/4] Update Torch QR dispatch --- pytensor/link/pytorch/dispatch/__init__.py | 1 + pytensor/link/pytorch/dispatch/nlinalg.py | 16 ------------- pytensor/link/pytorch/dispatch/slinalg.py | 23 ++++++++++++++++++ tests/link/pytorch/conftest.py | 16 +++++++++++++ tests/link/pytorch/test_nlinalg.py | 27 ---------------------- tests/link/pytorch/test_slinalg.py | 20 ++++++++++++++++ 6 files changed, 60 insertions(+), 43 deletions(-) create mode 100644 pytensor/link/pytorch/dispatch/slinalg.py create mode 100644 tests/link/pytorch/conftest.py create mode 100644 tests/link/pytorch/test_slinalg.py diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index 4caabf3e03..f46e35a46e 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -8,6 +8,7 @@ import pytensor.link.pytorch.dispatch.math import pytensor.link.pytorch.dispatch.extra_ops import pytensor.link.pytorch.dispatch.nlinalg +import pytensor.link.pytorch.dispatch.slinalg import pytensor.link.pytorch.dispatch.shape import pytensor.link.pytorch.dispatch.sort import pytensor.link.pytorch.dispatch.subtensor diff --git a/pytensor/link/pytorch/dispatch/nlinalg.py b/pytensor/link/pytorch/dispatch/nlinalg.py index 91690489e9..c4a03406e6 100644 --- a/pytensor/link/pytorch/dispatch/nlinalg.py +++ b/pytensor/link/pytorch/dispatch/nlinalg.py @@ -9,7 +9,6 @@ KroneckerProduct, MatrixInverse, MatrixPinv, - QRFull, SLogDet, ) @@ -70,21 +69,6 @@ def matrix_inverse(x): return matrix_inverse -@pytorch_funcify.register(QRFull) -def pytorch_funcify_QRFull(op, **kwargs): - mode = op.mode - if mode == "raw": - raise NotImplementedError("raw mode not implemented in PyTorch") - - def qr_full(x): - Q, R = torch.linalg.qr(x, mode=mode) - if mode == "r": - return R - return Q, R - - return qr_full - - @pytorch_funcify.register(MatrixPinv) def pytorch_funcify_Pinv(op, **kwargs): hermitian = op.hermitian diff --git a/pytensor/link/pytorch/dispatch/slinalg.py b/pytensor/link/pytorch/dispatch/slinalg.py new file mode 100644 index 0000000000..b49d281993 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/slinalg.py @@ -0,0 +1,23 @@ +import torch + +from pytensor.link.pytorch.dispatch import pytorch_funcify +from pytensor.tensor.slinalg import QR + + +@pytorch_funcify.register(QR) +def pytorch_funcify_QR(op, **kwargs): + mode = op.mode + if mode == "raw": + raise NotImplementedError("raw mode not implemented in PyTorch") + elif mode == "full": + mode = "complete" + elif mode == "economic": + mode = "reduced" + + def qr(x): + Q, R = torch.linalg.qr(x, mode=mode) + if mode == "r": + return R + return Q, R + + return qr diff --git a/tests/link/pytorch/conftest.py b/tests/link/pytorch/conftest.py new file mode 100644 index 0000000000..0d128a1d3a --- /dev/null +++ b/tests/link/pytorch/conftest.py @@ -0,0 +1,16 @@ +import numpy as np +import pytest + +from pytensor import config +from pytensor.tensor.type import matrix + + +@pytest.fixture +def matrix_test(): + rng = np.random.default_rng(213234) + + M = rng.normal(size=(3, 3)) + test_value = M.dot(M.T).astype(config.floatX) + + x = matrix("x") + return x, test_value diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py index 7e061f7cfc..58b27b4a2b 100644 --- a/tests/link/pytorch/test_nlinalg.py +++ b/tests/link/pytorch/test_nlinalg.py @@ -8,17 +8,6 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py -@pytest.fixture -def matrix_test(): - rng = np.random.default_rng(213234) - - M = rng.normal(size=(3, 3)) - test_value = M.dot(M.T).astype(config.floatX) - - x = matrix("x") - return (x, test_value) - - @pytest.mark.parametrize( "func", (pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det), @@ -34,22 +23,6 @@ def assert_fn(x, y): compare_pytorch_and_py([x], outs, [test_value], assert_fn=assert_fn) -@pytest.mark.parametrize( - "mode", - ( - "complete", - "reduced", - "r", - pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)), - ), -) -def test_qr(mode, matrix_test): - x, test_value = matrix_test - outs = pt_nla.qr(x, mode=mode) - - compare_pytorch_and_py([x], outs, [test_value]) - - @pytest.mark.parametrize("compute_uv", [True, False]) @pytest.mark.parametrize("full_matrices", [True, False]) def test_svd(compute_uv, full_matrices, matrix_test): diff --git a/tests/link/pytorch/test_slinalg.py b/tests/link/pytorch/test_slinalg.py new file mode 100644 index 0000000000..bf1d5e0a7b --- /dev/null +++ b/tests/link/pytorch/test_slinalg.py @@ -0,0 +1,20 @@ +import pytest + +import pytensor +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +@pytest.mark.parametrize( + "mode", + ( + "complete", + "reduced", + "r", + pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)), + ), +) +def test_qr(mode, matrix_test): + x, test_value = matrix_test + outs = pytensor.tensor.slinalg.qr(x, mode=mode) + + compare_pytorch_and_py([x], outs, [test_value]) From 112f6fd5685ec6a87f818cd9af401a5aad7cc567 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Thu, 3 Jul 2025 09:38:11 +0800 Subject: [PATCH 4/4] Update numba QR dispatch --- .../link/numba/dispatch/linalg/_LAPACK.py | 88 +- .../numba/dispatch/linalg/decomposition/qr.py | 880 ++++++++++++++++++ pytensor/link/numba/dispatch/nlinalg.py | 36 - pytensor/link/numba/dispatch/slinalg.py | 105 ++- tests/link/numba/test_nlinalg.py | 54 -- tests/link/numba/test_slinalg.py | 68 ++ 6 files changed, 1139 insertions(+), 92 deletions(-) create mode 100644 pytensor/link/numba/dispatch/linalg/decomposition/qr.py diff --git a/pytensor/link/numba/dispatch/linalg/_LAPACK.py b/pytensor/link/numba/dispatch/linalg/_LAPACK.py index 5ae7b78c50..421d182c94 100644 --- a/pytensor/link/numba/dispatch/linalg/_LAPACK.py +++ b/pytensor/link/numba/dispatch/linalg/_LAPACK.py @@ -283,7 +283,6 @@ def numba_xgetrs(cls, dtype): Called by scipy.linalg.lu_solve """ - ... lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs") functype = ctypes.CFUNCTYPE( None, @@ -457,3 +456,90 @@ def numba_xgtcon(cls, dtype): _ptr_int, # INFO ) return functype(lapack_ptr) + + @classmethod + def numba_xgeqrf(cls, dtype): + """ + Compute the QR factorization of a general M-by-N matrix A. + + Used in QR decomposition (no pivoting). + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqrf") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # M + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + float_pointer, # TAU + float_pointer, # WORK + _ptr_int, # LWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xgeqp3(cls, dtype): + """ + Compute the QR factorization with column pivoting of a general M-by-N matrix A. + + Used in QR decomposition with pivoting. + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # M + _ptr_int, # N + float_pointer, # A + _ptr_int, # LDA + _ptr_int, # JPVT + float_pointer, # TAU + float_pointer, # WORK + _ptr_int, # LWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xorgqr(cls, dtype): + """ + Generate the orthogonal matrix Q from a QR factorization (real types). + + Used in QR decomposition to form Q. + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "orgqr") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # M + _ptr_int, # N + _ptr_int, # K + float_pointer, # A + _ptr_int, # LDA + float_pointer, # TAU + float_pointer, # WORK + _ptr_int, # LWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) + + @classmethod + def numba_xungqr(cls, dtype): + """ + Generate the unitary matrix Q from a QR factorization (complex types). + + Used in QR decomposition to form Q for complex types. + """ + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "ungqr") + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # M + _ptr_int, # N + _ptr_int, # K + float_pointer, # A + _ptr_int, # LDA + float_pointer, # TAU + float_pointer, # WORK + _ptr_int, # LWORK + _ptr_int, # INFO + ) + return functype(lapack_ptr) diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/qr.py b/pytensor/link/numba/dispatch/linalg/decomposition/qr.py new file mode 100644 index 0000000000..c64489a16f --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/decomposition/qr.py @@ -0,0 +1,880 @@ +import numpy as np +from numba.core.extending import overload +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack +from scipy.linalg import get_lapack_funcs, qr + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) + + +def _xgeqrf(A: np.ndarray, overwrite_a: bool, lwork: int): + """LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A.""" + (geqrf,) = get_lapack_funcs(("geqrf",), (A,)) + return geqrf(A, overwrite_a=overwrite_a, lwork=lwork) + + +@overload(_xgeqrf) +def xgeqrf_impl(A, overwrite_a, lwork): + ensure_lapack() + dtype = A.dtype + w_type = _get_underlying_float(dtype) + geqrf = _LAPACK().numba_xgeqrf(dtype) + + def impl(A, overwrite_a, lwork): + M = np.int32(A.shape[0]) + N = np.int32(A.shape[1]) + + if overwrite_a and A.flags.f_contiguous: + A_copy = A + else: + A_copy = _copy_to_fortran_order(A) + + LDA = val_to_int_ptr(M) + TAU = np.empty(min(M, N), dtype=dtype) + + if lwork == -1: + WORK = np.empty(1, dtype=dtype) + LWORK = val_to_int_ptr(-1) + else: + WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype) + LWORK = val_to_int_ptr(WORK.size) + INFO = val_to_int_ptr(1) + + geqrf( + val_to_int_ptr(M), + val_to_int_ptr(N), + A_copy.view(w_type).ctypes, + LDA, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + LWORK, + INFO, + ) + return A_copy, TAU, WORK, int_ptr_to_val(INFO) + + return impl + + +def _xgeqp3(A: np.ndarray, overwrite_a: bool, lwork: int): + """LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A.""" + (geqp3,) = get_lapack_funcs(("geqp3",), (A,)) + return geqp3(A, overwrite_a=overwrite_a, lwork=lwork) + + +@overload(_xgeqp3) +def xgeqp3_impl(A, overwrite_a, lwork): + ensure_lapack() + dtype = A.dtype + w_type = _get_underlying_float(dtype) + geqp3 = _LAPACK().numba_xgeqp3(dtype) + + def impl(A, overwrite_a, lwork): + M = np.int32(A.shape[0]) + N = np.int32(A.shape[1]) + + if overwrite_a and A.flags.f_contiguous: + A_copy = A + else: + A_copy = _copy_to_fortran_order(A) + + LDA = val_to_int_ptr(M) + JPVT = np.zeros(N, dtype=np.int32) + TAU = np.empty(min(M, N), dtype=dtype) + + if lwork == -1: + WORK = np.empty(1, dtype=dtype) + LWORK = val_to_int_ptr(-1) + else: + WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype) + LWORK = val_to_int_ptr(WORK.size) + INFO = val_to_int_ptr(1) + + geqp3( + val_to_int_ptr(M), + val_to_int_ptr(N), + A_copy.view(w_type).ctypes, + LDA, + JPVT.ctypes, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + LWORK, + INFO, + ) + return A_copy, JPVT, TAU, WORK, int_ptr_to_val(INFO) + + return impl + + +def _xorgqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int): + """LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types).""" + (orgqr,) = get_lapack_funcs(("orgqr",), (A,)) + return orgqr(A, tau, overwrite_a=overwrite_a, lwork=lwork) + + +@overload(_xorgqr) +def xorgqr_impl(A, tau, overwrite_a, lwork): + ensure_lapack() + dtype = A.dtype + w_type = _get_underlying_float(dtype) + orgqr = _LAPACK().numba_xorgqr(dtype) + + def impl(A, tau, overwrite_a, lwork): + M = np.int32(A.shape[0]) + N = np.int32(A.shape[1]) + K = np.int32(tau.shape[0]) + + if overwrite_a and A.flags.f_contiguous: + A_copy = A + else: + A_copy = _copy_to_fortran_order(A) + + if lwork == -1: + WORK = np.empty(1, dtype=dtype) + LWORK = val_to_int_ptr(-1) + else: + WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype) + LWORK = val_to_int_ptr(WORK.size) + + LDA = val_to_int_ptr(M) + INFO = val_to_int_ptr(1) + + orgqr( + val_to_int_ptr(M), + val_to_int_ptr(N), + val_to_int_ptr(K), + A_copy.view(w_type).ctypes, + LDA, + tau.view(w_type).ctypes, + WORK.view(w_type).ctypes, + LWORK, + INFO, + ) + return A_copy, WORK, int_ptr_to_val(INFO) + + return impl + + +def _xungqr(A: np.ndarray, tau: np.ndarray, overwrite_a: bool, lwork: int): + """LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types).""" + (ungqr,) = get_lapack_funcs(("ungqr",), (A,)) + return ungqr(A, tau, overwrite_a=overwrite_a, lwork=lwork) + + +@overload(_xungqr) +def xungqr_impl(A, tau, overwrite_a, lwork): + ensure_lapack() + dtype = A.dtype + w_type = _get_underlying_float(dtype) + ungqr = _LAPACK().numba_xungqr(dtype) + + def impl(A, tau, overwrite_a, lwork): + M = np.int32(A.shape[0]) + N = np.int32(A.shape[1]) + K = np.int32(tau.shape[0]) + + if overwrite_a and A.flags.f_contiguous: + A_copy = A + else: + A_copy = _copy_to_fortran_order(A) + LDA = val_to_int_ptr(M) + + if lwork == -1: + WORK = np.empty(1, dtype=dtype) + LWORK = val_to_int_ptr(-1) + else: + WORK = np.empty(lwork if lwork > 0 else 1, dtype=dtype) + LWORK = val_to_int_ptr(WORK.size) + INFO = val_to_int_ptr(1) + + ungqr( + val_to_int_ptr(M), + val_to_int_ptr(N), + val_to_int_ptr(K), + A_copy.view(w_type).ctypes, + LDA, + tau.view(w_type).ctypes, + WORK.view(w_type).ctypes, + LWORK, + INFO, + ) + + return A_copy, WORK, int_ptr_to_val(INFO) + + return impl + + +def _qr_full_pivot( + x: np.ndarray, + mode: str = "full", + pivoting: bool = True, + overwrite_a: bool = False, + check_finite: bool = False, + lwork: int | None = None, +): + """ + Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same + script. + + Corresponds to the case where mode not "r" or "raw", and pivoting is True, resulting in a return of arrays Q, R, and + P. + """ + return qr( + x, + mode=mode, + pivoting=pivoting, + overwrite_a=overwrite_a, + check_finite=check_finite, + lwork=lwork, + ) + + +def _qr_full_no_pivot( + x: np.ndarray, + mode: str = "full", + pivoting: bool = False, + overwrite_a: bool = False, + check_finite: bool = False, + lwork: int | None = None, +): + """ + Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same + script. + + Corresponds to the case where mode not "r" or "raw", and pivoting is False, resulting in a return of arrays Q and R. + """ + return qr( + x, + mode=mode, + pivoting=pivoting, + overwrite_a=overwrite_a, + check_finite=check_finite, + lwork=lwork, + ) + + +def _qr_r_pivot( + x: np.ndarray, + mode: str = "r", + pivoting: bool = True, + overwrite_a: bool = False, + check_finite: bool = False, + lwork: int | None = None, +): + """ + Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same + script. + + Corresponds to the case where mode is "r" or "raw", and pivoting is True, resulting in a return of arrays R and P. + """ + return qr( + x, + mode=mode, + pivoting=pivoting, + overwrite_a=overwrite_a, + check_finite=check_finite, + lwork=lwork, + ) + + +def _qr_r_no_pivot( + x: np.ndarray, + mode: str = "r", + pivoting: bool = False, + overwrite_a: bool = False, + check_finite: bool = False, + lwork: int | None = None, +): + """ + Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same + script. + + Corresponds to the case where mode is "r" or "raw", and pivoting is False, resulting in a return of array R. + """ + return qr( + x, + mode=mode, + pivoting=pivoting, + overwrite_a=overwrite_a, + check_finite=check_finite, + lwork=lwork, + ) + + +def _qr_raw_no_pivot( + x: np.ndarray, + mode: str = "raw", + pivoting: bool = False, + overwrite_a: bool = False, + check_finite: bool = False, + lwork: int | None = None, +): + """ + Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same + script. + + Corresponds to the case where mode is "raw", and pivoting is False, resulting in a return of arrays H, tau, and R. + """ + (H, tau), R = qr( + x, + mode=mode, + pivoting=pivoting, + overwrite_a=overwrite_a, + check_finite=check_finite, + lwork=lwork, + ) + + return H, tau, R + + +def _qr_raw_pivot( + x: np.ndarray, + mode: str = "raw", + pivoting: bool = True, + overwrite_a: bool = False, + check_finite: bool = False, + lwork: int | None = None, +): + """ + Thin wrapper around scipy.linalg.qr, used to avoid side effects when users import pytensor and scipy in the same + script. + + Corresponds to the case where mode is "raw", and pivoting is True, resulting in a return of arrays H, tau, R, and P. + """ + (H, tau), R, P = qr( + x, + mode=mode, + pivoting=pivoting, + overwrite_a=overwrite_a, + check_finite=check_finite, + lwork=lwork, + ) + + return H, tau, R, P + + +@overload(_qr_full_pivot) +def qr_full_pivot_impl( + x, mode="full", pivoting=True, overwrite_a=False, check_finite=False, lwork=None +): + ensure_lapack() + dtype = x.dtype + w_type = _get_underlying_float(dtype) + geqp3 = _LAPACK().numba_xgeqp3(dtype) + orgqr = _LAPACK().numba_xorgqr(dtype) + + def impl( + x, + mode="full", + pivoting=True, + overwrite_a=False, + check_finite=False, + lwork=None, + ): + M = np.int32(x.shape[0]) + N = np.int32(x.shape[1]) + K = min(M, N) + + if overwrite_a and x.flags.f_contiguous: + x_copy = x + else: + x_copy = _copy_to_fortran_order(x) + + LDA = val_to_int_ptr(M) + TAU = np.empty(K, dtype=dtype) + JPVT = np.zeros(N, dtype=np.int32) + + if lwork is None: + lwork = -1 + + if lwork == -1: + WORK = np.empty(1, dtype=dtype) + geqp3( + val_to_int_ptr(M), + val_to_int_ptr(N), + x_copy.view(w_type).ctypes, + LDA, + JPVT.ctypes, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + val_to_int_ptr(-1), + val_to_int_ptr(1), + ) + lwork_val = int(WORK.item()) + + else: + lwork_val = lwork + + WORK = np.empty(lwork_val, dtype=dtype) + INFO = val_to_int_ptr(1) + geqp3( + val_to_int_ptr(M), + val_to_int_ptr(N), + x_copy.view(w_type).ctypes, + LDA, + JPVT.ctypes, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + val_to_int_ptr(lwork_val), + INFO, + ) + JPVT = (JPVT - 1).astype(np.int32) + + if mode == "full" or M < N: + R = np.triu(x_copy) + else: + R = np.triu(x_copy[:N, :]) + + if M < N: + Q_in = x_copy[:, :M] + elif M == N or mode == "economic": + Q_in = x_copy + else: + # Transpose to put the matrix into Fortran order + Q_in = np.empty((M, M), dtype=dtype).T + Q_in[:, :N] = x_copy + + if lwork == -1: + WORKQ = np.empty(1, dtype=dtype) + orgqr( + val_to_int_ptr(M), + val_to_int_ptr(Q_in.shape[1]), + val_to_int_ptr(K), + Q_in.view(w_type).ctypes, + val_to_int_ptr(M), + TAU.view(w_type).ctypes, + WORKQ.view(w_type).ctypes, + val_to_int_ptr(-1), + val_to_int_ptr(1), + ) + lwork_q = int(WORKQ.item()) + + else: + lwork_q = lwork + + WORKQ = np.empty(lwork_q, dtype=dtype) + INFOQ = val_to_int_ptr(1) + orgqr( + val_to_int_ptr(M), + val_to_int_ptr(Q_in.shape[1]), + val_to_int_ptr(K), + Q_in.view(w_type).ctypes, + val_to_int_ptr(M), + TAU.view(w_type).ctypes, + WORKQ.view(w_type).ctypes, + val_to_int_ptr(lwork_q), + INFOQ, + ) + return Q_in, R, JPVT + + return impl + + +@overload(_qr_full_no_pivot) +def qr_full_no_pivot_impl( + x, mode="full", pivoting=False, overwrite_a=False, check_finite=False, lwork=None +): + ensure_lapack() + dtype = x.dtype + w_type = _get_underlying_float(dtype) + geqrf = _LAPACK().numba_xgeqrf(dtype) + orgqr = _LAPACK().numba_xorgqr(dtype) + + def impl( + x, + mode="full", + pivoting=False, + overwrite_a=False, + check_finite=False, + lwork=None, + ): + M = np.int32(x.shape[0]) + N = np.int32(x.shape[1]) + K = min(M, N) + + if overwrite_a and x.flags.f_contiguous: + x_copy = x + else: + x_copy = _copy_to_fortran_order(x) + + LDA = val_to_int_ptr(M) + TAU = np.empty(K, dtype=dtype) + + if lwork is None: + lwork = -1 + + if lwork == -1: + WORK = np.empty(1, dtype=dtype) + geqrf( + val_to_int_ptr(M), + val_to_int_ptr(N), + x_copy.view(w_type).ctypes, + LDA, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + val_to_int_ptr(-1), + val_to_int_ptr(1), + ) + lwork_val = int(WORK.item()) + else: + lwork_val = lwork + + WORK = np.empty(lwork_val, dtype=dtype) + INFO = val_to_int_ptr(1) + + geqrf( + val_to_int_ptr(M), + val_to_int_ptr(N), + x_copy.view(w_type).ctypes, + LDA, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + val_to_int_ptr(lwork_val), + INFO, + ) + + if M < N or mode == "full": + R = np.triu(x_copy) + else: + R = np.triu(x_copy[:N, :]) + + if M < N: + Q_in = x_copy[:, :M] + elif M == N or mode == "economic": + Q_in = x_copy + else: + # Transpose to put the matrix into Fortran order + Q_in = np.empty((M, M), dtype=dtype).T + Q_in[:, :N] = x_copy + + if lwork == -1: + WORKQ = np.empty(1, dtype=dtype) + orgqr( + val_to_int_ptr(M), + val_to_int_ptr(Q_in.shape[1]), + val_to_int_ptr(K), + Q_in.view(w_type).ctypes, + val_to_int_ptr(M), + TAU.view(w_type).ctypes, + WORKQ.view(w_type).ctypes, + val_to_int_ptr(-1), + val_to_int_ptr(1), + ) + lwork_q = int(WORKQ.item()) + else: + lwork_q = lwork + + WORKQ = np.empty(lwork_q, dtype=dtype) + INFOQ = val_to_int_ptr(1) + + orgqr( + val_to_int_ptr(M), # M + val_to_int_ptr(Q_in.shape[1]), # N + val_to_int_ptr(K), # K + Q_in.view(w_type).ctypes, # A + val_to_int_ptr(M), # LDA + TAU.view(w_type).ctypes, # TAU + WORKQ.view(w_type).ctypes, # WORK + val_to_int_ptr(lwork_q), # LWORK + INFOQ, # INFO + ) + return Q_in, R + + return impl + + +@overload(_qr_r_pivot) +def qr_r_pivot_impl( + x, mode="r", pivoting=True, overwrite_a=False, check_finite=False, lwork=None +): + ensure_lapack() + dtype = x.dtype + w_type = _get_underlying_float(dtype) + geqp3 = _LAPACK().numba_xgeqp3(dtype) + + def impl( + x, + mode="r", + pivoting=True, + overwrite_a=False, + check_finite=False, + lwork=None, + ): + M = np.int32(x.shape[0]) + N = np.int32(x.shape[1]) + + if overwrite_a and x.flags.f_contiguous: + x_copy = x + else: + x_copy = _copy_to_fortran_order(x) + + LDA = val_to_int_ptr(M) + K = min(M, N) + TAU = np.empty(K, dtype=dtype) + JPVT = np.zeros(N, dtype=np.int32) + + if lwork is None: + lwork = -1 + if lwork == -1: + WORK = np.empty(1, dtype=dtype) + geqp3( + val_to_int_ptr(M), + val_to_int_ptr(N), + x_copy.view(w_type).ctypes, + LDA, + JPVT.ctypes, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + val_to_int_ptr(-1), + val_to_int_ptr(1), + ) + lwork_val = int(WORK.item()) + else: + lwork_val = lwork + + WORK = np.empty(lwork_val, dtype=dtype) + INFO = val_to_int_ptr(1) + + geqp3( + val_to_int_ptr(M), + val_to_int_ptr(N), + x_copy.view(w_type).ctypes, + LDA, + JPVT.ctypes, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + val_to_int_ptr(lwork_val), + INFO, + ) + JPVT = (JPVT - 1).astype(np.int32) + + if M < N: + R = np.triu(x_copy) + else: + R = np.triu(x_copy[:N, :]) + + return R, JPVT + + return impl + + +@overload(_qr_r_no_pivot) +def qr_r_no_pivot_impl( + x, mode="r", pivoting=False, overwrite_a=False, check_finite=False, lwork=None +): + ensure_lapack() + dtype = x.dtype + w_type = _get_underlying_float(dtype) + geqrf = _LAPACK().numba_xgeqrf(dtype) + + def impl( + x, + mode="r", + pivoting=False, + overwrite_a=False, + check_finite=False, + lwork=None, + ): + M = np.int32(x.shape[0]) + N = np.int32(x.shape[1]) + + if overwrite_a and x.flags.f_contiguous: + x_copy = x + else: + x_copy = _copy_to_fortran_order(x) + + LDA = val_to_int_ptr(M) + K = min(M, N) + TAU = np.empty(K, dtype=dtype) + + if lwork is None: + lwork = -1 + if lwork == -1: + WORK = np.empty(1, dtype=dtype) + geqrf( + val_to_int_ptr(M), + val_to_int_ptr(N), + x_copy.view(w_type).ctypes, + LDA, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + val_to_int_ptr(-1), + val_to_int_ptr(1), + ) + lwork_val = int(WORK.item()) + else: + lwork_val = lwork + + WORK = np.empty(lwork_val, dtype=dtype) + INFO = val_to_int_ptr(1) + + geqrf( + val_to_int_ptr(M), + val_to_int_ptr(N), + x_copy.view(w_type).ctypes, + LDA, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + val_to_int_ptr(lwork_val), + INFO, + ) + + if M < N: + R = np.triu(x_copy) + else: + R = np.triu(x_copy[:N, :]) + + # Return a tuple with R only to match the scipy qr interface + return (R,) + + return impl + + +@overload(_qr_raw_no_pivot) +def qr_raw_no_pivot_impl( + x, mode="raw", pivoting=False, overwrite_a=False, check_finite=False, lwork=None +): + ensure_lapack() + dtype = x.dtype + w_type = _get_underlying_float(dtype) + geqrf = _LAPACK().numba_xgeqrf(dtype) + + def impl( + x, + mode="raw", + pivoting=False, + overwrite_a=False, + check_finite=False, + lwork=None, + ): + M = np.int32(x.shape[0]) + N = np.int32(x.shape[1]) + + if overwrite_a and x.flags.f_contiguous: + x_copy = x + else: + x_copy = _copy_to_fortran_order(x) + + LDA = val_to_int_ptr(M) + K = min(M, N) + TAU = np.empty(K, dtype=dtype) + + if lwork is None: + lwork = -1 + if lwork == -1: + WORK = np.empty(1, dtype=dtype) + geqrf( + val_to_int_ptr(M), + val_to_int_ptr(N), + x_copy.view(w_type).ctypes, + LDA, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + val_to_int_ptr(-1), + val_to_int_ptr(1), + ) + lwork_val = int(WORK.item()) + else: + lwork_val = lwork + + WORK = np.empty(lwork_val, dtype=dtype) + INFO = val_to_int_ptr(1) + + geqrf( + val_to_int_ptr(M), + val_to_int_ptr(N), + x_copy.view(w_type).ctypes, + LDA, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + val_to_int_ptr(lwork_val), + INFO, + ) + + if M < N: + R = np.triu(x_copy) + else: + R = np.triu(x_copy[:N, :]) + + return x_copy, TAU, R + + return impl + + +@overload(_qr_raw_pivot) +def qr_raw_pivot_impl( + x, mode="raw", pivoting=True, overwrite_a=False, check_finite=False, lwork=None +): + ensure_lapack() + dtype = x.dtype + w_type = _get_underlying_float(dtype) + geqp3 = _LAPACK().numba_xgeqp3(dtype) + + def impl( + x, + mode="raw", + pivoting=True, + overwrite_a=False, + check_finite=False, + lwork=None, + ): + M = np.int32(x.shape[0]) + N = np.int32(x.shape[1]) + + if overwrite_a and x.flags.f_contiguous: + x_copy = x + else: + x_copy = _copy_to_fortran_order(x) + + LDA = val_to_int_ptr(M) + K = min(M, N) + TAU = np.empty(K, dtype=dtype) + JPVT = np.zeros(N, dtype=np.int32) + + if lwork is None: + lwork = -1 + if lwork == -1: + WORK = np.empty(1, dtype=dtype) + geqp3( + val_to_int_ptr(M), + val_to_int_ptr(N), + x_copy.view(w_type).ctypes, + LDA, + JPVT.ctypes, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + val_to_int_ptr(-1), + val_to_int_ptr(1), + ) + lwork_val = int(WORK.item()) + else: + lwork_val = lwork + + WORK = np.empty(lwork_val, dtype=dtype) + INFO = val_to_int_ptr(1) + + geqp3( + val_to_int_ptr(M), + val_to_int_ptr(N), + x_copy.view(w_type).ctypes, + LDA, + JPVT.ctypes, + TAU.view(w_type).ctypes, + WORK.view(w_type).ctypes, + val_to_int_ptr(lwork_val), + INFO, + ) + + JPVT = (JPVT - 1).astype(np.int32) + + if M < N: + R = np.triu(x_copy) + else: + R = np.triu(x_copy[:N, :]) + + return x_copy, TAU, R, JPVT + + return impl diff --git a/pytensor/link/numba/dispatch/nlinalg.py b/pytensor/link/numba/dispatch/nlinalg.py index 3271b5bd26..98d59a4595 100644 --- a/pytensor/link/numba/dispatch/nlinalg.py +++ b/pytensor/link/numba/dispatch/nlinalg.py @@ -16,7 +16,6 @@ Eigh, MatrixInverse, MatrixPinv, - QRFull, SLogDet, ) @@ -146,38 +145,3 @@ def matrixpinv(x): return np.linalg.pinv(inputs_cast(x)).astype(out_dtype) return matrixpinv - - -@numba_funcify.register(QRFull) -def numba_funcify_QRFull(op, node, **kwargs): - mode = op.mode - - if mode != "reduced": - warnings.warn( - ( - "Numba will use object mode to allow the " - "`mode` argument to `numpy.linalg.qr`." - ), - UserWarning, - ) - - if len(node.outputs) > 1: - ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs]) - else: - ret_sig = get_numba_type(node.outputs[0].type) - - @numba_basic.numba_njit - def qr_full(x): - with numba.objmode(ret=ret_sig): - ret = np.linalg.qr(x, mode=mode) - return ret - - else: - out_dtype = node.outputs[0].type.numpy_dtype - inputs_cast = int_to_float_fn(node.inputs, out_dtype) - - @numba_basic.numba_njit(inline="always") - def qr_full(x): - return np.linalg.qr(inputs_cast(x)) - - return qr_full diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 4630224f02..7d1e915298 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -2,6 +2,7 @@ import numpy as np +from pytensor import config from pytensor.link.numba.dispatch.basic import numba_funcify, numba_njit from pytensor.link.numba.dispatch.linalg.decomposition.cholesky import _cholesky from pytensor.link.numba.dispatch.linalg.decomposition.lu import ( @@ -11,6 +12,14 @@ _pivot_to_permutation, ) from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _lu_factor +from pytensor.link.numba.dispatch.linalg.decomposition.qr import ( + _qr_full_no_pivot, + _qr_full_pivot, + _qr_r_no_pivot, + _qr_r_pivot, + _qr_raw_no_pivot, + _qr_raw_pivot, +) from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd @@ -19,6 +28,7 @@ from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal from pytensor.tensor.slinalg import ( LU, + QR, BlockDiagonal, Cholesky, CholeskySolve, @@ -27,7 +37,7 @@ Solve, SolveTriangular, ) -from pytensor.tensor.type import complex_dtypes +from pytensor.tensor.type import complex_dtypes, integer_dtypes _COMPLEX_DTYPE_NOT_SUPPORTED_MSG = ( @@ -311,3 +321,96 @@ def cho_solve(c, b): ) return cho_solve + + +@numba_funcify.register(QR) +def numba_funcify_QR(op, node, **kwargs): + mode = op.mode + check_finite = op.check_finite + pivoting = op.pivoting + overwrite_a = op.overwrite_a + + dtype = node.inputs[0].dtype + if dtype in complex_dtypes: + raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) + + integer_input = dtype in integer_dtypes + in_dtype = config.floatX if integer_input else dtype + + @numba_njit(cache=False) + def qr(a): + if check_finite: + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): + raise np.linalg.LinAlgError( + "Non-numeric values (nan or inf) found in input to qr" + ) + + if integer_input: + a = a.astype(in_dtype) + + if (mode == "full" or mode == "economic") and pivoting: + Q, R, P = _qr_full_pivot( + a, + mode=mode, + pivoting=pivoting, + overwrite_a=overwrite_a, + check_finite=check_finite, + ) + return Q, R, P + + elif (mode == "full" or mode == "economic") and not pivoting: + Q, R = _qr_full_no_pivot( + a, + mode=mode, + pivoting=pivoting, + overwrite_a=overwrite_a, + check_finite=check_finite, + ) + return Q, R + + elif mode == "r" and pivoting: + R, P = _qr_r_pivot( + a, + mode=mode, + pivoting=pivoting, + overwrite_a=overwrite_a, + check_finite=check_finite, + ) + return R, P + + elif mode == "r" and not pivoting: + (R,) = _qr_r_no_pivot( + a, + mode=mode, + pivoting=pivoting, + overwrite_a=overwrite_a, + check_finite=check_finite, + ) + return R + + elif mode == "raw" and pivoting: + H, tau, R, P = _qr_raw_pivot( + a, + mode=mode, + pivoting=pivoting, + overwrite_a=overwrite_a, + check_finite=check_finite, + ) + return H, tau, R, P + + elif mode == "raw" and not pivoting: + H, tau, R = _qr_raw_no_pivot( + a, + mode=mode, + pivoting=pivoting, + overwrite_a=overwrite_a, + check_finite=check_finite, + ) + return H, tau, R + + else: + raise NotImplementedError( + f"QR mode={mode}, pivoting={pivoting} not supported in numba mode." + ) + + return qr diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 8d7c3a449c..ca7c458d15 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -186,60 +186,6 @@ def test_matrix_inverses(op, x, exc, op_args): ) -@pytest.mark.parametrize( - "x, mode, exc", - [ - ( - ( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - "reduced", - None, - ), - ( - ( - pt.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - "r", - None, - ), - ( - ( - pt.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - "reduced", - None, - ), - ( - ( - pt.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - "complete", - UserWarning, - ), - ], -) -def test_QRFull(x, mode, exc): - x, test_x = x - g = nlinalg.QRFull(mode)(x) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - [x], - g, - [test_x], - ) - - @pytest.mark.parametrize( "x, full_matrices, compute_uv, exc", [ diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 7bf3a6e889..bbbb26010f 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -10,6 +10,7 @@ from pytensor import In, config from pytensor.tensor.slinalg import ( LU, + QR, Cholesky, CholeskySolve, LUFactor, @@ -720,3 +721,70 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo # Can never destroy non-contiguous inputs np.testing.assert_allclose(b_val_not_contig, b_val) + + +@pytest.mark.parametrize( + "mode, pivoting", + [("economic", False), ("full", True), ("r", False), ("raw", True)], + ids=["economic", "full_pivot", "r", "raw_pivot"], +) +@pytest.mark.parametrize( + "overwrite_a", [True, False], ids=["overwrite_a", "no_overwrite"] +) +def test_qr(mode, pivoting, overwrite_a): + shape = (5, 5) + rng = np.random.default_rng() + A = pt.tensor( + "A", + shape=shape, + dtype=config.floatX, + ) + A_val = rng.normal(size=shape).astype(config.floatX) + + qr_outputs = pt.linalg.qr(A, mode=mode, pivoting=pivoting) + + fn, res = compare_numba_and_py( + [In(A, mutable=overwrite_a)], + qr_outputs, + [A_val], + numba_mode=numba_inplace_mode, + inplace=True, + ) + + op = fn.maker.fgraph.outputs[0].owner.op + assert isinstance(op, QR) + + destroy_map = op.destroy_map + + if overwrite_a: + assert destroy_map == {0: [0]} + else: + assert destroy_map == {} + + # Test F-contiguous input + val_f_contig = np.copy(A_val, order="F") + res_f_contig = fn(val_f_contig) + + for x, x_f_contig in zip(res, res_f_contig, strict=True): + np.testing.assert_allclose(x, x_f_contig) + + # Should always be destroyable + assert (A_val == val_f_contig).all() == (not overwrite_a) + + # Test C-contiguous input + val_c_contig = np.copy(A_val, order="C") + res_c_contig = fn(val_c_contig) + for x, x_c_contig in zip(res, res_c_contig, strict=True): + np.testing.assert_allclose(x, x_c_contig) + + # Cannot destroy C-contiguous input + np.testing.assert_allclose(val_c_contig, A_val) + + # Test non-contiguous input + val_not_contig = np.repeat(A_val, 2, axis=0)[::2] + res_not_contig = fn(val_not_contig) + for x, x_not_contig in zip(res, res_not_contig, strict=True): + np.testing.assert_allclose(x, x_not_contig) + + # Cannot destroy non-contiguous input + np.testing.assert_allclose(val_not_contig, A_val)