Skip to content

High overhead compared to native JAX #137

@jchia

Description

@jchia

I wrote a program, jax-onnx-rs.py, to compare the elapsed time of native JAX with jaxonnxruntime for ReduceSum. jaxonnxruntime is much slower, at least 5x as slow. Additionally, on a real-world ONNX network, which is not included, jaxonnxruntime was 10x slower than onnxruntime.

$ taskset -c 3 ./jax-onnx-rs.py 2
BENCHMARKING NATIVE JAX
TEST 2.0
warmup
total time 0.09750587900634855

BENCHMARKING BACKENDREP
warmup
total time 0.5798414099845104

jax-onnx-rs.py

#!/bin/env python

import argparse
from jax import jit
import jax.numpy as jnp
import numpy as np
from onnx import helper, TensorProto
import jaxonnxruntime
from timeit import timeit

@jit
def reduce_sum(x):
    return jnp.sum(x)

def benchmark_jax(size: int):
    x = np.zeros((size,), dtype=np.float32)

    print('TEST', reduce_sum(np.ones((size,), dtype=np.float32)))

    def run():
        reduce_sum(x).block_until_ready()

    print('warmup')
    for _ in range(5):
        run()
    print('total time', timeit(run, number=1000))

def benchmark_backend_rep(size: int):
    graph = helper.make_graph(
        name='SimpleModel',
        nodes=[helper.make_node("ReduceSum", inputs=['x'], name='y', outputs=['y'])],
        inputs=[
            helper.make_tensor_value_info('x', TensorProto.FLOAT, [size]),
        ],
        outputs=[
            helper.make_tensor_value_info('y', TensorProto.FLOAT, None),
        ]
    )

    model = helper.make_model(graph, producer_name='simple_model')
    backend_rep = jaxonnxruntime.backend.BackendRep(model)

    x = np.zeros((size,), dtype=np.float32)

    def run():
        y, = backend_rep.run((x,))
        y.block_until_ready()

    print('warmup')
    for _ in range(5):
        run()
    print('total time', timeit(run, number=1000))

def main():
    parser = argparse.ArgumentParser(description='Benchmark native JAX against jaxonnxruntime for ReduceSum')
    parser.add_argument('size', metavar='SIZE', nargs=1, type=int, help='size of ReduceSum')
    args = parser.parse_args()
    print("BENCHMARKING NATIVE JAX")
    benchmark_jax(args.size[0])
    print("\nBENCHMARKING BACKENDREP")
    benchmark_backend_rep(args.size[0])

main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions