Skip to content

Pairwise summation #577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 43 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
483b29a
Add function declaration for pairwise_sum
LukeMathWalker Jan 4, 2019
75860b6
Base case: array with 512 elements
LukeMathWalker Jan 4, 2019
b8304c0
Base case: use unrolled_fold
LukeMathWalker Jan 4, 2019
ec3722a
Implemented algorithm for not-base case branch
LukeMathWalker Jan 4, 2019
70d32b7
Implemented pairwise summation algorithm for an iterator parameter
LukeMathWalker Jan 4, 2019
c66dc98
Implemented pairwise summation algorithm for an iterator parameter wi…
LukeMathWalker Jan 4, 2019
175f9a2
Refactored: use fold to sum
LukeMathWalker Jan 4, 2019
8e403af
Refactored: use a constant to reuse size value of recursion base case…
LukeMathWalker Jan 4, 2019
465063f
Added documentation
LukeMathWalker Jan 4, 2019
6427d45
Minor edits to the docs
LukeMathWalker Jan 4, 2019
4414450
Don't forget to add the sum of the last elements (<512 ending block).
LukeMathWalker Jan 4, 2019
aeaad0e
Add a benchmark for summing a contiguous array
LukeMathWalker Jan 5, 2019
75109b1
Benchmarks for arrays of different length
LukeMathWalker Jan 5, 2019
3085194
Don't split midpoint, saving one operation
LukeMathWalker Jan 5, 2019
797e212
Revert "Don't split midpoint, saving one operation"
LukeMathWalker Jan 5, 2019
d2b636b
Benches for sum_axis
LukeMathWalker Jan 5, 2019
b3d2b42
Bench for contiguous sum with integer values
LukeMathWalker Jan 5, 2019
8f95705
Alternative implementation for sum_axis
LukeMathWalker Jan 9, 2019
74a74ae
Revert "Alternative implementation for sum_axis"
LukeMathWalker Jan 9, 2019
a592a7d
Ensure equal block size independently of underlying implementation
LukeMathWalker Jan 9, 2019
f73fb2d
Change threshold names
LukeMathWalker Jan 22, 2019
c7fa091
Change sum_axis implementation
LukeMathWalker Jan 22, 2019
f72164a
Reduce partial accumulators pairwise in unrolled_fold
LukeMathWalker Jan 22, 2019
9f1c4d2
Remove unused imports
LukeMathWalker Jan 22, 2019
bbc4a75
Get uniform behaviour across all pairwise_sum implementations
LukeMathWalker Jan 22, 2019
b98e30b
Add more benchmarks of sum/sum_axis
jturner314 Feb 3, 2019
ed88e2e
Improve performance of iterator_pairwise_sum
jturner314 Feb 3, 2019
e7835ee
Make sum pairwise over all dimensions
jturner314 Feb 3, 2019
8301c25
Implement contiguous sum_axis in terms of Zip
jturner314 Feb 3, 2019
82453df
Remove redundant len_of call
jturner314 Feb 3, 2019
1d51f70
Merge pull request #3 from jturner314/pairwise-summation
LukeMathWalker Feb 3, 2019
978f45a
Added test for axis independence
LukeMathWalker Feb 3, 2019
fa0ba30
Make sure we actually exercise the pairwise technique
LukeMathWalker Feb 3, 2019
b4136d7
Test discontinuous arrays
LukeMathWalker Feb 3, 2019
4a63cb3
Add more integer benchmark equivalents
LukeMathWalker Feb 3, 2019
f306b5f
Fix min_stride_axis to prefer axes with length > 1
jturner314 Feb 3, 2019
b7951df
Specialize min_stride_axis for Ix3
jturner314 Feb 3, 2019
3326de4
Enable min_stride_axis as pub(crate) method
jturner314 Feb 3, 2019
65b6046
Simplify fold to use min_stride_axis
jturner314 Feb 3, 2019
b0b391a
Improve performance of sum in certain cases
jturner314 Feb 3, 2019
7f04e6f
Update quickcheck and use quickcheck_macros
jturner314 Feb 3, 2019
1ed1a63
Clarify capacity calculation in iterator_pairwise_sum
jturner314 Feb 4, 2019
1e88385
Merge pull request #4 from jturner314/pairwise-summation
LukeMathWalker Feb 4, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -45,8 +45,10 @@ serde = { version = "1.0", optional = true }

[dev-dependencies]
defmac = "0.2"
quickcheck = { version = "0.7.2", default-features = false }
quickcheck = { version = "0.8.1", default-features = false }
quickcheck_macros = "0.8"
rawpointer = "0.1"
rand = "0.5.5"

