diff --git a/rust/sedona/src/record_batch_reader_provider.rs b/rust/sedona/src/record_batch_reader_provider.rs index 1d90a8e61..1832e93e7 100644 --- a/rust/sedona/src/record_batch_reader_provider.rs +++ b/rust/sedona/src/record_batch_reader_provider.rs @@ -14,7 +14,6 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -use std::sync::RwLock; use std::{any::Any, fmt::Debug, sync::Arc}; use arrow_array::RecordBatchReader; @@ -33,6 +32,7 @@ use datafusion::{ prelude::Expr, }; use datafusion_common::DataFusionError; +use parking_lot::Mutex; use sedona_common::sedona_internal_err; /// A [TableProvider] wrapping a [RecordBatchReader] @@ -42,7 +42,7 @@ use sedona_common::sedona_internal_err; /// such that extension types are preserved in DataFusion internals (i.e., /// it is intended for scanning external tables as SedonaDB). pub struct RecordBatchReaderProvider { - reader: RwLock>>, + reader: Mutex>>, schema: SchemaRef, } @@ -52,7 +52,7 @@ impl RecordBatchReaderProvider { pub fn new(reader: Box) -> Self { let schema = reader.schema(); Self { - reader: RwLock::new(Some(reader)), + reader: Mutex::new(Some(reader)), schema, } } @@ -88,10 +88,8 @@ impl TableProvider for RecordBatchReaderProvider { _filters: &[Expr], limit: Option, ) -> Result> { - let mut writable_reader = self.reader.try_write().map_err(|_| { - DataFusionError::Internal("Failed to acquire lock on RecordBatchReader".to_string()) - })?; - if let Some(reader) = writable_reader.take() { + let mut reader_guard = self.reader.lock(); + if let Some(reader) = reader_guard.take() { Ok(Arc::new(RecordBatchReaderExec::new(reader, limit))) } else { sedona_internal_err!("Can't scan RecordBatchReader provider more than once") @@ -99,15 +97,69 @@ impl TableProvider for RecordBatchReaderProvider { } } +/// An iterator that limits the number of rows from a RecordBatchReader +struct RowLimitedIterator { + reader: Option>, + limit: usize, + rows_consumed: usize, +} + +impl RowLimitedIterator { + fn new(reader: Box, limit: usize) -> Self { + Self { + reader: Some(reader), + limit, + rows_consumed: 0, + } + } +} + +impl Iterator for RowLimitedIterator { + type Item = Result; + + fn next(&mut self) -> Option { + // Check if we have already consumed enough rows + if self.rows_consumed >= self.limit { + self.reader = None; + return None; + } + + let reader = self.reader.as_mut()?; + match reader.next() { + Some(Ok(batch)) => { + let batch_rows = batch.num_rows(); + + if self.rows_consumed + batch_rows <= self.limit { + // Batch fits within limit, consume it entirely + self.rows_consumed += batch_rows; + Some(Ok(batch)) + } else { + // Batch would exceed limit, need to truncate it + let rows_to_take = self.limit - self.rows_consumed; + self.rows_consumed = self.limit; + self.reader = None; + Some(Ok(batch.slice(0, rows_to_take))) + } + } + Some(Err(e)) => { + self.reader = None; + Some(Err(DataFusionError::from(e))) + } + None => { + self.reader = None; + None + } + } + } +} + struct RecordBatchReaderExec { - reader: RwLock>>, + reader: Mutex>>, schema: SchemaRef, properties: PlanProperties, limit: Option, } -unsafe impl Sync for RecordBatchReaderExec {} - impl RecordBatchReaderExec { fn new(reader: Box, limit: Option) -> Self { let schema = reader.schema(); @@ -119,7 +171,7 @@ impl RecordBatchReaderExec { ); Self { - reader: RwLock::new(Some(reader)), + reader: Mutex::new(Some(reader)), schema, properties, limit, @@ -177,29 +229,35 @@ impl ExecutionPlan for RecordBatchReaderExec { _partition: usize, _context: Arc, ) -> Result { - let mut writable_reader = self.reader.try_write().map_err(|_| { - DataFusionError::Internal("Failed to acquire lock on RecordBatchReader".to_string()) - })?; + let mut reader_guard = self.reader.lock(); - let reader = if let Some(reader) = writable_reader.take() { + let reader = if let Some(reader) = reader_guard.take() { reader } else { return sedona_internal_err!("Can't scan RecordBatchReader provider more than once"); }; - let limit = self.limit; - - // Create a stream from the RecordBatchReader iterator - let iter = reader - .map(|item| match item { - Ok(batch) => Ok(batch), - Err(e) => Err(DataFusionError::from(e)), - }) - .take(limit.unwrap_or(usize::MAX)); - - let stream = Box::pin(futures::stream::iter(iter)); - let record_batch_stream = RecordBatchStreamAdapter::new(self.schema.clone(), stream); - Ok(Box::pin(record_batch_stream)) + match self.limit { + Some(limit) => { + // Create a row-limited iterator that properly handles row counting + let iter = RowLimitedIterator::new(reader, limit); + let stream = Box::pin(futures::stream::iter(iter)); + let record_batch_stream = + RecordBatchStreamAdapter::new(self.schema.clone(), stream); + Ok(Box::pin(record_batch_stream)) + } + None => { + // No limit, just convert the reader directly to a stream + let iter = reader.map(|item| match item { + Ok(batch) => Ok(batch), + Err(e) => Err(DataFusionError::from(e)), + }); + let stream = Box::pin(futures::stream::iter(iter)); + let record_batch_stream = + RecordBatchStreamAdapter::new(self.schema.clone(), stream); + Ok(Box::pin(record_batch_stream)) + } + } } } @@ -208,12 +266,40 @@ mod test { use arrow_array::{RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; - use datafusion::prelude::SessionContext; + use datafusion::prelude::{DataFrame, SessionContext}; + use rstest::rstest; use sedona_schema::datatypes::WKB_GEOMETRY; use sedona_testing::create::create_array_storage; use super::*; + fn create_test_batch(size: usize, start_id: i32) -> RecordBatch { + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let ids: Vec = (start_id..start_id + size as i32).collect(); + RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(arrow_array::Int32Array::from(ids))], + ) + .unwrap() + } + + fn create_test_reader(batch_sizes: Vec) -> Box { + let mut start_id = 0i32; + let batches: Vec = batch_sizes + .into_iter() + .map(|size| { + let batch = create_test_batch(size, start_id); + start_id += size as i32; + batch + }) + .collect(); + let schema = batches[0].schema(); + Box::new(RecordBatchIterator::new( + batches.into_iter().map(Ok), + schema, + )) + } + #[tokio::test] async fn provider() { let ctx = SessionContext::new(); @@ -244,4 +330,86 @@ mod test { let results = df.collect().await.unwrap(); assert_eq!(results, vec![batch]) } + + #[rstest] + #[case(vec![10, 20, 30], None, 60)] // No limit + #[case(vec![10, 20, 30], Some(5), 5)] // Limit within first batch + #[case(vec![10, 20, 30], Some(10), 10)] // Limit exactly at first batch boundary + #[case(vec![10, 20, 30], Some(15), 15)] // Limit within second batch + #[case(vec![10, 20, 30], Some(30), 30)] // Limit at second batch boundary + #[case(vec![10, 20, 30], Some(45), 45)] // Limit within third batch + #[case(vec![10, 20, 30], Some(60), 60)] // Limit at total rows + #[case(vec![10, 20, 30], Some(100), 60)] // Limit exceeds total rows + #[case(vec![0, 5, 0, 3], Some(6), 6)] // Empty batches mixed in, limit within data + #[case(vec![0, 5, 0, 3], Some(8), 8)] // Empty batches mixed in, limit equals total + #[case(vec![0, 5, 0, 3], None, 8)] // Empty batches mixed in, no limit + #[tokio::test] + async fn test_scan_with_row_limit( + #[case] batch_sizes: Vec, + #[case] limit: Option, + #[case] expected_rows: usize, + ) { + let ctx = SessionContext::new(); + + // Verify that the RecordBatchReaderExec node in the execution plan should contain the correct limit + let physical_plan = read_test_table_with_limit(&ctx, batch_sizes.clone(), limit) + .unwrap() + .create_physical_plan() + .await + .unwrap(); + let reader_exec = find_record_batch_reader_exec(physical_plan.as_ref()) + .expect("The plan should contain RecordBatchReaderExec"); + assert_eq!(reader_exec.limit, limit); + + let df = read_test_table_with_limit(&ctx, batch_sizes, limit).unwrap(); + let results = df.collect().await.unwrap(); + let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(total_rows, expected_rows); + + // Verify row values are correct (sequential IDs starting from 0) + if expected_rows > 0 { + let mut expected_id = 0i32; + for batch in results.iter() { + let id_array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..id_array.len() { + assert_eq!(id_array.value(i), expected_id); + expected_id += 1; + } + } + } + } + + fn read_test_table_with_limit( + ctx: &SessionContext, + batch_sizes: Vec, + limit: Option, + ) -> Result { + let reader = create_test_reader(batch_sizes); + let provider = Arc::new(RecordBatchReaderProvider::new(reader)); + let df = ctx.read_table(provider)?; + if let Some(limit) = limit { + df.limit(0, Some(limit)) + } else { + Ok(df) + } + } + + // Navigate through the plan structure to find our RecordBatchReaderExec + fn find_record_batch_reader_exec(plan: &dyn ExecutionPlan) -> Option<&RecordBatchReaderExec> { + if let Some(reader_exec) = plan.as_any().downcast_ref::() { + return Some(reader_exec); + } + + // Recursively search children + for child in plan.children() { + if let Some(reader_exec) = find_record_batch_reader_exec(child.as_ref()) { + return Some(reader_exec); + } + } + None + } }