-
Notifications
You must be signed in to change notification settings - Fork 145
Example of benchmarking backends against pure implementations in the respective backend #1632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ends, added block_until_ready to jax calls
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Thanks for working through all those examples, this is really nice! I only had a quick look at the code so far, but I think there are two issues with the timing code: I think the jax def block(func):
def inner(*args, **kwargs):
return jax.block_until_ready(func(*args, **kwargs))
return inner The other thing is that I think in the fibonacci example jax just constant folds everything, so it doesn't actually do any computation, but just returns the pre-computed result. If you want to avoid a jax loop (not sure if that's helpful though), you could do something like this to prevent constant folding: @partial(jax.jit, static_argnums=0)
def fibonacci_jax(n, a):
b = jnp.array(1, dtype=np.int32)
for _ in range(n):
a, b = a + b, a
return a
from functools import partial
fibonacci_jax = partial(fibonacci_jax, a=np.array(1, dtype=np.int32)) |
…ithm to prevent constant folding
Thank you @aseyboldt, for your help with properly blocking JAX functions and with the constant folding issue with the JAX fibonacci algorithm. I went ahead and updated both by using a slightly modified version of the blocking function you shared above and I replaced the for loop with a fori_loop to prevent the constant folding. |
@Dekermanjian I pushed a notebook building on top of yours. I focused on a different variation of the fibonnaci scan, where the n_steps are constant, but b is symbolic. This is because right now we have a very complex shape graph for symbolic steps that just makes things messier, and also prevents proper inplace sometimes. On my machine, the function evaluation with constant n_steps=10 is the following:
And with constant n_steps=1000:
The difference between My hypothesis for why
|
Small discrepancy that I don't know if it matters in your original notebook. You are using |
What makes you think fori wouldn't constant_fold but scan would? The point @aseyboldt was trying to make is that for jax the number of steps is always constant (even if you pass it at runtime, it's going to specialize on the specific value), and if you have constant a/b then all the variables are constant as far as JAX can see. The trick to use symbolic a/b is that those aren't treated as constant by jax and therefore it won't constant-fold it. You can always inspect the jaxepr or whatever it's called to be sure though. Anyway I prefer the constant n_steps because of the known issues related to Scan that I linked to above. |
I hacked a direct dispatch for the add that shows up in the fib scan. It bridges the gap between the pytensor_numba 150ms and the 50ms of my @numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
if isinstance(op.scalar_op, Add) and not op.inplace_pattern and len(node.inputs) == 2:
print("USING OPTIMIZED ADD")
@numba_basic.numba_njit
def add(x, y):
return x + y
return add However if I do something a bit more fair (assuming the core scalar op is actually needed and expects scalars), it's a bit closer to our regular Elemwise (130ms, so perhaps still 20ms slower) out_dtype = node.outputs[0].type.numpy_dtype
if (
(not op.inplace_pattern)
and (len(node.inputs) == 2)
and (len(node.outputs) == 1)
and (node.outputs[0].type.shape == ())
):
print("USING FAIRER OPTIMIZED ADD")
@numba_basic.numba_njit
def add(x, y):
out = scalar_op_fn(x[()], y[()])
return np.asarray(out, dtype=out_dtype)
return add No idea why is it so different. Both should need to allocate a new array, and both should need to access the inner value. Is this something we may want to investigate further? Edit: Ugh, numba downcasts 0d add to scalar?? That seems to explain it: numba/numba#10266 import numpy as np
import numba
@numba.njit
def add_scalars_1(x, y):
return x + y
a = np.zeros(())
b = np.ones(())
# It returns a float
assert isinstance(type(add_scalars_1(a, b)), np.ndarray) # AssertionError |
Summary from the fibonacci bit on numba. We're not doing anything dumb, the difference is we were basically comparing a scalar loop (due to numba bug) and an array loop. This performance question can be more properly addressed by rewriting scalar scan loops as scalar loops, as mentioned in #301. More broadly, we could try to modify the scan inner function to accept scalars as inputs and return scalars as outputs. Even more broadly we could try to create a numba function that will store the results in the output buffer like scan tries in the C backend. This is tricky because we have implemented Numba Ops in a way that they don't receive output storage. Note that we do have inplace rewrites in the scan dispatch, so output buffers aren't always created. However this doesn't prevent the copy back to the original tape. Inplace doesn't mean an Op will store the result in a specific input, just that it can (or destroy the input for whatever other purposes it wants). Even if it stores the output in an input, numba/llvm doesn't seem to be clever enough to understand the reassignment in the tape is a a no-op: import numpy as np
import numba
@numba.njit
def foo(x):
for i in range(x.shape[0]):
y = x[i:i+1]
y *= -1
x[i:i+1] = y
return x
@numba.njit
def bar(x):
for i in range(x.shape[0]):
y = x[i:i+1]
y *= -1
return x
x = np.ones(50)
assert foo(x.copy()).sum() == bar(x.copy()).sum() == -50.0
%timeit foo(x) # 2.1 us
%timeit bar(x) # 600 ns |
…ed fori_loop to explicitly tell it not to constant fold
@ricardoV94 Thank you for all of the above! I updated the elemwise pytensor function to exclude the rewrite as you suggested above. I will do a pass on the other functions and make sure the scan shows up in the dprint. For the jax fori_loop, I thought the fori_loop prevented constant folding and I added a little section printing the jaxexpr of the compiled function where you can (I believe at least) that it is not constant folding. This may have just been a fluke due to the size of the number of loops, so I also added an argument that tells the fori_loop not to constant fold (I think). |
Was it actually constant folding the scan, or that was just a guess? |
No, I don't think the jax fori_loop was constant folding but it may have only been because I have it set to loop 100,000 times and that may have been why jax decided not to constant fold it. I thought that fori_loop doesn't constant fold but upon reading the documentation again there is no mention of whether it does or doesn't when you don't pass anything to the unroll argument. I went ahead and explicitly passed unroll=False to the fori_loop to prevent constant folding, and please correct me if I am wrong but when you say constant folding you mean unrolling of the scan during compilation correct? |
I don't think looking at the jaxpr can tell us if the jax/xla will constant fold the function. jaxpr is just the input to the optimization pipeline, not the result of it. I don't know of a way to inspect the final optimized code in jax or xla (if someone does, I'd be very happy to hear about it!). Loop unrolling and constant folding are different things. Constant folding means that the compiler removes the whole loop and just directly returns the result. A sufficiently smart compiler can transform the code
when called with eg 2 with
So it will just not do any work at all anymore. Loop unrolling means that it replaces the loop by a series of loop bodies, so the loop is gone, but it still does the work. When benchmarking, I'd always make sure that it is impossible to constant fold the code. I'm relatively sure it was constant folding the loop in the first example I saw, it was I think impossibly fast. |
Thank you @aseyboldt for the detailed explanation. I want to make one correction to my last comment. It is not the jaxpr that I am presenting in that short section it is the stableHLO, I don't know if this makes any difference. Here is a copy paste of a section from a book about JAX's stableHLO: “Then, exciting things begin. We JIT-compile our function. For the jitted function, we can create a lowered function. Lowering is a process of converting a higher-level representation to a lower-level representation. Here, we create a StableHLO IR code consisting of basic StableHLO operations. This code is pretty straightforward and almost resembles the original calculations. Then we move further, compiling this StableHLO code to HLO optimized for the target backend. This HLO code contains fused computation. The full code is available in the book’s repository; here, I highlight only the StableHLO and HLO parts.”
Excerpt From Deep Learning with JAX Grigory Sapunov This material may be protected by copyright.
Maybe I should also check the HLO code if you think this makes a difference? I checked the HLO output (which I believe is after compilation because I had to call .compile on the lowered function) and still looks like it is not constant folding: HloModule jit_fibonacci_jax, is_scheduled=true, entry_computation_layout={()->s32[]}, allow_spmd_sharding_propagation_to_output={true} %region_0.2 (arg_tuple.1: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) { %arg_tuple.1 = (s32[], s32[], s32[]) parameter(0) %constant.3 = s32[] constant(1) %get-tuple-element.8 = s32[] get-tuple-element(%arg_tuple.1), index=2 %get-tuple-element.7 = s32[] get-tuple-element(%arg_tuple.1), index=1 %get-tuple-element.6 = s32[] get-tuple-element(%arg_tuple.1), index=0 %copy.5 = s32[] copy(%get-tuple-element.8) %copy.4 = s32[] copy(%get-tuple-element.7), metadata={op_name="jit(fibonacci_jax)/while/body/closed_call" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=9 source_end_line=9 source_column=11 source_end_column=48} %add.3 = s32[] add(%get-tuple-element.6, %constant.3), metadata={op_name="jit(fibonacci_jax)/while/body/add" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=9 source_end_line=9 source_column=11 source_end_column=48} %add.0 = s32[] add(%copy.4, %copy.5), metadata={op_name="jit(fibonacci_jax)/while/body/closed_call/add" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=7 source_end_line=7 source_column=15 source_end_column=20} %copy.8 = s32[] copy(%copy.4), control-predecessors={%copy.5} ROOT %tuple.5 = (s32[], s32[], s32[]) tuple(%add.3, %add.0, %copy.8) } %region_1.3 (arg_tuple.3: (s32[], s32[], s32[])) -> pred[] { %arg_tuple.3 = (s32[], s32[], s32[]) parameter(0) %constant.5 = s32[] constant(100000) %get-tuple-element.9 = s32[] get-tuple-element(%arg_tuple.3), index=0 ROOT %lt.1 = pred[] compare(%get-tuple-element.9, %constant.5), direction=LT, metadata={op_name="jit(fibonacci_jax)/while/cond/lt" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=9 source_end_line=9 source_column=11 source_end_column=48} } %while.6_computation (tuple.6: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) { %tuple.6 = (s32[], s32[], s32[]) parameter(0) ROOT %while.0 = (s32[], s32[], s32[]) while(%tuple.6), condition=%region_1.3, body=%region_0.2, metadata={op_name="jit(fibonacci_jax)/while" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=9 source_end_line=9 source_column=11 source_end_column=48}, backend_config={"known_trip_count":{"n":"100000"},"known_init_step":{"init":"0","step":"1"},"known_induction_variable":{"tuple_index":"0"}} } ENTRY %main.4 () -> s32[] { %constant.6 = s32[] constant(1) %constant.7 = s32[] constant(0) %copy.9 = s32[] copy(%constant.6) %copy.10 = s32[] copy(%constant.7) %copy.2 = s32[] copy(%copy.9) %tuple.2 = (s32[], s32[], s32[]) tuple(%copy.10, %copy.9, %copy.2) %call = (s32[], s32[], s32[]) call(%tuple.2), to_apply=%while.6_computation, frontend_attributes={xla_cpu_small_call="true"}, metadata={op_name="jit(fibonacci_jax)/while" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=9 source_end_line=9 source_column=11 source_end_column=48} ROOT %while.8 = s32[] get-tuple-element(%call), index=1, metadata={op_name="jit(fibonacci_jax)/while" source_file="/var/folders/qv/63yqp4p50630y7pqfgcbgtq80000gn/T/ipykernel_44972/469706885.py" source_line=9 source_end_line=9 source_column=11 source_end_column=48} }
Please let me know if I am just not getting it. I am pretty new to this sort of low level computational frameworks. |
The cusum one may be worth exploring. We're doing 10x worse than direct numba, and 2x than direct JAX. There are some differences, you're computing alarms in the loop for numba, but outside for pytensor, which requires another iteration. And the sneaky scalars may show up again |
…nsor dprint included scan in the graph
@Dekermanjian I would still suggest we focus on scan with fixed number of steps but symbolic inputs. This is the most fair across backends, as for instance JAX always compiles to a specific number of steps |
@ricardoV94, I reread your earlier comments and realized I missed that detail. I’ll adjust the functions so they always use a fixed number of steps with symbolic inputs. |
…_step arg, updated jax elemntwise to use fori_loop
@ricardoV94 I updated the Pytensor fibonacci algorithm to follow your method of fixing the number of iterations in the scans |
When sequences are passed n_steps defaults to those, but it doesn't hurt to specify again (or shouldn't!) |
added trust_inputs to pytensor functions, added multiple pytensor backends, added block_until_ready to jax calls
Description
This PR adds a benchmarking notebook that benchmarks several algorithms in different backends and compares them to pure implementations in their respective backend.
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1632.org.readthedocs.build/en/1632/