[features]
# Enable blas usage
200 changes: 198 additions & 2 deletions benches/numeric.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@

#![feature(test)]

extern crate test;
use test::Bencher;
use test::{black_box, Bencher};

extern crate ndarray;
use ndarray::prelude::*;
@@ -25,3 +24,200 @@ fn clip(bench: &mut Bencher)
})
});
}


#[bench]
fn contiguous_sum_1e7(bench: &mut Bencher)
{
let n = 1e7 as usize;
let a = Array::linspace(-1e6, 1e6, n);
bench.iter(|| {
a.sum()
});
}

#[bench]
fn contiguous_sum_int_1e7(bench: &mut Bencher)
{
let n = 1e7 as usize;
let a = Array::from_vec((0..n).collect());
bench.iter(|| {
a.sum()
});
}

#[bench]
fn contiguous_sum_1e4(bench: &mut Bencher)
{
let n = 1e4 as usize;
let a = Array::linspace(-1e6, 1e6, n);
bench.iter(|| {
a.sum()
});
}

#[bench]
fn contiguous_sum_int_1e4(bench: &mut Bencher)
{
let n = 1e4 as usize;
let a = Array::from_vec((0..n).collect());
bench.iter(|| {
a.sum()
});
}

#[bench]
fn contiguous_sum_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::linspace(-1e6, 1e6, n);
bench.iter(|| {
a.sum()
});
}

#[bench]
fn contiguous_sum_int_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::from_vec((0..n).collect());
bench.iter(|| {
a.sum()
});
}

#[bench]
fn contiguous_sum_ix3_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::linspace(-1e6, 1e6, n * n * n)
.into_shape([n, n, n])
.unwrap();
bench.iter(|| black_box(&a).sum());
}

#[bench]
fn contiguous_sum_int_ix3_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::from_vec((0..n.pow(3)).collect())
.into_shape([n, n, n])
.unwrap();
bench.iter(|| black_box(&a).sum());
}

#[bench]
fn inner_discontiguous_sum_ix3_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::linspace(-1e6, 1e6, n * n * 2*n)
.into_shape([n, n, 2*n])
.unwrap();
let v = a.slice(s![.., .., ..;2]);
bench.iter(|| black_box(&v).sum());
}

#[bench]
fn inner_discontiguous_sum_int_ix3_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::from_vec((0..(n.pow(3) * 2)).collect())
.into_shape([n, n, 2*n])
.unwrap();
let v = a.slice(s![.., .., ..;2]);
bench.iter(|| black_box(&v).sum());
}

#[bench]
fn middle_discontiguous_sum_ix3_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::linspace(-1e6, 1e6, n * 2*n * n)
.into_shape([n, 2*n, n])
.unwrap();
let v = a.slice(s![.., ..;2, ..]);
bench.iter(|| black_box(&v).sum());
}

#[bench]
fn middle_discontiguous_sum_int_ix3_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::from_vec((0..(n.pow(3) * 2)).collect())
.into_shape([n, 2*n, n])
.unwrap();
let v = a.slice(s![.., ..;2, ..]);
bench.iter(|| black_box(&v).sum());
}

#[bench]
fn sum_by_row_1e4(bench: &mut Bencher)
{
let n = 1e4 as usize;
let a = Array::linspace(-1e6, 1e6, n * n)
.into_shape([n, n])
.unwrap();
bench.iter(|| {
a.sum_axis(Axis(0))
});
}

#[bench]
fn sum_by_row_int_1e4(bench: &mut Bencher)
{
let n = 1e4 as usize;
let a = Array::from_vec((0..n.pow(2)).collect())
.into_shape([n, n])
.unwrap();
bench.iter(|| {
a.sum_axis(Axis(0))
});
}

#[bench]
fn sum_by_col_1e4(bench: &mut Bencher)
{
let n = 1e4 as usize;
let a = Array::linspace(-1e6, 1e6, n * n)
.into_shape([n, n])
.unwrap();
bench.iter(|| {
a.sum_axis(Axis(1))
});
}

#[bench]
fn sum_by_col_int_1e4(bench: &mut Bencher)
{
let n = 1e4 as usize;
let a = Array::from_vec((0..n.pow(2)).collect())
.into_shape([n, n])
.unwrap();
bench.iter(|| {
a.sum_axis(Axis(1))
});
}

#[bench]
fn sum_by_middle_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::linspace(-1e6, 1e6, n * n * n)
.into_shape([n, n, n])
.unwrap();
bench.iter(|| {
a.sum_axis(Axis(1))
});
}

