diff --git a/jaxonnxruntime/__init__.py b/jaxonnxruntime/__init__.py index fe5349f..7fa74c1 100644 --- a/jaxonnxruntime/__init__.py +++ b/jaxonnxruntime/__init__.py @@ -50,6 +50,7 @@ from jaxonnxruntime.onnx_ops import div from jaxonnxruntime.onnx_ops import dropout from jaxonnxruntime.onnx_ops import einsum +from jaxonnxruntime.onnx_ops import elu from jaxonnxruntime.onnx_ops import equal from jaxonnxruntime.onnx_ops import erf from jaxonnxruntime.onnx_ops import exp @@ -61,6 +62,7 @@ from jaxonnxruntime.onnx_ops import globalaveragepool from jaxonnxruntime.onnx_ops import greater from jaxonnxruntime.onnx_ops import greaterorequal +from jaxonnxruntime.onnx_ops import hardsigmoid from jaxonnxruntime.onnx_ops import identity from jaxonnxruntime.onnx_ops import if_op from jaxonnxruntime.onnx_ops import leakyrelu @@ -94,7 +96,9 @@ from jaxonnxruntime.onnx_ops import scatternd from jaxonnxruntime.onnx_ops import selu from jaxonnxruntime.onnx_ops import shape +from jaxonnxruntime.onnx_ops import shrink from jaxonnxruntime.onnx_ops import sigmoid +from jaxonnxruntime.onnx_ops import sign from jaxonnxruntime.onnx_ops import sin from jaxonnxruntime.onnx_ops import sinh from jaxonnxruntime.onnx_ops import slice @@ -105,6 +109,7 @@ from jaxonnxruntime.onnx_ops import squeeze from jaxonnxruntime.onnx_ops import sub from jaxonnxruntime.onnx_ops import sum +from jaxonnxruntime.onnx_ops import tan from jaxonnxruntime.onnx_ops import tanh from jaxonnxruntime.onnx_ops import tile from jaxonnxruntime.onnx_ops import topk diff --git a/jaxonnxruntime/onnx_ops/elu.py b/jaxonnxruntime/onnx_ops/elu.py new file mode 100644 index 0000000..ead0ef0 --- /dev/null +++ b/jaxonnxruntime/onnx_ops/elu.py @@ -0,0 +1,48 @@ +"""Define ONNX Elu operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test +import functools +from collections.abc import Callable, Sequence +from typing import Any + +import jax +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.onnx_ops import onnx_ops_utils + + +@handler.register_op("Elu") +class Elu(handler.Handler): + """Implementation of the ONNX Elu operator.""" + + @classmethod + def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + node.attrs_dict['alpha'] = node.attrs.get( + 'alpha', 1.0 + ) + + @classmethod + def version_1(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_1 Elu op.""" + cls._prepare(node, inputs, onnx_elu) + return onnx_elu + + @classmethod + def version_6(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_6 Elu op.""" + cls._prepare(node, inputs, onnx_elu) + return onnx_elu + + @classmethod + def version_22(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_22 Elu op.""" + cls._prepare(node, inputs, onnx_elu) + return onnx_elu + + +@functools.partial(jax.jit, static_argnames=()) +def onnx_elu(*input_args, alpha): + """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Elu for more details.""" + assert len(input_args) == 1 + data = input_args[0] + return jax.nn.elu(data, alpha) diff --git a/jaxonnxruntime/onnx_ops/hardsigmoid.py b/jaxonnxruntime/onnx_ops/hardsigmoid.py new file mode 100644 index 0000000..7877b6a --- /dev/null +++ b/jaxonnxruntime/onnx_ops/hardsigmoid.py @@ -0,0 +1,48 @@ +"""Define ONNX HardSigmoid operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test +import functools +from collections.abc import Callable, Sequence +from typing import Any + +import jax +from jax import numpy as jnp +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.onnx_ops import onnx_ops_utils + + +@handler.register_op("HardSigmoid") +class HardSigmoid(handler.Handler): + """Implementation of the ONNX HardSigmoid operator.""" + + @classmethod + def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + node.attrs_dict["alpha"] = node.attrs.get("alpha", 0.2) + node.attrs_dict["beta"] = node.attrs.get("beta", 0.5) + + @classmethod + def version_1(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_1 HardSigmoid op.""" + cls._prepare(node, inputs, onnx_hardsigmoid) + return onnx_hardsigmoid + + @classmethod + def version_6(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_6 HardSigmoid op.""" + cls._prepare(node, inputs, onnx_hardsigmoid) + return onnx_hardsigmoid + + @classmethod + def version_22(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_22 HardSigmoid op.""" + cls._prepare(node, inputs, onnx_hardsigmoid) + return onnx_hardsigmoid + + +@functools.partial(jax.jit, static_argnames=()) +def onnx_hardsigmoid(*input_args, alpha, beta): + """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#HardSigmoid for more details.""" + assert len(input_args) == 1 + data = input_args[0] + return jnp.maximum(0, jnp.minimum(1, data * alpha + beta)).astype(data.dtype) diff --git a/jaxonnxruntime/onnx_ops/shrink.py b/jaxonnxruntime/onnx_ops/shrink.py new file mode 100644 index 0000000..02b4d54 --- /dev/null +++ b/jaxonnxruntime/onnx_ops/shrink.py @@ -0,0 +1,40 @@ +"""Define ONNX Shrink operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test +import functools +from collections.abc import Callable, Sequence +from typing import Any + +import jax +from jax import numpy as jnp +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.onnx_ops import onnx_ops_utils + + +@handler.register_op("Shrink") +class Shrink(handler.Handler): + """Implementation of the ONNX Shrink operator.""" + + @classmethod + def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + node.attrs_dict['bias'] = node.attrs.get('bias', 0.0) + node.attrs_dict['lambd'] = node.attrs.get('lambd', 0.5) + + @classmethod + def version_9(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_9 Shrink op.""" + cls._prepare(node, inputs, onnx_shrink) + return onnx_shrink + + +@functools.partial(jax.jit, static_argnames=()) +def onnx_shrink(*input_args, bias, lambd): + """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Shrink for more details.""" + assert len(input_args) == 1 + data = input_args[0] + return jnp.where( + data < -lambd, + data + bias, + jnp.where(data > lambd, data - bias, 0), + ).astype(data.dtype) diff --git a/jaxonnxruntime/onnx_ops/sign.py b/jaxonnxruntime/onnx_ops/sign.py new file mode 100644 index 0000000..dfaa8c8 --- /dev/null +++ b/jaxonnxruntime/onnx_ops/sign.py @@ -0,0 +1,41 @@ +"""Define ONNX Sign operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test +import functools +from collections.abc import Callable, Sequence +from typing import Any + +import jax +from jax import numpy as jnp +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.onnx_ops import onnx_ops_utils + + +@handler.register_op("Sign") +class Sign(handler.Handler): + """Implementation of the ONNX Sign operator.""" + + @classmethod + def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) + + @classmethod + def version_9(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_9 Sign op.""" + cls._prepare(node, inputs, onnx_sign) + return onnx_sign + + @classmethod + def version_13(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_13 Sign op.""" + cls._prepare(node, inputs, onnx_sign) + return onnx_sign + + +@functools.partial(jax.jit, static_argnames=()) +def onnx_sign(*input_args): + """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Sign for more details.""" + assert len(input_args) == 1 + data = input_args[0] + return jnp.sign(data) diff --git a/jaxonnxruntime/onnx_ops/tan.py b/jaxonnxruntime/onnx_ops/tan.py new file mode 100644 index 0000000..f5f2e07 --- /dev/null +++ b/jaxonnxruntime/onnx_ops/tan.py @@ -0,0 +1,41 @@ +"""Define ONNX Tan operator.""" +# pylint: disable=unused-argument +# pylint: disable=g-explicit-length-test +import functools +from collections.abc import Callable, Sequence +from typing import Any + +import jax +from jax import numpy as jnp +from jaxonnxruntime.core import handler +from jaxonnxruntime.core import onnx_node +from jaxonnxruntime.onnx_ops import onnx_ops_utils + + +@handler.register_op("Tan") +class Tan(handler.Handler): + """Implementation of the ONNX Tan operator.""" + + @classmethod + def _prepare(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any], onnx_jax_impl: Any): + onnx_ops_utils.update_node_attrs_dict(node, onnx_jax_impl) + + @classmethod + def version_7(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_7 Tan op.""" + cls._prepare(node, inputs, onnx_tan) + return onnx_tan + + @classmethod + def version_22(cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]) -> Callable[..., Any]: + """ONNX version_22 Tan op.""" + cls._prepare(node, inputs, onnx_tan) + return onnx_tan + + +@functools.partial(jax.jit, static_argnames=()) +def onnx_tan(*input_args): + """https://github.com/onnx/onnx/blob/v1.12.0/docs/Operators.md#Tan for more details.""" + assert len(input_args) == 1 + data = input_args[0] + return jnp.tan(data) diff --git a/tests/onnx_ops_test.py b/tests/onnx_ops_test.py index 33e21b8..46779ba 100644 --- a/tests/onnx_ops_test.py +++ b/tests/onnx_ops_test.py @@ -95,6 +95,7 @@ class NodeTest(absltest.TestCase): include_patterns.append('test_div_') include_patterns.append('test_dropout_') include_patterns.append('test_einsum_') +include_patterns.append('test_elu_') include_patterns.append('test_equal_') include_patterns.append('test_erf_') include_patterns.append('test_exp_') @@ -106,6 +107,7 @@ class NodeTest(absltest.TestCase): include_patterns.append('test_globalaveragepool_') include_patterns.append('test_greater_') include_patterns.append('test_greaterorequal_') +include_patterns.append('test_hardsigmoid_') include_patterns.append('test_identity_') include_patterns.append('test_if_') include_patterns.append('test_leakyrelu_') @@ -137,7 +139,9 @@ class NodeTest(absltest.TestCase): include_patterns.append('test_scatternd_') include_patterns.append('test_selu_') include_patterns.append('test_shape_') +include_patterns.append('test_shrink') include_patterns.append('test_sigmoid_') +include_patterns.append('test_sign_') include_patterns.append('test_sin_') include_patterns.append('test_sinh_') include_patterns.append('test_slice_') @@ -148,6 +152,7 @@ class NodeTest(absltest.TestCase): include_patterns.append('test_squeeze_') include_patterns.append('test_sub_') include_patterns.append('test_sum_') +include_patterns.append('test_tan_') include_patterns.append('test_tanh_') include_patterns.append('test_tile_') include_patterns.append('test_top_k_')