diff --git a/jaxonnxruntime/experimental/export/exportable_utils.py b/jaxonnxruntime/experimental/export/exportable_utils.py index a5d8a15..77176e6 100644 --- a/jaxonnxruntime/experimental/export/exportable_utils.py +++ b/jaxonnxruntime/experimental/export/exportable_utils.py @@ -14,20 +14,19 @@ """jax.export.Exported utils.""" -import io import os +from typing import Any import jax from jax import export as jax_export from jax import numpy as jnp -from jax.lib import xla_client from jaxlib.mlir import ir from mlir.dialects import stablehlo import tensorflow as tf import torch MLIRModule = ir.Module -HloSharding = xla_client.HloSharding | None +HloSharding = Any | None Sharding = jax.sharding.Sharding | None