Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions python/sedonadb/python/sedonadb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,9 @@ def to_arrow_table(self, schema: Any = None) -> "pyarrow.Table":
import pyarrow as pa
import geoarrow.pyarrow # noqa: F401

if schema is None:
return pa.table(self)
else:
return pa.table(self, schema=pa.schema(schema))
# Collects all batches into an object that exposes __arrow_c_stream__()
batches = self._impl.to_batches(schema)
return pa.table(batches)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we do not need to pass schema to pa.table anymore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, passing it via to_batches() ensures it makes it to Rust!


def to_pandas(
self, geometry: Optional[str] = None
Expand Down
107 changes: 82 additions & 25 deletions python/sedonadb/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ use std::ffi::CString;
use std::str::FromStr;
use std::sync::Arc;

use arrow_array::ffi::FFI_ArrowSchema;
use arrow_array::ffi_stream::FFI_ArrowArrayStream;
use arrow_array::RecordBatchReader;
use arrow_schema::Schema;
use arrow_array::{RecordBatch, RecordBatchReader};
use arrow_schema::{Schema, SchemaRef};
use datafusion::catalog::MemTable;
use datafusion::logical_expr::SortExpr;
use datafusion::prelude::DataFrame;
use datafusion_common::Column;
use datafusion_common::{Column, DataFusionError};
use datafusion_expr::{ExplainFormat, ExplainOption, Expr};
use datafusion_ffi::table_provider::FFI_TableProvider;
use futures::TryStreamExt;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;
use sedona::context::{SedonaDataFrame, SedonaWriteOptions};
Expand All @@ -38,7 +38,7 @@ use tokio::runtime::Runtime;

use crate::context::InternalContext;
use crate::error::PySedonaError;
use crate::import_from::check_pycapsule;
use crate::import_from::import_arrow_schema;
use crate::reader::PySedonaStreamReader;
use crate::runtime::wait_for_future;
use crate::schema::PySedonaSchema;
Expand Down Expand Up @@ -100,14 +100,17 @@ impl InternalDataFrame {
}

fn execute<'py>(&self, py: Python<'py>) -> Result<usize, PySedonaError> {
let mut c = 0;
let stream = wait_for_future(py, &self.runtime, self.inner.clone().execute_stream())??;
let reader = PySedonaStreamReader::new(self.runtime.clone(), stream);
for batch in reader {
c += batch?.num_rows();
}
let df = self.inner.clone();
let count = wait_for_future(py, &self.runtime, async move {
let mut stream = df.execute_stream().await?;
let mut c = 0usize;
while let Some(batch) = stream.try_next().await? {
c += batch.num_rows();
}
Ok::<_, DataFusionError>(c)
})??;

Ok(c)
Ok(count)
}

fn count<'py>(&self, py: Python<'py>) -> Result<usize, PySedonaError> {
Expand Down Expand Up @@ -149,6 +152,28 @@ impl InternalDataFrame {
))
}

fn to_batches<'py>(
&self,
py: Python<'py>,
requested_schema: Option<Bound<'py, PyAny>>,
) -> Result<Batches, PySedonaError> {
check_py_requested_schema(requested_schema, self.inner.schema().as_arrow())?;

let df = self.inner.clone();
let batches = wait_for_future(py, &self.runtime, async move {
let mut stream = df.execute_stream().await?;
let schema = stream.schema();
let mut batches = Vec::new();
while let Some(batch) = stream.try_next().await? {
batches.push(batch);
}

Ok::<_, DataFusionError>(Batches { schema, batches })
})??;

Ok(batches)
}

#[allow(clippy::too_many_arguments)]
fn to_parquet<'py>(
&self,
Expand Down Expand Up @@ -265,20 +290,9 @@ impl InternalDataFrame {
fn __arrow_c_stream__<'py>(
&self,
py: Python<'py>,
#[allow(unused_variables)] requested_schema: Option<Bound<'py, PyCapsule>>,
requested_schema: Option<Bound<'py, PyAny>>,
) -> Result<Bound<'py, PyCapsule>, PySedonaError> {
if let Some(requested_capsule) = requested_schema {
let contents = check_pycapsule(&requested_capsule, "arrow_schema")?;
let ffi_schema = unsafe { FFI_ArrowSchema::from_raw(contents as _) };
let requested_schema = Schema::try_from(&ffi_schema)?;
let actual_schema = self.inner.schema().as_arrow();
if &requested_schema != actual_schema {
// Eventually we can support this by inserting a cast
return Err(PySedonaError::SedonaPython(
"Requested schema != DataFrame schema not yet supported".to_string(),
));
}
}
check_py_requested_schema(requested_schema, self.inner.schema().as_arrow())?;

let stream = wait_for_future(py, &self.runtime, self.inner.clone().execute_stream())??;
let reader = PySedonaStreamReader::new(self.runtime.clone(), stream);
Expand All @@ -289,3 +303,46 @@ impl InternalDataFrame {
Ok(PyCapsule::new(py, ffi_stream, Some(stream_capsule_name))?)
}
}

#[pyclass]
pub struct Batches {
schema: SchemaRef,
batches: Vec<RecordBatch>,
}

#[pymethods]
impl Batches {
#[pyo3(signature = (requested_schema=None))]
fn __arrow_c_stream__<'py>(
&self,
py: Python<'py>,
requested_schema: Option<Bound<'py, PyAny>>,
) -> Result<Bound<'py, PyCapsule>, PySedonaError> {
check_py_requested_schema(requested_schema, &self.schema)?;

let reader = arrow_array::RecordBatchIterator::new(
self.batches.clone().into_iter().map(Ok),
self.schema.clone(),
);
let reader: Box<dyn RecordBatchReader + Send> = Box::new(reader);

