Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 68 additions & 15 deletions src/symmetric/prf/shake_to_field.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,39 @@
use crate::F;

use super::Pseudorandom;
use p3_field::PrimeCharacteristicRing;
use p3_field::{PrimeCharacteristicRing, PrimeField64};
use serde::{Serialize, de::DeserializeOwned};
use sha3::{
Shake128,
digest::{ExtendableOutput, Update, XofReader},
};

/// Number of pseudorandom bytes to generate one pseudorandom field element
const PRF_BYTES_PER_FE: usize = 16;
/// Number of pseudorandom bytes to generate one pseudorandom field element.
///
/// Per RFC 9380 (hash-to-field), L = ceil((ceil(log2(p)) + k) / 8) where k is
/// the security parameter. For KoalaBear (p = 2^31 - 2^24 + 1, ceil(log2(p)) = 31)
/// and k = 128 (matching SHAKE128): L = ceil((31 + 128) / 8) = 20.
///
/// This gives a statistical distance from uniform of at most p / 2^161 < 2^{-129},
/// meeting the 128-bit security target.
const PRF_BYTES_PER_FE: usize = 20;

/// Reduce a 160-bit big-endian value to a field element with negligible bias.
///
/// Splits the 20-byte input into a 128-bit high part and a 32-bit low part,
/// then computes (hi * 2^32 + lo) mod p using native u128 arithmetic.
#[inline]
fn reduce_160_to_field(buf: &[u8; PRF_BYTES_PER_FE]) -> F {
let hi = u128::from_be_bytes(buf[..16].try_into().unwrap());
let lo = u32::from_be_bytes(buf[16..20].try_into().unwrap()) as u128;

let p = F::ORDER_U64 as u128;
let hi_mod = hi % p;
let two_32_mod_p = (1u128 << 32) % p;

let reduced = (hi_mod * two_32_mod_p + lo) % p;
F::from_u64(reduced as u64)
}

const KEY_LENGTH: usize = 32; // 32 bytes
const PRF_DOMAIN_SEP: [u8; 16] = [
Expand Down Expand Up @@ -61,14 +85,9 @@ where

// Mapping bytes to field elements
std::array::from_fn(|_| {
// Buffer to store the output
let mut buf = [0u8; PRF_BYTES_PER_FE];

// Read the extended output into the buffer
xof_reader.read(&mut buf);

// Mapping bytes to a field element
F::from_u128(u128::from_be_bytes(buf))
reduce_160_to_field(&buf)
})
}

Expand Down Expand Up @@ -105,14 +124,9 @@ where

// Mapping bytes to field elements
std::array::from_fn(|_| {
// Buffer to store the output
let mut buf = [0u8; PRF_BYTES_PER_FE];

// Read the extended output into the buffer
xof_reader.read(&mut buf);

// Mapping bytes to a field element
F::from_u128(u128::from_be_bytes(buf))
reduce_160_to_field(&buf)
})
}
}
Expand All @@ -121,6 +135,8 @@ where
mod tests {
use super::*;
use crate::MESSAGE_LENGTH;
use num_bigint::BigUint;
use p3_field::PrimeField64;
use proptest::prelude::*;

const DOMAIN_LEN: usize = 4;
Expand Down Expand Up @@ -209,5 +225,42 @@ mod tests {
let other_epoch = PRF::get_randomness(&key, epoch.wrapping_add(1), &msg, counter1);
prop_assert_ne!(result1, other_epoch);
}

#[test]
fn proptest_reduce_160_matches_bigint_reference(
bytes in prop::array::uniform20(any::<u8>())
) {
let fast = reduce_160_to_field(&bytes);

let value = BigUint::from_bytes_be(&bytes);
let p = BigUint::from(F::ORDER_U64);
let expected_u64: u64 = (value % p).try_into().unwrap();
let reference = F::from_u64(expected_u64);

prop_assert_eq!(fast, reference);
}
}

#[test]
fn test_prf_bytes_per_fe_matches_rfc9380() {
let ceil_log2_p = 64 - (F::ORDER_U64 - 1).leading_zeros() as usize;
let k = 128;
let expected_l = (ceil_log2_p + k).div_ceil(8);
assert_eq!(
PRF_BYTES_PER_FE, expected_l,
"PRF_BYTES_PER_FE should be L = ceil((ceil(log2(p)) + k) / 8) per RFC 9380"
);
}

#[test]
fn test_reduce_160_boundary_values() {
let all_zeros = [0u8; PRF_BYTES_PER_FE];
assert_eq!(reduce_160_to_field(&all_zeros), F::from_u64(0));

let all_ones = [0xff; PRF_BYTES_PER_FE];
let value = BigUint::from_bytes_be(&all_ones);
let p = BigUint::from(F::ORDER_U64);
let expected: u64 = (value % p).try_into().unwrap();
assert_eq!(reduce_160_to_field(&all_ones), F::from_u64(expected));
}
}