diff --git a/jaxonnxruntime/experimental/call_torch/call_torch_xla.py b/jaxonnxruntime/experimental/call_torch/call_torch_xla.py index e6caa85..edf5c2f 100644 --- a/jaxonnxruntime/experimental/call_torch/call_torch_xla.py +++ b/jaxonnxruntime/experimental/call_torch/call_torch_xla.py @@ -201,18 +201,11 @@ def refine_polymorphic_shapes( Returns: The refined module. """ - if xc.mlir_api_version >= 53: - refined_module_str = xc._xla.mlir.refine_polymorphic_shapes( # pylint: disable=protected-access - mlir.module_to_bytecode(module), - enable_shape_assertions=validate_static_shapes, - validate_static_shapes=validate_static_shapes, - ) - elif xc.mlir_api_version >= 50: - refined_module_str = xc._xla.mlir.refine_polymorphic_shapes( # pylint: disable=protected-access - mlir.module_to_bytecode(module) - ) - else: - raise NotImplementedError("refine_polymorphic_shapes needs jaxlib 0.4.12") + refined_module_str = xc._xla.mlir.refine_polymorphic_shapes( # pylint: disable=protected-access + mlir.module_to_bytecode(module), + enable_shape_assertions=validate_static_shapes, + validate_static_shapes=validate_static_shapes, + ) context = mlir.make_ir_context() with context: