Skip to content

Conversation

Dekermanjian
Copy link

@Dekermanjian Dekermanjian commented Oct 5, 2025

added trust_inputs to pytensor functions, added multiple pytensor backends, added block_until_ready to jax calls

Description

This PR adds a benchmarking notebook that benchmarks several algorithms in different backends and compares them to pure implementations in their respective backend.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • [ x] Other (please specify): Example Notebook

📚 Documentation preview 📚: https://pytensor--1632.org.readthedocs.build/en/1632/

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@aseyboldt
Copy link
Member

Thanks for working through all those examples, this is really nice!

I only had a quick look at the code so far, but I think there are two issues with the timing code:

I think the jax block_until_ready doesn't yet do what you want. You call jax.block_until_ready on the function, not the function return values. This means that it will just wait until all computations inside the function are done (of which there are none), and then returns the original function.
You want something like

def block(func):
    def inner(*args, **kwargs):
        return jax.block_until_ready(func(*args, **kwargs))
    return inner

The other thing is that I think in the fibonacci example jax just constant folds everything, so it doesn't actually do any computation, but just returns the pre-computed result.

If you want to avoid a jax loop (not sure if that's helpful though), you could do something like this to prevent constant folding:

@partial(jax.jit, static_argnums=0)
def fibonacci_jax(n, a):
    b = jnp.array(1, dtype=np.int32)
    for _ in range(n):
        a, b = a + b, a
    return a

from functools import partial
fibonacci_jax = partial(fibonacci_jax, a=np.array(1, dtype=np.int32))

@Dekermanjian
Copy link
Author

Thank you @aseyboldt, for your help with properly blocking JAX functions and with the constant folding issue with the JAX fibonacci algorithm. I went ahead and updated both by using a slightly modified version of the blocking function you shared above and I replaced the for loop with a fori_loop to prevent the constant folding.

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 7, 2025

@Dekermanjian I pushed a notebook building on top of yours. I focused on a different variation of the fibonnaci scan, where the n_steps are constant, but b is symbolic. This is because right now we have a very complex shape graph for symbolic steps that just makes things messier, and also prevents proper inplace sometimes.
Related issues: #1283, #112
Related PRs: #1299

On my machine, the function evaluation with constant n_steps=10 is the following:

c-backend: 46.us
numba-backend: 5.5 us
numba-backend-skipping-pytensor: 3.3 us
direct_numba: 1.2 us
numba_that_looks_like_pytensor: 2.2 us

And with constant n_steps=1000:

c-backend: 2.22ms
numba-backend: 165 us
numba-backend-skipping-pytensor: 159 us
direct_numba: 3.2 us
numba_that_looks_like_pytensor: 55 us

The difference between numba-backend and numba-backend-skipping-pytensor is small and on the order of 3-5 us. It should come further down after #1351

My hypothesis for why numba-backend-skipping-pytensor is slower than numba_that_looks_like_pytensor is because of how we implement Elemwise in PyTensor which I didn't try to replicate. see below

I'm a bit baffled that pure numba remains so fast for n=1000, I guess it just manages to stay in scalar land See below

@ricardoV94
Copy link
Member

Small discrepancy that I don't know if it matters in your original notebook. You are using pt.constant(1.0) which is float by default, but you use int32 in the numba/jax variations.

@ricardoV94
Copy link
Member

I replaced the for loop with a fori_loop to prevent the constant folding.

What makes you think fori wouldn't constant_fold but scan would? The point @aseyboldt was trying to make is that for jax the number of steps is always constant (even if you pass it at runtime, it's going to specialize on the specific value), and if you have constant a/b then all the variables are constant as far as JAX can see. The trick to use symbolic a/b is that those aren't treated as constant by jax and therefore it won't constant-fold it. You can always inspect the jaxepr or whatever it's called to be sure though.

Anyway I prefer the constant n_steps because of the known issues related to Scan that I linked to above.

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 7, 2025

