diff --git a/jaxonnxruntime/experimental/export/exportable_utils.py b/jaxonnxruntime/experimental/export/exportable_utils.py index ee997bf..7166f7b 100644 --- a/jaxonnxruntime/experimental/export/exportable_utils.py +++ b/jaxonnxruntime/experimental/export/exportable_utils.py @@ -100,7 +100,7 @@ def torch_tensor_to_jax_array( return jax.dlpack.from_dlpack(tensor) -def serialize_stablehlo_mlir_str(mlir_str: str) -> bytes: +def serialize_stablehlo_mlir_str(mlir_str: str | bytes) -> bytes: """Serializes a StableHLO MLIR module string to a bytecode.""" # https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md # `stablehlo.get_minimum_version()` returns `consumer_version_min` diff --git a/jaxonnxruntime/experimental/export/tensorflow_exportable.py b/jaxonnxruntime/experimental/export/tensorflow_exportable.py index 3905528..5fa397f 100644 --- a/jaxonnxruntime/experimental/export/tensorflow_exportable.py +++ b/jaxonnxruntime/experimental/export/tensorflow_exportable.py @@ -17,7 +17,6 @@ import dataclasses import jax -from jax.lib import xla_client from jax.lib import xla_extension from jaxonnxruntime.experimental.export import exportable from jaxonnxruntime.experimental.export import exportable_utils @@ -155,7 +154,7 @@ def nr_devices(self) -> int: return len(jax.devices()) @property - def mlir_module_str(self) -> str: + def mlir_module_str(self) -> bytes: """Returns the mlir module from TF.""" args_tf_flat = jax.tree_util.tree_map( lambda x: x.tensor if isinstance(x, TensorWithSharding) else x, @@ -165,8 +164,7 @@ def mlir_module_str(self) -> str: *args_tf_flat )(stage="hlo_serialized", platform_name=self.tf_platform) - xla_comp = xla_client.XlaComputation(func_tf_hlo) - mlir_str = xla_extension.mlir.xla_computation_to_mlir_module(xla_comp) + mlir_str = xla_extension.mlir.hlo_to_stablehlo(func_tf_hlo) return mlir_str @property