Skip to content

Commit 5547eb0

Browse files
Jesse GrabowskiricardoV94
authored andcommitted
Don't return tuple from jax.scipy.linalg.qr when mode = 'r' (only one return)
1 parent 84252b5 commit 5547eb0

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ def jax_funcify_QR(op, **kwargs):
177177
mode = op.mode
178178

179179
def qr(x, mode=mode):
180-
return jax.scipy.linalg.qr(x, mode=mode)
180+
res = jax.scipy.linalg.qr(x, mode=mode)
181+
return res[0] if len(res) == 1 else res
181182

182183
return qr
183184

tests/link/jax/test_slinalg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,15 @@ def test_jax_expm():
370370
out = pt_slinalg.expm(A)
371371

372372
compare_jax_and_py([A], [out], [A_val])
373+
374+
375+
@pytest.mark.parametrize("mode", ["full", "r"])
376+
def test_jax_qr(mode):
377+
# "full" and "r" modes are tested because "full" returns two matrices (Q, R), while (R,) returns only one.
378+
# Pytensor does not return a tuple when only one output is expected.
379+
rng = np.random.default_rng(utt.fetch_seed())
380+
A = pt.tensor(name="A", shape=(5, 5))
381+
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
382+
out = pt_slinalg.qr(A, mode=mode)
383+
384+
compare_jax_and_py([A], out, [A_val])

0 commit comments

Comments
 (0)