diff --git a/jaxonnxruntime/experimental/call_torch/call_torch_xla.py b/jaxonnxruntime/experimental/call_torch/call_torch_xla.py index 7c321c1..1ac7d8b 100644 --- a/jaxonnxruntime/experimental/call_torch/call_torch_xla.py +++ b/jaxonnxruntime/experimental/call_torch/call_torch_xla.py @@ -204,19 +204,11 @@ def refine_polymorphic_shapes( Returns: The refined module. """ - if xc.mlir_api_version >= 53: - refined_module_str = xla_extension.mlir.refine_polymorphic_shapes( - mlir.module_to_bytecode(module), - enable_shape_assertions=validate_static_shapes, - validate_static_shapes=validate_static_shapes, - ) - elif xc.mlir_api_version >= 50: - refined_module_str = xla_extension.mlir.refine_polymorphic_shapes( - mlir.module_to_bytecode(module) - ) - else: - raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12") - + refined_module_str = xla_extension.mlir.refine_polymorphic_shapes( + mlir.module_to_bytecode(module), + enable_shape_assertions=validate_static_shapes, + validate_static_shapes=validate_static_shapes, + ) context = mlir.make_ir_context() with context: return ir.Module.parse(refined_module_str)