#[bench]
fn sum_by_middle_int_1e2(bench: &mut Bencher)
{
let n = 1e2 as usize;
let a = Array::from_vec((0..n.pow(3)).collect())
.into_shape([n, n, n])
.unwrap();
bench.iter(|| {
a.sum_axis(Axis(1))
});
}
29 changes: 23 additions & 6 deletions src/dimension/dimension_trait.rs
Original file line number Diff line number Diff line change
@@ -291,8 +291,8 @@ pub trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
indices
}

/// Compute the minimum stride axis (absolute value), under the constraint
/// that the length of the axis is > 1;
/// Compute the minimum stride axis (absolute value), preferring axes with
/// length > 1.
#[doc(hidden)]
fn min_stride_axis(&self, strides: &Self) -> Axis {
let n = match self.ndim() {
@@ -301,7 +301,7 @@ pub trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
n => n,
};
axes_of(self, strides)
.rev()
.filter(|ax| ax.len() > 1)
.min_by_key(|ax| ax.stride().abs())
.map_or(Axis(n - 1), |ax| ax.axis())
}
@@ -588,9 +588,9 @@ impl Dimension for Dim<[Ix; 2]> {

#[inline]
fn min_stride_axis(&self, strides: &Self) -> Axis {
let s = get!(strides, 0) as Ixs;
let t = get!(strides, 1) as Ixs;
if s.abs() < t.abs() {
let s = (get!(strides, 0) as isize).abs();
let t = (get!(strides, 1) as isize).abs();
if s < t && get!(self, 0) > 1 {
Axis(0)
} else {
Axis(1)
@@ -697,6 +697,23 @@ impl Dimension for Dim<[Ix; 3]> {
Some(Ix3(i, j, k))
}

#[inline]
fn min_stride_axis(&self, strides: &Self) -> Axis {
let s = (get!(strides, 0) as isize).abs();
let t = (get!(strides, 1) as isize).abs();
let u = (get!(strides, 2) as isize).abs();
let (argmin, min) = if t < u && get!(self, 1) > 1 {
(Axis(1), t)
} else {
(Axis(2), u)
};
if s < min && get!(self, 0) > 1 {
Axis(0)
} else {
argmin
}
}

/// Self is an index, return the stride offset
#[inline]
fn stride_offset(index: &Self, strides: &Self) -> isize {
157 changes: 78 additions & 79 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
@@ -629,7 +629,8 @@ mod test {
use crate::error::{from_kind, ErrorKind};
use crate::slice::Slice;
use num_integer::gcd;
use quickcheck::{quickcheck, TestResult};
use quickcheck::TestResult;
use quickcheck_macros::quickcheck;

#[test]
fn slice_indexing_uncommon_strides() {
@@ -738,30 +739,29 @@ mod test {
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
}

quickcheck! {
fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec<u8>, dim: Vec<usize>) -> bool {
let dim = IxDyn(&dim);
let result = can_index_slice_not_custom(&data, &dim);
if dim.size_checked().is_none() {
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
result.is_err()
} else {
result == can_index_slice(&data, &dim, &dim.default_strides()) &&
result == can_index_slice(&data, &dim, &dim.fortran_strides())
}
#[quickcheck]
fn can_index_slice_not_custom_same_as_can_index_slice(data: Vec<u8>, dim: Vec<usize>) -> bool {
let dim = IxDyn(&dim);
let result = can_index_slice_not_custom(&data, &dim);
if dim.size_checked().is_none() {
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
result.is_err()
} else {
result == can_index_slice(&data, &dim, &dim.default_strides()) &&
result == can_index_slice(&data, &dim, &dim.fortran_strides())
}
}

quickcheck! {
fn extended_gcd_solves_eq(a: isize, b: isize) -> bool {
let (g, (x, y)) = extended_gcd(a, b);
a * x + b * y == g
}
#[quickcheck]
fn extended_gcd_solves_eq(a: isize, b: isize) -> bool {
let (g, (x, y)) = extended_gcd(a, b);
a * x + b * y == g
}

fn extended_gcd_correct_gcd(a: isize, b: isize) -> bool {
let (g, _) = extended_gcd(a, b);
g == gcd(a, b)
}
#[quickcheck]
fn extended_gcd_correct_gcd(a: isize, b: isize) -> bool {
let (g, _) = extended_gcd(a, b);
g == gcd(a, b)
}

#[test]
@@ -773,73 +773,72 @@ mod test {
assert_eq!(extended_gcd(-5, 0), (5, (-1, 0)));
}

quickcheck! {
fn solve_linear_diophantine_eq_solution_existence(
a: isize, b: isize, c: isize
) -> TestResult {
if a == 0 || b == 0 {
TestResult::discard()
} else {
TestResult::from_bool(
(c % gcd(a, b) == 0) == solve_linear_diophantine_eq(a, b, c).is_some()
)
}
#[quickcheck]
fn solve_linear_diophantine_eq_solution_existence(
a: isize, b: isize, c: isize
) -> TestResult {
if a == 0 || b == 0 {
TestResult::discard()
} else {
TestResult::from_bool(
(c % gcd(a, b) == 0) == solve_linear_diophantine_eq(a, b, c).is_some()
)
}
}

fn solve_linear_diophantine_eq_correct_solution(
a: isize, b: isize, c: isize, t: isize
) -> TestResult {
if a == 0 || b == 0 {
TestResult::discard()
} else {
match solve_linear_diophantine_eq(a, b, c) {
Some((x0, xd)) => {
let x = x0 + xd * t;
let y = (c - a * x) / b;
TestResult::from_bool(a * x + b * y == c)
}
None => TestResult::discard(),
#[quickcheck]
fn solve_linear_diophantine_eq_correct_solution(
a: isize, b: isize, c: isize, t: isize
) -> TestResult {
if a == 0 || b == 0 {
TestResult::discard()
} else {
match solve_linear_diophantine_eq(a, b, c) {
Some((x0, xd)) => {
let x = x0 + xd * t;
let y = (c - a * x) / b;
TestResult::from_bool(a * x + b * y == c)
}
None => TestResult::discard(),
}
}
}

quickcheck! {
fn arith_seq_intersect_correct(
first1: isize, len1: isize, step1: isize,
first2: isize, len2: isize, step2: isize
) -> TestResult {
use std::cmp;
#[quickcheck]
fn arith_seq_intersect_correct(
first1: isize, len1: isize, step1: isize,
first2: isize, len2: isize, step2: isize
) -> TestResult {
use std::cmp;

if len1 == 0 || len2 == 0 {
// This case is impossible to reach in `arith_seq_intersect()`
// because the `min*` and `max*` arguments are inclusive.
return TestResult::discard();
}
let len1 = len1.abs();
let len2 = len2.abs();

// Convert to `min*` and `max*` arguments for `arith_seq_intersect()`.
let last1 = first1 + step1 * (len1 - 1);
let (min1, max1) = (cmp::min(first1, last1), cmp::max(first1, last1));
let last2 = first2 + step2 * (len2 - 1);
let (min2, max2) = (cmp::min(first2, last2), cmp::max(first2, last2));

// Naively determine if the sequences intersect.
let seq1: Vec<_> = (0..len1)
.map(|n| first1 + step1 * n)
.collect();
let intersects = (0..len2)
.map(|n| first2 + step2 * n)
.any(|elem2| seq1.contains(&elem2));

TestResult::from_bool(
arith_seq_intersect(
(min1, max1, if step1 == 0 { 1 } else { step1 }),
(min2, max2, if step2 == 0 { 1 } else { step2 })
) == intersects
)
if len1 == 0 || len2 == 0 {
// This case is impossible to reach in `arith_seq_intersect()`
// because the `min*` and `max*` arguments are inclusive.
return TestResult::discard();
}
let len1 = len1.abs();
let len2 = len2.abs();

// Convert to `min*` and `max*` arguments for `arith_seq_intersect()`.
let last1 = first1 + step1 * (len1 - 1);
let (min1, max1) = (cmp::min(first1, last1), cmp::max(first1, last1));
let last2 = first2 + step2 * (len2 - 1);
let (min2, max2) = (cmp::min(first2, last2), cmp::max(first2, last2));

// Naively determine if the sequences intersect.
let seq1: Vec<_> = (0..len1)
.map(|n| first1 + step1 * n)
.collect();
let intersects = (0..len2)
.map(|n| first2 + step2 * n)
.any(|elem2| seq1.contains(&elem2));

TestResult::from_bool(
arith_seq_intersect(
(min1, max1, if step1 == 0 { 1 } else { step1 }),
(min2, max2, if step2 == 0 { 1 } else { step2 })
) == intersects
)
}

#[test]
70 changes: 47 additions & 23 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
@@ -1619,12 +1619,11 @@ where
axes_of(&self.dim, &self.strides)
}

/*
/// Return the axis with the least stride (by absolute value)
pub fn min_stride_axis(&self) -> Axis {
/// Return the axis with the least stride (by absolute value),
/// preferring axes with len > 1.
pub(crate) fn min_stride_axis(&self) -> Axis {
self.dim.min_stride_axis(&self.strides)
}
*/

/// Return the axis with the greatest stride (by absolute value),
/// preferring axes with len > 1.
@@ -1854,25 +1853,11 @@ where
} else {
let mut v = self.view();
// put the narrowest axis at the last position
match v.ndim() {
0 | 1 => {}
2 => {
if self.len_of(Axis(1)) <= 1
|| self.len_of(Axis(0)) > 1
&& self.stride_of(Axis(0)).abs() < self.stride_of(Axis(1)).abs()
{
v.swap_axes(0, 1);
}
}
n => {
let last = n - 1;
let narrow_axis = v
.axes()
.filter(|ax| ax.len() > 1)
.min_by_key(|ax| ax.stride().abs())
.map_or(last, |ax| ax.axis().index());
v.swap_axes(last, narrow_axis);
}
let n = v.ndim();
if n > 1 {
let last = n - 1;
let narrow_axis = self.min_stride_axis();
v.swap_axes(last, narrow_axis.index());
}
v.into_elements_base().fold(init, f)
}
@@ -2103,3 +2088,42 @@ where
})
}
}

#[cfg(test)]
mod tests {
use crate::prelude::*;

#[test]
fn min_stride_axis() {
let a = Array1::<u8>::zeros(10);
assert_eq!(a.min_stride_axis(), Axis(0));

let a = Array2::<u8>::zeros((3, 3));
assert_eq!(a.min_stride_axis(), Axis(1));
assert_eq!(a.t().min_stride_axis(), Axis(0));

let a = ArrayD::<u8>::zeros(vec![3, 3]);
assert_eq!(a.min_stride_axis(), Axis(1));
assert_eq!(a.t().min_stride_axis(), Axis(0));

let min_axis = a.axes().min_by_key(|t| t.2.abs()).unwrap().axis();
assert_eq!(min_axis, Axis(1));

let mut b = ArrayD::<u8>::zeros(vec![2, 3, 4, 5]);
assert_eq!(b.min_stride_axis(), Axis(3));
for ax in 0..3 {
b.swap_axes(3, ax);
assert_eq!(b.min_stride_axis(), Axis(ax));
b.swap_axes(3, ax);
}
let mut v = b.view();
v.collapse_axis(Axis(3), 0);
assert_eq!(v.min_stride_axis(), Axis(2));

let a = Array2::<u8>::zeros((3, 3));
let v = a.broadcast((8, 3, 3)).unwrap();
assert_eq!(v.min_stride_axis(), Axis(0));
let v2 = a.broadcast((1, 3, 3)).unwrap();
assert_eq!(v2.min_stride_axis(), Axis(2));
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -104,6 +104,10 @@ extern crate num_integer;

#[cfg(test)]
extern crate quickcheck;
#[cfg(test)]
extern crate quickcheck_macros;
#[cfg(test)]
extern crate rand;

#[cfg(feature = "docs")]
pub mod doc;
186 changes: 165 additions & 21 deletions src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
@@ -8,7 +8,6 @@

use std::ops::{Add, Div, Mul};
use num_traits::{self, Zero, Float, FromPrimitive};
use itertools::free::enumerate;

use crate::imp_prelude::*;
use crate::numeric_util;
@@ -33,17 +32,17 @@ impl<A, S, D> ArrayBase<S, D>
where A: Clone + Add<Output=A> + num_traits::Zero,
{
if let Some(slc) = self.as_slice_memory_order() {
return numeric_util::unrolled_fold(slc, A::zero, A::add);
return numeric_util::pairwise_sum(&slc);
}
let mut sum = A::zero();
for row in self.inner_rows() {
if let Some(slc) = row.as_slice() {
sum = sum + numeric_util::unrolled_fold(slc, A::zero, A::add);
} else {
sum = sum + row.iter().fold(A::zero(), |acc, elt| acc + elt.clone());
if self.ndim() > 1 {
let ax = self.dim.min_stride_axis(&self.strides);
if self.len_of(ax) >= numeric_util::UNROLL_SIZE && self.stride_of(ax) == 1 {
let partial_sums: Vec<_> =
self.lanes(ax).into_iter().map(|lane| lane.sum()).collect();
return numeric_util::pure_pairwise_sum(&partial_sums);
}
}
sum
numeric_util::iterator_pairwise_sum(self.iter())
}

/// Return the sum of all elements in the array.
@@ -104,21 +103,19 @@ impl<A, S, D> ArrayBase<S, D>
D: RemoveAxis,
{
let n = self.len_of(axis);
let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
let stride = self.strides()[axis.index()];
if self.ndim() == 2 && stride == 1 {
if self.stride_of(axis) == 1 {
// contiguous along the axis we are summing
let ax = axis.index();
for (i, elt) in enumerate(&mut res) {
*elt = self.index_axis(Axis(1 - ax), i).sum();
}
let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
Zip::from(&mut res)
.and(self.lanes(axis))
.apply(|sum, lane| *sum = lane.sum());
res
} else if n <= numeric_util::NAIVE_SUM_THRESHOLD {
self.fold_axis(axis, A::zero(), |acc, x| acc.clone() + x.clone())
} else {
for i in 0..n {
let view = self.index_axis(axis, i);
res = res + &view;
}
let (v1, v2) = self.view().split_at(axis, n / 2);
v1.sum_axis(axis) + v2.sum_axis(axis)
}
res
}

/// Return mean along `axis`.
@@ -287,3 +284,150 @@ impl<A, S, D> ArrayBase<S, D>
}
}

#[cfg(test)]
mod tests {
use super::*;
use super::numeric_util::{NAIVE_SUM_THRESHOLD, UNROLL_SIZE};
use self::{Array, s};
use quickcheck::{QuickCheck, StdGen, TestResult};

#[test]
fn test_sum_value_does_not_depend_on_axis() {
// `size` controls the length of the array of data
// We set it to be randomly drawn between 0 and
// a number larger than NAIVE_SUM_THRESHOLD * UNROLL_SIZE
let rng = StdGen::new(
rand::thread_rng(),
5* (NAIVE_SUM_THRESHOLD * UNROLL_SIZE).pow(3)
);
let mut quickcheck = QuickCheck::new().gen(rng).tests(100);
quickcheck.quickcheck(
_sum_value_does_not_depend_on_axis
as fn(
Vec<f64>
) -> TestResult,
);
}

fn _sum_value_does_not_depend_on_axis(xs: Vec<f64>) -> TestResult {
// We want three axis of equal length - we drop some elements
// to get the right number
let axis_length = (xs.len() as f64).cbrt().floor() as usize;
let xs = &xs[..axis_length.pow(3)];

// We want to check that summing with respect to an axis
// is independent from the specific underlying implementation of
// pairwise sum, which is itself conditional on the arrangement
// in memory of the array elements.
// We will thus swap axes and compute the sum, in turn, with respect to
// axes 0, 1 and 2, while making sure that mathematically the same
// number should be spit out (because we are properly transposing before summing).
if axis_length > 0 {
let (a, b, c) = equivalent_arrays(xs.to_vec(), axis_length);

let sum1 = a.sum_axis(Axis(0));
let sum2 = b.sum_axis(Axis(1));
let sum3 = c.sum_axis(Axis(2));

let tol = 1e-10;
let first = (sum2.clone() - sum1.clone()).iter().all(|x| x.abs() < tol);
let second = (sum3.clone() - sum1.clone()).iter().all(|x| x.abs() < tol);
let third = (sum3.clone() - sum2.clone()).iter().all(|x| x.abs() < tol);

if first && second && third {
TestResult::passed()
} else {
TestResult::failed()
}
} else {
TestResult::passed()
}
}

#[test]
fn test_sum_value_does_not_depend_on_axis_with_discontinuous_array() {
// `size` controls the length of the array of data
// We set it to be randomly drawn between 0 and
// a number larger than NAIVE_SUM_THRESHOLD * UNROLL_SIZE
let rng = StdGen::new(
rand::thread_rng(),
5* (NAIVE_SUM_THRESHOLD * UNROLL_SIZE).pow(3)
);
let mut quickcheck = QuickCheck::new().gen(rng).tests(100);
quickcheck.quickcheck(
_sum_value_does_not_depend_on_axis_w_discontinuous_array
as fn(
Vec<f64>
) -> TestResult,
);
}

fn _sum_value_does_not_depend_on_axis_w_discontinuous_array(xs: Vec<f64>) -> TestResult {
// We want three axis of equal length - we drop some elements
// to get the right number
let axis_length = (xs.len() as f64).cbrt().floor() as usize;
let xs = &xs[..axis_length.pow(3)];

// We want to check that summing with respect to an axis
// is independent from the specific underlying implementation of
// pairwise sum, which is itself conditional on the arrangement
// in memory of the array elements.
// We will thus swap axes and compute the sum, in turn, with respect to
// axes 0, 1 and 2, while making sure that mathematically the same
// number should be spit out (because we are properly transposing before summing).
if axis_length > 0 {
let (a, b, c) = equivalent_arrays(xs.to_vec(), axis_length);

let sum1 = a.slice(s![..;2, .., ..]).sum_axis(Axis(0));
let sum2 = b.slice(s![.., ..;2, ..]).sum_axis(Axis(1));
let sum3 = c.slice(s![.., .., ..;2]).sum_axis(Axis(2));

let tol = 1e-10;
let first = (sum2.clone() - sum1.clone()).iter().all(|x| x.abs() < tol);
let second = (sum3.clone() - sum1.clone()).iter().all(|x| x.abs() < tol);
let third = (sum3.clone() - sum2.clone()).iter().all(|x| x.abs() < tol);

if first && second && third {
TestResult::passed()
} else {
TestResult::failed()
}
} else {
TestResult::passed()
}
}

// Given a vector with axis_length^3 elements, it returns three arrays,
// built using the vector elements, such that (mathematically):
// a.sum_axis(Axis(0) == b.sum_axis(Axis(1)) == c.sum_axis(Axis(2))
fn equivalent_arrays(xs: Vec<f64>, axis_length: usize) -> (Array3<f64>, Array3<f64>, Array3<f64>) {
assert!(xs.len() == axis_length.pow(3));

let a = Array::from_vec(xs)
.into_shape((axis_length, axis_length, axis_length))
.unwrap();
assert!(a.is_standard_layout());

let mut b = Array::zeros(a.raw_dim());
assert!(b.is_standard_layout());
for i in 0..axis_length {
for j in 0..axis_length {
for k in 0..axis_length {
b[(i, j, k)] = a[(j, i, k)].clone();
}
}
}

let mut c = Array::zeros(a.raw_dim());
assert!(c.is_standard_layout());
for i in 0..axis_length {
for j in 0..axis_length {
for k in 0..axis_length {
c[(i, j, k)] = a[(k, i, j)].clone();
}
}
}
return (a, b, c)
}

}
114 changes: 106 additions & 8 deletions src/numeric_util.rs
Original file line number Diff line number Diff line change
@@ -6,9 +6,95 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use std::cmp;

use std::ops::Add;
use num_traits::{self, Zero};
use crate::LinalgScalar;

/// Size threshold to switch to naive summation in all implementations of pairwise summation.
#[cfg(not(test))]
pub(crate) const NAIVE_SUM_THRESHOLD: usize = 64;
// Set it to a smaller number for testing purposes
#[cfg(test)]
pub(crate) const NAIVE_SUM_THRESHOLD: usize = 2;

/// Number of elements processed by unrolled operators (to leverage SIMD instructions).
pub(crate) const UNROLL_SIZE: usize = 8;

/// An implementation of pairwise summation for a vector slice.
///
/// Pairwise summation compute the sum of a set of *n* numbers by splitting
/// it recursively in two halves, summing their elements and then adding the respective
/// sums.
/// It switches to the naive sum algorithm once the size of the set to be summed
/// is below a certain pre-defined threshold ([`threshold`]).
///
/// Pairwise summation is useful to reduce the accumulated round-off error
/// when summing floating point numbers.
/// Pairwise summation provides an asymptotic error bound of *O(ε* log *n)*, where
/// *ε* is machine precision, compared to *O(εn)* of the naive summation algorithm.
/// For more details, see [`paper`] or [`Wikipedia`].
///
/// [`paper`]: https://epubs.siam.org/doi/10.1137/0914050
/// [`Wikipedia`]: https://en.wikipedia.org/wiki/Pairwise_summation
/// [`threshold`]: constant.NAIVE_SUM_THRESHOLD.html
pub(crate) fn pairwise_sum<A>(v: &[A]) -> A
where
A: Clone + Add<Output=A> + Zero,
{
let n = v.len();
if n <= NAIVE_SUM_THRESHOLD * UNROLL_SIZE {
return unrolled_fold(v, A::zero, A::add);
} else {
let mid_index = n / 2;
let (v1, v2) = v.split_at(mid_index);
pairwise_sum(v1) + pairwise_sum(v2)
}
}

/// An implementation of pairwise summation for an iterator.
///
/// See [`pairwise_sum`] for details on the algorithm.
///
/// [`pairwise_sum`]: fn.pairwise_sum.html
pub(crate) fn iterator_pairwise_sum<'a, I, A: 'a>(iter: I) -> A
where
I: Iterator<Item=&'a A>,
A: Clone + Add<Output=A> + Zero,
{
let (len, _) = iter.size_hint();
let cap = len / NAIVE_SUM_THRESHOLD + if len % NAIVE_SUM_THRESHOLD != 0 { 1 } else { 0 };
let mut partial_sums = Vec::with_capacity(cap);
let (_, last_sum) = iter.fold((0, A::zero()), |(count, partial_sum), x| {
if count < NAIVE_SUM_THRESHOLD {
(count + 1, partial_sum + x.clone())
} else {
partial_sums.push(partial_sum);
(1, x.clone())
}
});
partial_sums.push(last_sum);

pure_pairwise_sum(&partial_sums)
}

/// An implementation of pairwise summation for a vector slice that never
/// switches to the naive sum algorithm.
pub(crate) fn pure_pairwise_sum<A>(v: &[A]) -> A
where
A: Clone + Add<Output=A> + Zero,
{
let n = v.len();
match n {
0 => A::zero(),
1 => v[0].clone(),
n => {
let mid_index = n / 2;
let (v1, v2) = v.split_at(mid_index);
pure_pairwise_sum(v1) + pure_pairwise_sum(v2)
}
}
}

/// Fold over the manually unrolled `xs` with `f`
pub fn unrolled_fold<A, I, F>(mut xs: &[A], init: I, f: F) -> A
where A: Clone,
@@ -17,7 +103,6 @@ pub fn unrolled_fold<A, I, F>(mut xs: &[A], init: I, f: F) -> A
{
// eightfold unrolled so that floating point can be vectorized
// (even with strict floating point accuracy semantics)
let mut acc = init();
let (mut p0, mut p1, mut p2, mut p3,
mut p4, mut p5, mut p6, mut p7) =
(init(), init(), init(), init(),
@@ -34,18 +119,19 @@ pub fn unrolled_fold<A, I, F>(mut xs: &[A], init: I, f: F) -> A

xs = &xs[8..];
}
acc = f(acc.clone(), f(p0, p4));
acc = f(acc.clone(), f(p1, p5));
acc = f(acc.clone(), f(p2, p6));
acc = f(acc.clone(), f(p3, p7));
let (q0, q1, q2, q3) = (f(p0, p4), f(p1, p5), f(p2, p6), f(p3, p7));
let (r0, r1) = (f(q0, q2), f(q1, q3));
let unrolled = f(r0, r1);

// make it clear to the optimizer that this loop is short
// and can not be autovectorized.
let mut partial = init();
for i in 0..xs.len() {
if i >= 7 { break; }
acc = f(acc.clone(), xs[i].clone())
partial = f(partial.clone(), xs[i].clone())
}
acc

f(unrolled, partial)
}

/// Compute the dot product.
@@ -126,3 +212,15 @@ pub fn unrolled_eq<A>(xs: &[A], ys: &[A]) -> bool

true
}

#[cfg(test)]
mod tests {
use quickcheck_macros::quickcheck;
use std::num::Wrapping;
use super::iterator_pairwise_sum;

#[quickcheck]
fn iterator_pairwise_sum_is_correct(xs: Vec<Wrapping<i32>>) -> bool {
iterator_pairwise_sum(xs.iter()) == xs.iter().sum()
}
}
31 changes: 0 additions & 31 deletions tests/dimension.rs
Original file line number Diff line number Diff line change
@@ -132,37 +132,6 @@ fn fastest_varying_order() {

type ArrayF32<D> = Array<f32, D>;

/*
#[test]
fn min_stride_axis() {
let a = ArrayF32::zeros(10);
assert_eq!(a.min_stride_axis(), Axis(0));
let a = ArrayF32::zeros((3, 3));
assert_eq!(a.min_stride_axis(), Axis(1));
assert_eq!(a.t().min_stride_axis(), Axis(0));
let a = ArrayF32::zeros(vec![3, 3]);
assert_eq!(a.min_stride_axis(), Axis(1));
assert_eq!(a.t().min_stride_axis(), Axis(0));
let min_axis = a.axes().min_by_key(|t| t.2.abs()).unwrap().axis();
assert_eq!(min_axis, Axis(1));
let mut b = ArrayF32::zeros(vec![2, 3, 4, 5]);
assert_eq!(b.min_stride_axis(), Axis(3));
for ax in 0..3 {
b.swap_axes(3, ax);
assert_eq!(b.min_stride_axis(), Axis(ax));
b.swap_axes(3, ax);
}
let a = ArrayF32::zeros((3, 3));
let v = a.broadcast((8, 3, 3)).unwrap();
assert_eq!(v.min_stride_axis(), Axis(0));
}
*/

#[test]
fn max_stride_axis() {
let a = ArrayF32::zeros(10);