Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(arrow-ord): support boolean in rank and add tests for sorting lists of booleans #6912

Merged
merged 6 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 163 additions & 2 deletions arrow-ord/src/rank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::{downcast_primitive_array, Array, ArrowNativeTypeOp, GenericByteArray};
use arrow_array::{
downcast_primitive_array, Array, ArrowNativeTypeOp, BooleanArray, GenericByteArray,
};
use arrow_buffer::NullBuffer;
use arrow_schema::{ArrowError, DataType, SortOptions};
use std::cmp::Ordering;
Expand All @@ -29,7 +31,11 @@ pub(crate) fn can_rank(data_type: &DataType) -> bool {
data_type.is_primitive()
|| matches!(
data_type,
DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary
DataType::Boolean
| DataType::Utf8
| DataType::LargeUtf8
| DataType::Binary
| DataType::LargeBinary
)
}

Expand All @@ -49,6 +55,7 @@ pub fn rank(array: &dyn Array, options: Option<SortOptions>) -> Result<Vec<u32>,
let options = options.unwrap_or_default();
let ranks = downcast_primitive_array! {
array => primitive_rank(array.values(), array.nulls(), options),
DataType::Boolean => boolean_rank(array.as_boolean(), options),
DataType::Utf8 => bytes_rank(array.as_bytes::<Utf8Type>(), options),
DataType::LargeUtf8 => bytes_rank(array.as_bytes::<LargeUtf8Type>(), options),
DataType::Binary => bytes_rank(array.as_bytes::<BinaryType>(), options),
Expand Down Expand Up @@ -135,6 +142,84 @@ where
out
}

/// Return the index for the rank when ranking boolean array
///
/// The index is calculated as follows:
/// if is_null is true, the index is 2
/// if is_null is false and the value is true, the index is 1
/// otherwise, the index is 0
///
/// false is 0 and true is 1 because these are the value when cast to number
#[inline]
fn get_boolean_rank_index(value: bool, is_null: bool) -> usize {
let is_null_num = is_null as usize;
(is_null_num << 1) | (value as usize & !is_null_num)
}

#[inline(never)]
fn boolean_rank(array: &BooleanArray, options: SortOptions) -> Vec<u32> {
let null_count = array.null_count() as u32;
let true_count = array.true_count() as u32;
let false_count = array.len() as u32 - null_count - true_count;

// Rank values for [false, true, null] in that order
//
// The value for a rank is last value rank + own value count
// this means that if we have the following order: `false`, `true` and then `null`
// the ranks will be:
// - false: false_count
// - true: false_count + true_count
// - null: false_count + true_count + null_count
//
// If we have the following order: `null`, `false` and then `true`
// the ranks will be:
// - false: null_count + false_count
// - true: null_count + false_count + true_count
// - null: null_count
//
// You will notice that the last rank is always the total length of the array but we don't use it for readability on how the rank is calculated
let ranks_index: [u32; 3] = match (options.descending, options.nulls_first) {
// The order is null, true, false
(true, true) => [
null_count + true_count + false_count,
null_count + true_count,
null_count,
],
// The order is true, false, null
(true, false) => [
true_count + false_count,
true_count,
true_count + false_count + null_count,
],
// The order is null, false, true
(false, true) => [
null_count + false_count,
null_count + false_count + true_count,
null_count,
],
// The order is false, true, null
(false, false) => [
false_count,
false_count + true_count,
false_count + true_count + null_count,
],
};

match array.nulls().filter(|n| n.null_count() > 0) {
Some(n) => array
.values()
.iter()
.zip(n.iter())
.map(|(value, is_valid)| ranks_index[get_boolean_rank_index(value, !is_valid)])
.collect::<Vec<u32>>(),
None => array
.values()
.iter()
.map(|value| ranks_index[value as usize])
.collect::<Vec<u32>>(),
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -177,6 +262,82 @@ mod tests {
assert_eq!(res, &[4, 6, 3, 6, 3, 3]);
}

#[test]
fn test_get_boolean_rank_index() {
assert_eq!(get_boolean_rank_index(true, true), 2);
assert_eq!(get_boolean_rank_index(false, true), 2);
assert_eq!(get_boolean_rank_index(true, false), 1);
assert_eq!(get_boolean_rank_index(false, false), 0);
}

#[test]
fn test_nullable_booleans() {
let descending = SortOptions {
descending: true,
nulls_first: true,
};

let nulls_last = SortOptions {
descending: false,
nulls_first: false,
};

let nulls_last_descending = SortOptions {
descending: true,
nulls_first: false,
};

let a = BooleanArray::from(vec![Some(true), Some(true), None, Some(false), Some(false)]);
let res = rank(&a, None).unwrap();
assert_eq!(res, &[5, 5, 1, 3, 3]);

let res = rank(&a, Some(descending)).unwrap();
assert_eq!(res, &[3, 3, 1, 5, 5]);

let res = rank(&a, Some(nulls_last)).unwrap();
assert_eq!(res, &[4, 4, 5, 2, 2]);

let res = rank(&a, Some(nulls_last_descending)).unwrap();
assert_eq!(res, &[2, 2, 5, 4, 4]);

// Test with non-zero null values
let nulls = NullBuffer::from(vec![true, true, false, true, true]);
let a = BooleanArray::new(vec![true, true, true, false, false].into(), Some(nulls));
let res = rank(&a, None).unwrap();
assert_eq!(res, &[5, 5, 1, 3, 3]);
}

#[test]
fn test_booleans() {
let descending = SortOptions {
descending: true,
nulls_first: true,
};

let nulls_last = SortOptions {
descending: false,
nulls_first: false,
};

let nulls_last_descending = SortOptions {
descending: true,
nulls_first: false,
};

let a = BooleanArray::from(vec![true, false, false, false, true]);
let res = rank(&a, None).unwrap();
assert_eq!(res, &[5, 3, 3, 3, 5]);

let res = rank(&a, Some(descending)).unwrap();
assert_eq!(res, &[2, 5, 5, 5, 2]);

let res = rank(&a, Some(nulls_last)).unwrap();
assert_eq!(res, &[5, 3, 3, 3, 5]);

let res = rank(&a, Some(nulls_last_descending)).unwrap();
assert_eq!(res, &[2, 5, 5, 5, 2]);
}

#[test]
fn test_bytes() {
let v = vec!["foo", "fo", "bar", "bar"];
Expand Down
Loading
Loading