diff --git a/pytensor/compile/io.py b/pytensor/compile/io.py index 9ce0421235..ff3531c343 100644 --- a/pytensor/compile/io.py +++ b/pytensor/compile/io.py @@ -95,7 +95,7 @@ def __init__( self.implicit = implicit def __str__(self): - if self.update: + if self.update is not None: return f"In({self.variable} -> {self.update})" else: return f"In({self.variable})" diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index 3c3080765c..3082c6481a 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -1,5 +1,8 @@ -import jax +from itertools import chain + import jax.numpy as jnp +import numpy as np +from jax._src.lax.control_flow import scan as jax_scan from pytensor.compile.mode import JAX, get_mode from pytensor.link.jax.dispatch.basic import jax_funcify @@ -8,16 +11,22 @@ @jax_funcify.register(Scan) def jax_funcify_Scan(op: Scan, **kwargs): + # Note: This implementation is different from the internal PyTensor Scan op. + # In particular, we don't make use of the provided buffers for recurring outputs (MIT-SOT, SIT-SOT) + # These buffers include the initial state and enough space to store as many intermediate results as needed. + # Instead, we let JAX scan recreate the concatenated buffer itself from the values computed in each iteration, + # and then prepend the initial_state and/or truncate results we don't need at the end. + # Likewise, we allow JAX to stack NIT-SOT outputs itself, instead of writing to an empty buffer with the final size. + # In contrast, MIT-MOT behave like PyTensor Scan. We read from and write to the original buffer as we iterate. + # Hopefully, JAX can do the same sort of memory optimizations as PyTensor does. + # Performance-wise, the benchmarks show this approach is better, specially when auto-diffing through JAX. + # For an implementation that is closer to the internal PyTensor Scan, check intermediate commit in + # https://github.com/pymc-devs/pytensor/pull/1651 info = op.info if info.as_while: raise NotImplementedError("While Scan cannot yet be converted to JAX") - if info.n_mit_mot: - raise NotImplementedError( - "Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX" - ) - # Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode) rewriter = ( get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer @@ -27,20 +36,28 @@ def jax_funcify_Scan(op: Scan, **kwargs): def scan(*outer_inputs): # Extract JAX scan inputs + # JAX doesn't want some inputs to be tuple, but later lists (e.g., from list-comprehensions). + # We convert everything to list, so that it remains a list after slicing. outer_inputs = list(outer_inputs) n_steps = outer_inputs[0] # JAX `length` seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs` - mit_sot_init = [] - for tap, seq in zip( - op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs), strict=True - ): - init_slice = seq[: abs(min(tap))] - mit_sot_init.append(init_slice) + # MIT-MOT don't have a concept of "initial state" + # The whole buffer is meaningful at the start of the Scan + mit_mot_init = op.outer_mitmot(outer_inputs) - sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)] + # For MIT-SOT and SIT-SOT, extract the initial states from the outer input buffers + mit_sot_init = [ + buff[: -min(tap)] + for buff, tap in zip( + op.outer_mitsot(outer_inputs), op.info.mit_sot_in_slices, strict=True + ) + ] + sit_sot_init = [buff[0] for buff in op.outer_sitsot(outer_inputs)] init_carry = ( + 0, # loop counter, needed for indexing MIT-MOT + mit_mot_init, mit_sot_init, sit_sot_init, op.outer_shared(outer_inputs), @@ -50,11 +67,13 @@ def scan(*outer_inputs): def jax_args_to_inner_func_args(carry, x): """Convert JAX scan arguments into format expected by scan_inner_func. - scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs) + scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, shared, non_seqs) """ # `carry` contains all inner taps, shared terms, and non_seqs ( + i, + inner_mit_mot, inner_mit_sot, inner_sit_sot, inner_shared, @@ -64,21 +83,34 @@ def jax_args_to_inner_func_args(carry, x): # `x` contains the inner sequences inner_seqs = x - mit_sot_flatten = [] - for array, index in zip( - inner_mit_sot, op.info.mit_sot_in_slices, strict=True - ): - mit_sot_flatten.extend(array[jnp.array(index)]) + # chain.from_iterable is used to flatten the first dimension of each indexed buffer + # [buf1[[idx0, idx1]], buf2[[idx0, idx1]]] -> [buf1[idx0], buf1[idx1], buf2[idx0], buf2[idx1]] + # Benchmarking suggests unpacking advanced indexing on all taps is faster than basic index one tap at a time + mit_mot_flatten = list( + chain.from_iterable( + buffer[(i + np.array(taps))] + for buffer, taps in zip( + inner_mit_mot, info.mit_mot_in_slices, strict=True + ) + ) + ) + mit_sot_flatten = list( + chain.from_iterable( + buffer[np.array(taps)] + for buffer, taps in zip( + inner_mit_sot, info.mit_sot_in_slices, strict=True + ) + ) + ) - inner_scan_inputs = [ + return ( *inner_seqs, + *mit_mot_flatten, *mit_sot_flatten, *inner_sit_sot, *inner_shared, *inner_non_seqs, - ] - - return inner_scan_inputs + ) def inner_func_outs_to_jax_outs( old_carry, @@ -86,47 +118,54 @@ def inner_func_outs_to_jax_outs( ): """Convert inner_scan_func outputs into format expected by JAX scan. - old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys) + old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys) """ ( - inner_mit_sot, - _inner_sit_sot, - inner_shared, + i, + old_mit_mot, + old_mit_sot, + _old_sit_sot, + _old_shared, inner_non_seqs, ) = old_carry - inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs) - inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs) - inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs) - inner_shared_outs = op.inner_shared_outs(inner_scan_outs) - - # Replace the oldest mit_sot tap by the newest value - inner_mit_sot_new = [ - jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0) - for old_mit_sot, new_val in zip( - inner_mit_sot, inner_mit_sot_outs, strict=True + new_mit_mot_vals = op.inner_mitmot_outs_grouped(inner_scan_outs) + new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs) + new_sit_sot = op.inner_sitsot_outs(inner_scan_outs) + new_nit_sot = op.inner_nitsot_outs(inner_scan_outs) + new_shared = op.inner_shared_outs(inner_scan_outs) + + # New carry for next step + # Update MIT-MOT buffer at positions indicated by output taps + new_mit_mot = [ + buffer.at[i + np.array(taps)].set(new_vals) + for buffer, new_vals, taps in zip( + old_mit_mot, new_mit_mot_vals, info.mit_mot_out_slices, strict=True ) ] - - # Nothing needs to be done with sit_sot - inner_sit_sot_new = inner_sit_sot_outs - - inner_shared_new = inner_shared - # Replace old shared inputs by new shared outputs - inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs - + # Discard oldest MIT-SOT and append newest value + new_mit_sot = [ + jnp.concatenate([old_buffer[1:], new_val[None, ...]], axis=0) + for old_buffer, new_val in zip( + old_mit_sot, new_mit_sot_vals, strict=True + ) + ] + # For SIT-SOT, and shared just pass along the new value + # Non-sequences remain unchanged new_carry = ( - inner_mit_sot_new, - inner_sit_sot_new, - inner_shared_new, + i + 1, + new_mit_mot, + new_mit_sot, + new_sit_sot, + new_shared, inner_non_seqs, ) - # Shared variables and non_seqs are not traced + # Select new MIT-SOT, SIT-SOT, and NIT-SOT for tracing traced_outs = [ - *inner_mit_sot_outs, - *inner_sit_sot_outs, - *inner_nit_sot_outs, + *new_mit_sot_vals, + *new_sit_sot, + *new_nit_sot, ] return new_carry, traced_outs @@ -138,9 +177,17 @@ def jax_inner_func(carry, x): return new_carry, traced_outs # Extract PyTensor scan outputs - final_carry, traces = jax.lax.scan( - jax_inner_func, init_carry, seqs, length=n_steps - ) + ( + ( + _final_i, + final_mit_mot, + _final_mit_sot, + _final_sit_sot, + final_shared, + _final_non_seqs, + ), + traces, + ) = jax_scan(jax_inner_func, init_carry, seqs, length=n_steps) def get_partial_traces(traces): """Convert JAX scan traces to PyTensor traces. @@ -162,38 +209,37 @@ def get_partial_traces(traces): ): if init_state is not None: # MIT-SOT and SIT-SOT: The final output should be as long as the input buffer - trace = jnp.atleast_1d(trace) - init_state = jnp.expand_dims( - init_state, range(trace.ndim - init_state.ndim) - ) - full_trace = jnp.concatenate([init_state, trace], axis=0) buffer_size = buffer.shape[0] + if trace.shape[0] > buffer_size: + # Trace is longer than buffer, keep just the last `buffer.shape[0]` entries + partial_trace = trace[-buffer_size:] + else: + # Trace is shorter than buffer, this happens when we keep the initial_state + if init_state.ndim < buffer.ndim: + init_state = init_state[None] + if ( + n_init_needed := buffer_size - trace.shape[0] + ) < init_state.shape[0]: + # We may not need to keep all the initial states + init_state = init_state[-n_init_needed:] + partial_trace = jnp.concatenate([init_state, trace], axis=0) else: # NIT-SOT: Buffer is just the number of entries that should be returned - full_trace = jnp.atleast_1d(trace) buffer_size = buffer + partial_trace = ( + trace[-buffer_size:] if trace.shape[0] > buffer else trace + ) - partial_trace = full_trace[-buffer_size:] + assert partial_trace.shape[0] == buffer_size partial_traces.append(partial_trace) return partial_traces - def get_shared_outs(final_carry): - """Retrive last state of shared_outs from final_carry. - - These outputs cannot be traced in PyTensor Scan - """ - ( - _inner_out_mit_sot, - _inner_out_sit_sot, - inner_out_shared, - _inner_in_non_seqs, - ) = final_carry - - shared_outs = inner_out_shared[: info.n_shared_outs] - return list(shared_outs) - - scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry) + scan_outs_final = [ + *final_mit_mot, + *get_partial_traces(traces), + *final_shared, + ] if len(scan_outs_final) == 1: scan_outs_final = scan_outs_final[0] diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 80cfa0fcf3..eda97560b3 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -307,6 +307,17 @@ def inner_mitmot_outs(self, list_outputs): n_taps = sum(len(x) for x in self.info.mit_mot_out_slices) return list_outputs[:n_taps] + def inner_mitmot_outs_grouped(self, list_outputs): + # Like inner_mitmot_outs but returns a list of lists, one per mitmot + # Instead of a flat list + n_taps = [len(x) for x in self.info.mit_mot_out_slices] + grouped_outs = [] + offset = 0 + for nt in n_taps: + grouped_outs.append(list_outputs[offset : offset + nt]) + offset += nt + return grouped_outs + def outer_mitmot_outs(self, list_outputs): return list_outputs[: self.info.n_mit_mot] diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index 4ee95ab527..ff9f4893af 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -4,15 +4,16 @@ import pytest import pytensor.tensor as pt -from pytensor import function, shared +from pytensor import function, ifelse, shared from pytensor.compile import get_mode from pytensor.configdefaults import config +from pytensor.graph import Apply, Op from pytensor.scan import until from pytensor.scan.basic import scan from pytensor.scan.op import Scan from pytensor.tensor import random from pytensor.tensor.math import gammaln, log -from pytensor.tensor.type import dmatrix, dvector, lscalar, matrix, scalar, vector +from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, vector from tests.link.jax.test_basic import compare_jax_and_py @@ -98,16 +99,26 @@ def test_scan_nit_sot(view): assert len(scan_nodes) == 1 -@pytest.mark.xfail(raises=NotImplementedError) def test_scan_mit_mot(): - xs = pt.vector("xs", shape=(10,)) - ys, _ = scan( - lambda xtm2, xtm1: (xtm2 + xtm1), - outputs_info=[{"initial": xs, "taps": [-2, -1]}], + def step(xtm1, ytm3, ytm1, rho): + return (xtm1 + ytm1) * rho, ytm3 * (1 - rho) + ytm1 * rho + + rho = pt.scalar("rho", dtype="float64") + x0 = pt.vector("xs", shape=(2,)) + y0 = pt.vector("ys", shape=(3,)) + [outs, _], _ = scan( + step, + outputs_info=[x0, {"initial": y0, "taps": [-3, -1]}], + non_sequences=[rho], n_steps=10, ) - grads_wrt_xs = pt.grad(ys.sum(), wrt=xs) - compare_jax_and_py([xs], [grads_wrt_xs], [np.arange(10)]) + grads = pt.grad(outs.sum(), wrt=[x0, y0, rho]) + compare_jax_and_py( + [x0, y0, rho], + grads, + [np.arange(2), np.array([0.5, 0.5, 0.5]), np.array(0.95)], + jax_mode=get_mode("JAX"), + ) def test_scan_update(): @@ -189,96 +200,6 @@ def test_scan_while(): compare_jax_and_py([], [xs], []) -def test_scan_SEIR(): - """Test a scan implementation of a SEIR model. - - SEIR model definition: - S[t+1] = S[t] - B[t] - E[t+1] = E[t] +B[t] - C[t] - I[t+1] = I[t+1] + C[t] - D[t] - - B[t] ~ Binom(S[t], beta) - C[t] ~ Binom(E[t], gamma) - D[t] ~ Binom(I[t], delta) - """ - - def binomln(n, k): - return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1) - - def binom_log_prob(n, p, value): - return binomln(n, value) + value * log(p) + (n - value) * log(1 - p) - - # sequences - at_C = vector("C_t", dtype="int32", shape=(8,)) - at_D = vector("D_t", dtype="int32", shape=(8,)) - # outputs_info (initial conditions) - st0 = lscalar("s_t0") - et0 = lscalar("e_t0") - it0 = lscalar("i_t0") - logp_c = scalar("logp_c") - logp_d = scalar("logp_d") - # non_sequences - beta = scalar("beta") - gamma = scalar("gamma") - delta = scalar("delta") - - # TODO: Use random streams when their JAX conversions are implemented. - # trng = pytensor.tensor.random.RandomStream(1234) - - def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): - # bt0 = trng.binomial(n=st0, p=beta) - bt0 = st0 * beta - bt0 = bt0.astype(st0.dtype) - - logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype) - logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype) - - st1 = st0 - bt0 - et1 = et0 + bt0 - ct0 - it1 = it0 + ct0 - dt0 - return st1, et1, it1, logp_c1, logp_d1 - - (st, et, it, logp_c_all, logp_d_all), _ = scan( - fn=seir_one_step, - sequences=[at_C, at_D], - outputs_info=[st0, et0, it0, logp_c, logp_d], - non_sequences=[beta, gamma, delta], - ) - st.name = "S_t" - et.name = "E_t" - it.name = "I_t" - logp_c_all.name = "C_t_logp" - logp_d_all.name = "D_t_logp" - - s0, e0, i0 = 100, 50, 25 - logp_c0 = np.array(0.0, dtype=config.floatX) - logp_d0 = np.array(0.0, dtype=config.floatX) - beta_val, gamma_val, delta_val = ( - np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753] - ) - C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32) - D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32) - - test_input_vals = [ - C, - D, - s0, - e0, - i0, - logp_c0, - logp_d0, - beta_val, - gamma_val, - delta_val, - ] - compare_jax_and_py( - [at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], - [st, et, it, logp_c_all, logp_d_all], - test_input_vals, - jax_mode="JAX", - ) - - def test_scan_mitsot_with_nonseq(): a_pt = scalar("a") @@ -413,10 +334,275 @@ def test_default_mode_excludes_incompatible_rewrites(): def test_dynamic_sequence_length(): - x = pt.tensor("x", shape=(None,)) - out, _ = scan(lambda x: x + 1, sequences=[x]) + # Imported here to not trigger import of JAX in non-JAX CI jobs + from pytensor.link.jax.dispatch.basic import jax_funcify + + class IncWithoutStaticShape(Op): + def make_node(self, x): + x = pt.as_tensor_variable(x) + return Apply(self, [x], [pt.tensor(shape=(None,) * x.type.ndim)]) + + def perform(self, node, inputs, outputs): + outputs[0][0] = inputs[0] + 1 + + @jax_funcify.register(IncWithoutStaticShape) + def _(op, **kwargs): + return lambda x: x + 1 + inc_without_static_shape = IncWithoutStaticShape() + + x = pt.tensor("x", shape=(None, 3)) + + out, _ = scan( + lambda x: inc_without_static_shape(x), outputs_info=[None], sequences=[x] + ) f = function([x], out, mode=get_mode("JAX").excluding("scan")) assert sum(isinstance(node.op, Scan) for node in f.maker.fgraph.apply_nodes) == 1 - np.testing.assert_allclose(f([]), []) - np.testing.assert_allclose(f([1, 2, 3]), np.array([2, 3, 4])) + np.testing.assert_allclose(f([[1, 2, 3]]), np.array([[2, 3, 4]])) + + # This works if we use JAX scan internally, but not if we use a fori_loop with a buffer allocated by us + np.testing.assert_allclose(f(np.zeros((0, 3))), np.empty((0, 3))) + + # With known static shape we should always manage, regardless of the internal implementation + out2, _ = scan( + lambda x: pt.specify_shape(inc_without_static_shape(x), x.shape), + outputs_info=[None], + sequences=[x], + ) + f2 = function([x], out2, mode=get_mode("JAX").excluding("scan")) + np.testing.assert_allclose(f2([[1, 2, 3]]), np.array([[2, 3, 4]])) + np.testing.assert_allclose(f2(np.zeros((0, 3))), np.empty((0, 3))) + + +def SEIR_model_logp(): + """Setup a Scan implementation of a SEIR model. + + SEIR model definition: + S[t+1] = S[t] - B[t] + E[t+1] = E[t] +B[t] - C[t] + I[t+1] = I[t+1] + C[t] - D[t] + + B[t] ~ Binom(S[t], beta) + C[t] ~ Binom(E[t], gamma) + D[t] ~ Binom(I[t], delta) + """ + + def binomln(n, k): + return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1) + + def binom_log_prob(n, p, value): + return binomln(n, value) + value * log(p) + (n - value) * log(1 - p) + + # sequences + C_t = vector("C_t", dtype="int32", shape=(1200,)) + D_t = vector("D_t", dtype="int32", shape=(1200,)) + # outputs_info (initial conditions) + st0 = scalar("s_t0") + et0 = scalar("e_t0") + it0 = scalar("i_t0") + # non_sequences + beta = scalar("beta") + gamma = scalar("gamma") + delta = scalar("delta") + + def seir_one_step(ct0, dt0, st0, et0, it0, beta, gamma, delta): + # bt0 = trng.binomial(n=st0, p=beta) + bt0 = st0 * beta + bt0 = bt0.astype(st0.dtype) + + logp_c1 = binom_log_prob(et0, gamma, ct0) + logp_d1 = binom_log_prob(it0, delta, dt0) + + st1 = st0 - bt0 + et1 = et0 + bt0 - ct0 + it1 = it0 + ct0 - dt0 + return st1, et1, it1, logp_c1, logp_d1 + + (st, et, it, logp_c_all, logp_d_all), _ = scan( + fn=seir_one_step, + sequences=[C_t, D_t], + outputs_info=[st0, et0, it0, None, None], + non_sequences=[beta, gamma, delta], + ) + st.name = "S_t" + et.name = "E_t" + it.name = "I_t" + logp_c_all.name = "C_t_logp" + logp_d_all.name = "D_t_logp" + + st0_val, et0_val, it0_val = np.array(100.0), np.array(50.0), np.array(25.0) + beta_val, gamma_val, delta_val = ( + np.array(0.277792), + np.array(0.135330), + np.array(0.108753), + ) + C_t_val = np.array([3, 5, 8, 13, 21, 26, 10, 3] * 150, dtype=np.int32) + D_t_val = np.array([1, 2, 3, 7, 9, 11, 5, 1] * 150, dtype=np.int32) + assert C_t_val.shape == D_t_val.shape == C_t.type.shape == D_t.type.shape + + test_input_vals = [ + C_t_val, + D_t_val, + st0_val, + et0_val, + it0_val, + beta_val, + gamma_val, + delta_val, + ] + + loss_graph = logp_c_all.sum() + logp_d_all.sum() + + return dict( + graph_inputs=[C_t, D_t, st0, et0, it0, beta, gamma, delta], + differentiable_vars=[st0, et0, it0, beta, gamma, delta], + test_input_vals=test_input_vals, + loss_graph=loss_graph, + ) + + +def cyclical_reduction(): + """Setup a Scan implementation of the cyclical reduction algorithm. + + This solves the matrix equation A @ X @ X + B @ X + C = 0 for X + + Adapted from https://github.com/jessegrabowski/gEconpy/blob/da495b22ac383cb6cb5dec15f305506aebef7302/gEconpy/solvers/cycle_reduction.py#L187 + """ + + def stabilize(x, jitter=1e-16): + return x + jitter * pt.eye(x.shape[0]) + + def step(A0, A1, A2, A1_hat, norm, step_num, tol): + def cycle_step(A0, A1, A2, A1_hat, _norm, step_num): + tmp = pt.dot( + pt.vertical_stack(A0, A2), + pt.linalg.solve( + stabilize(A1), + pt.horizontal_stack(A0, A2), + assume_a="gen", + check_finite=False, + ), + ) + + n = A0.shape[0] + idx_0 = pt.arange(n) + idx_1 = idx_0 + n + A1 = A1 - tmp[idx_0, :][:, idx_1] - tmp[idx_1, :][:, idx_0] + A0 = -tmp[idx_0, :][:, idx_0] + A2 = -tmp[idx_1, :][:, idx_1] + A1_hat = A1_hat - tmp[idx_1, :][:, idx_0] + + A0_L1_norm = pt.linalg.norm(A0, ord=1) + + return A0, A1, A2, A1_hat, A0_L1_norm, step_num + 1 + + return ifelse( + norm < tol, + (A0, A1, A2, A1_hat, norm, step_num), + cycle_step(A0, A1, A2, A1_hat, norm, step_num), + ) + + A = pt.matrix("A", shape=(20, 20)) + B = pt.matrix("B", shape=(20, 20)) + C = pt.matrix("C", shape=(20, 20)) + + norm = np.array(1e9, dtype="float64") + step_num = pt.zeros((), dtype="int32") + max_iter = 100 + tol = 1e-7 + + (*_, A1_hat, norm, _n_steps), _ = scan( + step, + outputs_info=[A, B, C, B, norm, step_num], + non_sequences=[tol], + n_steps=max_iter, + ) + A1_hat = A1_hat[-1] + + T = -pt.linalg.solve(stabilize(A1_hat), A, assume_a="gen", check_finite=False) + + rng = np.random.default_rng(sum(map(ord, "cycle_reduction"))) + n = A.type.shape[0] + A_test = rng.standard_normal(size=(n, n)) + C_test = rng.standard_normal(size=(n, n)) + # B must be invertible, so we make it symmetric positive-definite + B_rand = rng.standard_normal(size=(n, n)) + B_test = B_rand @ B_rand.T + np.eye(n) * 1e-3 + + return dict( + graph_inputs=[A, B, C], + differentiable_vars=[A, B, C], + test_input_vals=[A_test, B_test, C_test], + loss_graph=pt.sum(T), + ) + + +@pytest.mark.parametrize("gradient_backend", ["PYTENSOR", "JAX"]) +@pytest.mark.parametrize("mode", ("0forward", "1backward", "2both")) +@pytest.mark.parametrize("model", [cyclical_reduction, SEIR_model_logp]) +def test_scan_benchmark(model, mode, gradient_backend, benchmark): + model_dict = model() + graph_inputs = model_dict["graph_inputs"] + differentiable_vars = model_dict["differentiable_vars"] + loss_graph = model_dict["loss_graph"] + test_input_vals = model_dict["test_input_vals"] + + if gradient_backend == "PYTENSOR": + backward_loss = pt.grad( + loss_graph, + wrt=differentiable_vars, + ) + + match mode: + # TODO: Restore original test separately + case "0forward": + graph_outputs = [loss_graph] + case "1backward": + graph_outputs = backward_loss + case "2both": + graph_outputs = [loss_graph, *backward_loss] + case _: + raise ValueError(f"Unknown mode: {mode}") + + jax_fn, _ = compare_jax_and_py( + graph_inputs, + graph_outputs, + test_input_vals, + jax_mode="JAX", + ) + jax_fn.trust_input = True + + else: # gradient_backend == "JAX" + import jax + + loss_fn_tuple = function(graph_inputs, loss_graph, mode="JAX").vm.jit_fn + + def loss_fn(*args): + return loss_fn_tuple(*args)[0] + + match mode: + case "0forward": + jax_fn = jax.jit(loss_fn_tuple) + case "1backward": + jax_fn = jax.jit( + jax.grad(loss_fn, argnums=tuple(range(len(graph_inputs))[2:])) + ) + case "2both": + value_and_grad_fn = jax.value_and_grad( + loss_fn, argnums=tuple(range(len(graph_inputs))[2:]) + ) + + @jax.jit + def jax_fn(*args): + loss, grads = value_and_grad_fn(*args) + return loss, *grads + + case _: + raise ValueError(f"Unknown mode: {mode}") + + def block_until_ready(*inputs, jax_fn=jax_fn): + return [o.block_until_ready() for o in jax_fn(*inputs)] + + block_until_ready(*test_input_vals) # Warmup + + benchmark.pedantic(block_until_ready, test_input_vals, rounds=200, iterations=1)