Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion libm-test/benches/icount.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ fn icount_bench_u128_widen_mul(cases: Vec<(u128, u128)>) {
}
}

#[library_benchmark]
#[bench::linspace(setup_u128_mul())]
fn icount_bench_u256_narrowing_div(cases: Vec<(u128, u128)>) {
use libm::support::NarrowingDiv;
for (x, y) in cases.iter().copied() {
let x = black_box(x.widen_hi());
let y = black_box(y);
black_box(x.checked_narrowing_div_rem(y));
}
}

#[library_benchmark]
#[bench::linspace(setup_u256_add())]
fn icount_bench_u256_add(cases: Vec<(u256, u256)>) {
Expand Down Expand Up @@ -145,7 +156,7 @@ fn icount_bench_u256_shr(cases: Vec<(u256, u32)>) {

library_benchmark_group!(
name = icount_bench_u128_group;
benchmarks = icount_bench_u128_widen_mul, icount_bench_u256_add, icount_bench_u256_sub, icount_bench_u256_shl, icount_bench_u256_shr
benchmarks = icount_bench_u128_widen_mul, icount_bench_u256_narrowing_div, icount_bench_u256_add, icount_bench_u256_sub, icount_bench_u256_shl, icount_bench_u256_shr
);

#[library_benchmark]
Expand Down
26 changes: 26 additions & 0 deletions libm/src/math/support/big/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::string::String;
use std::{eprintln, format};

use super::{HInt, MinInt, i256, u256};
use crate::support::{Int as _, NarrowingDiv};

const LOHI_SPLIT: u128 = 0xaaaaaaaaaaaaaaaaffffffffffffffff;

Expand Down Expand Up @@ -336,3 +337,28 @@ fn i256_shifts() {
x = y;
}
}
#[test]
fn div_u256_by_u128() {
for j in i8::MIN..=i8::MAX {
let y: u128 = (j as i128).rotate_right(4).unsigned();
if y == 0 {
continue;
}
for i in i8::MIN..=i8::MAX {
let x: u128 = (i as i128).rotate_right(4).unsigned();
let xy = x.widen_mul(y);
assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0)));
if y != 1 {
assert_eq!((xy + u256::ONE).checked_narrowing_div_rem(y), Some((x, 1)));
}
if x != 0 {
assert_eq!(
(xy - u256::ONE).checked_narrowing_div_rem(y),
Some((x - 1, y - 1))
);
}
let r = ((y as f64) * 0.12345) as u128;
assert_eq!((xy + r.widen()).checked_narrowing_div_rem(y), Some((x, r)));
}
}
}
3 changes: 3 additions & 0 deletions libm/src/math/support/int_traits.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use core::{cmp, fmt, ops};

mod narrowing_div;
pub use narrowing_div::NarrowingDiv;

/// Minimal integer implementations needed on all integer types, including wide integers.
#[allow(dead_code)] // Some constants are only used with tests
pub trait MinInt:
Expand Down
176 changes: 176 additions & 0 deletions libm/src/math/support/int_traits/narrowing_div.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/* SPDX-License-Identifier: MIT OR Apache-2.0 */
use crate::support::{CastInto, DInt, HInt, Int, MinInt, u256};

/// Trait for unsigned division of a double-wide integer
/// when the quotient doesn't overflow.
///
/// This is the inverse of widening multiplication:
/// - for any `x` and nonzero `y`: `x.widen_mul(y).checked_narrowing_div_rem(y) == Some((x, 0))`,
/// - and for any `r in 0..y`: `x.carrying_mul(y, r).checked_narrowing_div_rem(y) == Some((x, r))`,
#[allow(dead_code)]
pub trait NarrowingDiv: DInt + MinInt<Unsigned = Self> {
/// Computes `(self / n, self % n))`
///
/// # Safety
/// The caller must ensure that `self.hi() < n`, or equivalently,
/// that the quotient does not overflow.
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H);

/// Returns `Some((self / n, self % n))` when `self.hi() < n`.
fn checked_narrowing_div_rem(self, n: Self::H) -> Option<(Self::H, Self::H)> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: from std's convention, this can just be narrowing_div_rem

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which convention? Option-returning methods on e.g. u32 all seem to be named checked_*

if self.hi() < n {
Some(unsafe { self.unchecked_narrowing_div_rem(n) })
} else {
None
}
}
}

// For primitive types we can just use the standard
// division operators in the double-wide type.
macro_rules! impl_narrowing_div_primitive {
($D:ident) => {
impl NarrowingDiv for $D {
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
if self.hi() >= n {
unsafe { core::hint::unreachable_unchecked() }
}
((self / n.widen()).cast(), (self % n.widen()).cast())
}
}
};
}

