diff --git a/datafusion/spark/src/function/math/hex.rs b/datafusion/spark/src/function/math/hex.rs index 134324f45f5b..06c77f37021b 100644 --- a/datafusion/spark/src/function/math/hex.rs +++ b/datafusion/spark/src/function/math/hex.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::str::from_utf8_unchecked; use std::sync::Arc; -use arrow::array::{Array, BinaryArray, Int64Array, StringArray, StringBuilder}; +use arrow::array::{Array, ArrayRef, StringBuilder}; use arrow::datatypes::DataType; use arrow::{ array::{as_dictionary_array, as_largestring_array, as_string_array}, @@ -92,11 +92,13 @@ impl ScalarUDFImpl for SparkHex { &self.signature } - fn return_type( - &self, - _arg_types: &[DataType], - ) -> datafusion_common::Result { - Ok(DataType::Utf8) + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(match &arg_types[0] { + DataType::Dictionary(key_type, _) => { + DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8)) + } + _ => DataType::Utf8, + }) } fn invoke_with_args( @@ -136,7 +138,7 @@ fn hex_encode_bytes<'a, I, T>( iter: I, lowercase: bool, len: usize, -) -> Result +) -> Result where I: Iterator>, T: AsRef<[u8]> + 'a, @@ -166,14 +168,14 @@ where } } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + Ok(Arc::new(builder.finish())) } /// Generic hex encoding for int64 type -fn hex_encode_int64(iter: I, len: usize) -> Result -where - I: Iterator>, -{ +fn hex_encode_int64( + iter: impl Iterator>, + len: usize, +) -> Result { let mut builder = StringBuilder::with_capacity(len, len * 16); for v in iter { @@ -189,7 +191,7 @@ where } } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + Ok(Arc::new(builder.finish())) } /// Spark-compatible `hex` function @@ -215,55 +217,109 @@ pub fn compute_hex( ColumnarValue::Array(array) => match array.data_type() { DataType::Int64 => { let array = as_int64_array(array)?; - hex_encode_int64(array.iter(), array.len()) + Ok(ColumnarValue::Array(hex_encode_int64( + array.iter(), + array.len(), + )?)) } DataType::Utf8 => { let array = as_string_array(array); - hex_encode_bytes(array.iter(), lowercase, array.len()) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::Utf8View => { let array = as_string_view_array(array)?; - hex_encode_bytes(array.iter(), lowercase, array.len()) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::LargeUtf8 => { let array = as_largestring_array(array); - hex_encode_bytes(array.iter(), lowercase, array.len()) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::Binary => { let array = as_binary_array(array)?; - hex_encode_bytes(array.iter(), lowercase, array.len()) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::LargeBinary => { let array = as_large_binary_array(array)?; - hex_encode_bytes(array.iter(), lowercase, array.len()) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } DataType::FixedSizeBinary(_) => { let array = as_fixed_size_binary_array(array)?; - hex_encode_bytes(array.iter(), lowercase, array.len()) + Ok(ColumnarValue::Array(hex_encode_bytes( + array.iter(), + lowercase, + array.len(), + )?)) } - DataType::Dictionary(_, value_type) => { + DataType::Dictionary(key_type, _) => { + if **key_type != DataType::Int32 { + return exec_err!( + "hex only supports Int32 dictionary keys, get: {}", + key_type + ); + } + let dict = as_dictionary_array::(&array); + let dict_values = dict.values(); - match **value_type { + let encoded_values = match dict_values.data_type() { DataType::Int64 => { - let arr = dict.downcast_dict::().unwrap(); - hex_encode_int64(arr.into_iter(), dict.len()) + let arr = as_int64_array(dict_values)?; + hex_encode_int64(arr.iter(), arr.len())? } DataType::Utf8 => { - let arr = dict.downcast_dict::().unwrap(); - hex_encode_bytes(arr.into_iter(), lowercase, dict.len()) + let arr = as_string_array(dict_values); + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::LargeUtf8 => { + let arr = as_largestring_array(dict_values); + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::Utf8View => { + let arr = as_string_view_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? } DataType::Binary => { - let arr = dict.downcast_dict::().unwrap(); - hex_encode_bytes(arr.into_iter(), lowercase, dict.len()) + let arr = as_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::LargeBinary => { + let arr = as_large_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? + } + DataType::FixedSizeBinary(_) => { + let arr = as_fixed_size_binary_array(dict_values)?; + hex_encode_bytes(arr.iter(), lowercase, arr.len())? } _ => { - exec_err!( + return exec_err!( "hex got an unexpected argument type: {}", - array.data_type() - ) + dict_values.data_type() + ); } - } + }; + + let new_dict = dict.with_values(encoded_values); + Ok(ColumnarValue::Array(Arc::new(new_dict))) } _ => exec_err!("hex got an unexpected argument type: {}", array.data_type()), }, @@ -279,11 +335,12 @@ mod test { use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray}; use arrow::{ array::{ - BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder, - StringDictionaryBuilder, as_string_array, + BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder, + as_string_array, }, datatypes::{Int32Type, Int64Type}, }; + use datafusion_common::cast::as_dictionary_array; use datafusion_expr::ColumnarValue; #[test] @@ -295,12 +352,12 @@ mod test { input_builder.append_value("rust"); let input = input_builder.finish(); - let mut string_builder = StringBuilder::new(); - string_builder.append_value("6869"); - string_builder.append_value("627965"); - string_builder.append_null(); - string_builder.append_value("72757374"); - let expected = string_builder.finish(); + let mut expected_builder = StringDictionaryBuilder::::new(); + expected_builder.append_value("6869"); + expected_builder.append_value("627965"); + expected_builder.append_null(); + expected_builder.append_value("72757374"); + let expected = expected_builder.finish(); let columnar_value = ColumnarValue::Array(Arc::new(input)); let result = super::spark_hex(&[columnar_value]).unwrap(); @@ -310,7 +367,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -324,12 +381,12 @@ mod test { input_builder.append_value(3); let input = input_builder.finish(); - let mut string_builder = StringBuilder::new(); - string_builder.append_value("1"); - string_builder.append_value("2"); - string_builder.append_null(); - string_builder.append_value("3"); - let expected = string_builder.finish(); + let mut expected_builder = StringDictionaryBuilder::::new(); + expected_builder.append_value("1"); + expected_builder.append_value("2"); + expected_builder.append_null(); + expected_builder.append_value("3"); + let expected = expected_builder.finish(); let columnar_value = ColumnarValue::Array(Arc::new(input)); let result = super::spark_hex(&[columnar_value]).unwrap(); @@ -339,7 +396,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -353,7 +410,7 @@ mod test { input_builder.append_value("3"); let input = input_builder.finish(); - let mut expected_builder = StringBuilder::new(); + let mut expected_builder = StringDictionaryBuilder::::new(); expected_builder.append_value("31"); expected_builder.append_value("6A"); expected_builder.append_null(); @@ -368,7 +425,7 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); + let result = as_dictionary_array(&result).unwrap(); assert_eq!(result, &expected); } @@ -425,8 +482,11 @@ mod test { _ => panic!("Expected array"), }; - let result = as_string_array(&result); - let expected = StringArray::from(vec![Some("20"), None, None]); + let result = as_dictionary_array(&result).unwrap(); + + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let vals = StringArray::from(vec![Some("20"), None]); + let expected = DictionaryArray::new(keys, Arc::new(vals)); assert_eq!(&expected, result); } diff --git a/datafusion/sqllogictest/test_files/spark/math/hex.slt b/datafusion/sqllogictest/test_files/spark/math/hex.slt index 05c9fb3f31b2..17e9ff432890 100644 --- a/datafusion/sqllogictest/test_files/spark/math/hex.slt +++ b/datafusion/sqllogictest/test_files/spark/math/hex.slt @@ -63,3 +63,23 @@ query T SELECT hex(arrow_cast('test', 'LargeBinary')) as lar_b; ---- 74657374 + +statement ok +CREATE TABLE t_dict_binary AS +SELECT arrow_cast(column1, 'Dictionary(Int32, Binary)') as dict_col +FROM VALUES ('foo'), ('bar'), ('foo'), (NULL), ('baz'), ('bar'); + +query T +SELECT hex(dict_col) FROM t_dict_binary; +---- +666F6F +626172 +666F6F +NULL +62617A +626172 + +query T +SELECT arrow_typeof(hex(dict_col)) FROM t_dict_binary LIMIT 1; +---- +Dictionary(Int32, Utf8)