Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jaxonnxruntime/experimental/export/exportable_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def torch_tensor_to_jax_array(
return jax.dlpack.from_dlpack(tensor)


def serialize_stablehlo_mlir_str(mlir_str: str) -> bytes:
def serialize_stablehlo_mlir_str(mlir_str: str | bytes) -> bytes:
"""Serializes a StableHLO MLIR module string to a bytecode."""
# https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
# `stablehlo.get_minimum_version()` returns `consumer_version_min`
Expand Down
6 changes: 2 additions & 4 deletions jaxonnxruntime/experimental/export/tensorflow_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import dataclasses

import jax
from jax.lib import xla_client
from jax.lib import xla_extension
from jaxonnxruntime.experimental.export import exportable
from jaxonnxruntime.experimental.export import exportable_utils
Expand Down Expand Up @@ -155,7 +154,7 @@ def nr_devices(self) -> int:
return len(jax.devices())

@property
def mlir_module_str(self) -> str:
def mlir_module_str(self) -> bytes:
"""Returns the mlir module from TF."""
args_tf_flat = jax.tree_util.tree_map(
lambda x: x.tensor if isinstance(x, TensorWithSharding) else x,
Expand All @@ -165,8 +164,7 @@ def mlir_module_str(self) -> str:
*args_tf_flat
)(stage="hlo_serialized", platform_name=self.tf_platform)

xla_comp = xla_client.XlaComputation(func_tf_hlo)
mlir_str = xla_extension.mlir.xla_computation_to_mlir_module(xla_comp)
mlir_str = xla_extension.mlir.hlo_to_stablehlo(func_tf_hlo)
return mlir_str

@property
Expand Down