diff --git a/rust/sedona/src/record_batch_reader_provider.rs b/rust/sedona/src/record_batch_reader_provider.rs
index 1d90a8e61..1832e93e7 100644
--- a/rust/sedona/src/record_batch_reader_provider.rs
+++ b/rust/sedona/src/record_batch_reader_provider.rs
@@ -14,7 +14,6 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
-use std::sync::RwLock;
use std::{any::Any, fmt::Debug, sync::Arc};
use arrow_array::RecordBatchReader;
@@ -33,6 +32,7 @@ use datafusion::{
prelude::Expr,
};
use datafusion_common::DataFusionError;
+use parking_lot::Mutex;
use sedona_common::sedona_internal_err;
/// A [TableProvider] wrapping a [RecordBatchReader]
@@ -42,7 +42,7 @@ use sedona_common::sedona_internal_err;
/// such that extension types are preserved in DataFusion internals (i.e.,
/// it is intended for scanning external tables as SedonaDB).
pub struct RecordBatchReaderProvider {
- reader: RwLock>>,
+ reader: Mutex >>,
schema: SchemaRef,
}
@@ -52,7 +52,7 @@ impl RecordBatchReaderProvider {
pub fn new(reader: Box) -> Self {
let schema = reader.schema();
Self {
- reader: RwLock::new(Some(reader)),
+ reader: Mutex::new(Some(reader)),
schema,
}
}
@@ -88,10 +88,8 @@ impl TableProvider for RecordBatchReaderProvider {
_filters: &[Expr],
limit: Option,
) -> Result> {
- let mut writable_reader = self.reader.try_write().map_err(|_| {
- DataFusionError::Internal("Failed to acquire lock on RecordBatchReader".to_string())
- })?;
- if let Some(reader) = writable_reader.take() {
+ let mut reader_guard = self.reader.lock();
+ if let Some(reader) = reader_guard.take() {
Ok(Arc::new(RecordBatchReaderExec::new(reader, limit)))
} else {
sedona_internal_err!("Can't scan RecordBatchReader provider more than once")
@@ -99,15 +97,69 @@ impl TableProvider for RecordBatchReaderProvider {
}
}
+/// An iterator that limits the number of rows from a RecordBatchReader
+struct RowLimitedIterator {
+ reader: Option>,
+ limit: usize,
+ rows_consumed: usize,
+}
+
+impl RowLimitedIterator {
+ fn new(reader: Box, limit: usize) -> Self {
+ Self {
+ reader: Some(reader),
+ limit,
+ rows_consumed: 0,
+ }
+ }
+}
+
+impl Iterator for RowLimitedIterator {
+ type Item = Result;
+
+ fn next(&mut self) -> Option {
+ // Check if we have already consumed enough rows
+ if self.rows_consumed >= self.limit {
+ self.reader = None;
+ return None;
+ }
+
+ let reader = self.reader.as_mut()?;
+ match reader.next() {
+ Some(Ok(batch)) => {
+ let batch_rows = batch.num_rows();
+
+ if self.rows_consumed + batch_rows <= self.limit {
+ // Batch fits within limit, consume it entirely
+ self.rows_consumed += batch_rows;
+ Some(Ok(batch))
+ } else {
+ // Batch would exceed limit, need to truncate it
+ let rows_to_take = self.limit - self.rows_consumed;
+ self.rows_consumed = self.limit;
+ self.reader = None;
+ Some(Ok(batch.slice(0, rows_to_take)))
+ }
+ }
+ Some(Err(e)) => {
+ self.reader = None;
+ Some(Err(DataFusionError::from(e)))
+ }
+ None => {
+ self.reader = None;
+ None
+ }
+ }
+ }
+}
+
struct RecordBatchReaderExec {
- reader: RwLock>>,
+ reader: Mutex >>,
schema: SchemaRef,
properties: PlanProperties,
limit: Option,
}
-unsafe impl Sync for RecordBatchReaderExec {}
-
impl RecordBatchReaderExec {
fn new(reader: Box, limit: Option) -> Self {
let schema = reader.schema();
@@ -119,7 +171,7 @@ impl RecordBatchReaderExec {
);
Self {
- reader: RwLock::new(Some(reader)),
+ reader: Mutex::new(Some(reader)),
schema,
properties,
limit,
@@ -177,29 +229,35 @@ impl ExecutionPlan for RecordBatchReaderExec {
_partition: usize,
_context: Arc,
) -> Result {
- let mut writable_reader = self.reader.try_write().map_err(|_| {
- DataFusionError::Internal("Failed to acquire lock on RecordBatchReader".to_string())
- })?;
+ let mut reader_guard = self.reader.lock();
- let reader = if let Some(reader) = writable_reader.take() {
+ let reader = if let Some(reader) = reader_guard.take() {
reader
} else {
return sedona_internal_err!("Can't scan RecordBatchReader provider more than once");
};
- let limit = self.limit;
-
- // Create a stream from the RecordBatchReader iterator
- let iter = reader
- .map(|item| match item {
- Ok(batch) => Ok(batch),
- Err(e) => Err(DataFusionError::from(e)),
- })
- .take(limit.unwrap_or(usize::MAX));
-
- let stream = Box::pin(futures::stream::iter(iter));
- let record_batch_stream = RecordBatchStreamAdapter::new(self.schema.clone(), stream);
- Ok(Box::pin(record_batch_stream))
+ match self.limit {
+ Some(limit) => {
+ // Create a row-limited iterator that properly handles row counting
+ let iter = RowLimitedIterator::new(reader, limit);
+ let stream = Box::pin(futures::stream::iter(iter));
+ let record_batch_stream =
+ RecordBatchStreamAdapter::new(self.schema.clone(), stream);
+ Ok(Box::pin(record_batch_stream))
+ }
+ None => {
+ // No limit, just convert the reader directly to a stream
+ let iter = reader.map(|item| match item {
+ Ok(batch) => Ok(batch),
+ Err(e) => Err(DataFusionError::from(e)),
+ });
+ let stream = Box::pin(futures::stream::iter(iter));
+ let record_batch_stream =
+ RecordBatchStreamAdapter::new(self.schema.clone(), stream);
+ Ok(Box::pin(record_batch_stream))
+ }
+ }
}
}
@@ -208,12 +266,40 @@ mod test {
use arrow_array::{RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
- use datafusion::prelude::SessionContext;
+ use datafusion::prelude::{DataFrame, SessionContext};
+ use rstest::rstest;
use sedona_schema::datatypes::WKB_GEOMETRY;
use sedona_testing::create::create_array_storage;
use super::*;
+ fn create_test_batch(size: usize, start_id: i32) -> RecordBatch {
+ let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
+ let ids: Vec = (start_id..start_id + size as i32).collect();
+ RecordBatch::try_new(
+ Arc::new(schema),
+ vec![Arc::new(arrow_array::Int32Array::from(ids))],
+ )
+ .unwrap()
+ }
+
+ fn create_test_reader(batch_sizes: Vec) -> Box {
+ let mut start_id = 0i32;
+ let batches: Vec = batch_sizes
+ .into_iter()
+ .map(|size| {
+ let batch = create_test_batch(size, start_id);
+ start_id += size as i32;
+ batch
+ })
+ .collect();
+ let schema = batches[0].schema();
+ Box::new(RecordBatchIterator::new(
+ batches.into_iter().map(Ok),
+ schema,
+ ))
+ }
+
#[tokio::test]
async fn provider() {
let ctx = SessionContext::new();
@@ -244,4 +330,86 @@ mod test {
let results = df.collect().await.unwrap();
assert_eq!(results, vec![batch])
}
+
+ #[rstest]
+ #[case(vec![10, 20, 30], None, 60)] // No limit
+ #[case(vec![10, 20, 30], Some(5), 5)] // Limit within first batch
+ #[case(vec![10, 20, 30], Some(10), 10)] // Limit exactly at first batch boundary
+ #[case(vec![10, 20, 30], Some(15), 15)] // Limit within second batch
+ #[case(vec![10, 20, 30], Some(30), 30)] // Limit at second batch boundary
+ #[case(vec![10, 20, 30], Some(45), 45)] // Limit within third batch
+ #[case(vec![10, 20, 30], Some(60), 60)] // Limit at total rows
+ #[case(vec![10, 20, 30], Some(100), 60)] // Limit exceeds total rows
+ #[case(vec![0, 5, 0, 3], Some(6), 6)] // Empty batches mixed in, limit within data
+ #[case(vec![0, 5, 0, 3], Some(8), 8)] // Empty batches mixed in, limit equals total
+ #[case(vec![0, 5, 0, 3], None, 8)] // Empty batches mixed in, no limit
+ #[tokio::test]
+ async fn test_scan_with_row_limit(
+ #[case] batch_sizes: Vec,
+ #[case] limit: Option,
+ #[case] expected_rows: usize,
+ ) {
+ let ctx = SessionContext::new();
+
+ // Verify that the RecordBatchReaderExec node in the execution plan should contain the correct limit
+ let physical_plan = read_test_table_with_limit(&ctx, batch_sizes.clone(), limit)
+ .unwrap()
+ .create_physical_plan()
+ .await
+ .unwrap();
+ let reader_exec = find_record_batch_reader_exec(physical_plan.as_ref())
+ .expect("The plan should contain RecordBatchReaderExec");
+ assert_eq!(reader_exec.limit, limit);
+
+ let df = read_test_table_with_limit(&ctx, batch_sizes, limit).unwrap();
+ let results = df.collect().await.unwrap();
+ let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum();
+ assert_eq!(total_rows, expected_rows);
+
+ // Verify row values are correct (sequential IDs starting from 0)
+ if expected_rows > 0 {
+ let mut expected_id = 0i32;
+ for batch in results.iter() {
+ let id_array = batch
+ .column(0)
+ .as_any()
+ .downcast_ref::()
+ .unwrap();
+ for i in 0..id_array.len() {
+ assert_eq!(id_array.value(i), expected_id);
+ expected_id += 1;
+ }
+ }
+ }
+ }
+
+ fn read_test_table_with_limit(
+ ctx: &SessionContext,
+ batch_sizes: Vec,
+ limit: Option,
+ ) -> Result {
+ let reader = create_test_reader(batch_sizes);
+ let provider = Arc::new(RecordBatchReaderProvider::new(reader));
+ let df = ctx.read_table(provider)?;
+ if let Some(limit) = limit {
+ df.limit(0, Some(limit))
+ } else {
+ Ok(df)
+ }
+ }
+
+ // Navigate through the plan structure to find our RecordBatchReaderExec
+ fn find_record_batch_reader_exec(plan: &dyn ExecutionPlan) -> Option<&RecordBatchReaderExec> {
+ if let Some(reader_exec) = plan.as_any().downcast_ref::() {
+ return Some(reader_exec);
+ }
+
+ // Recursively search children
+ for child in plan.children() {
+ if let Some(reader_exec) = find_record_batch_reader_exec(child.as_ref()) {
+ return Some(reader_exec);
+ }
+ }
+ None
+ }
}