From b2a1c225549fb3f12927eef41690db1dcdc1baf7 Mon Sep 17 00:00:00 2001 From: soffer-anyscale Date: Fri, 21 Nov 2025 12:28:05 -0700 Subject: [PATCH] Add Polars batch format support to map_batches - Add 'polars' to VALID_BATCH_FORMATS - Implement to_polars() methods in ArrowBlockAccessor and PandasBlockAccessor - Add batch_to_block_from_polars() for converting Polars DataFrames to blocks - Update _validate_batch_output() to accept Polars DataFrames - Add comprehensive validation for LazyFrame, Series, and invalid types - Update documentation with Polars examples and performance notes - Add tests for Polars format in map_batches, iter_batches, and take_batch - Include URLs to Polars documentation in docstrings - Document that Polars format always creates copies (no zero-copy support) Signed-off-by: soffer-anyscale --- doc/source/data/inspecting-data.rst | 25 +++ doc/source/data/iterating-over-data.rst | 25 +++ doc/source/data/transforming-data.rst | 19 +- python/ray/data/_internal/arrow_block.py | 34 ++++ python/ray/data/_internal/pandas_block.py | 50 ++++++ .../data/_internal/planner/plan_udf_map_op.py | 50 +++++- python/ray/data/block.py | 162 +++++++++++++++++- python/ray/data/dataset.py | 43 ++++- .../data/tests/block_batching/test_util.py | 11 +- python/ray/data/tests/test_consumption.py | 24 +++ python/ray/data/tests/test_map_batches.py | 40 +++++ 11 files changed, 470 insertions(+), 13 deletions(-) diff --git a/doc/source/data/inspecting-data.rst b/doc/source/data/inspecting-data.rst index 0936204fc655..caddd8a13343 100644 --- a/doc/source/data/inspecting-data.rst +++ b/doc/source/data/inspecting-data.rst @@ -129,6 +129,31 @@ of the returned batch, set ``batch_format``. 0 5.1 3.5 ... 0.2 0 1 4.9 3.0 ... 0.2 0 + .. tab-item:: Polars + + .. testcode:: + + import ray + import polars as pl + + ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv") + + batch = ds.take_batch(batch_size=2, batch_format="polars") + print(batch) + + .. testoutput:: + :options: +MOCK + + shape: (2, 5) + ┌─────────────────┬────────────────┬─────────────────┬────────────────┬────────┐ + │ sepal length... │ sepal width... │ petal length... │ petal width... │ target │ + │ --- │ --- │ --- │ --- │ --- │ + │ f64 │ f64 │ f64 │ f64 │ i64 │ + ╞═════════════════╪════════════════╪═════════════════╪════════════════╪════════╡ + │ 5.1 │ 3.5 │ 1.4 │ 0.2 │ 0 │ + │ 4.9 │ 3.0 │ 1.4 │ 0.2 │ 0 │ + └─────────────────┴────────────────┴─────────────────┴────────────────┴────────┘ + For more information on working with batches, see :ref:`Transforming batches ` and :ref:`Iterating over batches `. diff --git a/doc/source/data/iterating-over-data.rst b/doc/source/data/iterating-over-data.rst index 52af861eae43..f73c43b43587 100644 --- a/doc/source/data/iterating-over-data.rst +++ b/doc/source/data/iterating-over-data.rst @@ -94,6 +94,31 @@ formats by calling one of the following methods: sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target 0 5.1 3.5 1.4 0.2 0 1 4.9 3.0 1.4 0.2 0 + + .. tab-item:: Polars + + .. testcode:: + + import ray + import polars as pl + + ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv") + + for batch in ds.iter_batches(batch_size=2, batch_format="polars"): + print(batch) + + .. testoutput:: + :options: +MOCK + + shape: (2, 5) + ┌─────────────────┬────────────────┬─────────────────┬────────────────┬────────┐ + │ sepal length... │ sepal width... │ petal length... │ petal width... │ target │ + │ --- │ --- │ --- │ --- │ --- │ + │ f64 │ f64 │ f64 │ f64 │ i64 │ + ╞═════════════════╪════════════════╪═════════════════╪════════════════╪════════╡ + │ 5.1 │ 3.5 │ 1.4 │ 0.2 │ 0 │ + │ 4.9 │ 3.0 │ 1.4 │ 0.2 │ 0 │ + └─────────────────┴────────────────┴─────────────────┴────────────────┴────────┘ ... sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target 0 6.2 3.4 5.4 2.3 2 diff --git a/doc/source/data/transforming-data.rst b/doc/source/data/transforming-data.rst index a06bce0cd238..b4550008badb 100644 --- a/doc/source/data/transforming-data.rst +++ b/doc/source/data/transforming-data.rst @@ -142,7 +142,7 @@ batches is more performant than transforming rows. Configuring batch format ~~~~~~~~~~~~~~~~~~~~~~~~ -Ray Data represents batches as dicts of NumPy ndarrays or pandas DataFrames. By +Ray Data represents batches as dicts of NumPy ndarrays, pandas DataFrames, or Polars DataFrames. By default, Ray Data represents batches as dicts of NumPy ndarrays. To configure the batch type, specify ``batch_format`` in :meth:`~ray.data.Dataset.map_batches`. You can return either format from your function, but ``batch_format`` should match the input of your function. @@ -181,9 +181,24 @@ format from your function, but ``batch_format`` should match the input of your f .map_batches(drop_nas, batch_format="pandas") ) + .. tab-item:: Polars + + .. testcode:: + + import polars as pl + import ray + + def drop_nas(batch: pl.DataFrame) -> pl.DataFrame: + return batch.drop_nulls() + + ds = ( + ray.data.read_csv("s3://anonymous@air-example-data/iris.csv") + .map_batches(drop_nas, batch_format="polars") + ) + The user defined function you pass to :meth:`~ray.data.Dataset.map_batches` is more flexible. Because you can represent batches in multiple ways (see :ref:`Configuring batch format `), the function should be of type -``Callable[DataBatch, DataBatch]``, where ``DataBatch = Union[pd.DataFrame, Dict[str, np.ndarray]]``. In +``Callable[DataBatch, DataBatch]``, where ``DataBatch = Union[pd.DataFrame, pl.DataFrame, Dict[str, np.ndarray]]``. In other words, your function should take as input and output a batch of data which you can represent as a pandas DataFrame or a dictionary with string keys and NumPy ndarrays values. For example, your function might look like: diff --git a/python/ray/data/_internal/arrow_block.py b/python/ray/data/_internal/arrow_block.py index 99a8bc4dc57d..206610f5ac4f 100644 --- a/python/ray/data/_internal/arrow_block.py +++ b/python/ray/data/_internal/arrow_block.py @@ -49,6 +49,7 @@ if TYPE_CHECKING: import pandas + import polars from ray.data._internal.planner.exchange.sort_task_spec import SortKey @@ -305,6 +306,39 @@ def to_numpy( def to_arrow(self) -> "pyarrow.Table": return self._table + def to_polars(self) -> "polars.DataFrame": + """Convert this Arrow block into a Polars DataFrame. + + Converts a PyArrow Table to a Polars DataFrame. See + https://docs.pola.rs/ for Polars documentation. + + Note: This conversion creates a copy of the data. Zero-copy conversion + from Arrow to Polars is not possible. + + Returns: + A Polars DataFrame containing the data. + + Raises: + ImportError: If Polars is not installed. + """ + try: + import polars as pl + except ImportError: + raise ImportError( + "Polars is not installed. Install with `pip install polars`. " + "See https://docs.pola.rs/ for more information." + ) + + # Combine chunks for better performance and compatibility + # Polars works better with contiguous arrays + from ray.data._internal.arrow_ops import transform_pyarrow + + combined_table = transform_pyarrow.combine_chunks(self._table, copy=False) + + # Convert to Polars DataFrame using from_arrow() + # See https://docs.pola.rs/api/dataframe/#polars.DataFrame.from_arrow + return pl.from_arrow(combined_table) + def num_rows(self) -> int: # Arrow may represent an empty table via an N > 0 row, 0-column table, e.g. when # slicing an empty table, so we return 0 if num_columns == 0. diff --git a/python/ray/data/_internal/pandas_block.py b/python/ray/data/_internal/pandas_block.py index bef681f70ad4..f99841bee7de 100644 --- a/python/ray/data/_internal/pandas_block.py +++ b/python/ray/data/_internal/pandas_block.py @@ -38,6 +38,7 @@ if TYPE_CHECKING: import pandas + import polars import pyarrow from ray.data._internal.planner.exchange.sort_task_spec import SortKey @@ -479,6 +480,55 @@ def to_arrow(self) -> "pyarrow.Table": return arrow_table + def to_polars(self) -> "polars.DataFrame": + """Convert this Pandas block into a Polars DataFrame. + + Converts a Pandas DataFrame to a Polars DataFrame. See + https://docs.pola.rs/ for Polars documentation. + + Note: This conversion creates a copy of the data. Zero-copy conversion + from Pandas to Polars is not possible. + + Returns: + A Polars DataFrame containing the data. + + Raises: + ImportError: If Polars is not installed. + ValueError: If the Pandas DataFrame has duplicate column names or + invalid column names. + """ + try: + import polars as pl + except ImportError: + raise ImportError( + "Polars is not installed. Install with `pip install polars`. " + "See https://docs.pola.rs/ for more information." + ) + + # Validate column names before conversion + # Polars doesn't allow duplicate column names + if len(self._table.columns) != len(set(self._table.columns)): + duplicates = [ + col for col in self._table.columns + if list(self._table.columns).count(col) > 1 + ] + raise ValueError( + f"Pandas DataFrame has duplicate column names: {duplicates}. " + "Rename duplicate columns before converting to Polars." + ) + + # Validate column names are strings + for col in self._table.columns: + if not isinstance(col, str): + raise ValueError( + f"Pandas DataFrame has non-string column name: {col} (type: {type(col)}). " + "All column names must be strings for Polars conversion." + ) + + # Convert to Polars DataFrame using from_pandas() + # See https://docs.pola.rs/api/dataframe/#polars.DataFrame.from_pandas + return pl.from_pandas(self._table) + def num_rows(self) -> int: return self._table.shape[0] diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index e95580571a06..518f2b2eff50 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -420,6 +420,29 @@ def _try_wrap_udf_exception(e: Exception, item: Any = None): def _validate_batch_output(batch: Block) -> None: + """Validate that a batch output from a UDF is a supported type. + + See https://docs.pola.rs/ for Polars documentation. + """ + # Check for Polars DataFrame + # Polars is an optional dependency, so we check for it here + try: + import polars as pl + + if isinstance(batch, pl.DataFrame): + # Polars DataFrames are valid - DataFrame is always eager + # LazyFrame is a separate class, so if we get here it's already a DataFrame + return + elif isinstance(batch, pl.LazyFrame): + raise ValueError( + "The `fn` you passed to `map_batches` returned a Polars LazyFrame. " + "LazyFrames must be collected before returning. Use `.collect()` to " + "materialize the LazyFrame into a DataFrame. " + "See https://docs.pola.rs/api/lazyframe/#collect for details." + ) + except ImportError: + pass + if not isinstance( batch, ( @@ -434,7 +457,7 @@ def _validate_batch_output(batch: Block) -> None: raise ValueError( "The `fn` you passed to `map_batches` returned a value of type " f"{type(batch)}. This isn't allowed -- `map_batches` expects " - "`fn` to return a `pandas.DataFrame`, `pyarrow.Table`, " + "`fn` to return a `pandas.DataFrame`, `polars.DataFrame`, `pyarrow.Table`, " "`numpy.ndarray`, `list`, or `dict[str, numpy.ndarray]`." ) @@ -510,8 +533,33 @@ def transform_fn( else: raise e from None else: + # Validate all yielded batches (for generators, validate each item) for out_batch in res: _validate_batch_output(out_batch) + # Additional validation: ensure Polars DataFrames are eager + # See https://docs.pola.rs/ for Polars documentation + try: + import polars as pl + + if isinstance(out_batch, pl.LazyFrame): + raise ValueError( + "Generator yielded a Polars LazyFrame. " + "All yielded frames must be materialized. " + "Call .collect() on LazyFrames before yielding. " + "See https://docs.pola.rs/api/lazyframe/#collect for details." + ) + elif isinstance(out_batch, pl.DataFrame): + # DataFrame is always eager, but verify it's valid + try: + # Access schema to ensure DataFrame is valid + _ = out_batch.schema + except Exception as e: + raise ValueError( + f"Polars DataFrame is in invalid state: {e}. " + "Ensure the DataFrame is properly constructed." + ) from e + except ImportError: + pass yield out_batch return transform_fn diff --git a/python/ray/data/block.py b/python/ray/data/block.py index cdc83c9fc205..d415d68c63f9 100644 --- a/python/ray/data/block.py +++ b/python/ray/data/block.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: import pandas + import polars import pyarrow from ray.data._internal.block_builder import BlockBuilder @@ -58,7 +59,11 @@ # Represents a single column of the ``Batch`` BatchColumn = Union[ - "pandas.Series", "np.ndarray", "pyarrow.Array", "pyarrow.ChunkedArray" + "pandas.Series", + "polars.Series", + "np.ndarray", + "pyarrow.Array", + "pyarrow.ChunkedArray", ] @@ -77,11 +82,17 @@ class BatchFormat(str, Enum): ARROW = "pyarrow" PANDAS = "pandas" NUMPY = "numpy" + POLARS = "polars" # User-facing data batch type. This is the data type for data that is supplied to and # returned from batch UDFs. -DataBatch = Union["pyarrow.Table", "pandas.DataFrame", Dict[str, np.ndarray]] +DataBatch = Union[ + "pyarrow.Table", + "pandas.DataFrame", + "polars.DataFrame", + Dict[str, np.ndarray], +] # User-facing data column type. This is the data type for data that is supplied to and # returned from column UDFs. @@ -112,7 +123,7 @@ def __call__(self, __arg: T) -> Union[U, Iterator[U]]: # same type as the metadata that describes each block in the partition. BlockPartitionMetadata = List["BlockMetadata"] -VALID_BATCH_FORMATS = ["pandas", "pyarrow", "numpy", None] +VALID_BATCH_FORMATS = ["pandas", "pyarrow", "numpy", "polars", None] DEFAULT_BATCH_FORMAT = "numpy" @@ -389,6 +400,10 @@ def to_arrow(self) -> "pyarrow.Table": """Convert this block into an Arrow table.""" raise NotImplementedError + def to_polars(self) -> "polars.DataFrame": + """Convert this block into a Polars DataFrame.""" + raise NotImplementedError + def to_block(self) -> Block: """Return the base block that this accessor wraps.""" raise NotImplementedError @@ -416,6 +431,8 @@ def to_batch_format(self, batch_format: Optional[str]) -> DataBatch: return self.to_arrow() elif batch_format == "numpy": return self.to_numpy() + elif batch_format == "polars": + return self.to_polars() else: raise ValueError( f"The batch format must be one of {VALID_BATCH_FORMATS}, got: " @@ -488,6 +505,33 @@ def batch_to_block( else: assert block_type == BlockType.PANDAS return cls.batch_to_pandas_block(batch) + else: + # Check if batch is a Polars DataFrame or related type + try: + import polars as pl + + if isinstance(batch, pl.DataFrame): + return cls.batch_to_block_from_polars(batch, block_type) + elif isinstance(batch, (pl.LazyFrame, pl.Series)): + # Provide helpful error messages for common mistakes + if isinstance(batch, pl.LazyFrame): + raise ValueError( + "Cannot convert Polars LazyFrame to block. " + "Call .collect() first to materialize it into a DataFrame." + ) + elif isinstance(batch, pl.Series): + raise ValueError( + "Cannot convert Polars Series to block. " + "Return a DataFrame instead. Use pl.DataFrame({column_name: series})." + ) + except ImportError: + # If Polars is not installed but we got here, it means the batch + # might be a Polars DataFrame. Raise a clear error. + if hasattr(batch, "__class__") and "polars" in str(type(batch)).lower(): + raise ImportError( + "Polars is not installed but a Polars object was returned. " + "Install with `pip install polars`." + ) return batch @classmethod @@ -504,6 +548,118 @@ def batch_to_pandas_block(cls, batch: Dict[str, Any]) -> Block: return PandasBlockBuilder._table_from_pydict(batch) + @classmethod + def batch_to_block_from_polars( + cls, batch: "polars.DataFrame", block_type: Optional[BlockType] = None + ) -> Block: + """Create a block from a Polars DataFrame. + + Converts a Polars DataFrame to an Arrow Table or Pandas DataFrame block. + See https://docs.pola.rs/ for Polars documentation. + + Note: This conversion always creates a copy of the data. Polars DataFrames + cannot be zero-copy converted to Arrow/Pandas blocks. + + Args: + batch: A Polars DataFrame to convert. Must be an eager DataFrame, + not a LazyFrame. See https://docs.pola.rs/api/lazyframe/ for + LazyFrame documentation. + block_type: The target block type. If None, defaults to Arrow. + + Returns: + A Block (Arrow Table or Pandas DataFrame) containing the data. + + Raises: + ImportError: If Polars is not installed. + ValueError: If batch is not a Polars DataFrame or has invalid state. + """ + try: + import polars as pl + except ImportError: + raise ImportError( + "Polars is not installed. Install with `pip install polars`. " + "See https://docs.pola.rs/ for more information." + ) + + # Validate input type + if isinstance(batch, pl.LazyFrame): + raise ValueError( + "Cannot convert Polars LazyFrame. Call .collect() first to " + "materialize the LazyFrame into a DataFrame. " + "See https://docs.pola.rs/api/lazyframe/#collect for details." + ) + + if isinstance(batch, pl.Series): + raise ValueError( + "Cannot convert Polars Series. Return a DataFrame instead. " + "Use pl.DataFrame({column_name: series}) to convert Series to DataFrame." + ) + + if not isinstance(batch, pl.DataFrame): + raise ValueError( + f"Expected polars.DataFrame, got {type(batch)}. " + "If you have a LazyFrame, call .collect() first to materialize it." + ) + + # Handle empty DataFrame + if batch.height == 0: + # Empty DataFrame - create empty Arrow table with same schema + try: + arrow_table = batch.to_arrow() + if block_type is None or block_type == BlockType.ARROW: + return arrow_table + else: + return batch.to_pandas() + except Exception as e: + raise ValueError( + f"Failed to convert empty Polars DataFrame: {e}" + ) from e + + # Validate column names (Polars doesn't allow duplicates) + # See https://docs.pola.rs/api/dataframe/#polars.DataFrame.columns + if len(batch.columns) != len(set(batch.columns)): + duplicates = [ + col for col in batch.columns if batch.columns.count(col) > 1 + ] + raise ValueError( + f"Polars DataFrame has duplicate column names: {duplicates}. " + "Rename duplicate columns before converting." + ) + + # Validate column names are strings + for col in batch.columns: + if not isinstance(col, str): + raise ValueError( + f"Polars DataFrame has non-string column name: {col} (type: {type(col)}). " + "All column names must be strings." + ) + + # Convert Polars DataFrame to Arrow Table + # See https://docs.pola.rs/api/dataframe/#polars.DataFrame.to_arrow + # Note: This always creates a copy - Polars cannot zero-copy to Arrow + try: + arrow_table = batch.to_arrow() + except Exception as e: + raise ValueError( + f"Failed to convert Polars DataFrame to Arrow Table: {e}. " + "This may be due to unsupported data types or schema issues." + ) from e + + if block_type is None or block_type == BlockType.ARROW: + return arrow_table + else: + # Convert to Pandas if needed + # See https://docs.pola.rs/api/dataframe/#polars.DataFrame.to_pandas + # Note: This creates another copy (Polars -> Arrow -> Pandas) + try: + pandas_df = batch.to_pandas() + return pandas_df + except Exception as e: + raise ValueError( + f"Failed to convert Polars DataFrame to Pandas DataFrame: {e}. " + "This may be due to unsupported data types." + ) from e + @staticmethod def for_block(block: Block) -> "BlockAccessor[T]": """Create a block accessor for the given block.""" diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 4a2cbdc72a0f..e8a221097be4 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -500,6 +500,12 @@ def map_batches( *whole* batch for you, providing ``fn`` with a copy rather than a zero-copy view. + .. note:: + When using ``batch_format="polars"``, data is always copied during + conversion, so modifying the input Polars DataFrame is safe. However, + be aware that Polars DataFrames returned from ``fn`` will also be + copied when converting back to blocks, which may increase memory usage. + .. warning:: Specifying both ``num_cpus`` and ``num_gpus`` for map tasks is experimental, and may result in scheduling or stability issues. Please @@ -616,8 +622,16 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: batch_format: If ``"default"`` or ``"numpy"``, batches are ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are ``pandas.DataFrame``. If ``"pyarrow"``, batches are - ``pyarrow.Table``. If ``batch_format`` is set to ``None`` input + ``pyarrow.Table``. If ``"polars"``, batches are + ``polars.DataFrame``. If ``batch_format`` is set to ``None`` input block format will be used. + + .. note:: + When using ``batch_format="polars"``, the conversion always + creates a copy of the data. The ``zero_copy_batch`` parameter + has no effect for Polars format, as zero-copy conversion is + not possible. This may result in higher memory usage compared + to Arrow or Pandas formats. zero_copy_batch: Whether ``fn`` should be provided zero-copy, read-only batches. If this is ``True`` and no copy is required for the ``batch_format`` conversion, the batch is a zero-copy, read-only @@ -627,6 +641,10 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: modify underlying data buffers (like tensors, binary arrays, etc) in place. It's recommended to copy only the data you need to modify instead of resorting to copying the whole batch. + + .. note:: + This parameter has no effect when ``batch_format="polars"``, + as Polars DataFrames always require data copying during conversion. fn_args: Positional arguments to pass to ``fn`` after the first argument. These arguments are top-level arguments to the underlying Ray task. fn_kwargs: Keyword arguments to pass to ``fn``. These arguments are @@ -922,7 +940,7 @@ def add_column( :func:`ray.remote` for details. """ # Check that batch_format - accepted_batch_formats = ["pandas", "pyarrow", "numpy"] + accepted_batch_formats = ["pandas", "pyarrow", "numpy", "polars"] if batch_format not in accepted_batch_formats: raise ValueError( f"batch_format argument must be on of {accepted_batch_formats}, " @@ -3403,9 +3421,13 @@ def take_batch( ) -> DataBatch: """Return up to ``batch_size`` rows from the :class:`Dataset` in a batch. - Ray Data represents batches as NumPy arrays or pandas DataFrames. You can + Ray Data represents batches as NumPy arrays, pandas DataFrames, or Polars DataFrames. You can configure the batch type by specifying ``batch_format``. + .. note:: + When using ``batch_format="polars"``, the conversion always creates + a copy of the data, which may increase memory usage compared to other formats. + This method is useful for inspecting inputs to :meth:`~Dataset.map_batches`. .. warning:: @@ -3427,7 +3449,9 @@ def take_batch( batch_size: The maximum number of rows to return. batch_format: If ``"default"`` or ``"numpy"``, batches are ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are - ``pandas.DataFrame``. + ``pandas.DataFrame``. If ``"polars"``, batches are + ``polars.DataFrame``. If ``"pyarrow"``, batches are + ``pyarrow.Table``. Returns: A batch of up to ``batch_size`` rows from the dataset. @@ -5295,7 +5319,16 @@ def iter_batches( ``drop_last`` is ``False``. Defaults to 256. batch_format: If ``"default"`` or ``"numpy"``, batches are ``Dict[str, numpy.ndarray]``. If ``"pandas"``, batches are - ``pandas.DataFrame``. + ``pandas.DataFrame``. If ``"polars"``, batches are + ``polars.DataFrame``. If ``"pyarrow"``, batches are + ``pyarrow.Table``. + + .. note:: + When using ``batch_format="polars"``, each batch conversion + creates a copy of the data. This may result in 2-3x memory usage + compared to Arrow format (input conversion + output conversion). + For large datasets, consider using ``batch_format="pyarrow"`` for + better memory efficiency. drop_last: Whether to drop the last batch if it's incomplete. local_shuffle_buffer_size: If not ``None``, the data is randomly shuffled using a local in-memory shuffle buffer, and this value serves as the diff --git a/python/ray/data/tests/block_batching/test_util.py b/python/ray/data/tests/block_batching/test_util.py index f8be82e43281..1a765a022dd1 100644 --- a/python/ray/data/tests/block_batching/test_util.py +++ b/python/ray/data/tests/block_batching/test_util.py @@ -69,7 +69,7 @@ def test_blocks_to_batches(block_size, drop_last): ) -@pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow"]) +@pytest.mark.parametrize("batch_format", ["pandas", "numpy", "pyarrow", "polars"]) def test_format_batches(batch_format): block_iter = block_generator(num_rows=2, num_blocks=2) batch_iter = ( @@ -80,11 +80,18 @@ def test_format_batches(batch_format): for batch in batch_iter: if batch_format == "pandas": assert isinstance(batch.data, pd.DataFrame) - elif batch_format == "arrow": + elif batch_format == "arrow" or batch_format == "pyarrow": assert isinstance(batch.data, pa.Table) elif batch_format == "numpy": assert isinstance(batch.data, dict) assert isinstance(batch.data["foo"], np.ndarray) + elif batch_format == "polars": + try: + import polars as pl + + assert isinstance(batch.data, pl.DataFrame) + except ImportError: + pytest.skip("Polars not installed") assert [batch.metadata.batch_idx for batch in batch_iter] == list( range(len(batch_iter)) diff --git a/python/ray/data/tests/test_consumption.py b/python/ray/data/tests/test_consumption.py index 721daa92b7fe..a0b53be51468 100644 --- a/python/ray/data/tests/test_consumption.py +++ b/python/ray/data/tests/test_consumption.py @@ -830,6 +830,14 @@ def test_batch_formats(shutdown_only): assert isinstance(next(iter(ds.iter_batches(batch_format="pandas"))), pd.DataFrame) assert isinstance(next(iter(ds.iter_batches(batch_format="pyarrow"))), pa.Table) assert isinstance(next(iter(ds.iter_batches(batch_format="numpy"))), dict) + try: + import polars as pl + + assert isinstance( + next(iter(ds.iter_batches(batch_format="polars"))), pl.DataFrame + ) + except ImportError: + pass ds = ray.data.range_tensor(100) assert isinstance(next(iter(ds.iter_batches(batch_format=None))), pa.Table) @@ -837,6 +845,14 @@ def test_batch_formats(shutdown_only): assert isinstance(next(iter(ds.iter_batches(batch_format="pandas"))), pd.DataFrame) assert isinstance(next(iter(ds.iter_batches(batch_format="pyarrow"))), pa.Table) assert isinstance(next(iter(ds.iter_batches(batch_format="numpy"))), dict) + try: + import polars as pl + + assert isinstance( + next(iter(ds.iter_batches(batch_format="polars"))), pl.DataFrame + ) + except ImportError: + pass df = pd.DataFrame({"foo": ["a", "b"], "bar": [0, 1]}) ds = ray.data.from_pandas(df) @@ -845,6 +861,14 @@ def test_batch_formats(shutdown_only): assert isinstance(next(iter(ds.iter_batches(batch_format="pandas"))), pd.DataFrame) assert isinstance(next(iter(ds.iter_batches(batch_format="pyarrow"))), pa.Table) assert isinstance(next(iter(ds.iter_batches(batch_format="numpy"))), dict) + try: + import polars as pl + + assert isinstance( + next(iter(ds.iter_batches(batch_format="polars"))), pl.DataFrame + ) + except ImportError: + pass def test_dataset_schema_after_read_stats(ray_start_cluster): diff --git a/python/ray/data/tests/test_map_batches.py b/python/ray/data/tests/test_map_batches.py index 2247b07d0603..fe8a1b033393 100644 --- a/python/ray/data/tests/test_map_batches.py +++ b/python/ray/data/tests/test_map_batches.py @@ -96,6 +96,26 @@ def test_map_batches_basic( values = [s["two"] for s in ds_list] assert values == [2, 3, 4] + # Test Polars + try: + import polars as pl + + ds = ray.data.read_parquet(str(tmp_path)) + ds2 = ds.map_batches( + lambda polars_df: polars_df.with_columns( + [pl.col("one") + 1, pl.col("two") + 1] + ), + batch_size=1, + batch_format="polars", + ) + ds_list = ds2.take() + values = [s["one"] for s in ds_list] + assert values == [2, 3, 4] + values = [s["two"] for s in ds_list] + assert values == [3, 4, 5] + except ImportError: + pytest.skip("Polars not installed") + # Test batch size = 300 ds = ray.data.range(size) @@ -733,6 +753,26 @@ def test_map_batches_timestamp_nanosecs( ) pd.testing.assert_frame_equal(processed_df_pandas, expected_df) + # Using polars format + try: + import polars as pl + + def process_timestamp_data_batch_polars(batch: pl.DataFrame) -> pl.DataFrame: + return batch.with_columns( + (pl.col("timestamp") + pl.duration(nanoseconds=1)).alias("timestamp") + ) + + result_polars = ray_data.map_batches( + process_timestamp_data_batch_polars, batch_format="polars" + ) + processed_df_polars = result_polars.to_pandas() + processed_df_polars["timestamp"] = processed_df_polars["timestamp"].astype( + "datetime64[ns]" + ) + pd.testing.assert_frame_equal(processed_df_polars, expected_df) + except ImportError: + pytest.skip("Polars not installed") + def test_map_batches_async_exception_propagation(shutdown_only): ray.shutdown()