Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions datafusion-examples/examples/udf/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,16 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
let prods = emit_to.take_needed(&mut self.prods);
let nulls = self.null_state.build(emit_to);

assert_eq!(nulls.len(), prods.len());
if let Some(nulls) = &nulls {
assert_eq!(nulls.len(), counts.len());
}
assert_eq!(counts.len(), prods.len());

// don't evaluate geometric mean with null inputs to avoid errors on null values

let array: PrimitiveArray<Float64Type> = if nulls.null_count() > 0 {
let array: PrimitiveArray<Float64Type> = if let Some(nulls) = &nulls
&& nulls.null_count() > 0
{
let mut builder = PrimitiveBuilder::<Float64Type>::with_capacity(nulls.len());
let iter = prods.into_iter().zip(counts).zip(nulls.iter());

Expand All @@ -337,7 +341,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
.zip(counts)
.map(|(prod, count)| prod.powf(1.0 / count as f64))
.collect::<Vec<_>>();
PrimitiveArray::new(geo_mean.into(), Some(nulls)) // no copy
PrimitiveArray::new(geo_mean.into(), nulls) // no copy
.with_data_type(self.return_data_type.clone())
};

Expand All @@ -347,7 +351,6 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {
// return arrays for counts and prods
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let nulls = self.null_state.build(emit_to);
let nulls = Some(nulls);

let counts = emit_to.take_needed(&mut self.counts);
let counts = UInt32Array::new(counts.into(), nulls.clone()); // zero copy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,70 @@
//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator

use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::buffer::NullBuffer;
use arrow::datatypes::ArrowPrimitiveType;

use datafusion_expr_common::groups_accumulator::EmitTo;

/// If the input has nulls, then the accumulator must potentially
/// handle each input null value specially (e.g. for `SUM` to mark the
/// corresponding sum as null)
///
/// If there are filters present, `NullState` tracks if it has seen
/// *any* value for that group (as some values may be filtered
/// out). Without a filter, the accumulator is only passed groups that
/// had at least one value to accumulate so they do not need to track
/// if they have seen values for a particular group.
#[derive(Debug)]
pub enum SeenValues {
/// All groups seen so far have seen at least one non-null value
All {
num_values: usize,
},
// Some groups have not yet seen a non-null value
Some {
values: BooleanBufferBuilder,
},
}

impl Default for SeenValues {
fn default() -> Self {
SeenValues::All { num_values: 0 }
}
}

impl SeenValues {
/// Return a mutable reference to the `BooleanBufferBuilder` in `SeenValues::Some`.
///
/// If `self` is `SeenValues::All`, it is transitioned to `SeenValues::Some`
/// by creating a new `BooleanBufferBuilder` where the first `num_values` are true.
///
/// The builder is then ensured to have at least `total_num_groups` length,
/// with any new entries initialized to false.
fn get_builder(&mut self, total_num_groups: usize) -> &mut BooleanBufferBuilder {
match self {
SeenValues::All { num_values } => {
let mut builder = BooleanBufferBuilder::new(total_num_groups);
builder.append_n(*num_values, true);
if total_num_groups > *num_values {
builder.append_n(total_num_groups - *num_values, false);
}
*self = SeenValues::Some { values: builder };
match self {
SeenValues::Some { values } => values,
_ => unreachable!(),
}
}
SeenValues::Some { values } => {
if values.len() < total_num_groups {
values.append_n(total_num_groups - values.len(), false);
}
values
}
}
}
}

/// Track the accumulator null state per row: if any values for that
/// group were null and if any values have been seen at all for that group.
///
Expand Down Expand Up @@ -53,12 +113,14 @@ use datafusion_expr_common::groups_accumulator::EmitTo;
pub struct NullState {
/// Have we seen any non-filtered input values for `group_index`?
///
/// If `seen_values[i]` is true, have seen at least one non null
/// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is true, have seen at least one non null
/// value for group `i`
///
/// If `seen_values[i]` is false, have not seen any values that
/// If `seen_values` is `SeenValues::Some(buffer)` and buffer\[i\] is false, have not seen any values that
/// pass the filter yet for group `i`
seen_values: BooleanBufferBuilder,
///
/// If `seen_values` is `SeenValues::All`, all groups have seen at least one non null value
seen_values: SeenValues,
}

impl Default for NullState {
Expand All @@ -70,14 +132,16 @@ impl Default for NullState {
impl NullState {
pub fn new() -> Self {
Self {
seen_values: BooleanBufferBuilder::new(0),
seen_values: SeenValues::All { num_values: 0 },
}
}

/// return the size of all buffers allocated by this null state, not including self
pub fn size(&self) -> usize {
// capacity is in bits, so convert to bytes
self.seen_values.capacity() / 8
match &self.seen_values {
SeenValues::All { .. } => 0,
SeenValues::Some { values } => values.capacity() / 8,
}
}

/// Invokes `value_fn(group_index, value)` for each non null, non
Expand Down Expand Up @@ -107,10 +171,17 @@ impl NullState {
T: ArrowPrimitiveType + Send,
F: FnMut(usize, T::Native) + Send,
{
// ensure the seen_values is big enough (start everything at
// "not seen" valid)
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);
// skip null handling if no nulls in input or accumulator
if let SeenValues::All { num_values } = &mut self.seen_values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One possibility is to make this another another function to reduce some duplication and have a place it could be explained

Maybe something like

if let Some(num_values) = self.all_values_mut() && opt_filter.is_none && values.null_count == 0 {
...
}

?

Though maybe an extra level of indirection would make it harder to follow.

&& opt_filter.is_none()
&& values.null_count() == 0
{
accumulate(group_indices, values, None, value_fn);
*num_values = total_num_groups;
return;
}

let seen_values = self.seen_values.get_builder(total_num_groups);
accumulate(group_indices, values, opt_filter, |group_index, value| {
seen_values.set_bit(group_index, true);
value_fn(group_index, value);
Expand Down Expand Up @@ -140,10 +211,21 @@ impl NullState {
let data = values.values();
assert_eq!(data.len(), group_indices.len());

// ensure the seen_values is big enough (start everything at
// "not seen" valid)
let seen_values =
initialize_builder(&mut self.seen_values, total_num_groups, false);
// skip null handling if no nulls in input or accumulator
if let SeenValues::All { num_values } = &mut self.seen_values
&& opt_filter.is_none()
&& values.null_count() == 0
{
group_indices
.iter()
.zip(data.iter())
.for_each(|(&group_index, new_value)| value_fn(group_index, new_value));
*num_values = total_num_groups;

return;
}

let seen_values = self.seen_values.get_builder(total_num_groups);

// These could be made more performant by iterating in chunks of 64 bits at a time
match (values.null_count() > 0, opt_filter) {
Expand Down Expand Up @@ -211,21 +293,39 @@ impl NullState {
/// for the `emit_to` rows.
///
/// resets the internal state appropriately
pub fn build(&mut self, emit_to: EmitTo) -> NullBuffer {
let nulls: BooleanBuffer = self.seen_values.finish();

let nulls = match emit_to {
EmitTo::All => nulls,
EmitTo::First(n) => {
// split off the first N values in seen_values
let first_n_null: BooleanBuffer = nulls.slice(0, n);
// reset the existing seen buffer
self.seen_values
.append_buffer(&nulls.slice(n, nulls.len() - n));
first_n_null
pub fn build(&mut self, emit_to: EmitTo) -> Option<NullBuffer> {
match emit_to {
EmitTo::All => {
let old_seen = std::mem::take(&mut self.seen_values);
match old_seen {
SeenValues::All { .. } => None,
SeenValues::Some { mut values } => {
Some(NullBuffer::new(values.finish()))
}
}
}
};
NullBuffer::new(nulls)
EmitTo::First(n) => match &mut self.seen_values {
SeenValues::All { num_values } => {
*num_values = num_values.saturating_sub(n);
None
}
SeenValues::Some { .. } => {
let mut old_values = match std::mem::take(&mut self.seen_values) {
SeenValues::Some { values } => values,
_ => unreachable!(),
};
let nulls = old_values.finish();
let first_n_null = nulls.slice(0, n);
let remainder = nulls.slice(n, nulls.len() - n);
let mut new_builder = BooleanBufferBuilder::new(remainder.len());
new_builder.append_buffer(&remainder);
self.seen_values = SeenValues::Some {
values: new_builder,
};
Some(NullBuffer::new(first_n_null))
}
},
}
}
}

Expand Down Expand Up @@ -573,27 +673,14 @@ pub fn accumulate_indices<F>(
}
}

/// Ensures that `builder` contains a `BooleanBufferBuilder with at
/// least `total_num_groups`.
///
/// All new entries are initialized to `default_value`
fn initialize_builder(
builder: &mut BooleanBufferBuilder,
total_num_groups: usize,
default_value: bool,
) -> &mut BooleanBufferBuilder {
if builder.len() < total_num_groups {
let new_groups = total_num_groups - builder.len();
builder.append_n(new_groups, default_value);
}
builder
}

#[cfg(test)]
mod test {
use super::*;

use arrow::array::{Int32Array, UInt32Array};
use arrow::{
array::{Int32Array, UInt32Array},
buffer::BooleanBuffer,
};
use rand::{Rng, rngs::ThreadRng};
use std::collections::HashSet;

Expand Down Expand Up @@ -834,15 +921,24 @@ mod test {
accumulated_values, expected_values,
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
);
let seen_values = null_state.seen_values.finish_cloned();
mock.validate_seen_values(&seen_values);

match &null_state.seen_values {
SeenValues::All { num_values } => {
assert_eq!(*num_values, total_num_groups);
}
SeenValues::Some { values } => {
let seen_values = values.finish_cloned();
mock.validate_seen_values(&seen_values);
}
}

// Validate the final buffer (one value per group)
let expected_null_buffer = mock.expected_null_buffer(total_num_groups);

let null_buffer = null_state.build(EmitTo::All);

assert_eq!(null_buffer, expected_null_buffer);
if let Some(nulls) = &null_buffer {
assert_eq!(*nulls, expected_null_buffer);
}
}

// Calls `accumulate_indices`
Expand Down Expand Up @@ -955,15 +1051,25 @@ mod test {
"\n\naccumulated_values:{accumulated_values:#?}\n\nexpected_values:{expected_values:#?}"
);

let seen_values = null_state.seen_values.finish_cloned();
mock.validate_seen_values(&seen_values);
match &null_state.seen_values {
SeenValues::All { num_values } => {
assert_eq!(*num_values, total_num_groups);
}
SeenValues::Some { values } => {
let seen_values = values.finish_cloned();
mock.validate_seen_values(&seen_values);
}
}

// Validate the final buffer (one value per group)
let expected_null_buffer = mock.expected_null_buffer(total_num_groups);
let expected_null_buffer = Some(mock.expected_null_buffer(total_num_groups));

let is_all_seen = matches!(null_state.seen_values, SeenValues::All { .. });
let null_buffer = null_state.build(EmitTo::All);

assert_eq!(null_buffer, expected_null_buffer);
if !is_all_seen {
assert_eq!(null_buffer, expected_null_buffer);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ where
};

let nulls = self.null_state.build(emit_to);
let values = BooleanArray::new(values, Some(nulls));
let values = BooleanArray::new(values, nulls);
Ok(Arc::new(values))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ where
opt_filter,
total_num_groups,
|group_index, new_value| {
let value = &mut self.values[group_index];
// SAFETY: group_index is guaranteed to be in bounds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend adding safety notes to the docs of GroupsAccumulator in https://github.com/apache/datafusion/blob/36ec9f1de0aeabca60b8f7ebe07d650b8ef03506/datafusion/expr-common/src/groups_accumulator.rs#L114-L113

That explains that all group indexes are guaranteed to be <= total_num_groups and that can be relied on for safety

let value = unsafe { self.values.get_unchecked_mut(group_index) };
(self.prim_fn)(value, new_value);
},
);
Expand All @@ -117,7 +118,7 @@ where
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
let values = emit_to.take_needed(&mut self.values);
let nulls = self.null_state.build(emit_to);
let values = PrimitiveArray::<T>::new(values.into(), Some(nulls)) // no copy
let values = PrimitiveArray::<T>::new(values.into(), nulls) // no copy
.with_data_type(self.data_type.clone());
Ok(Arc::new(values))
}
Expand Down
Loading
Loading