diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 5ceeee57b0be4..8a99ddc5383bd 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -294,3 +294,8 @@ required-features = ["unicode_expressions"] harness = false name = "levenshtein" required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "factorial" +required-features = ["math_expressions"] diff --git a/datafusion/functions/benches/factorial.rs b/datafusion/functions/benches/factorial.rs new file mode 100644 index 0000000000000..5c5ff991d7453 --- /dev/null +++ b/datafusion/functions/benches/factorial.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::Int64Array; +use arrow::datatypes::{DataType, Field}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_common::ScalarValue; +use datafusion_common::config::ConfigOptions; +use datafusion_expr::ScalarFunctionArgs; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_functions::math::factorial; +use std::hint::black_box; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let factorial = factorial(); + let config_options = Arc::new(ConfigOptions::default()); + + let arr_args = vec![ColumnarValue::Array(Arc::new(Int64Array::from_iter( + (0..1024).map(|i| Some(i % 21)), + )))]; + c.bench_function(&format!("{}_array", factorial.name()), |b| { + b.iter(|| { + let args_cloned = arr_args.clone(); + black_box(factorial.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Utf8, true).into()], + number_rows: arr_args.len(), + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); + + let scalar_args = vec![ColumnarValue::Scalar(ScalarValue::Int64(Some(20)))]; + c.bench_function(&format!("{}_scalar", factorial.name()), |b| { + b.iter(|| { + let args_cloned = scalar_args.clone(); + black_box(factorial.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + arg_fields: vec![Field::new("a", DataType::Utf8, true).into()], + number_rows: 1, + return_field: Field::new("f", DataType::Utf8, true).into(), + config_options: Arc::clone(&config_options), + })) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index 4d651cf77d534..99b5304d07b00 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -16,17 +16,17 @@ // under the License. use arrow::{ - array::{ArrayRef, Int64Array}, + array::{ArrayRef, AsArray, Int64Array}, error::ArrowError, }; use std::any::Any; use std::sync::Arc; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; +use arrow::datatypes::{DataType, Int64Type}; use crate::utils::make_scalar_function; -use datafusion_common::{Result, arrow_datafusion_err, exec_err}; +use datafusion_common::{Result, exec_err}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -92,26 +92,47 @@ impl ScalarUDFImpl for FactorialFunc { } } +const FACTORIALS: [i64; 21] = [ + 1, + 1, + 2, + 6, + 24, + 120, + 720, + 5040, + 40320, + 362880, + 3628800, + 39916800, + 479001600, + 6227020800, + 87178291200, + 1307674368000, + 20922789888000, + 355687428096000, + 6402373705728000, + 121645100408832000, + 2432902008176640000, +]; + /// Factorial SQL function fn factorial(args: &[ArrayRef]) -> Result { match args[0].data_type() { Int64 => { - let arg = downcast_named_arg!((&args[0]), "value", Int64Array); - Ok(arg - .iter() - .map(|a| match a { - Some(a) => (2..=a) - .try_fold(1i64, i64::checked_mul) - .ok_or_else(|| { - arrow_datafusion_err!(ArrowError::ComputeError(format!( - "Overflow happened on FACTORIAL({a})" - ))) - }) - .map(Some), - _ => Ok(None), - }) - .collect::>() - .map(Arc::new)? as ArrayRef) + let result: Int64Array = + args[0].as_primitive::().try_unary(|a| { + if a < 0 { + Ok(1) + } else if a < FACTORIALS.len() as i64 { + Ok(FACTORIALS[a as usize]) + } else { + Err(ArrowError::ComputeError(format!( + "Overflow happened on FACTORIAL({a})" + ))) + } + })?; + Ok(Arc::new(result) as ArrayRef) } other => exec_err!("Unsupported data type {other:?} for function factorial."), } @@ -119,23 +140,31 @@ fn factorial(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { - - use datafusion_common::cast::as_int64_array; - use super::*; + use datafusion_common::cast::as_int64_array; #[test] fn test_factorial_i64() { let args: Vec = vec![ - Arc::new(Int64Array::from(vec![0, 1, 2, 4])), // input + Arc::new(Int64Array::from(vec![0, 1, 2, 4, 20, -1])), // input ]; let result = factorial(&args).expect("failed to initialize function factorial"); let ints = as_int64_array(&result).expect("failed to initialize function factorial"); - let expected = Int64Array::from(vec![1, 1, 2, 24]); + let expected = Int64Array::from(vec![1, 1, 2, 24, 2432902008176640000, 1]); assert_eq!(ints, &expected); } + + #[test] + fn test_overflow() { + let args: Vec = vec![ + Arc::new(Int64Array::from(vec![21])), // input + ]; + + let result = factorial(&args); + assert!(result.is_err()); + } }