// Extend division from `u2N / uN` to `u4N / u2N`
// This is not the most efficient algorithm, but it is
// relatively simple.
macro_rules! impl_narrowing_div_recurse {
($D:ident) => {
impl NarrowingDiv for $D {
unsafe fn unchecked_narrowing_div_rem(self, n: Self::H) -> (Self::H, Self::H) {
if self.hi() >= n {
unsafe { core::hint::unreachable_unchecked() }
}

// Normalize the divisor by shifting the most significant one
// to the leading position. `n != 0` is implied by `self.hi() < n`
let lz = n.leading_zeros();
let a = self << lz;
let b = n << lz;
Comment on lines +55 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make any sense to check if a.hi() == 0 and do a normal div in a smaller type in that case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that could make sense, and it should be even better to check for self.hi() == 0 before normalizing the divisor. However, I don't think that case would be hit very often. I also like how the current construction only depends on NarrowingDiv for the smaller type.


let ah = a.hi();
let (a0, a1) = a.lo().lo_hi();
// SAFETY: For both calls, `b.leading_zeros() == 0` by the above shift.
// SAFETY: `ah < b` follows from `self.hi() < n`
let (q1, r) = unsafe { div_three_digits_by_two(a1, ah, b) };
// SAFETY: `r < b` is given as the postcondition of the previous call
let (q0, r) = unsafe { div_three_digits_by_two(a0, r, b) };

// Undo the earlier normalization for the remainder
(Self::H::from_lo_hi(q0, q1), r >> lz)
}
}
};
}

impl_narrowing_div_primitive!(u16);
impl_narrowing_div_primitive!(u32);
impl_narrowing_div_primitive!(u64);
impl_narrowing_div_primitive!(u128);
impl_narrowing_div_recurse!(u256);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this one is only used once, it can probably just be an impl without the macro. Please no f256 😆

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The macro was convenient for more thorough testing by using the recursive construction all the way from u16, and I'd want to do that for any future changes to it.


/// Implement `u3N / u2N`-division on top of `u2N / uN`-division.
///
/// Returns the quotient and remainder of `(a * R + a0) / n`,
/// where `R = (1 << U::BITS)` is the digit size.
///
/// # Safety
/// Requires that `n.leading_zeros() == 0` and `a < n`.
unsafe fn div_three_digits_by_two<U>(a0: U, a: U::D, n: U::D) -> (U, U::D)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe div_three_3x_by_2x or so? Just thinking that the inputs aren't really digits

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are digits in base R = 2.pow(U::BITS), and this function does division of a 3-digit number by a 2-digit number in that base. Admittedly the signature is rather weird, passing the "high two thirds" of the 3-digit number as one U::D and the least significant third as a U.

where
U: HInt,
U::D: Int + NarrowingDiv,
{
if n.leading_zeros() > 0 || a >= n {
unsafe { core::hint::unreachable_unchecked() }
}

// n = n1R + n0
let (n0, n1) = n.lo_hi();
// a = a2R + a1
let (a1, a2) = a.lo_hi();

let mut q;
let mut r;
let mut wrap;
// `a < n` is guaranteed by the caller, but `a2 == n1 && a1 < n0` is possible
if let Some((q0, r1)) = a.checked_narrowing_div_rem(n1) {
q = q0;
// a = qn1 + r1, where 0 <= r1 < n1

// Include the remainder with the low bits:
// r = a0 + r1R
r = U::D::from_lo_hi(a0, r1);

// Subtract the contribution of the divisor low bits with the estimated quotient
let d = q.widen_mul(n0);
(r, wrap) = r.overflowing_sub(d);

// Since `q` is the quotient of dividing with a slightly smaller divisor,
// it may be an overapproximation, but is never too small, and similarly,
// `r` is now either the correct remainder ...
if !wrap {
return (q, r);
}
// ... or the remainder went "negative" (by as much as `d = qn0 < RR`)
// and we have to adjust.
q -= U::ONE;
} else {
debug_assert!(a2 == n1 && a1 < n0);
// Otherwise, `a2 == n1`, and the estimated quotient would be
// `R + (a1 % n1)`, but the correct quotient can't overflow.
// We'll start from `q = R = (1 << U::BITS)`,
// so `r = aR + a0 - qn = (a - n)R + a0`
r = U::D::from_lo_hi(a0, a1.wrapping_sub(n0));
// Since `a < n`, the first decrement is always needed:
q = U::MAX; /* R - 1 */
}

(r, wrap) = r.overflowing_add(n);
if wrap {
return (q, r);
}

// If the remainder still didn't wrap, we need another step.
q -= U::ONE;
(r, wrap) = r.overflowing_add(n);
// Since `n >= RR/2`, at least one of the two `r += n` must have wrapped.
debug_assert!(wrap, "estimated quotient should be off by at most two");
(q, r)
}

#[cfg(test)]
mod test {
use super::{HInt, NarrowingDiv};

#[test]
fn inverse_mul() {
for x in 0..=u8::MAX {
for y in 1..=u8::MAX {
let xy = x.widen_mul(y);
assert_eq!(xy.checked_narrowing_div_rem(y), Some((x, 0)));
assert_eq!(
(xy + (y - 1) as u16).checked_narrowing_div_rem(y),
Some((x, y - 1))
);
if y > 1 {
assert_eq!((xy + 1).checked_narrowing_div_rem(y), Some((x, 1)));
assert_eq!(
(xy + (y - 2) as u16).checked_narrowing_div_rem(y),
Some((x, y - 2))
);
}
}
}
}
}
3 changes: 2 additions & 1 deletion libm/src/math/support/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ pub use hex_float::hf16;
pub use hex_float::hf128;
#[allow(unused_imports)]
pub use hex_float::{hf32, hf64};
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt};
#[allow(unused_imports)]
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt, NarrowingDiv};

/// Hint to the compiler that the current path is cold.
pub fn cold_path() {
Expand Down
Loading