Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 112 additions & 52 deletions datafusion/spark/src/function/math/hex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::any::Any;
use std::str::from_utf8_unchecked;
use std::sync::Arc;

use arrow::array::{Array, BinaryArray, Int64Array, StringArray, StringBuilder};
use arrow::array::{Array, ArrayRef, StringBuilder};
use arrow::datatypes::DataType;
use arrow::{
array::{as_dictionary_array, as_largestring_array, as_string_array},
Expand Down Expand Up @@ -92,11 +92,13 @@ impl ScalarUDFImpl for SparkHex {
&self.signature
}

fn return_type(
&self,
_arg_types: &[DataType],
) -> datafusion_common::Result<DataType> {
Ok(DataType::Utf8)
fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result<DataType> {
Ok(match &arg_types[0] {
DataType::Dictionary(key_type, _) => {
DataType::Dictionary(key_type.clone(), Box::new(DataType::Utf8))
}
_ => DataType::Utf8,
})
}

fn invoke_with_args(
Expand Down Expand Up @@ -136,7 +138,7 @@ fn hex_encode_bytes<'a, I, T>(
iter: I,
lowercase: bool,
len: usize,
) -> Result<ColumnarValue, DataFusionError>
) -> Result<ArrayRef, DataFusionError>
where
I: Iterator<Item = Option<T>>,
T: AsRef<[u8]> + 'a,
Expand Down Expand Up @@ -166,14 +168,14 @@ where
}
}

Ok(ColumnarValue::Array(Arc::new(builder.finish())))
Ok(Arc::new(builder.finish()))
}

/// Generic hex encoding for int64 type
fn hex_encode_int64<I>(iter: I, len: usize) -> Result<ColumnarValue, DataFusionError>
where
I: Iterator<Item = Option<i64>>,
{
fn hex_encode_int64(
iter: impl Iterator<Item = Option<i64>>,
len: usize,
) -> Result<ArrayRef, DataFusionError> {
let mut builder = StringBuilder::with_capacity(len, len * 16);

for v in iter {
Expand All @@ -189,7 +191,7 @@ where
}
}

Ok(ColumnarValue::Array(Arc::new(builder.finish())))
Ok(Arc::new(builder.finish()))
}

