diff --git a/jaxonnxruntime/onnx_ops/flatten.py b/jaxonnxruntime/onnx_ops/flatten.py index ca644b8..f777392 100644 --- a/jaxonnxruntime/onnx_ops/flatten.py +++ b/jaxonnxruntime/onnx_ops/flatten.py @@ -51,6 +51,14 @@ def version_1( cls._prepare(node, inputs, onnx_flatten) return onnx_flatten + @classmethod + def version_9( + cls, node: onnx_node.OnnxNode, inputs: Sequence[Any] + ) -> Callable[..., Any]: + """ONNX version_9 Flatten op.""" + cls._prepare(node, inputs, onnx_flatten) + return onnx_flatten + @classmethod def version_11( cls, node: onnx_node.OnnxNode, inputs: Sequence[Any]