From 5a0120c55fcf8463422833e317e1531fe6b9ce88 Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Wed, 19 Nov 2025 12:50:04 +0000 Subject: [PATCH 01/13] Write test in async_udf.rs --- Cargo.lock | 1 + datafusion-testing | 2 +- datafusion/expr/Cargo.toml | 1 + datafusion/expr/src/async_udf.rs | 134 ++++++++++++++++++++----------- 4 files changed, 89 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1afa1e349167..ed2572f53c52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2273,6 +2273,7 @@ dependencies = [ "recursive", "serde_json", "sqlparser", + "tokio", ] [[package]] diff --git a/datafusion-testing b/datafusion-testing index eccb0e4a4263..905df5f65cc9 160000 --- a/datafusion-testing +++ b/datafusion-testing @@ -1 +1 @@ -Subproject commit eccb0e4a426344ef3faf534cd60e02e9c3afd3ac +Subproject commit 905df5f65cc9d0851719c21f5a4dd5cd77621f19 diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 11d6ca1533db..759519da3f78 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -66,3 +66,4 @@ sqlparser = { workspace = true, optional = true } ctor = { workspace = true } env_logger = { workspace = true } insta = { workspace = true } +tokio = { workspace = true } diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index 561ef1dc15e7..3152734fe253 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -140,87 +140,91 @@ mod tests { sync::Arc, }; - use arrow::datatypes::DataType; + use arrow::array::{Array, StringArray}; + use arrow::datatypes::{DataType, Field}; use async_trait::async_trait; + use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; - use datafusion_expr_common::{columnar_value::ColumnarValue, signature::Signature}; + use datafusion_common::ScalarValue; + use datafusion_expr_common::columnar_value::ColumnarValue; + use datafusion_expr_common::signature::{Signature, Volatility}; use crate::{ async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}, ScalarFunctionArgs, ScalarUDFImpl, }; - #[derive(Debug, PartialEq, Eq, Hash, Clone)] - struct TestAsyncUDFImpl1 { - a: i32, - } - - impl ScalarUDFImpl for TestAsyncUDFImpl1 { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn name(&self) -> &str { - todo!() + /// Helper function to convert ColumnarValue to Vec + fn columnar_to_vec_string(cv: &ColumnarValue) -> Result> { + match cv { + ColumnarValue::Array(arr) => { + let string_arr = arr.as_any().downcast_ref::().unwrap(); + Ok(string_arr + .iter() + .map(|s| s.unwrap_or("").to_string()) + .collect()) + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => Ok(vec![s.clone()]), + _ => panic!("Unexpected type"), } + } - fn signature(&self) -> &Signature { - todo!() - } + /// Simulates calling an async external service + async fn call_external_service(arg1: &ColumnarValue) -> Result> { + let vec1 = columnar_to_vec_string(arg1)?; - fn return_type(&self, _arg_types: &[DataType]) -> Result { - todo!() - } - - fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { - todo!() - } + Ok(vec1) } - #[async_trait] - impl AsyncScalarUDFImpl for TestAsyncUDFImpl1 { - async fn invoke_async_with_args( - &self, - _args: ScalarFunctionArgs, - ) -> Result { - todo!() - } + #[derive(Debug, PartialEq, Eq, Hash, Clone)] + struct TestAsyncUDFImpl { + batch_size: usize, + signature: Signature, } - #[derive(Debug, PartialEq, Eq, Hash, Clone)] - struct TestAsyncUDFImpl2 { - a: i32, + impl TestAsyncUDFImpl { + fn new(batch_size: usize) -> Self { + Self { + batch_size, + signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), + } + } } - impl ScalarUDFImpl for TestAsyncUDFImpl2 { + impl ScalarUDFImpl for TestAsyncUDFImpl { fn as_any(&self) -> &dyn std::any::Any { self } fn name(&self) -> &str { - todo!() + "test_async_udf" } fn signature(&self) -> &Signature { - todo!() + &self.signature } fn return_type(&self, _arg_types: &[DataType]) -> Result { - todo!() + Ok(DataType::Utf8) } fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { - todo!() + panic!("Call invoke_async_with_args instead") } } #[async_trait] - impl AsyncScalarUDFImpl for TestAsyncUDFImpl2 { + impl AsyncScalarUDFImpl for TestAsyncUDFImpl { + fn ideal_batch_size(&self) -> Option { + Some(self.batch_size) + } async fn invoke_async_with_args( &self, - _args: ScalarFunctionArgs, + args: ScalarFunctionArgs, ) -> Result { - todo!() + let arg1 = &args.args[0]; + let results = call_external_service(arg1).await?; + Ok(ColumnarValue::Array(Arc::new(StringArray::from(results)))) } } @@ -233,7 +237,7 @@ mod tests { #[test] fn test_async_udf_partial_eq_and_hash() { // Inner is same cloned arc -> equal - let inner = Arc::new(TestAsyncUDFImpl1 { a: 1 }); + let inner = Arc::new(TestAsyncUDFImpl { a: 1 }); let a = AsyncScalarUDF::new(Arc::clone(&inner) as Arc); let b = AsyncScalarUDF::new(inner); assert_eq!(a, b); @@ -246,15 +250,49 @@ mod tests { assert_eq!(hash(&a), hash(&b)); // Negative case: inner is different value -> not equal - let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); - let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 2 })); + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(1))); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(2))); assert_ne!(a, b); assert_ne!(hash(&a), hash(&b)); // Negative case: different functions -> not equal - let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); - let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl2 { a: 1 })); + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(1))); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(1))); assert_ne!(a, b); assert_ne!(hash(&a), hash(&b)); } + + #[tokio::test] + async fn test_async_udf_with_ideal_batch_size() { + // Create async UDF with ideal batch size of 2 + let udf = TestAsyncUDFImpl::new(2); + assert_eq!(udf.ideal_batch_size(), Some(2)); + + // Create test data with 3 rows, because 3 % 2 != 0 + let test_data = vec!["a", "b", "c"]; + let input = ColumnarValue::Array(Arc::new(StringArray::from(test_data.clone()))); + + let args = ScalarFunctionArgs { + args: vec![input], + arg_fields: vec![Arc::new(Field::new("arg", DataType::Utf8, false))], + number_rows: test_data.len(), + return_field: Arc::new(Field::new("result", DataType::Utf8, false)), + config_options: Arc::new(ConfigOptions::default()), + }; + + // Invoke the async function - it should handle all rows + let result = udf.invoke_async_with_args(args).await.unwrap(); + + // Verify all rows are processed + match result { + ColumnarValue::Array(arr) => { + let string_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(string_arr.len(), 3); + for (i, expected) in test_data.iter().enumerate() { + assert_eq!(string_arr.value(i), *expected); + } + } + _ => panic!("Expected array result"), + } + } } From b77c21b2c281bb1e1e16632e21347680446b9b4e Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Wed, 19 Nov 2025 14:40:57 +0000 Subject: [PATCH 02/13] move test --- datafusion/core/tests/user_defined/mod.rs | 3 + .../user_defined_async_scalar_functions.rs | 150 ++++++++++++++++++ datafusion/expr/src/async_udf.rs | 16 +- 3 files changed, 164 insertions(+), 5 deletions(-) create mode 100644 datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 5d84cdb69283..515a82ab6c30 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +/// Tests for user defined Async Scalar functions +mod user_defined_async_scalar_functions; + /// Tests for user defined Scalar functions mod user_defined_scalar_functions; diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs new file mode 100644 index 000000000000..72a1bfc25a84 --- /dev/null +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::time::Duration; + +use arrow::array::{Int32Array, RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use async_trait::async_trait; +use datafusion::prelude::*; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; + +#[tokio::test] +async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { + let num_rows = 3; + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("prompt", DataType::Utf8, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from((0..num_rows).collect::>())), + Arc::new(StringArray::from( + (0..num_rows) + .map(|i| format!("prompt{}", i)) + .collect::>(), + )), + ], + )?; + + println!("Created test data with {} rows\n", batch.num_rows()); + + // Create context and register UDF + let ctx = SessionContext::new(); + ctx.register_batch("test_table", batch)?; + + ctx.register_udf( + AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(2))).into_scalar_udf(), + ); + + // Execute query + println!("Executing query...\n"); + let df = ctx + .sql("SELECT id, test_async_udf(prompt) as result FROM test_table") + .await?; + + let results = df.collect().await?; + + println!("=== Final Results ==="); + for batch in results { + println!("Result batch has {} rows", batch.num_rows()); + println!("{:?}", batch); + } + + Ok(()) +} + +/// Helper function to convert ColumnarValue to Vec +fn columnar_to_vec_string(cv: &ColumnarValue) -> Result> { + match cv { + ColumnarValue::Array(arr) => { + let string_arr = arr.as_any().downcast_ref::().unwrap(); + Ok(string_arr + .iter() + .map(|s| s.unwrap_or("").to_string()) + .collect()) + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => Ok(vec![s.clone()]), + _ => panic!("Unexpected type"), + } +} + +/// Simulates calling an async external service +async fn call_external_service(arg1: &ColumnarValue) -> Result> { + let vec1 = columnar_to_vec_string(arg1)?; + tokio::time::sleep(Duration::from_millis(10)).await; + Ok(vec1) +} + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +struct TestAsyncUDFImpl { + batch_size: usize, + signature: Signature, +} + +impl TestAsyncUDFImpl { + fn new(batch_size: usize) -> Self { + Self { + batch_size, + signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), + } + } +} + +impl ScalarUDFImpl for TestAsyncUDFImpl { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "test_async_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + panic!("Call invoke_async_with_args instead") + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for TestAsyncUDFImpl { + fn ideal_batch_size(&self) -> Option { + Some(self.batch_size) + } + async fn invoke_async_with_args( + &self, + args: ScalarFunctionArgs, + ) -> Result { + let arg1 = &args.args[0]; + let results = call_external_service(arg1).await?; + Ok(ColumnarValue::Array(Arc::new(StringArray::from(results)))) + } +} diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index 3152734fe253..059154f7e01f 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -140,8 +140,14 @@ mod tests { sync::Arc, }; - use arrow::array::{Array, StringArray}; - use arrow::datatypes::{DataType, Field}; + use arrow::{ + array::Int32Array, + datatypes::{DataType, Field}, + }; + use arrow::{ + array::{Array, RecordBatch, StringArray}, + datatypes::Schema, + }; use async_trait::async_trait; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; @@ -237,15 +243,15 @@ mod tests { #[test] fn test_async_udf_partial_eq_and_hash() { // Inner is same cloned arc -> equal - let inner = Arc::new(TestAsyncUDFImpl { a: 1 }); + let inner = Arc::new(TestAsyncUDFImpl::new(1)); let a = AsyncScalarUDF::new(Arc::clone(&inner) as Arc); let b = AsyncScalarUDF::new(inner); assert_eq!(a, b); assert_eq!(hash(&a), hash(&b)); // Inner is distinct arc -> still equal - let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); - let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(1))); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(1))); assert_eq!(a, b); assert_eq!(hash(&a), hash(&b)); From 2a026b9adb0c3fd2f5fc966996010ed966f57340 Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Wed, 19 Nov 2025 14:42:51 +0000 Subject: [PATCH 03/13] revert async_udf changes --- Cargo.lock | 1 - datafusion/expr/Cargo.toml | 1 - datafusion/expr/src/async_udf.rs | 144 +++++++++++-------------------- 3 files changed, 50 insertions(+), 96 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ed2572f53c52..1afa1e349167 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2273,7 +2273,6 @@ dependencies = [ "recursive", "serde_json", "sqlparser", - "tokio", ] [[package]] diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 759519da3f78..11d6ca1533db 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -66,4 +66,3 @@ sqlparser = { workspace = true, optional = true } ctor = { workspace = true } env_logger = { workspace = true } insta = { workspace = true } -tokio = { workspace = true } diff --git a/datafusion/expr/src/async_udf.rs b/datafusion/expr/src/async_udf.rs index 059154f7e01f..561ef1dc15e7 100644 --- a/datafusion/expr/src/async_udf.rs +++ b/datafusion/expr/src/async_udf.rs @@ -140,97 +140,87 @@ mod tests { sync::Arc, }; - use arrow::{ - array::Int32Array, - datatypes::{DataType, Field}, - }; - use arrow::{ - array::{Array, RecordBatch, StringArray}, - datatypes::Schema, - }; + use arrow::datatypes::DataType; use async_trait::async_trait; - use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; - use datafusion_common::ScalarValue; - use datafusion_expr_common::columnar_value::ColumnarValue; - use datafusion_expr_common::signature::{Signature, Volatility}; + use datafusion_expr_common::{columnar_value::ColumnarValue, signature::Signature}; use crate::{ async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}, ScalarFunctionArgs, ScalarUDFImpl, }; - /// Helper function to convert ColumnarValue to Vec - fn columnar_to_vec_string(cv: &ColumnarValue) -> Result> { - match cv { - ColumnarValue::Array(arr) => { - let string_arr = arr.as_any().downcast_ref::().unwrap(); - Ok(string_arr - .iter() - .map(|s| s.unwrap_or("").to_string()) - .collect()) - } - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => Ok(vec![s.clone()]), - _ => panic!("Unexpected type"), - } + #[derive(Debug, PartialEq, Eq, Hash, Clone)] + struct TestAsyncUDFImpl1 { + a: i32, } - /// Simulates calling an async external service - async fn call_external_service(arg1: &ColumnarValue) -> Result> { - let vec1 = columnar_to_vec_string(arg1)?; + impl ScalarUDFImpl for TestAsyncUDFImpl1 { + fn as_any(&self) -> &dyn std::any::Any { + self + } - Ok(vec1) - } + fn name(&self) -> &str { + todo!() + } - #[derive(Debug, PartialEq, Eq, Hash, Clone)] - struct TestAsyncUDFImpl { - batch_size: usize, - signature: Signature, + fn signature(&self) -> &Signature { + todo!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + todo!() + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + todo!() + } } - impl TestAsyncUDFImpl { - fn new(batch_size: usize) -> Self { - Self { - batch_size, - signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), - } + #[async_trait] + impl AsyncScalarUDFImpl for TestAsyncUDFImpl1 { + async fn invoke_async_with_args( + &self, + _args: ScalarFunctionArgs, + ) -> Result { + todo!() } } - impl ScalarUDFImpl for TestAsyncUDFImpl { + #[derive(Debug, PartialEq, Eq, Hash, Clone)] + struct TestAsyncUDFImpl2 { + a: i32, + } + + impl ScalarUDFImpl for TestAsyncUDFImpl2 { fn as_any(&self) -> &dyn std::any::Any { self } fn name(&self) -> &str { - "test_async_udf" + todo!() } fn signature(&self) -> &Signature { - &self.signature + todo!() } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Utf8) + todo!() } fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { - panic!("Call invoke_async_with_args instead") + todo!() } } #[async_trait] - impl AsyncScalarUDFImpl for TestAsyncUDFImpl { - fn ideal_batch_size(&self) -> Option { - Some(self.batch_size) - } + impl AsyncScalarUDFImpl for TestAsyncUDFImpl2 { async fn invoke_async_with_args( &self, - args: ScalarFunctionArgs, + _args: ScalarFunctionArgs, ) -> Result { - let arg1 = &args.args[0]; - let results = call_external_service(arg1).await?; - Ok(ColumnarValue::Array(Arc::new(StringArray::from(results)))) + todo!() } } @@ -243,62 +233,28 @@ mod tests { #[test] fn test_async_udf_partial_eq_and_hash() { // Inner is same cloned arc -> equal - let inner = Arc::new(TestAsyncUDFImpl::new(1)); + let inner = Arc::new(TestAsyncUDFImpl1 { a: 1 }); let a = AsyncScalarUDF::new(Arc::clone(&inner) as Arc); let b = AsyncScalarUDF::new(inner); assert_eq!(a, b); assert_eq!(hash(&a), hash(&b)); // Inner is distinct arc -> still equal - let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(1))); - let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(1))); + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); assert_eq!(a, b); assert_eq!(hash(&a), hash(&b)); // Negative case: inner is different value -> not equal - let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(1))); - let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(2))); + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 2 })); assert_ne!(a, b); assert_ne!(hash(&a), hash(&b)); // Negative case: different functions -> not equal - let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(1))); - let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(1))); + let a = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl1 { a: 1 })); + let b = AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl2 { a: 1 })); assert_ne!(a, b); assert_ne!(hash(&a), hash(&b)); } - - #[tokio::test] - async fn test_async_udf_with_ideal_batch_size() { - // Create async UDF with ideal batch size of 2 - let udf = TestAsyncUDFImpl::new(2); - assert_eq!(udf.ideal_batch_size(), Some(2)); - - // Create test data with 3 rows, because 3 % 2 != 0 - let test_data = vec!["a", "b", "c"]; - let input = ColumnarValue::Array(Arc::new(StringArray::from(test_data.clone()))); - - let args = ScalarFunctionArgs { - args: vec![input], - arg_fields: vec![Arc::new(Field::new("arg", DataType::Utf8, false))], - number_rows: test_data.len(), - return_field: Arc::new(Field::new("result", DataType::Utf8, false)), - config_options: Arc::new(ConfigOptions::default()), - }; - - // Invoke the async function - it should handle all rows - let result = udf.invoke_async_with_args(args).await.unwrap(); - - // Verify all rows are processed - match result { - ColumnarValue::Array(arr) => { - let string_arr = arr.as_any().downcast_ref::().unwrap(); - assert_eq!(string_arr.len(), 3); - for (i, expected) in test_data.iter().enumerate() { - assert_eq!(string_arr.value(i), *expected); - } - } - _ => panic!("Expected array result"), - } - } } From a9f56d3881a91affdbcb5d8790aca59068cdb2d2 Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Wed, 19 Nov 2025 14:56:50 +0000 Subject: [PATCH 04/13] fix --- .../physical-expr/src/async_scalar_function.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index b434694a20cc..eb240c2cbda7 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -192,10 +192,18 @@ impl AsyncFuncExpr { ); } - let datas = ColumnarValue::values_to_arrays(&result_batches)? + let datas = result_batches .iter() - .map(|b| b.to_data()) - .collect::>(); + .map(|cv| match cv { + ColumnarValue::Array(arr) => Ok(arr.to_data()), + ColumnarValue::Scalar(scalar) => { + // This shouldn't happen in practice since async UDFs should return arrays, + // but handle it for completeness + Ok(scalar.to_array_of_size(1)?.to_data()) + } + }) + .collect::>>()?; + let total_len = datas.iter().map(|d| d.len()).sum(); let mut mutable = MutableArrayData::new(datas.iter().collect(), false, total_len); datas.iter().enumerate().for_each(|(i, data)| { From a4c3cf6878746b5566f3777715692b721221d901 Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Wed, 19 Nov 2025 15:17:54 +0000 Subject: [PATCH 05/13] improve test --- .../user_defined_async_scalar_functions.rs | 56 ++++++++++--------- .../src/async_scalar_function.rs | 2 - 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs index 72a1bfc25a84..d04bcf0227b8 100644 --- a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -22,15 +22,19 @@ use arrow::array::{Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use async_trait::async_trait; use datafusion::prelude::*; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{assert_batches_eq, exec_err, Result, ScalarValue}; use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; +// This test checks the case where batch_size doesn't evenly divide +// the number of rows. #[tokio::test] async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { let num_rows = 3; + let batch_size = 2; + let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false), Field::new("prompt", DataType::Utf8, false), @@ -48,51 +52,49 @@ async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { ], )?; - println!("Created test data with {} rows\n", batch.num_rows()); - - // Create context and register UDF let ctx = SessionContext::new(); ctx.register_batch("test_table", batch)?; ctx.register_udf( - AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(2))).into_scalar_udf(), + AsyncScalarUDF::new(Arc::new(TestAsyncUDFImpl::new(batch_size))) + .into_scalar_udf(), ); - // Execute query - println!("Executing query...\n"); let df = ctx .sql("SELECT id, test_async_udf(prompt) as result FROM test_table") .await?; - let results = df.collect().await?; - - println!("=== Final Results ==="); - for batch in results { - println!("Result batch has {} rows", batch.num_rows()); - println!("{:?}", batch); - } + let result = df.collect().await?; + + assert_batches_eq!( + &[ + "+----+---------+", + "| id | result |", + "+----+---------+", + "| 0 | prompt0 |", + "| 1 | prompt1 |", + "| 2 | prompt2 |", + "+----+---------+" + ], + &result + ); Ok(()) } -/// Helper function to convert ColumnarValue to Vec -fn columnar_to_vec_string(cv: &ColumnarValue) -> Result> { - match cv { +/// Simulates calling an async external service +async fn call_external_service(arg1: &ColumnarValue) -> Result> { + let vec1 = match arg1 { ColumnarValue::Array(arr) => { let string_arr = arr.as_any().downcast_ref::().unwrap(); - Ok(string_arr + string_arr .iter() .map(|s| s.unwrap_or("").to_string()) - .collect()) + .collect() } - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => Ok(vec![s.clone()]), - _ => panic!("Unexpected type"), - } -} - -/// Simulates calling an async external service -async fn call_external_service(arg1: &ColumnarValue) -> Result> { - let vec1 = columnar_to_vec_string(arg1)?; + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => vec![s.clone()], + _ => return exec_err!("Unexpected data type for arg1"), + }; tokio::time::sleep(Duration::from_millis(10)).await; Ok(vec1) } diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index eb240c2cbda7..1a794f411bf0 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -197,8 +197,6 @@ impl AsyncFuncExpr { .map(|cv| match cv { ColumnarValue::Array(arr) => Ok(arr.to_data()), ColumnarValue::Scalar(scalar) => { - // This shouldn't happen in practice since async UDFs should return arrays, - // but handle it for completeness Ok(scalar.to_array_of_size(1)?.to_data()) } }) From d5bdbdeca94d826bd390db333cde754082cb68c6 Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Wed, 19 Nov 2025 15:19:51 +0000 Subject: [PATCH 06/13] clean up --- .../user_defined_async_scalar_functions.rs | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs index d04bcf0227b8..c77ce6a32d2c 100644 --- a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -82,23 +82,6 @@ async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { Ok(()) } -/// Simulates calling an async external service -async fn call_external_service(arg1: &ColumnarValue) -> Result> { - let vec1 = match arg1 { - ColumnarValue::Array(arr) => { - let string_arr = arr.as_any().downcast_ref::().unwrap(); - string_arr - .iter() - .map(|s| s.unwrap_or("").to_string()) - .collect() - } - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => vec![s.clone()], - _ => return exec_err!("Unexpected data type for arg1"), - }; - tokio::time::sleep(Duration::from_millis(10)).await; - Ok(vec1) -} - #[derive(Debug, PartialEq, Eq, Hash, Clone)] struct TestAsyncUDFImpl { batch_size: usize, @@ -150,3 +133,20 @@ impl AsyncScalarUDFImpl for TestAsyncUDFImpl { Ok(ColumnarValue::Array(Arc::new(StringArray::from(results)))) } } + +/// Simulates calling an async external service +async fn call_external_service(arg1: &ColumnarValue) -> Result> { + let vec1 = match arg1 { + ColumnarValue::Array(arr) => { + let string_arr = arr.as_any().downcast_ref::().unwrap(); + string_arr + .iter() + .map(|s| s.unwrap_or("").to_string()) + .collect() + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => vec![s.clone()], + _ => return exec_err!("Unexpected data type for arg1"), + }; + tokio::time::sleep(Duration::from_millis(10)).await; + Ok(vec1) +} From 2dd8b25ca9508de8a30525e8339dd2d9032ade49 Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Wed, 19 Nov 2025 15:32:13 +0000 Subject: [PATCH 07/13] simplify test --- .../user_defined_async_scalar_functions.rs | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs index c77ce6a32d2c..f4fcf3959f82 100644 --- a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -22,7 +22,7 @@ use arrow::array::{Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use async_trait::async_trait; use datafusion::prelude::*; -use datafusion_common::{assert_batches_eq, exec_err, Result, ScalarValue}; +use datafusion_common::{assert_batches_eq, Result}; use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -129,24 +129,13 @@ impl AsyncScalarUDFImpl for TestAsyncUDFImpl { args: ScalarFunctionArgs, ) -> Result { let arg1 = &args.args[0]; - let results = call_external_service(arg1).await?; - Ok(ColumnarValue::Array(Arc::new(StringArray::from(results)))) + let results = call_external_service(arg1.clone()).await?; + Ok(results) } } /// Simulates calling an async external service -async fn call_external_service(arg1: &ColumnarValue) -> Result> { - let vec1 = match arg1 { - ColumnarValue::Array(arr) => { - let string_arr = arr.as_any().downcast_ref::().unwrap(); - string_arr - .iter() - .map(|s| s.unwrap_or("").to_string()) - .collect() - } - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => vec![s.clone()], - _ => return exec_err!("Unexpected data type for arg1"), - }; +async fn call_external_service(arg1: ColumnarValue) -> Result { tokio::time::sleep(Duration::from_millis(10)).await; - Ok(vec1) + Ok(arg1) } From bed69334108722f5a94014ff9ee54acb22162db8 Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Thu, 20 Nov 2025 09:31:33 +0000 Subject: [PATCH 08/13] remove sleep --- .../tests/user_defined/user_defined_async_scalar_functions.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs index f4fcf3959f82..b0e6563e6620 100644 --- a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -136,6 +136,5 @@ impl AsyncScalarUDFImpl for TestAsyncUDFImpl { /// Simulates calling an async external service async fn call_external_service(arg1: ColumnarValue) -> Result { - tokio::time::sleep(Duration::from_millis(10)).await; Ok(arg1) } From f3efa4a3a3033af708f00d939da5437a4d94c374 Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Thu, 20 Nov 2025 10:38:30 +0000 Subject: [PATCH 09/13] clippy --- .../tests/user_defined/user_defined_async_scalar_functions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs index f4fcf3959f82..09e9564f8472 100644 --- a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -46,7 +46,7 @@ async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { Arc::new(Int32Array::from((0..num_rows).collect::>())), Arc::new(StringArray::from( (0..num_rows) - .map(|i| format!("prompt{}", i)) + .map(|i| format!("prompt{i}")) .collect::>(), )), ], From 7c977f650f29a2d068f84bdc7597a04bb3c1eebe Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Thu, 20 Nov 2025 11:25:28 +0000 Subject: [PATCH 10/13] unused import --- .../tests/user_defined/user_defined_async_scalar_functions.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs index 00b0e28333ef..5b9585170a44 100644 --- a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -16,7 +16,6 @@ // under the License. use std::sync::Arc; -use std::time::Duration; use arrow::array::{Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; From 59d407c86496f7af9145c6501be6075a57119e70 Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Thu, 20 Nov 2025 12:02:09 +0000 Subject: [PATCH 11/13] revert datafusion-testing changes --- datafusion-testing | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion-testing b/datafusion-testing index 905df5f65cc9..eccb0e4a4263 160000 --- a/datafusion-testing +++ b/datafusion-testing @@ -1 +1 @@ -Subproject commit 905df5f65cc9d0851719c21f5a4dd5cd77621f19 +Subproject commit eccb0e4a426344ef3faf534cd60e02e9c3afd3ac From 1698f9323598509b34fee78110bb453857688b9b Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Fri, 21 Nov 2025 08:43:59 +0000 Subject: [PATCH 12/13] use concat --- .../src/async_scalar_function.rs | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index 1a794f411bf0..6acc67acd829 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -16,7 +16,6 @@ // under the License. use crate::ScalarFunctionExpr; -use arrow::array::{make_array, MutableArrayData, RecordBatch}; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::Result; @@ -29,6 +28,8 @@ use std::any::Any; use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use arrow::array::RecordBatch; +use arrow::compute::concat; /// Wrapper around a scalar function that can be evaluated asynchronously #[derive(Debug, Clone, Eq)] @@ -193,22 +194,20 @@ impl AsyncFuncExpr { } let datas = result_batches - .iter() + .into_iter() .map(|cv| match cv { - ColumnarValue::Array(arr) => Ok(arr.to_data()), - ColumnarValue::Scalar(scalar) => { - Ok(scalar.to_array_of_size(1)?.to_data()) - } + ColumnarValue::Array(arr) => Ok(arr), + ColumnarValue::Scalar(scalar) => Ok(scalar.to_array_of_size(1)?), }) .collect::>>()?; - let total_len = datas.iter().map(|d| d.len()).sum(); - let mut mutable = MutableArrayData::new(datas.iter().collect(), false, total_len); - datas.iter().enumerate().for_each(|(i, data)| { - mutable.extend(i, 0, data.len()); - }); - let array_ref = make_array(mutable.freeze()); - Ok(ColumnarValue::Array(array_ref)) + // Get references to the arrays as dyn Array to call concat + let dyn_arrays = datas + .iter() + .map(|arr| arr as &dyn arrow::array::Array) + .collect::>(); + let result_array = concat(&dyn_arrays)?; + Ok(ColumnarValue::Array(result_array)) } } From 33dfd8b2f676e9c807d6b7e1e81f7dc882a7ca0c Mon Sep 17 00:00:00 2001 From: Shiv Bhatia Date: Fri, 21 Nov 2025 08:51:08 +0000 Subject: [PATCH 13/13] cargo fmt --- datafusion/physical-expr/src/async_scalar_function.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index 6acc67acd829..f1833666d6bb 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -16,6 +16,8 @@ // under the License. use crate::ScalarFunctionExpr; +use arrow::array::RecordBatch; +use arrow::compute::concat; use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::Result; @@ -28,8 +30,6 @@ use std::any::Any; use std::fmt::Display; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use arrow::array::RecordBatch; -use arrow::compute::concat; /// Wrapper around a scalar function that can be evaluated asynchronously #[derive(Debug, Clone, Eq)]