diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index 3732897f3d37..0a012b3215ba 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow::compute::kernels::length::length; use arrow::datatypes::DataType; use std::any::Any; use crate::utils::utf8_to_int_type; +use arrow::array::{ + Array, Int32Array, Int32Builder, Int64Builder, LargeStringArray, StringArray, + StringViewArray, +}; use datafusion_common::types::logical_string; use datafusion_common::utils::take_function_args; use datafusion_common::{Result, ScalarValue}; @@ -28,6 +31,7 @@ use datafusion_expr::{ TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; +use std::sync::Arc; #[user_doc( doc_section(label = "String Functions"), @@ -90,7 +94,39 @@ impl ScalarUDFImpl for OctetLengthFunc { let [array] = take_function_args(self.name(), &args.args)?; match array { - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), + ColumnarValue::Array(v) => { + if let Some(arr) = v.as_any().downcast_ref::() { + let mut builder = Int32Builder::with_capacity(arr.len()); + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + builder.append_value(arr.value_length(i)); + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } else if let Some(arr) = v.as_any().downcast_ref::() { + let mut builder = Int64Builder::with_capacity(arr.len()); + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + builder.append_value(arr.value_length(i)); + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } else if let Some(arr) = v.as_any().downcast_ref::() { + let result = arr + .iter() + .map(|s| s.map(|s| s.len() as i32)) + .collect::(); + + Ok(ColumnarValue::Array(Arc::new(result))) + } else { + unreachable!("octet_length expects string arrays") + } + } + ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( v.as_ref().map(|x| x.len() as i32),