-
Notifications
You must be signed in to change notification settings - Fork 21
Open
Description
I wrote a program, jax-onnx-rs.py, to compare the elapsed time of native JAX with jaxonnxruntime for ReduceSum. jaxonnxruntime is much slower, at least 5x as slow. Additionally, on a real-world ONNX network, which is not included, jaxonnxruntime was 10x slower than onnxruntime.
$ taskset -c 3 ./jax-onnx-rs.py 2
BENCHMARKING NATIVE JAX
TEST 2.0
warmup
total time 0.09750587900634855
BENCHMARKING BACKENDREP
warmup
total time 0.5798414099845104
jax-onnx-rs.py
#!/bin/env python
import argparse
from jax import jit
import jax.numpy as jnp
import numpy as np
from onnx import helper, TensorProto
import jaxonnxruntime
from timeit import timeit
@jit
def reduce_sum(x):
return jnp.sum(x)
def benchmark_jax(size: int):
x = np.zeros((size,), dtype=np.float32)
print('TEST', reduce_sum(np.ones((size,), dtype=np.float32)))
def run():
reduce_sum(x).block_until_ready()
print('warmup')
for _ in range(5):
run()
print('total time', timeit(run, number=1000))
def benchmark_backend_rep(size: int):
graph = helper.make_graph(
name='SimpleModel',
nodes=[helper.make_node("ReduceSum", inputs=['x'], name='y', outputs=['y'])],
inputs=[
helper.make_tensor_value_info('x', TensorProto.FLOAT, [size]),
],
outputs=[
helper.make_tensor_value_info('y', TensorProto.FLOAT, None),
]
)
model = helper.make_model(graph, producer_name='simple_model')
backend_rep = jaxonnxruntime.backend.BackendRep(model)
x = np.zeros((size,), dtype=np.float32)
def run():
y, = backend_rep.run((x,))
y.block_until_ready()
print('warmup')
for _ in range(5):
run()
print('total time', timeit(run, number=1000))
def main():
parser = argparse.ArgumentParser(description='Benchmark native JAX against jaxonnxruntime for ReduceSum')
parser.add_argument('size', metavar='SIZE', nargs=1, type=int, help='size of ReduceSum')
args = parser.parse_args()
print("BENCHMARKING NATIVE JAX")
benchmark_jax(args.size[0])
print("\nBENCHMARKING BACKENDREP")
benchmark_backend_rep(args.size[0])
main()
sotetsuk
Metadata
Metadata
Assignees
Labels
No labels