diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 466c17b789136..8a5c68a5d4e4b 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -92,3 +92,7 @@ name = "substring" [[bench]] harness = false name = "unhex" + +[[bench]] +harness = false +name = "sha2" diff --git a/datafusion/spark/benches/sha2.rs b/datafusion/spark/benches/sha2.rs new file mode 100644 index 0000000000000..3fa2220f158fd --- /dev/null +++ b/datafusion/spark/benches/sha2.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::*; +use arrow::datatypes::*; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion_spark::function::hash::sha2::SparkSha2; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint::black_box; +use std::sync::Arc; + +fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +fn generate_binary_data(size: usize, null_density: f32) -> BinaryArray { + let mut rng = seedable_rng(); + let mut builder = BinaryBuilder::new(); + for _ in 0..size { + if rng.random::() < null_density { + builder.append_null(); + } else { + let len = rng.random_range::(1..=100); + let bytes: Vec = (0..len).map(|_| rng.random()).collect(); + builder.append_value(&bytes); + } + } + builder.finish() +} + +fn run_benchmark(c: &mut Criterion, name: &str, size: usize, args: &[ColumnarValue]) { + let sha2_func = SparkSha2::new(); + let arg_fields: Vec<_> = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true).into()) + .collect(); + let config_options = Arc::new(ConfigOptions::default()); + + c.bench_function(&format!("{name}/size={size}"), |b| { + b.iter(|| { + black_box( + sha2_func + .invoke_with_args(ScalarFunctionArgs { + args: args.to_vec(), + arg_fields: arg_fields.clone(), + number_rows: size, + return_field: Arc::new(Field::new("f", DataType::Utf8, true)), + config_options: Arc::clone(&config_options), + }) + .unwrap(), + ) + }) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + // Scalar benchmark (avoid array expansion) + let scalar_args = vec![ + ColumnarValue::Scalar(ScalarValue::Binary(Some(b"Spark".to_vec()))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(256))), + ]; + run_benchmark(c, "sha2/scalar", 1, &scalar_args); + + let sizes = vec![1024, 4096, 8192]; + let null_density = 0.1; + + for &size in &sizes { + let values = generate_binary_data(size, null_density); + let bit_lengths = Int32Array::from(vec![256; size]); + let array_args = vec![ + ColumnarValue::Array(Arc::new(values)), + ColumnarValue::Array(Arc::new(bit_lengths)), + ]; + run_benchmark(c, "sha2/array_binary_256", size, &array_args); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/src/function/hash/sha2.rs b/datafusion/spark/src/function/hash/sha2.rs index a7ce5d7eb0ae0..0ffade4308e80 100644 --- a/datafusion/spark/src/function/hash/sha2.rs +++ b/datafusion/spark/src/function/hash/sha2.rs @@ -15,13 +15,15 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray}; +use arrow::array::{ + ArrayRef, AsArray, BinaryArrayType, Int32Array, StringArray, new_null_array, +}; use arrow::datatypes::{DataType, Int32Type}; use datafusion_common::types::{ NativeType, logical_binary, logical_int32, logical_string, }; use datafusion_common::utils::take_function_args; -use datafusion_common::{Result, internal_err}; +use datafusion_common::{Result, ScalarValue, internal_err}; use datafusion_expr::{ Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, @@ -29,7 +31,6 @@ use datafusion_expr::{ use datafusion_functions::utils::make_scalar_function; use sha2::{self, Digest}; use std::any::Any; -use std::fmt::Write; use std::sync::Arc; /// Differs from DataFusion version in allowing array input for bit lengths, and @@ -87,7 +88,98 @@ impl ScalarUDFImpl for SparkSha2 { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - make_scalar_function(sha2_impl, vec![])(&args.args) + let values = &args.args[0]; + let bit_lengths = &args.args[1]; + + match (values, bit_lengths) { + ( + ColumnarValue::Scalar(value_scalar), + ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))), + ) => { + if value_scalar.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + + // Accept both Binary and Utf8 scalars (depending on coercion) + let bytes = match value_scalar { + ScalarValue::Binary(Some(b)) => b.as_slice(), + ScalarValue::LargeBinary(Some(b)) => b.as_slice(), + ScalarValue::BinaryView(Some(b)) => b.as_slice(), + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.as_bytes(), + other => { + return internal_err!( + "Unsupported scalar datatype for sha2: {}", + other.data_type() + ); + } + }; + + let out = match bit_length { + 224 => { + let mut digest = sha2::Sha224::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) + } + 0 | 256 => { + let mut digest = sha2::Sha256::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) + } + 384 => { + let mut digest = sha2::Sha384::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) + } + 512 => { + let mut digest = sha2::Sha512::default(); + digest.update(bytes); + Some(hex_encode(digest.finalize())) + } + _ => None, + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(out))) + } + // Array values + scalar bit length (common case: sha2(col, 256)) + ( + ColumnarValue::Array(values_array), + ColumnarValue::Scalar(ScalarValue::Int32(Some(bit_length))), + ) => { + let output: ArrayRef = match values_array.data_type() { + DataType::Binary => sha2_binary_scalar_bitlen( + &values_array.as_binary::(), + *bit_length, + ), + DataType::LargeBinary => sha2_binary_scalar_bitlen( + &values_array.as_binary::(), + *bit_length, + ), + DataType::BinaryView => sha2_binary_scalar_bitlen( + &values_array.as_binary_view(), + *bit_length, + ), + dt => return internal_err!("Unsupported datatype for sha2: {dt}"), + }; + Ok(ColumnarValue::Array(output)) + } + ( + ColumnarValue::Scalar(_), + ColumnarValue::Scalar(ScalarValue::Int32(None)), + ) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + ( + ColumnarValue::Array(_), + ColumnarValue::Scalar(ScalarValue::Int32(None)), + ) => Ok(ColumnarValue::Array(new_null_array( + &DataType::Utf8, + args.number_rows, + ))), + _ => { + // Fallback to existing behavior for any array/mixed cases + make_scalar_function(sha2_impl, vec![])(&args.args) + } + } } } @@ -112,10 +204,31 @@ fn sha2_binary_impl<'a, BinaryArrType>( ) -> ArrayRef where BinaryArrType: BinaryArrayType<'a>, +{ + sha2_binary_bitlen_iter(values, bit_lengths.iter()) +} + +fn sha2_binary_scalar_bitlen<'a, BinaryArrType>( + values: &BinaryArrType, + bit_length: i32, +) -> ArrayRef +where + BinaryArrType: BinaryArrayType<'a>, +{ + sha2_binary_bitlen_iter(values, std::iter::repeat(Some(bit_length))) +} + +fn sha2_binary_bitlen_iter<'a, BinaryArrType, I>( + values: &BinaryArrType, + bit_lengths: I, +) -> ArrayRef +where + BinaryArrType: BinaryArrayType<'a>, + I: Iterator>, { let array = values .iter() - .zip(bit_lengths.iter()) + .zip(bit_lengths) .map(|(value, bit_length)| match (value, bit_length) { (Some(value), Some(224)) => { let mut digest = sha2::Sha224::default(); @@ -144,11 +257,18 @@ where Arc::new(array) } +const HEX_CHARS: [u8; 16] = *b"0123456789abcdef"; + +#[inline] fn hex_encode>(data: T) -> String { - let mut s = String::with_capacity(data.as_ref().len() * 2); - for b in data.as_ref() { - // Writing to a string never errors, so we can unwrap here. - write!(&mut s, "{b:02x}").unwrap(); + let bytes = data.as_ref(); + let mut out = Vec::with_capacity(bytes.len() * 2); + for &b in bytes { + let hi = b >> 4; + let lo = b & 0x0F; + out.push(HEX_CHARS[hi as usize]); + out.push(HEX_CHARS[lo as usize]); } - s + // SAFETY: out contains only ASCII + unsafe { String::from_utf8_unchecked(out) } }