let ffi_stream = FFI_ArrowArrayStream::new(reader);
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
Ok(PyCapsule::new(py, ffi_stream, Some(stream_capsule_name))?)
}
}

fn check_py_requested_schema<'py>(
requested_schema: Option<Bound<'py, PyAny>>,
actual_schema: &Schema,
) -> Result<(), PySedonaError> {
if let Some(requested_obj) = requested_schema {
let requested = import_arrow_schema(&requested_obj)?;
if &requested != actual_schema {
return Err(PySedonaError::SedonaPython(
"Requested schema != actual schema not yet supported".to_string(),
));
}
}
Ok(())
}
20 changes: 14 additions & 6 deletions python/sedonadb/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ use tokio::{runtime::Runtime, time::sleep};
use crate::error::PySedonaError;

// Adapted from datafusion-python:
// https://github.com/apache/datafusion-python/blob/cbe845b1e840c78f7a9fc4d83d184a1e6f35f47c/src/utils.rs#L64
// https://github.com/apache/datafusion-python/blob/7aff3635c93d5897d470642928c39c86e7851931/src/utils.rs#L80-L106
pub fn wait_for_future<F>(py: Python, runtime: &Runtime, fut: F) -> Result<F::Output, PySedonaError>
where
F: Future + Send,
F::Output: Send,
{
const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(2_000);

py.run(cr"pass", None, None)?;
py.check_signals()?;

py.allow_threads(|| {
runtime.block_on(async {
Expand All @@ -37,7 +40,10 @@ where
tokio::select! {
res = &mut fut => break Ok(res),
_ = sleep(INTERVAL_CHECK_SIGNALS) => {
Python::with_gil(|py| py.check_signals())?;
Python::with_gil(|py| {
py.run(cr"pass", None, None)?;
py.check_signals()
})?;
}
}
}
Expand All @@ -52,15 +58,17 @@ where
F: Future + Send,
F::Output: Send,
{
const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);

const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(2_000);
runtime.block_on(async {
tokio::pin!(fut);
loop {
tokio::select! {
res = &mut fut => break Ok(res),
_ = sleep(INTERVAL_CHECK_SIGNALS) => {
Python::with_gil(|py| py.check_signals())?;
Python::with_gil(|py| {
py.run(cr"pass", None, None)?;
py.check_signals()
})?;
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions python/sedonadb/tests/functions/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
# under the License.
import math

import pyarrow
import pytest
import shapely
import sedonadb
from sedonadb.testing import PostGIS, SedonaDB, geom_or_null, val_or_null


Expand Down Expand Up @@ -1709,7 +1709,7 @@ def test_st_geomfromwkbunchecked_invalid_wkb(eng):
)

# Using invalid WKB elsewhere may result in undefined behavior.
with pytest.raises(pyarrow.lib.ArrowInvalid, match="failed to fill whole buffer"):
with pytest.raises(sedonadb._lib.SedonaError, match="failed to fill whole buffer"):
eng.execute_and_collect("SELECT ST_AsText(ST_GeomFromWKBUnchecked(0x01))")


Expand Down
3 changes: 3 additions & 0 deletions python/sedonadb/tests/functions/test_wkb.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
],
)
def test_st_asewkb(eng, srid, geom):
if shapely.geos_version < (3, 12, 0):
pytest.skip("GEOS version 3.12+ required for EWKB tests")

eng = eng.create_or_skip()

if geom is not None:
Expand Down
2 changes: 1 addition & 1 deletion python/sedonadb/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def test_dataframe_to_arrow(con):
# ...but not otherwise (yet)
with pytest.raises(
sedonadb._lib.SedonaError,
match="Requested schema != DataFrame schema not yet supported",
match="Requested schema != actual schema not yet supported",
):
df.to_arrow_table(schema=pa.schema({}))

Expand Down
20 changes: 17 additions & 3 deletions python/sedonadb/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pandas as pd
import pyarrow as pa
import pytest
import sedonadb
from sedonadb import udf


Expand Down Expand Up @@ -122,6 +123,19 @@ def shapely_udf(geom, distance):
pd.DataFrame({"col": [3857]}, dtype=np.uint32),
)

# Ensure we can collect with >1 batch without hanging
con.funcs.table.sd_random_geometry("Point", 20000).to_view("pts", overwrite=True)
df = con.sql(
"SELECT ST_Area(shapely_udf(ST_Point(0, 0), 2.0)) as col FROM pts"
).to_pandas()
assert len(df) == 20000

# Ensure we can execute with >1 batch without hanging
count = con.sql(
"SELECT ST_Area(shapely_udf(ST_Point(0, 0), 2.0)) as col FROM pts"
).execute()
assert count == 20000


def test_py_sedona_value(con):
@udf.arrow_udf(pa.int64())
Expand Down Expand Up @@ -170,7 +184,7 @@ def questionable_udf(arg):

con.register_udf(questionable_udf)
with pytest.raises(
ValueError,
sedonadb._lib.SedonaError,
match="Expected result of user-defined function to return an object implementing __arrow_c_array__",
):
con.sql("SELECT questionable_udf(123) as col").to_pandas()
Expand All @@ -183,7 +197,7 @@ def questionable_udf(arg):

con.register_udf(questionable_udf)
with pytest.raises(
ValueError,
sedonadb._lib.SedonaError,
match=(
"Expected result of user-defined function to "
"return array of type Binary or its storage "
Expand All @@ -200,7 +214,7 @@ def questionable_udf(arg):

con.register_udf(questionable_udf)
with pytest.raises(
ValueError,
sedonadb._lib.SedonaError,
match="Expected result of user-defined function to return array of length 1 but got 2",
):
con.sql("SELECT questionable_udf(123) as col").to_pandas()
Expand Down