Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 128 additions & 104 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ use std::any::Any;

use super::power::PowerFunc;

use crate::utils::{
calculate_binary_math, decimal32_to_i32, decimal64_to_i64, decimal128_to_i128,
};
use crate::utils::calculate_binary_math;
use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float16Type,
Expand All @@ -44,7 +42,7 @@ use datafusion_expr::{
};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
use num_traits::Float;
use num_traits::{Float, ToPrimitive};

#[user_doc(
doc_section(label = "Math Functions"),
Expand Down Expand Up @@ -104,109 +102,109 @@ impl LogFunc {
}
}

/// Binary function to calculate logarithm of Decimal32 `value` using `base` base
/// Returns error if base is invalid
fn log_decimal32(value: i32, scale: i8, base: f64) -> Result<f64, ArrowError> {
if !base.is_finite() || base.trunc() != base {
return Err(ArrowError::ComputeError(format!(
"Log cannot use non-integer base: {base}"
)));
}
if (base as u32) < 2 {
return Err(ArrowError::ComputeError(format!(
"Log base must be greater than 1: {base}"
)));
}

// Match f64::log behaviour
if value <= 0 {
return Ok(f64::NAN);
}
/// Checks if the base is valid for the efficient integer logarithm algorithm.
#[inline]
fn is_valid_integer_base(base: f64) -> bool {
base.trunc() == base && base >= 2.0 && base <= u32::MAX as f64
}

if scale < 0 {
let actual_value = (value as f64) * 10.0_f64.powi(-(scale as i32));
Ok(actual_value.log(base))
} else {
let unscaled_value = decimal32_to_i32(value, scale)?;
if unscaled_value <= 0 {
return Ok(f64::NAN);
}
let log_value: u32 = unscaled_value.ilog(base as i32);
Ok(log_value as f64)
/// Calculate logarithm for Decimal32 values.
/// For integer bases >= 2 with non-negative scale, uses the efficient u32 ilog algorithm.
/// Otherwise falls back to f64 computation.
fn log_decimal32(value: i32, scale: i8, base: f64) -> Result<f64, ArrowError> {
if is_valid_integer_base(base)
&& scale >= 0
&& let Some(unscaled) = unscale_to_u32(value, scale)
{
return if unscaled > 0 {
Ok(unscaled.ilog(base as u32) as f64)
} else {
Ok(f64::NAN)
};
}
decimal_to_f64(value, scale).map(|v| v.log(base))
}

/// Binary function to calculate logarithm of Decimal64 `value` using `base` base
/// Returns error if base is invalid
/// Calculate logarithm for Decimal64 values.
/// For integer bases >= 2 with non-negative scale, uses the efficient u64 ilog algorithm.
/// Otherwise falls back to f64 computation.
fn log_decimal64(value: i64, scale: i8, base: f64) -> Result<f64, ArrowError> {
if !base.is_finite() || base.trunc() != base {
return Err(ArrowError::ComputeError(format!(
"Log cannot use non-integer base: {base}"
)));
}
if (base as u32) < 2 {
return Err(ArrowError::ComputeError(format!(
"Log base must be greater than 1: {base}"
)));
if is_valid_integer_base(base)
&& scale >= 0
&& let Some(unscaled) = unscale_to_u64(value, scale)
{
return if unscaled > 0 {
Ok(unscaled.ilog(base as u64) as f64)
} else {
Ok(f64::NAN)
};
}
decimal_to_f64(value, scale).map(|v| v.log(base))
}

if value <= 0 {
return Ok(f64::NAN);
/// Calculate logarithm for Decimal128 values.
/// For integer bases >= 2 with non-negative scale, uses the efficient u128 ilog algorithm.
/// Otherwise falls back to f64 computation.
fn log_decimal128(value: i128, scale: i8, base: f64) -> Result<f64, ArrowError> {
if is_valid_integer_base(base)
&& scale >= 0
&& let Some(unscaled) = unscale_to_u128(value, scale)
{
return if unscaled > 0 {
Ok(unscaled.ilog(base as u128) as f64)
} else {
Ok(f64::NAN)
};
}
decimal_to_f64(value, scale).map(|v| v.log(base))
}

if scale < 0 {
let actual_value = (value as f64) * 10.0_f64.powi(-(scale as i32));
Ok(actual_value.log(base))
} else {
let unscaled_value = decimal64_to_i64(value, scale)?;
if unscaled_value <= 0 {
return Ok(f64::NAN);
}
let log_value: u32 = unscaled_value.ilog(base as i64);
Ok(log_value as f64)
}
/// Unscale a Decimal32 value to u32.
#[inline]
fn unscale_to_u32(value: i32, scale: i8) -> Option<u32> {
let value_u32 = u32::try_from(value).ok()?;
let divisor = 10u32.checked_pow(scale as u32)?;
Some(value_u32 / divisor)
}

/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base
/// Returns error if base is invalid
fn log_decimal128(value: i128, scale: i8, base: f64) -> Result<f64, ArrowError> {
if !base.is_finite() || base.trunc() != base {
return Err(ArrowError::ComputeError(format!(
"Log cannot use non-integer base: {base}"
)));
}
if (base as u32) < 2 {
return Err(ArrowError::ComputeError(format!(
"Log base must be greater than 1: {base}"
)));
}
/// Unscale a Decimal64 value to u64.
#[inline]
fn unscale_to_u64(value: i64, scale: i8) -> Option<u64> {
let value_u64 = u64::try_from(value).ok()?;
let divisor = 10u64.checked_pow(scale as u32)?;
Some(value_u64 / divisor)
}

if value <= 0 {
// Reflect f64::log behaviour
return Ok(f64::NAN);
}
/// Unscale a Decimal128 value to u128.
#[inline]
fn unscale_to_u128(value: i128, scale: i8) -> Option<u128> {
let value_u128 = u128::try_from(value).ok()?;
let divisor = 10u128.checked_pow(scale as u32)?;
Some(value_u128 / divisor)
}

if scale < 0 {
let actual_value = (value as f64) * 10.0_f64.powi(-(scale as i32));
Ok(actual_value.log(base))
} else {
let unscaled_value = decimal128_to_i128(value, scale)?;
if unscaled_value <= 0 {
return Ok(f64::NAN);
}
let log_value: u32 = unscaled_value.ilog(base as i128);
Ok(log_value as f64)
}
/// Convert a scaled decimal value to f64.
#[inline]
fn decimal_to_f64<T: ToPrimitive + Copy>(value: T, scale: i8) -> Result<f64, ArrowError> {
let value_f64 = value.to_f64().ok_or_else(|| {
ArrowError::ComputeError("Cannot convert value to f64".to_string())
})?;
let scale_factor = 10f64.powi(scale as i32);
Ok(value_f64 / scale_factor)
}

/// Binary function to calculate an integer logarithm of Decimal128 `value` using `base` base
/// Returns error if base is invalid or if value is out of bounds of Decimal128
fn log_decimal256(value: i256, scale: i8, base: f64) -> Result<f64, ArrowError> {
// Try to convert to i128 for the optimized path
match value.to_i128() {
Some(value) => log_decimal128(value, scale, base),
None => Err(ArrowError::NotYetImplemented(format!(
"Log of Decimal256 larger than Decimal128 is not yet supported: {value}"
))),
Some(v) => log_decimal128(v, scale, base),
None => {
// For very large Decimal256 values, use f64 computation
let value_f64 = value.to_f64().ok_or_else(|| {
ArrowError::ComputeError(format!("Cannot convert {value} to f64"))
})?;
let scale_factor = 10f64.powi(scale as i32);
Ok((value_f64 / scale_factor).log(base))
}
}
}

Expand Down Expand Up @@ -1169,7 +1167,8 @@ mod tests {
}

#[test]
fn test_log_decimal128_wrong_base() {
fn test_log_decimal128_invalid_base() {
// Invalid base (-2.0) should return NaN, matching f64::log behavior
let arg_fields = vec![
Field::new("b", DataType::Float64, false).into(),
Field::new("x", DataType::Decimal128(38, 0), false).into(),
Expand All @@ -1184,16 +1183,26 @@ mod tests {
return_field: Field::new("f", DataType::Float64, true).into(),
config_options: Arc::new(ConfigOptions::default()),
};
let result = LogFunc::new().invoke_with_args(args);
assert!(result.is_err());
assert_eq!(
"Arrow error: Compute error: Log base must be greater than 1: -2",
result.unwrap_err().to_string().lines().next().unwrap()
);
let result = LogFunc::new()
.invoke_with_args(args)
.expect("should not error on invalid base");

match result {
ColumnarValue::Array(arr) => {
let floats = as_float64_array(&arr)
.expect("failed to convert result to a Float64Array");
assert_eq!(floats.len(), 1);
assert!(floats.value(0).is_nan());
}
ColumnarValue::Scalar(_) => {
panic!("Expected an array value")
}
}
}

#[test]
fn test_log_decimal256_error() {
fn test_log_decimal256_large() {
// Large Decimal256 values that don't fit in i128 now use f64 fallback
let arg_field = Field::new("a", DataType::Decimal256(38, 0), false).into();
let args = ScalarFunctionArgs {
args: vec![
Expand All @@ -1207,11 +1216,26 @@ mod tests {
return_field: Field::new("f", DataType::Float64, true).into(),
config_options: Arc::new(ConfigOptions::default()),
};
let result = LogFunc::new().invoke_with_args(args);
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string().lines().next().unwrap(),
"Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported: 170141183460469231731687303715884106727"
);
let result = LogFunc::new()
.invoke_with_args(args)
.expect("should handle large Decimal256 via f64 fallback");

match result {
ColumnarValue::Array(arr) => {
let floats = as_float64_array(&arr)
.expect("failed to convert result to a Float64Array");
assert_eq!(floats.len(), 1);
// The f64 fallback may lose some precision for very large numbers,
// but we verify we get a reasonable positive result (not NaN/infinity)
let log_result = floats.value(0);
assert!(
log_result.is_finite() && log_result > 0.0,
"Expected positive finite log result, got {log_result}"
);
}
ColumnarValue::Scalar(_) => {
panic!("Expected an array value")
}
}
}
}
12 changes: 10 additions & 2 deletions datafusion/sqllogictest/test_files/decimal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -868,9 +868,11 @@ select log(100000000000000000000000000000000000::decimal(76,0));
----
35

# log(10^50) for decimal256 for a value larger than i128
query error Arrow error: Not yet implemented: Log of Decimal256 larger than Decimal128 is not yet supported
# log(10^50) for decimal256 for a value larger than i128 (uses f64 fallback)
query R
select log(100000000000000000000000000000000000000000000000000::decimal(76,0));
----
50

# log(10^35) for decimal128 with explicit base
query R
Expand Down Expand Up @@ -904,6 +906,12 @@ select log(2.0, 100000000000000000000000000000000000::decimal(38,0));
----
116

# log with non-integer base (fallback to f64)
query R
select log(2.5, 100::decimal(38,0));
----
5.025883189464

# null cases
query R
select log(null, 100);
Expand Down