/// Spark-compatible `hex` function
Expand All @@ -215,55 +217,109 @@ pub fn compute_hex(
ColumnarValue::Array(array) => match array.data_type() {
DataType::Int64 => {
let array = as_int64_array(array)?;
hex_encode_int64(array.iter(), array.len())
Ok(ColumnarValue::Array(hex_encode_int64(
array.iter(),
array.len(),
)?))
}
DataType::Utf8 => {
let array = as_string_array(array);
hex_encode_bytes(array.iter(), lowercase, array.len())
Ok(ColumnarValue::Array(hex_encode_bytes(
array.iter(),
lowercase,
array.len(),
)?))
}
DataType::Utf8View => {
let array = as_string_view_array(array)?;
hex_encode_bytes(array.iter(), lowercase, array.len())
Ok(ColumnarValue::Array(hex_encode_bytes(
array.iter(),
lowercase,
array.len(),
)?))
}
DataType::LargeUtf8 => {
let array = as_largestring_array(array);
hex_encode_bytes(array.iter(), lowercase, array.len())
Ok(ColumnarValue::Array(hex_encode_bytes(
array.iter(),
lowercase,
array.len(),
)?))
}
DataType::Binary => {
let array = as_binary_array(array)?;
hex_encode_bytes(array.iter(), lowercase, array.len())
Ok(ColumnarValue::Array(hex_encode_bytes(
array.iter(),
lowercase,
array.len(),
)?))
}
DataType::LargeBinary => {
let array = as_large_binary_array(array)?;
hex_encode_bytes(array.iter(), lowercase, array.len())
Ok(ColumnarValue::Array(hex_encode_bytes(
array.iter(),
lowercase,
array.len(),
)?))
}
DataType::FixedSizeBinary(_) => {
let array = as_fixed_size_binary_array(array)?;
hex_encode_bytes(array.iter(), lowercase, array.len())
Ok(ColumnarValue::Array(hex_encode_bytes(
array.iter(),
lowercase,
array.len(),
)?))
}
DataType::Dictionary(_, value_type) => {
DataType::Dictionary(key_type, _) => {
if **key_type != DataType::Int32 {
return exec_err!(
"hex only supports Int32 dictionary keys, get: {}",
key_type
);
}

let dict = as_dictionary_array::<Int32Type>(&array);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should have some check that the dictionary has i32 key type, otherwise this will panic

let dict_values = dict.values();

match **value_type {
let encoded_values = match dict_values.data_type() {
DataType::Int64 => {
let arr = dict.downcast_dict::<Int64Array>().unwrap();
hex_encode_int64(arr.into_iter(), dict.len())
let arr = as_int64_array(dict_values)?;
hex_encode_int64(arr.iter(), arr.len())?
}
DataType::Utf8 => {
let arr = dict.downcast_dict::<StringArray>().unwrap();
hex_encode_bytes(arr.into_iter(), lowercase, dict.len())
let arr = as_string_array(dict_values);
hex_encode_bytes(arr.iter(), lowercase, arr.len())?
}
DataType::LargeUtf8 => {
let arr = as_largestring_array(dict_values);
hex_encode_bytes(arr.iter(), lowercase, arr.len())?
}
DataType::Utf8View => {
let arr = as_string_view_array(dict_values)?;
hex_encode_bytes(arr.iter(), lowercase, arr.len())?
}
DataType::Binary => {
let arr = dict.downcast_dict::<BinaryArray>().unwrap();
hex_encode_bytes(arr.into_iter(), lowercase, dict.len())
let arr = as_binary_array(dict_values)?;
hex_encode_bytes(arr.iter(), lowercase, arr.len())?
}
DataType::LargeBinary => {
let arr = as_large_binary_array(dict_values)?;
hex_encode_bytes(arr.iter(), lowercase, arr.len())?
}
DataType::FixedSizeBinary(_) => {
let arr = as_fixed_size_binary_array(dict_values)?;
hex_encode_bytes(arr.iter(), lowercase, arr.len())?
}
_ => {
exec_err!(
return exec_err!(
"hex got an unexpected argument type: {}",
array.data_type()
)
dict_values.data_type()
);
}
}
};

let new_dict = dict.with_values(encoded_values);
Ok(ColumnarValue::Array(Arc::new(new_dict)))
}
_ => exec_err!("hex got an unexpected argument type: {}", array.data_type()),
},
Expand All @@ -279,11 +335,12 @@ mod test {
use arrow::array::{DictionaryArray, Int32Array, Int64Array, StringArray};
use arrow::{
array::{
BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder,
StringDictionaryBuilder, as_string_array,
BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringDictionaryBuilder,
as_string_array,
},
datatypes::{Int32Type, Int64Type},
};
use datafusion_common::cast::as_dictionary_array;
use datafusion_expr::ColumnarValue;

#[test]
Expand All @@ -295,12 +352,12 @@ mod test {
input_builder.append_value("rust");
let input = input_builder.finish();

let mut string_builder = StringBuilder::new();
string_builder.append_value("6869");
string_builder.append_value("627965");
string_builder.append_null();
string_builder.append_value("72757374");
let expected = string_builder.finish();
let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
expected_builder.append_value("6869");
expected_builder.append_value("627965");
expected_builder.append_null();
expected_builder.append_value("72757374");
let expected = expected_builder.finish();

let columnar_value = ColumnarValue::Array(Arc::new(input));
let result = super::spark_hex(&[columnar_value]).unwrap();
Expand All @@ -310,7 +367,7 @@ mod test {
_ => panic!("Expected array"),
};

let result = as_string_array(&result);
let result = as_dictionary_array(&result).unwrap();

assert_eq!(result, &expected);
}
Expand All @@ -324,12 +381,12 @@ mod test {
input_builder.append_value(3);
let input = input_builder.finish();

let mut string_builder = StringBuilder::new();
string_builder.append_value("1");
string_builder.append_value("2");
string_builder.append_null();
string_builder.append_value("3");
let expected = string_builder.finish();
let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
expected_builder.append_value("1");
expected_builder.append_value("2");
expected_builder.append_null();
expected_builder.append_value("3");
let expected = expected_builder.finish();

let columnar_value = ColumnarValue::Array(Arc::new(input));
let result = super::spark_hex(&[columnar_value]).unwrap();
Expand All @@ -339,7 +396,7 @@ mod test {
_ => panic!("Expected array"),
};

let result = as_string_array(&result);
let result = as_dictionary_array(&result).unwrap();

assert_eq!(result, &expected);
}
Expand All @@ -353,7 +410,7 @@ mod test {
input_builder.append_value("3");
let input = input_builder.finish();

let mut expected_builder = StringBuilder::new();
let mut expected_builder = StringDictionaryBuilder::<Int32Type>::new();
expected_builder.append_value("31");
expected_builder.append_value("6A");
expected_builder.append_null();
Expand All @@ -368,7 +425,7 @@ mod test {
_ => panic!("Expected array"),
};

let result = as_string_array(&result);
let result = as_dictionary_array(&result).unwrap();

assert_eq!(result, &expected);
}
Expand Down Expand Up @@ -425,8 +482,11 @@ mod test {
_ => panic!("Expected array"),
};

let result = as_string_array(&result);
let expected = StringArray::from(vec![Some("20"), None, None]);
let result = as_dictionary_array(&result).unwrap();

let keys = Int32Array::from(vec![Some(0), None, Some(1)]);
let vals = StringArray::from(vec![Some("20"), None]);
let expected = DictionaryArray::new(keys, Arc::new(vals));

assert_eq!(&expected, result);
}
Expand Down
20 changes: 20 additions & 0 deletions datafusion/sqllogictest/test_files/spark/math/hex.slt
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,23 @@ query T
SELECT hex(arrow_cast('test', 'LargeBinary')) as lar_b;
----
74657374

statement ok
CREATE TABLE t_dict_binary AS
SELECT arrow_cast(column1, 'Dictionary(Int32, Binary)') as dict_col
FROM VALUES ('foo'), ('bar'), ('foo'), (NULL), ('baz'), ('bar');

query T
SELECT hex(dict_col) FROM t_dict_binary;
----
666F6F
626172
666F6F
NULL
62617A
626172

query T
SELECT arrow_typeof(hex(dict_col)) FROM t_dict_binary LIMIT 1;
----
Dictionary(Int32, Utf8)