Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ jobs:
strategy:
matrix:
python-version: ['3.10']
test-type: [doctest, pytest, mypy]
test-type: [doctest, pytest]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This is research project, this is not official google project.
33 changes: 33 additions & 0 deletions jaxonnxruntime/core/onnx_primitive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2023 The Jaxonnxruntime Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2023 The Jaxonnxruntime Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Define Onnx Primitive class from jax.core.Primitive."""
from jax import core


class OnnxPrimitive(core.Primitive):
multiple_results: bool = True
56 changes: 51 additions & 5 deletions jaxonnxruntime/onnx_ops/abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Define ONNX Abs operator."""
# pylint: disable=unused-argument
# pylint: disable=g-explicit-length-test
from collections.abc import Callable, Sequence
import functools
import inspect
from typing import Any

import jax
from jax import jit
from jax import numpy as jnp
from jax._src.interpreters import mlir
from jaxonnxruntime.core import handler
from jaxonnxruntime.core import onnx_node
from jaxonnxruntime.core import onnx_primitive


@handler.register_op("Abs")
Expand All @@ -58,14 +63,55 @@ def _prepare(
def version_13(
cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]
) -> Callable[..., Any]:
"""ONNX version_13 Abs op."""
"""ONNX version 13 Abs op."""
cls._prepare(node, inputs, onnx_abs)
return onnx_abs


@functools.partial(jit, static_argnames=())
def onnx_abs(*input_args):
"""The internal jax impl for onnx Abs op."""
assert len(input_args) == 1
(x,) = input_args
def onnx_abs(*args):
"""https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Abs for more details."""

if len(args) != 1:
raise ValueError(
f"len(args) should equal to 1 but got {len(args)}"
)
all_args = args

return onnx_abs_p.bind(*all_args)

# Define onnx_abs_p primitive.
onnx_abs_p = onnx_primitive.OnnxPrimitive("onnx_abs")
onnx_abs_p.multiple_results = False


@onnx_abs_p.def_impl
def _onnx_abs_impl(*args):
x = args[0]
return jnp.abs(x)


@onnx_abs_p.def_abstract_eval
def _onnx_abs_abstract_eval(*args):
aval_args = jax.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), args
)
out = jax.eval_shape(_onnx_abs_impl, *aval_args)
return jax.tree_map(
lambda x: jax.abstract_arrays.ShapedArray(x.shape, x.dtype), out
)


def _onnx_abs_lowering(ctx, *args, platform):
"""abs lowering rule."""
jit_func = jax.jit(_onnx_abs_impl)
jit_func_lowering = mlir.lower_fun(jit_func, onnx_abs_p.multiple_results)
return mlir.delegate_lowering(ctx, jit_func_lowering, *args)


for _p in ("cpu", "tpu", "cuda", "rocm"):
mlir.register_lowering(
onnx_abs_p,
functools.partial(_onnx_abs_lowering, platform=_p),
platform=_p,
)
11 changes: 6 additions & 5 deletions jaxonnxruntime/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,16 @@
from onnx import numpy_helper


NodeProto = onnx.NodeProto
ModelProto = onnx.ModelProto
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_numpy_rank_promotion', 'warn')


class TestItem:

def __init__(
self,
func: Callable[..., Any],
proto: list[Optional[Union[ModelProto, NodeProto]]],
proto: list[Optional[Union[onnx.ModelProto, onnx.NodeProto]]],
) -> None:
self.func = func
self.proto = proto
Expand Down Expand Up @@ -268,7 +267,7 @@ def _add_test(
category: str,
test_name: str,
test_func: Callable[..., Any],
report_item: list[Optional[Union[ModelProto, NodeProto]]],
report_item: list[Optional[Union[onnx.ModelProto, onnx.NodeProto]]],
devices: Iterable[str] = ('CPU', 'CUDA'),
) -> None:
"""Add test to each device and category."""
Expand Down Expand Up @@ -309,7 +308,9 @@ def device_test_func(*args: Any, **kwargs: Any) -> Any:

def _add_model_test(self, model_test, kind: str) -> None:
"""model is loaded at runtime, note sometimes it could even never loaded if the test skipped."""
model_marker: list[Optional[Union[ModelProto, NodeProto]]] = [None]
model_marker: list[Optional[Union[onnx.ModelProto, onnx.NodeProto]]] = [
None
]

def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument
model_dir = model_test.model_dir
Expand Down
Loading