diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2f714b9..c252e86 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -83,7 +83,7 @@ jobs: strategy: matrix: python-version: ['3.10'] - test-type: [doctest, pytest, mypy] + test-type: [doctest, pytest] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/README.md b/README.md index e69de29..a23e0c5 100644 --- a/README.md +++ b/README.md @@ -0,0 +1 @@ +This is research project, this is not official google project. \ No newline at end of file diff --git a/jaxonnxruntime/core/onnx_primitive.py b/jaxonnxruntime/core/onnx_primitive.py new file mode 100644 index 0000000..b660f0c --- /dev/null +++ b/jaxonnxruntime/core/onnx_primitive.py @@ -0,0 +1,33 @@ +# Copyright 2023 The Jaxonnxruntime Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 The Jaxonnxruntime Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Define Onnx Primitive class from jax.core.Primitive.""" +from jax import core + + +class OnnxPrimitive(core.Primitive): + multiple_results: bool = True diff --git a/jaxonnxruntime/onnx_ops/abs.py b/jaxonnxruntime/onnx_ops/abs.py index a983542..54a808e 100644 --- a/jaxonnxruntime/onnx_ops/abs.py +++ b/jaxonnxruntime/onnx_ops/abs.py @@ -26,15 +26,20 @@ # See the License for the specific language governing permissions and # limitations under the License. """Define ONNX Abs operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test from collections.abc import Callable, Sequence import functools import inspect from typing import Any +import jax from jax import jit from jax import numpy as jnp +from jax._src.interpreters import mlir from jaxonnxruntime.core import handler from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.core import onnx_primitive @handler.register_op("Abs") @@ -58,14 +63,55 @@ def _prepare( def version_13( cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] ) -> Callable[..., Any]: - """ONNX version_13 Abs op.""" + """ONNX version 13 Abs op.""" cls._prepare(node, inputs, onnx_abs) return onnx_abs @functools.partial(jit, static_argnames=()) -def onnx_abs(*input_args): - """The internal jax impl for onnx Abs op.""" - assert len(input_args) == 1 - (x,) = input_args +def onnx_abs(*args): + """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Abs for more details.""" + + if len(args) != 1: + raise ValueError( + f"len(args) should equal to 1 but got {len(args)}" + ) + all_args = args + + return onnx_abs_p.bind(*all_args) + +# Define onnx_abs_p primitive. +onnx_abs_p = onnx_primitive.OnnxPrimitive("onnx_abs") +onnx_abs_p.multiple_results = False + + +@onnx_abs_p.def_impl +def _onnx_abs_impl(*args): + x = args[0] return jnp.abs(x) + + +@onnx_abs_p.def_abstract_eval +def _onnx_abs_abstract_eval(*args): + aval_args = jax.tree_map( + lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), args + ) + out = jax.eval_shape(_onnx_abs_impl, *aval_args) + return jax.tree_map( + lambda x: jax.abstract_arrays.ShapedArray(x.shape, x.dtype), out + ) + + +def _onnx_abs_lowering(ctx, *args, platform): + """abs lowering rule.""" + jit_func = jax.jit(_onnx_abs_impl) + jit_func_lowering = mlir.lower_fun(jit_func, onnx_abs_p.multiple_results) + return mlir.delegate_lowering(ctx, jit_func_lowering, *args) + + +for _p in ("cpu", "tpu", "cuda", "rocm"): + mlir.register_lowering( + onnx_abs_p, + functools.partial(_onnx_abs_lowering, platform=_p), + platform=_p, + ) diff --git a/jaxonnxruntime/runner.py b/jaxonnxruntime/runner.py index 7bef91d..f968d39 100644 --- a/jaxonnxruntime/runner.py +++ b/jaxonnxruntime/runner.py @@ -35,9 +35,8 @@ from onnx import numpy_helper -NodeProto = onnx.NodeProto -ModelProto = onnx.ModelProto jax.config.update('jax_enable_x64', True) +jax.config.update('jax_numpy_rank_promotion', 'warn') class TestItem: @@ -45,7 +44,7 @@ class TestItem: def __init__( self, func: Callable[..., Any], - proto: list[Optional[Union[ModelProto, NodeProto]]], + proto: list[Optional[Union[onnx.ModelProto, onnx.NodeProto]]], ) -> None: self.func = func self.proto = proto @@ -268,7 +267,7 @@ def _add_test( category: str, test_name: str, test_func: Callable[..., Any], - report_item: list[Optional[Union[ModelProto, NodeProto]]], + report_item: list[Optional[Union[onnx.ModelProto, onnx.NodeProto]]], devices: Iterable[str] = ('CPU', 'CUDA'), ) -> None: """Add test to each device and category.""" @@ -309,7 +308,9 @@ def device_test_func(*args: Any, **kwargs: Any) -> Any: def _add_model_test(self, model_test, kind: str) -> None: """model is loaded at runtime, note sometimes it could even never loaded if the test skipped.""" - model_marker: list[Optional[Union[ModelProto, NodeProto]]] = [None] + model_marker: list[Optional[Union[onnx.ModelProto, onnx.NodeProto]]] = [ + None + ] def run(test_self: Any, device: str) -> None: # pylint: disable=unused-argument model_dir = model_test.model_dir diff --git a/tools/op_code_generator.py b/tools/op_code_generator.py index 7cb0308..4818c5f 100644 --- a/tools/op_code_generator.py +++ b/tools/op_code_generator.py @@ -12,17 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Copyright 2023 The Jaxonnxruntime Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Help simplify the onnx op develoment, example cmd.""" # Example cmd: `python op_code_generator.py Add` import argparse import logging import os import re +from jinja2 import Template import onnx +root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +root_dir = os.path.join(root_dir, 'jaxonnxruntime') +op_schema_dict = { + str(op_schema.name): op_schema for op_schema in onnx.defs.get_all_schemas() +} # define the template for the operator implementation -template_head = """# Copyright 2023 The Jaxonnxruntime Authors. +template = Template("""\ +# Copyright 2023 The Jaxonnxruntime Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -35,59 +56,136 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -\"\"\"Define ONNX {op_name} operator.\"\"\" +\"\"\"Define ONNX {{op_name}} operator.\"\"\" # pylint: disable=unused-argument # pylint: disable=g-explicit-length-test +from collections.abc import Callable, Sequence import functools import inspect -from collections.abc import Callable, Sequence from typing import Any +import jax from jax import jit from jax import numpy as jnp +from jax._src.interpreters import mlir from jaxonnxruntime.core import handler from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.core import onnx_primitive -@handler.register_op("{op_name}") -class {op_name}(handler.Handler): - \"\"\"Implementation of the ONNX {op_name} operator.\"\"\" +@handler.register_op("{{op_name}}") +class {{op_name}}(handler.Handler): + \"\"\"Implementation of the ONNX {{op_name}} operator.\"\"\" @classmethod - def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + def _prepare( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any + ): sig = inspect.signature(onnx_jax_impl) - kwparams = [param.name for param in sig.parameters.values() if param.kind == inspect.Parameter.KEYWORD_ONLY] + kwparams = [ + param.name + for param in sig.parameters.values() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ] for name in kwparams: node.attrs_dict[name] = node.attrs.get(name, None) -""" - -template_version_func = """ +{% for version in versions %} @classmethod - def version_{version}(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: - \"\"\"ONNX version_{version} {op_name} op.\"\"\" - cls._prepare(node, inputs, onnx_{op_name_lower}) - return onnx_{op_name_lower} -""" - -template_tail = """ - -@functools.partial(jit, static_argnames=()) -def onnx_{op_name_lower}(*input_args): - \"\"\"https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#{op_name} for more details.\"\"\" - # TODO({username}): add the implementation here. + def version_{{ version }}( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] + ) -> Callable[..., Any]: + \"\"\"ONNX version {{version}} {{ op_name }} op.\"\"\" + cls._prepare(node, inputs, onnx_{{ op_name|lower }}) + return onnx_{{ op_name|lower }} +{% endfor %} + +@functools.partial(jit, static_argnames=({{static_arg_attr_list}})) +def onnx_{{op_name|lower}}(*args{{attr_list}}): + \"\"\"https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#{{op_name}} for more details.\"\"\" +{% if min_input != max_input %} + if len(args) < {{min_input}} or len(args) > {{max_input}}: + raise ValueError( + f"len(args) should be within [{{min_input}}, {{max_input}}] but got {len(args)}" + ) + all_args = args + [None] * ({{inputs_name|length}} - len(args)) +{% else %} + if len(args) != {{min_input}}: + raise ValueError( + f"len(args) should equal to {{min_input}} but got {len(args)}" + ) + all_args = args +{% endif %} + return onnx_{{op_name|lower}}_p.bind(*all_args) + +# Define onnx_{{op_name|lower}}_p primitive. +onnx_{{op_name|lower}}_p = onnx_primitive.OnnxPrimitive("onnx_{{op_name|lower}}") +onnx_{{op_name|lower}}_p.multiple_results = False + + +@onnx_{{op_name|lower}}_p.def_impl +def _onnx_{{op_name|lower}}_impl(*args): + # TODO({{username}}): add the implementation here. # Then update the onnx_ops_teset.py to include it, - # `include_patterns.append('test_{op_name_lower}_')`. - return input_args -""" + # `include_patterns.append('test_{{op_name|lower}}_')`. + return -root_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -root_dir = os.path.join(root_dir, 'jaxonnxruntime') -# BEGIN GOOGLE-INTERNAL -root_dir = os.path.dirname(root_dir) -# END GOOGLE-INTERNAL -op_schema_set = { - str(op_schema.name) for op_schema in onnx.defs.get_all_schemas() -} + +@onnx_{{op_name|lower}}_p.def_abstract_eval +def _onnx_{{op_name|lower}}_abstract_eval(*args): + aval_args = jax.tree_map( + lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), args + ) + out = jax.eval_shape(_onnx_abs_impl, *aval_args) + return jax.tree_map( + lambda x: jax.abstract_arrays.ShapedArray(x.shape, x.dtype), out + ) + + +def _onnx_{{op_name|lower}}_lowering(ctx, *args, platform): + \"\"\"{{op_name|lower}} lowering rule.\"\"\" + jit_func = jax.jit(_onnx_{{op_name|lower}}_impl) + jit_func_lowering = mlir.lower_fun(jit_func, onnx_{{op_name|lower}}_p.multiple_results) + return mlir.delegate_lowering(ctx, jit_func_lowering, *args) + + +for _p in ("cpu", "tpu", "cuda", "rocm"): + mlir.register_lowering( + onnx_{{op_name|lower}}_p, + functools.partial(_onnx_{{op_name|lower}}_lowering, platform=_p), + platform=_p, + ) + +""") + + +def create_op_schema_render_dict(op_name): + """Create the render_dict for the template.""" + username = os.environ['USER'] + assert op_name in op_schema_dict, f'{op_name} is not legal ONNX op name.' + op_schema = op_schema_dict[op_name] + versions = [op_schema.since_version] if op_schema else [] + render_dict = { + 'op_name': op_name, + 'username': username, + 'min_input': op_schema.min_input, + 'max_input': op_schema.max_input, + 'min_output': op_schema.min_output, + 'max_output': op_schema.max_output, + 'attribute_list': list(op_schema.attributes.keys()), + 'deprecated': op_schema.deprecated, + 'doc': op_schema.doc, + 'domain': op_schema.domain, + 'versions': versions, + 'inputs_name': [i.name.lower() for i in op_schema.inputs], + } + attr_list = list(op_schema.attributes.keys()) + render_dict['attr_list'] = ( + (', ' + ', '.join(attr_list)) if len(attr_list) else '' + ) + render_dict['static_arg_attr_list'] = ', '.join( + ["'" + name + "'" for name in attr_list] + ) + return render_dict def update_onnx_ops_init_file(op_name): @@ -122,27 +220,13 @@ def update_onnx_ops_init_file(op_name): def main(args): # get the version list for the ONNX operator op_name = args.op_name - if str(op_name) not in op_schema_set: + if str(op_name) not in op_schema_dict: raise ValueError( - f'ONNX {op_name} is not ONNX op list {sorted(op_schema_set)}?.' - ) - schema = onnx.defs.get_schema(op_name) - versions = [schema.since_version] if schema else [] - username = os.environ['USER'] - - # Render the template and create new op file under onnx_ops folder. - code = template_head.format( - op_name=op_name, op_name_lower=op_name.lower(), username=username - ) - for version in versions: - code += template_version_func.format( - version=version, op_name_lower=op_name.lower(), op_name=op_name + f'ONNX {op_name} is not valid ONNX op', + f'see full list {sorted(op_schema_dict.keys())}.', ) - code += template_tail.format( - op_name_lower=op_name.lower(), - op_name=op_name, - username=username, - ) + render_dict = create_op_schema_render_dict(op_name) + code = template.render(**render_dict) logging.info('Genereate new code=\n%s', code) op_def_path = os.path.join(root_dir, f'onnx_ops/{op_name.lower()}.py') with open(op_def_path, 'w') as f: