diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 1578331e57f89..ea58a95cb1c6e 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -28,6 +28,11 @@ use datafusion_expr::{ TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; +use arrow::array::{ + Array, StringArray, LargeStringArray, Int32Builder, +}; +use std::sync::Arc; + #[user_doc( doc_section(label = "String Functions"), @@ -90,17 +95,44 @@ impl ScalarUDFImpl for BitLengthFunc { let [array] = take_function_args(self.name(), &args.args)?; match array { - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), + ColumnarValue::Array(v) => { + if let Some(arr) = v.as_any().downcast_ref::() { + let mut builder = Int32Builder::with_capacity(arr.len()); + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + let byte_len = arr.value(i).as_bytes().len(); + builder.append_value((byte_len * 8) as i32); + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } else if let Some(arr) = v.as_any().downcast_ref::() { + let mut builder = Int32Builder::with_capacity(arr.len()); + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + let byte_len = arr.value(i).as_bytes().len(); + builder.append_value((byte_len * 8) as i32); + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } else { + // fallback for Utf8View, Dictionary, Binary, etc. + Ok(ColumnarValue::Array(bit_length(v.as_ref())?)) + } + } ColumnarValue::Scalar(v) => match v { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( v.as_ref().map(|x| (x.len() * 8) as i32), ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), - )), - ScalarValue::Utf8View(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int32(v.as_ref().map(|x| (x.len() * 8) as i32)), - )), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + v.as_ref().map(|x| (x.len() * 8) as i64), + ))), + ScalarValue::Utf8View(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| (x.len() * 8) as i32), + ))), _ => unreachable!("bit length"), }, }