Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 197 additions & 29 deletions rust/sedona/src/record_batch_reader_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::sync::RwLock;
use std::{any::Any, fmt::Debug, sync::Arc};

use arrow_array::RecordBatchReader;
Expand All @@ -33,6 +32,7 @@ use datafusion::{
prelude::Expr,
};
use datafusion_common::DataFusionError;
use parking_lot::Mutex;
use sedona_common::sedona_internal_err;

/// A [TableProvider] wrapping a [RecordBatchReader]
Expand All @@ -42,7 +42,7 @@ use sedona_common::sedona_internal_err;
/// such that extension types are preserved in DataFusion internals (i.e.,
/// it is intended for scanning external tables as SedonaDB).
pub struct RecordBatchReaderProvider {
reader: RwLock<Option<Box<dyn RecordBatchReader + Send>>>,
reader: Mutex<Option<Box<dyn RecordBatchReader + Send>>>,
schema: SchemaRef,
}

Expand All @@ -52,7 +52,7 @@ impl RecordBatchReaderProvider {
pub fn new(reader: Box<dyn RecordBatchReader + Send>) -> Self {
let schema = reader.schema();
Self {
reader: RwLock::new(Some(reader)),
reader: Mutex::new(Some(reader)),
schema,
}
}
Expand Down Expand Up @@ -88,26 +88,78 @@ impl TableProvider for RecordBatchReaderProvider {
_filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut writable_reader = self.reader.try_write().map_err(|_| {
DataFusionError::Internal("Failed to acquire lock on RecordBatchReader".to_string())
})?;
if let Some(reader) = writable_reader.take() {
let mut reader_guard = self.reader.lock();
if let Some(reader) = reader_guard.take() {
Ok(Arc::new(RecordBatchReaderExec::new(reader, limit)))
} else {
sedona_internal_err!("Can't scan RecordBatchReader provider more than once")
}
}
}

/// An iterator that limits the number of rows from a RecordBatchReader
struct RowLimitedIterator {
reader: Option<Box<dyn RecordBatchReader + Send>>,
limit: usize,
rows_consumed: usize,
}

impl RowLimitedIterator {
fn new(reader: Box<dyn RecordBatchReader + Send>, limit: usize) -> Self {
Self {
reader: Some(reader),
limit,
rows_consumed: 0,
}
}
}

impl Iterator for RowLimitedIterator {
type Item = Result<arrow_array::RecordBatch>;

fn next(&mut self) -> Option<Self::Item> {
// Check if we have already consumed enough rows
if self.rows_consumed >= self.limit {
self.reader = None;
return None;
}

let reader = self.reader.as_mut()?;
match reader.next() {
Some(Ok(batch)) => {
let batch_rows = batch.num_rows();

if self.rows_consumed + batch_rows <= self.limit {
// Batch fits within limit, consume it entirely
self.rows_consumed += batch_rows;
Some(Ok(batch))
} else {
// Batch would exceed limit, need to truncate it
let rows_to_take = self.limit - self.rows_consumed;
self.rows_consumed = self.limit;
self.reader = None;
Some(Ok(batch.slice(0, rows_to_take)))
}
}
Some(Err(e)) => {
self.reader = None;
Some(Err(DataFusionError::from(e)))
}
None => {
self.reader = None;
None
}
}
}
}

struct RecordBatchReaderExec {
reader: RwLock<Option<Box<dyn RecordBatchReader + Send>>>,
reader: Mutex<Option<Box<dyn RecordBatchReader + Send>>>,
schema: SchemaRef,
properties: PlanProperties,
limit: Option<usize>,
}

unsafe impl Sync for RecordBatchReaderExec {}

impl RecordBatchReaderExec {
fn new(reader: Box<dyn RecordBatchReader + Send>, limit: Option<usize>) -> Self {
let schema = reader.schema();
Expand All @@ -119,7 +171,7 @@ impl RecordBatchReaderExec {
);

Self {
reader: RwLock::new(Some(reader)),
reader: Mutex::new(Some(reader)),
schema,
properties,
limit,
Expand Down Expand Up @@ -177,29 +229,35 @@ impl ExecutionPlan for RecordBatchReaderExec {
_partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let mut writable_reader = self.reader.try_write().map_err(|_| {
DataFusionError::Internal("Failed to acquire lock on RecordBatchReader".to_string())
})?;
let mut reader_guard = self.reader.lock();

let reader = if let Some(reader) = writable_reader.take() {
let reader = if let Some(reader) = reader_guard.take() {
reader
} else {
return sedona_internal_err!("Can't scan RecordBatchReader provider more than once");
};

let limit = self.limit;

// Create a stream from the RecordBatchReader iterator
let iter = reader
.map(|item| match item {
Ok(batch) => Ok(batch),
Err(e) => Err(DataFusionError::from(e)),
})
.take(limit.unwrap_or(usize::MAX));

let stream = Box::pin(futures::stream::iter(iter));
let record_batch_stream = RecordBatchStreamAdapter::new(self.schema.clone(), stream);
Ok(Box::pin(record_batch_stream))
match self.limit {
Some(limit) => {
// Create a row-limited iterator that properly handles row counting
let iter = RowLimitedIterator::new(reader, limit);
let stream = Box::pin(futures::stream::iter(iter));
let record_batch_stream =
RecordBatchStreamAdapter::new(self.schema.clone(), stream);
Ok(Box::pin(record_batch_stream))
}
None => {
// No limit, just convert the reader directly to a stream
let iter = reader.map(|item| match item {
Ok(batch) => Ok(batch),
Err(e) => Err(DataFusionError::from(e)),
});
let stream = Box::pin(futures::stream::iter(iter));
let record_batch_stream =
RecordBatchStreamAdapter::new(self.schema.clone(), stream);
Ok(Box::pin(record_batch_stream))
}
}
}
}

Expand All @@ -208,12 +266,40 @@ mod test {

use arrow_array::{RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::SessionContext;
use datafusion::prelude::{DataFrame, SessionContext};
use rstest::rstest;
use sedona_schema::datatypes::WKB_GEOMETRY;
use sedona_testing::create::create_array_storage;

use super::*;

fn create_test_batch(size: usize, start_id: i32) -> RecordBatch {
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
let ids: Vec<i32> = (start_id..start_id + size as i32).collect();
RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(arrow_array::Int32Array::from(ids))],
)
.unwrap()
}

fn create_test_reader(batch_sizes: Vec<usize>) -> Box<dyn RecordBatchReader + Send> {
let mut start_id = 0i32;
let batches: Vec<RecordBatch> = batch_sizes
.into_iter()
.map(|size| {
let batch = create_test_batch(size, start_id);
start_id += size as i32;
batch
})
.collect();
let schema = batches[0].schema();
Box::new(RecordBatchIterator::new(
batches.into_iter().map(Ok),
schema,
))
}

#[tokio::test]
async fn provider() {
let ctx = SessionContext::new();
Expand Down Expand Up @@ -244,4 +330,86 @@ mod test {
let results = df.collect().await.unwrap();
assert_eq!(results, vec![batch])
}

#[rstest]
#[case(vec![10, 20, 30], None, 60)] // No limit
#[case(vec![10, 20, 30], Some(5), 5)] // Limit within first batch
#[case(vec![10, 20, 30], Some(10), 10)] // Limit exactly at first batch boundary
#[case(vec![10, 20, 30], Some(15), 15)] // Limit within second batch
#[case(vec![10, 20, 30], Some(30), 30)] // Limit at second batch boundary
#[case(vec![10, 20, 30], Some(45), 45)] // Limit within third batch
#[case(vec![10, 20, 30], Some(60), 60)] // Limit at total rows
#[case(vec![10, 20, 30], Some(100), 60)] // Limit exceeds total rows
#[case(vec![0, 5, 0, 3], Some(6), 6)] // Empty batches mixed in, limit within data
#[case(vec![0, 5, 0, 3], Some(8), 8)] // Empty batches mixed in, limit equals total
#[case(vec![0, 5, 0, 3], None, 8)] // Empty batches mixed in, no limit
#[tokio::test]
async fn test_scan_with_row_limit(
#[case] batch_sizes: Vec<usize>,
#[case] limit: Option<usize>,
#[case] expected_rows: usize,
) {
let ctx = SessionContext::new();

// Verify that the RecordBatchReaderExec node in the execution plan should contain the correct limit
let physical_plan = read_test_table_with_limit(&ctx, batch_sizes.clone(), limit)
.unwrap()
.create_physical_plan()
.await
.unwrap();
let reader_exec = find_record_batch_reader_exec(physical_plan.as_ref())
.expect("The plan should contain RecordBatchReaderExec");
assert_eq!(reader_exec.limit, limit);

let df = read_test_table_with_limit(&ctx, batch_sizes, limit).unwrap();
let results = df.collect().await.unwrap();
let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum();
assert_eq!(total_rows, expected_rows);

// Verify row values are correct (sequential IDs starting from 0)
if expected_rows > 0 {
let mut expected_id = 0i32;
for batch in results.iter() {
let id_array = batch
.column(0)
.as_any()
.downcast_ref::<arrow_array::Int32Array>()
.unwrap();
for i in 0..id_array.len() {
assert_eq!(id_array.value(i), expected_id);
expected_id += 1;
}
}
}
}

fn read_test_table_with_limit(
ctx: &SessionContext,
batch_sizes: Vec<usize>,
limit: Option<usize>,
) -> Result<DataFrame> {
let reader = create_test_reader(batch_sizes);
let provider = Arc::new(RecordBatchReaderProvider::new(reader));
let df = ctx.read_table(provider)?;
if let Some(limit) = limit {
df.limit(0, Some(limit))
} else {
Ok(df)
}
}

// Navigate through the plan structure to find our RecordBatchReaderExec
fn find_record_batch_reader_exec(plan: &dyn ExecutionPlan) -> Option<&RecordBatchReaderExec> {
if let Some(reader_exec) = plan.as_any().downcast_ref::<RecordBatchReaderExec>() {
return Some(reader_exec);
}

// Recursively search children
for child in plan.children() {
if let Some(reader_exec) = find_record_batch_reader_exec(child.as_ref()) {
return Some(reader_exec);
}
}
None
}
}
Loading