diff --git a/python/sedonadb/python/sedonadb/dataframe.py b/python/sedonadb/python/sedonadb/dataframe.py index 429800637..38ed5caf8 100644 --- a/python/sedonadb/python/sedonadb/dataframe.py +++ b/python/sedonadb/python/sedonadb/dataframe.py @@ -253,10 +253,9 @@ def to_arrow_table(self, schema: Any = None) -> "pyarrow.Table": import pyarrow as pa import geoarrow.pyarrow # noqa: F401 - if schema is None: - return pa.table(self) - else: - return pa.table(self, schema=pa.schema(schema)) + # Collects all batches into an object that exposes __arrow_c_stream__() + batches = self._impl.to_batches(schema) + return pa.table(batches) def to_pandas( self, geometry: Optional[str] = None diff --git a/python/sedonadb/src/dataframe.rs b/python/sedonadb/src/dataframe.rs index 3ae85b599..eb57f6db9 100644 --- a/python/sedonadb/src/dataframe.rs +++ b/python/sedonadb/src/dataframe.rs @@ -18,16 +18,16 @@ use std::ffi::CString; use std::str::FromStr; use std::sync::Arc; -use arrow_array::ffi::FFI_ArrowSchema; use arrow_array::ffi_stream::FFI_ArrowArrayStream; -use arrow_array::RecordBatchReader; -use arrow_schema::Schema; +use arrow_array::{RecordBatch, RecordBatchReader}; +use arrow_schema::{Schema, SchemaRef}; use datafusion::catalog::MemTable; use datafusion::logical_expr::SortExpr; use datafusion::prelude::DataFrame; -use datafusion_common::Column; +use datafusion_common::{Column, DataFusionError}; use datafusion_expr::{ExplainFormat, ExplainOption, Expr}; use datafusion_ffi::table_provider::FFI_TableProvider; +use futures::TryStreamExt; use pyo3::prelude::*; use pyo3::types::PyCapsule; use sedona::context::{SedonaDataFrame, SedonaWriteOptions}; @@ -38,7 +38,7 @@ use tokio::runtime::Runtime; use crate::context::InternalContext; use crate::error::PySedonaError; -use crate::import_from::check_pycapsule; +use crate::import_from::import_arrow_schema; use crate::reader::PySedonaStreamReader; use crate::runtime::wait_for_future; use crate::schema::PySedonaSchema; @@ -100,14 +100,17 @@ impl InternalDataFrame { } fn execute<'py>(&self, py: Python<'py>) -> Result { - let mut c = 0; - let stream = wait_for_future(py, &self.runtime, self.inner.clone().execute_stream())??; - let reader = PySedonaStreamReader::new(self.runtime.clone(), stream); - for batch in reader { - c += batch?.num_rows(); - } + let df = self.inner.clone(); + let count = wait_for_future(py, &self.runtime, async move { + let mut stream = df.execute_stream().await?; + let mut c = 0usize; + while let Some(batch) = stream.try_next().await? { + c += batch.num_rows(); + } + Ok::<_, DataFusionError>(c) + })??; - Ok(c) + Ok(count) } fn count<'py>(&self, py: Python<'py>) -> Result { @@ -149,6 +152,28 @@ impl InternalDataFrame { )) } + fn to_batches<'py>( + &self, + py: Python<'py>, + requested_schema: Option>, + ) -> Result { + check_py_requested_schema(requested_schema, self.inner.schema().as_arrow())?; + + let df = self.inner.clone(); + let batches = wait_for_future(py, &self.runtime, async move { + let mut stream = df.execute_stream().await?; + let schema = stream.schema(); + let mut batches = Vec::new(); + while let Some(batch) = stream.try_next().await? { + batches.push(batch); + } + + Ok::<_, DataFusionError>(Batches { schema, batches }) + })??; + + Ok(batches) + } + #[allow(clippy::too_many_arguments)] fn to_parquet<'py>( &self, @@ -265,20 +290,9 @@ impl InternalDataFrame { fn __arrow_c_stream__<'py>( &self, py: Python<'py>, - #[allow(unused_variables)] requested_schema: Option>, + requested_schema: Option>, ) -> Result, PySedonaError> { - if let Some(requested_capsule) = requested_schema { - let contents = check_pycapsule(&requested_capsule, "arrow_schema")?; - let ffi_schema = unsafe { FFI_ArrowSchema::from_raw(contents as _) }; - let requested_schema = Schema::try_from(&ffi_schema)?; - let actual_schema = self.inner.schema().as_arrow(); - if &requested_schema != actual_schema { - // Eventually we can support this by inserting a cast - return Err(PySedonaError::SedonaPython( - "Requested schema != DataFrame schema not yet supported".to_string(), - )); - } - } + check_py_requested_schema(requested_schema, self.inner.schema().as_arrow())?; let stream = wait_for_future(py, &self.runtime, self.inner.clone().execute_stream())??; let reader = PySedonaStreamReader::new(self.runtime.clone(), stream); @@ -289,3 +303,46 @@ impl InternalDataFrame { Ok(PyCapsule::new(py, ffi_stream, Some(stream_capsule_name))?) } } + +#[pyclass] +pub struct Batches { + schema: SchemaRef, + batches: Vec, +} + +#[pymethods] +impl Batches { + #[pyo3(signature = (requested_schema=None))] + fn __arrow_c_stream__<'py>( + &self, + py: Python<'py>, + requested_schema: Option>, + ) -> Result, PySedonaError> { + check_py_requested_schema(requested_schema, &self.schema)?; + + let reader = arrow_array::RecordBatchIterator::new( + self.batches.clone().into_iter().map(Ok), + self.schema.clone(), + ); + let reader: Box = Box::new(reader); + + let ffi_stream = FFI_ArrowArrayStream::new(reader); + let stream_capsule_name = CString::new("arrow_array_stream").unwrap(); + Ok(PyCapsule::new(py, ffi_stream, Some(stream_capsule_name))?) + } +} + +fn check_py_requested_schema<'py>( + requested_schema: Option>, + actual_schema: &Schema, +) -> Result<(), PySedonaError> { + if let Some(requested_obj) = requested_schema { + let requested = import_arrow_schema(&requested_obj)?; + if &requested != actual_schema { + return Err(PySedonaError::SedonaPython( + "Requested schema != actual schema not yet supported".to_string(), + )); + } + } + Ok(()) +} diff --git a/python/sedonadb/src/runtime.rs b/python/sedonadb/src/runtime.rs index d6db68c71..772f90b1e 100644 --- a/python/sedonadb/src/runtime.rs +++ b/python/sedonadb/src/runtime.rs @@ -22,13 +22,16 @@ use tokio::{runtime::Runtime, time::sleep}; use crate::error::PySedonaError; // Adapted from datafusion-python: -// https://github.com/apache/datafusion-python/blob/cbe845b1e840c78f7a9fc4d83d184a1e6f35f47c/src/utils.rs#L64 +// https://github.com/apache/datafusion-python/blob/7aff3635c93d5897d470642928c39c86e7851931/src/utils.rs#L80-L106 pub fn wait_for_future(py: Python, runtime: &Runtime, fut: F) -> Result where F: Future + Send, F::Output: Send, { - const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000); + const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(2_000); + + py.run(cr"pass", None, None)?; + py.check_signals()?; py.allow_threads(|| { runtime.block_on(async { @@ -37,7 +40,10 @@ where tokio::select! { res = &mut fut => break Ok(res), _ = sleep(INTERVAL_CHECK_SIGNALS) => { - Python::with_gil(|py| py.check_signals())?; + Python::with_gil(|py| { + py.run(cr"pass", None, None)?; + py.check_signals() + })?; } } } @@ -52,15 +58,17 @@ where F: Future + Send, F::Output: Send, { - const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000); - + const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(2_000); runtime.block_on(async { tokio::pin!(fut); loop { tokio::select! { res = &mut fut => break Ok(res), _ = sleep(INTERVAL_CHECK_SIGNALS) => { - Python::with_gil(|py| py.check_signals())?; + Python::with_gil(|py| { + py.run(cr"pass", None, None)?; + py.check_signals() + })?; } } } diff --git a/python/sedonadb/tests/functions/test_functions.py b/python/sedonadb/tests/functions/test_functions.py index 1c16b6d33..34159360d 100644 --- a/python/sedonadb/tests/functions/test_functions.py +++ b/python/sedonadb/tests/functions/test_functions.py @@ -16,9 +16,9 @@ # under the License. import math -import pyarrow import pytest import shapely +import sedonadb from sedonadb.testing import PostGIS, SedonaDB, geom_or_null, val_or_null @@ -1709,7 +1709,7 @@ def test_st_geomfromwkbunchecked_invalid_wkb(eng): ) # Using invalid WKB elsewhere may result in undefined behavior. - with pytest.raises(pyarrow.lib.ArrowInvalid, match="failed to fill whole buffer"): + with pytest.raises(sedonadb._lib.SedonaError, match="failed to fill whole buffer"): eng.execute_and_collect("SELECT ST_AsText(ST_GeomFromWKBUnchecked(0x01))") diff --git a/python/sedonadb/tests/functions/test_wkb.py b/python/sedonadb/tests/functions/test_wkb.py index 424d9a36b..02d56f0c8 100644 --- a/python/sedonadb/tests/functions/test_wkb.py +++ b/python/sedonadb/tests/functions/test_wkb.py @@ -70,6 +70,9 @@ ], ) def test_st_asewkb(eng, srid, geom): + if shapely.geos_version < (3, 12, 0): + pytest.skip("GEOS version 3.12+ required for EWKB tests") + eng = eng.create_or_skip() if geom is not None: diff --git a/python/sedonadb/tests/test_dataframe.py b/python/sedonadb/tests/test_dataframe.py index bb3ca7828..681d6d823 100644 --- a/python/sedonadb/tests/test_dataframe.py +++ b/python/sedonadb/tests/test_dataframe.py @@ -261,7 +261,7 @@ def test_dataframe_to_arrow(con): # ...but not otherwise (yet) with pytest.raises( sedonadb._lib.SedonaError, - match="Requested schema != DataFrame schema not yet supported", + match="Requested schema != actual schema not yet supported", ): df.to_arrow_table(schema=pa.schema({})) diff --git a/python/sedonadb/tests/test_udf.py b/python/sedonadb/tests/test_udf.py index ab019f9bd..4159c96ff 100644 --- a/python/sedonadb/tests/test_udf.py +++ b/python/sedonadb/tests/test_udf.py @@ -18,6 +18,7 @@ import pandas as pd import pyarrow as pa import pytest +import sedonadb from sedonadb import udf @@ -122,6 +123,19 @@ def shapely_udf(geom, distance): pd.DataFrame({"col": [3857]}, dtype=np.uint32), ) + # Ensure we can collect with >1 batch without hanging + con.funcs.table.sd_random_geometry("Point", 20000).to_view("pts", overwrite=True) + df = con.sql( + "SELECT ST_Area(shapely_udf(ST_Point(0, 0), 2.0)) as col FROM pts" + ).to_pandas() + assert len(df) == 20000 + + # Ensure we can execute with >1 batch without hanging + count = con.sql( + "SELECT ST_Area(shapely_udf(ST_Point(0, 0), 2.0)) as col FROM pts" + ).execute() + assert count == 20000 + def test_py_sedona_value(con): @udf.arrow_udf(pa.int64()) @@ -170,7 +184,7 @@ def questionable_udf(arg): con.register_udf(questionable_udf) with pytest.raises( - ValueError, + sedonadb._lib.SedonaError, match="Expected result of user-defined function to return an object implementing __arrow_c_array__", ): con.sql("SELECT questionable_udf(123) as col").to_pandas() @@ -183,7 +197,7 @@ def questionable_udf(arg): con.register_udf(questionable_udf) with pytest.raises( - ValueError, + sedonadb._lib.SedonaError, match=( "Expected result of user-defined function to " "return array of type Binary or its storage " @@ -200,7 +214,7 @@ def questionable_udf(arg): con.register_udf(questionable_udf) with pytest.raises( - ValueError, + sedonadb._lib.SedonaError, match="Expected result of user-defined function to return array of length 1 but got 2", ): con.sql("SELECT questionable_udf(123) as col").to_pandas()