I'm a bit baffled that pure numba remains so fast for n=1000, I guess it just manages to stay in scalar land

This one behaves much more like the recreated pytensor version. Not sure if it's extra copies or it doesn't work just with scalars...

image

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 7, 2025

Your second example is not actually testing scan. PyTensor has a rewrite that converts pure elementwise scans to regular elementwise. You should always dprint your final version to be sure you're testing what you mean to.

image

In this case you can keep it by excluding the specific rewrite:
image

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 7, 2025

I hacked a direct dispatch for the add that shows up in the fib scan. It bridges the gap between the pytensor_numba 150ms and the 50ms of my numba_that_looks_like_pytensor implementation benched above.

@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
    if isinstance(op.scalar_op, Add) and not op.inplace_pattern and len(node.inputs) == 2:
        print("USING OPTIMIZED ADD")
        @numba_basic.numba_njit
        def add(x, y):
            return x + y

        return add

However if I do something a bit more fair (assuming the core scalar op is actually needed and expects scalars), it's a bit closer to our regular Elemwise (130ms, so perhaps still 20ms slower)

    out_dtype = node.outputs[0].type.numpy_dtype
    if (
        (not op.inplace_pattern)
        and (len(node.inputs) == 2)
        and (len(node.outputs) == 1)
        and (node.outputs[0].type.shape == ())
    ):
        print("USING FAIRER OPTIMIZED ADD")
        @numba_basic.numba_njit
        def add(x, y):
            out = scalar_op_fn(x[()], y[()])
            return np.asarray(out, dtype=out_dtype)

        return add

No idea why is it so different. Both should need to allocate a new array, and both should need to access the inner value. Is this something we may want to investigate further?

Edit: Ugh, numba downcasts 0d add to scalar?? That seems to explain it: numba/numba#10266

import numpy as np
import numba

@numba.njit
def add_scalars_1(x, y):
    return x + y

a = np.zeros(())
b = np.ones(())
# It returns a float
assert isinstance(type(add_scalars_1(a, b)), np.ndarray)  # AssertionError

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 7, 2025

Summary from the fibonacci bit on numba. We're not doing anything dumb, the difference is we were basically comparing a scalar loop (due to numba bug) and an array loop. This performance question can be more properly addressed by rewriting scalar scan loops as scalar loops, as mentioned in #301.

More broadly, we could try to modify the scan inner function to accept scalars as inputs and return scalars as outputs. Even more broadly we could try to create a numba function that will store the results in the output buffer like scan tries in the C backend. This is tricky because we have implemented Numba Ops in a way that they don't receive output storage.

Note that we do have inplace rewrites in the scan dispatch, so output buffers aren't always created. However this doesn't prevent the copy back to the original tape. Inplace doesn't mean an Op will store the result in a specific input, just that it can (or destroy the input for whatever other purposes it wants).

Even if it stores the output in an input, numba/llvm doesn't seem to be clever enough to understand the reassignment in the tape is a a no-op:

import numpy as np
import numba

@numba.njit
def foo(x):
    for i in range(x.shape[0]):
        y = x[i:i+1]
        y *= -1
        x[i:i+1] = y
    return x

@numba.njit
def bar(x):
    for i in range(x.shape[0]):
        y = x[i:i+1]
        y *= -1    
    return x

x = np.ones(50)
assert foo(x.copy()).sum() == bar(x.copy()).sum() == -50.0
%timeit foo(x)  # 2.1 us
%timeit bar(x)  # 600 ns

…ed fori_loop to explicitly tell it not to constant fold
Copy link
Author

@ricardoV94 Thank you for all of the above! I updated the elemwise pytensor function to exclude the rewrite as you suggested above. I will do a pass on the other functions and make sure the scan shows up in the dprint. For the jax fori_loop, I thought the fori_loop prevented constant folding and I added a little section printing the jaxexpr of the compiled function where you can (I believe at least) that it is not constant folding. This may have just been a fluke due to the size of the number of loops, so I also added an argument that tells the fori_loop not to constant fold (I think).

@ricardoV94
Copy link
Member

Was it actually constant folding the scan, or that was just a guess?

Copy link
Author

No, I don't think the jax fori_loop was constant folding but it may have only been because I have it set to loop 100,000 times and that may have been why jax decided not to constant fold it. I thought that fori_loop doesn't constant fold but upon reading the documentation again there is no mention of whether it does or doesn't when you don't pass anything to the unroll argument. I went ahead and explicitly passed unroll=False to the fori_loop to prevent constant folding, and please correct me if I am wrong but when you say constant folding you mean unrolling of the scan during compilation correct?

@aseyboldt
Copy link
Member

aseyboldt commented Oct 8, 2025

I don't think looking at the jaxpr can tell us if the jax/xla will constant fold the function. jaxpr is just the input to the optimization pipeline, not the result of it. I don't know of a way to inspect the final optimized code in jax or xla (if someone does, I'd be very happy to hear about it!).

Loop unrolling and constant folding are different things. Constant folding means that the compiler removes the whole loop and just directly returns the result. A sufficiently smart compiler can transform the code

@jit(static_argnums=0)
def fib(n):
    b = jnp.array(1, dtype=np.int32)
    a = b
    for _ in range(n):
        a, b = a + b, a
    return a

when called with eg 2 with

def fib(n):
    return 2

So it will just not do any work at all anymore.

Loop unrolling means that it replaces the loop by a series of loop bodies, so the loop is gone, but it still does the work.

When benchmarking, I'd always make sure that it is impossible to constant fold the code.

I'm relatively sure it was constant folding the loop in the first example I saw, it was I think impossibly fast.

Copy link
Author

Dekermanjian commented Oct 8, 2025

Thank you @aseyboldt for the detailed explanation. I want to make one correction to my last comment. It is not the jaxpr that I am presenting in that short section it is the stableHLO, I don't know if this makes any difference. Here is a copy paste of a section from a book about JAX's stableHLO:

“Then, exciting things begin. We JIT-compile our function. For the jitted function, we can create a lowered function. Lowering is a process of converting a higher-level representation to a lower-level representation. Here, we create a StableHLO IR code consisting of basic StableHLO operations. This code is pretty straightforward and almost resembles the original calculations. Then we move further, compiling this StableHLO code to HLO optimized for the target backend. This HLO code contains fused computation. The full code is available in the book’s repository; here, I highlight only the StableHLO and HLO parts.”

Excerpt From

Deep Learning with JAX

Grigory Sapunov

This material may be protected by copyright.

Maybe I should also check the HLO code if you think this makes a difference? I checked the HLO output (which I believe is after compilation because I had to call .compile on the lowered function) and still looks like it is not constant folding:

HloModule jit_fibonacci_jax, is_scheduled=true, entry_computation_layout={()->s32[]}, allow_spmd_sharding_propagation_to_output={true}

%region_0.2 (arg_tuple.1: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
  %arg_tuple.1 = (s32[], s32[], s32[]) parameter(0)
  %constant.3 = s32[] constant(1)
  %get-tuple-element.8 = s32[] get-tuple-element(%arg_tuple.1), index=2
  %get-tuple-element.7 = s32[] get-tuple-element(%arg_tuple.1), index=1
  %get-tuple-element.6 = s32[] get-tuple-element(%arg_tuple.1), index=0
  %copy.5 = s32[] copy(%get-tuple-element.8)
  %copy.4 = s32[] copy(%get-tuple-element.7), metadata={op_name="jit(fibonacci_jax)/while/body/closed_call" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=9 source_end_line=9 source_column=11 source_end_column=48}
  %add.3 = s32[] add(%get-tuple-element.6, %constant.3), metadata={op_name="jit(fibonacci_jax)/while/body/add" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=9 source_end_line=9 source_column=11 source_end_column=48}
  %add.0 = s32[] add(%copy.4, %copy.5), metadata={op_name="jit(fibonacci_jax)/while/body/closed_call/add" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=7 source_end_line=7 source_column=15 source_end_column=20}
  %copy.8 = s32[] copy(%copy.4), control-predecessors={%copy.5}
  ROOT %tuple.5 = (s32[], s32[], s32[]) tuple(%add.3, %add.0, %copy.8)
}

%region_1.3 (arg_tuple.3: (s32[], s32[], s32[])) -> pred[] {
  %arg_tuple.3 = (s32[], s32[], s32[]) parameter(0)
  %constant.5 = s32[] constant(100000)
  %get-tuple-element.9 = s32[] get-tuple-element(%arg_tuple.3), index=0
  ROOT %lt.1 = pred[] compare(%get-tuple-element.9, %constant.5), direction=LT, metadata={op_name="jit(fibonacci_jax)/while/cond/lt" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=9 source_end_line=9 source_column=11 source_end_column=48}
}

%while.6_computation (tuple.6: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
  %tuple.6 = (s32[], s32[], s32[]) parameter(0)
  ROOT %while.0 = (s32[], s32[], s32[]) while(%tuple.6), condition=%region_1.3, body=%region_0.2, metadata={op_name="jit(fibonacci_jax)/while" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=9 source_end_line=9 source_column=11 source_end_column=48}, backend_config={"known_trip_count":{"n":"100000"},"known_init_step":{"init":"0","step":"1"},"known_induction_variable":{"tuple_index":"0"}}
}

ENTRY %main.4 () -> s32[] {
  %constant.6 = s32[] constant(1)
  %constant.7 = s32[] constant(0)
  %copy.9 = s32[] copy(%constant.6)
  %copy.10 = s32[] copy(%constant.7)
  %copy.2 = s32[] copy(%copy.9)
  %tuple.2 = (s32[], s32[], s32[]) tuple(%copy.10, %copy.9, %copy.2)
  %call = (s32[], s32[], s32[]) call(%tuple.2), to_apply=%while.6_computation, frontend_attributes={xla_cpu_small_call="true"}, metadata={op_name="jit(fibonacci_jax)/while" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=9 source_end_line=9 source_column=11 source_end_column=48}
  ROOT %while.8 = s32[] get-tuple-element(%call), index=1, metadata={op_name="jit(fibonacci_jax)/while" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=9 source_end_line=9 source_column=11 source_end_column=48}
}

Please let me know if I am just not getting it. I am pretty new to this sort of low level computational frameworks.

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 9, 2025

The cusum one may be worth exploring. We're doing 10x worse than direct numba, and 2x than direct JAX.

There are some differences, you're computing alarms in the loop for numba, but outside for pytensor, which requires another iteration. And the sneaky scalars may show up again

@ricardoV94
Copy link
Member

@Dekermanjian I would still suggest we focus on scan with fixed number of steps but symbolic inputs. This is the most fair across backends, as for instance JAX always compiles to a specific number of steps

@Dekermanjian
Copy link
Author

@ricardoV94, I reread your earlier comments and realized I missed that detail. I’ll adjust the functions so they always use a fixed number of steps with symbolic inputs.

…_step arg, updated jax elemntwise to use fori_loop
@Dekermanjian
Copy link
Author

@ricardoV94 I updated the Pytensor fibonacci algorithm to follow your method of fixing the number of iterations in the scans n_step argument. I wasn't sure if this was also needed when sequences were passed to pytensor.scan but I also fixed the number of iterations in those cases as well, just in case.

@ricardoV94
Copy link
Member

@ricardoV94 I updated the Pytensor fibonacci algorithm to follow your method of fixing the number of iterations in the scans n_step argument. I wasn't sure if this was also needed when sequences were passed to pytensor.scan but I also fixed the number of iterations in those cases as well, just in case.

When sequences are passed n_steps defaults to those, but it doesn't hurt to specify again (or shouldn't!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants