diff --git a/crates/algorithms/sha3/src/simd/arm64.rs b/crates/algorithms/sha3/src/simd/arm64.rs index 852279b8e1..d32b0ed463 100644 --- a/crates/algorithms/sha3/src/simd/arm64.rs +++ b/crates/algorithms/sha3/src/simd/arm64.rs @@ -1,203 +1,16 @@ -use libcrux_intrinsics::arm64::*; - -use crate::{generic_keccak::KeccakState, traits::*}; - -#[allow(non_camel_case_types)] -pub type uint64x2_t = _uint64x2_t; - -#[inline(always)] -fn _veor5q_u64( - a: uint64x2_t, - b: uint64x2_t, - c: uint64x2_t, - d: uint64x2_t, - e: uint64x2_t, -) -> uint64x2_t { - _veor3q_u64(_veor3q_u64(a, b, c), d, e) -} - -#[inline(always)] -fn _vrax1q_u64(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { - libcrux_intrinsics::arm64::_vrax1q_u64(a, b) -} - -#[inline(always)] -fn _vxarq_u64(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { - libcrux_intrinsics::arm64::_vxarq_u64::(a, b) -} - -#[inline(always)] -fn _vbcaxq_u64(a: uint64x2_t, b: uint64x2_t, c: uint64x2_t) -> uint64x2_t { - libcrux_intrinsics::arm64::_vbcaxq_u64(a, b, c) -} - -#[inline(always)] -fn _veorq_n_u64(a: uint64x2_t, c: u64) -> uint64x2_t { - let c = _vdupq_n_u64(c); - _veorq_u64(a, c) -} - -#[inline(always)] -pub(crate) fn load_block( - s: &mut [uint64x2_t; 25], - blocks: &[&[u8]; 2], - offset: usize, -) { - #[cfg(not(eurydice))] - debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0 && blocks[0].len() == blocks[1].len()); - for i in 0..RATE / 16 { - let start = offset + 16 * i; - let v0 = _vld1q_bytes_u64(&blocks[0][start..start + 16]); - let v1 = _vld1q_bytes_u64(&blocks[1][start..start + 16]); - let i0 = (2 * i) / 5; - let j0 = (2 * i) % 5; - let i1 = (2 * i + 1) / 5; - let j1 = (2 * i + 1) % 5; - set_ij( - s, - i0, - j0, - _veorq_u64(*get_ij(s, i0, j0), _vtrn1q_u64(v0, v1)), - ); - set_ij( - s, - i1, - j1, - _veorq_u64(*get_ij(s, i1, j1), _vtrn2q_u64(v0, v1)), - ); - } - if RATE % 16 != 0 { - let i = RATE / 8 - 1; - let mut u = [0u64; 2]; - let start = offset + RATE - 8; - u[0] = u64::from_le_bytes(blocks[0][start..start + 8].try_into().unwrap()); - u[1] = u64::from_le_bytes(blocks[1][start..start + 8].try_into().unwrap()); - let uvec = _vld1q_u64(&u); - set_ij(s, i / 5, i % 5, _veorq_u64(*get_ij(s, i / 5, i % 5), uvec)); - } -} - -#[inline(always)] -pub(crate) fn load_last( - state: &mut [uint64x2_t; 25], - blocks: &[&[u8]; 2], - offset: usize, - len: usize, -) { - #[cfg(not(eurydice))] - debug_assert!(offset + len <= blocks[0].len() && blocks[0].len() == blocks[1].len()); - - let mut buffer0 = [0u8; RATE]; - buffer0[0..len].copy_from_slice(&blocks[0][offset..offset + len]); - buffer0[len] = DELIMITER; - buffer0[RATE - 1] |= 0x80; - - let mut buffer1 = [0u8; RATE]; - buffer1[0..len].copy_from_slice(&blocks[1][offset..offset + len]); - buffer1[len] = DELIMITER; - buffer1[RATE - 1] |= 0x80; - - load_block::(state, &[&buffer0, &buffer1], 0); -} - -#[inline(always)] -pub(crate) fn store_block( - s: &[uint64x2_t; 25], - out0: &mut [u8], - out1: &mut [u8], - start: usize, - len: usize, -) { - #[cfg(not(eurydice))] - debug_assert!(len <= RATE && start + len <= out0.len() && out0.len() == out1.len()); - for i in 0..len / 16 { - let i0 = (2 * i) / 5; - let j0 = (2 * i) % 5; - let i1 = (2 * i + 1) / 5; - let j1 = (2 * i + 1) % 5; - let v0 = _vtrn1q_u64(*get_ij(s, i0, j0), *get_ij(s, i1, j1)); - let v1 = _vtrn2q_u64(*get_ij(s, i0, j0), *get_ij(s, i1, j1)); - _vst1q_bytes_u64(&mut out0[start + 16 * i..start + 16 * (i + 1)], v0); - _vst1q_bytes_u64(&mut out1[start + 16 * i..start + 16 * (i + 1)], v1); - } - let remaining = len % 16; - if remaining > 8 { - let mut out0_tmp = [0u8; 16]; - let mut out1_tmp = [0u8; 16]; - let i = 2 * (len / 16); - let i0 = i / 5; - let j0 = i % 5; - let i1 = (i + 1) / 5; - let j1 = (i + 1) % 5; - let v0 = _vtrn1q_u64(*get_ij(s, i0, j0), *get_ij(s, i1, j1)); - let v1 = _vtrn2q_u64(*get_ij(s, i0, j0), *get_ij(s, i1, j1)); - _vst1q_bytes_u64(&mut out0_tmp, v0); - _vst1q_bytes_u64(&mut out1_tmp, v1); - out0[start + len - remaining..start + len].copy_from_slice(&out0_tmp[0..remaining]); - out1[start + len - remaining..start + len].copy_from_slice(&out1_tmp[0..remaining]); - } else if remaining > 0 { - let mut out01 = [0u8; 16]; - let i = 2 * (len / 16); - _vst1q_bytes_u64(&mut out01, *get_ij(s, i / 5, i % 5)); - out0[start + len - remaining..start + len].copy_from_slice(&out01[0..remaining]); - out1[start + len - remaining..start + len].copy_from_slice(&out01[8..8 + remaining]); - } -} - -impl KeccakItem<2> for uint64x2_t { - #[inline(always)] - fn zero() -> Self { - _vdupq_n_u64(0) - } - #[inline(always)] - fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { - _veor5q_u64(a, b, c, d, e) - } - #[inline(always)] - fn rotate_left1_and_xor(a: Self, b: Self) -> Self { - _vrax1q_u64(a, b) - } - #[inline(always)] - fn xor_and_rotate(a: Self, b: Self) -> Self { - _vxarq_u64::(a, b) - } - #[inline(always)] - fn and_not_xor(a: Self, b: Self, c: Self) -> Self { - _vbcaxq_u64(a, b, c) - } - #[inline(always)] - fn xor_constant(a: Self, c: u64) -> Self { - _veorq_n_u64(a, c) - } - #[inline(always)] - fn xor(a: Self, b: Self) -> Self { - _veorq_u64(a, b) - } -} - -impl Absorb<2> for KeccakState<2, uint64x2_t> { - fn load_block(&mut self, input: &[&[u8]; 2], start: usize) { - load_block::(&mut self.st, input, start); - } - - fn load_last( - &mut self, - input: &[&[u8]; 2], - start: usize, - len: usize, - ) { - load_last::(&mut self.st, input, start, len); - } -} - -impl Squeeze2 for KeccakState<2, uint64x2_t> { - fn squeeze2( - &self, - out0: &mut [u8], - out1: &mut [u8], - start: usize, - len: usize, - ) { - store_block::(&self.st, out0, out1, start, len); - } -} +//! Arm64 (NEON) SIMD backend for SHA-3. +//! +//! Module-declaration shim; all bodies live in the submodules: +//! - [`wrappers`] — math wrappers, the `uint64x2_t` type alias, and +//! the `KeccakItem<2>` impl. +//! - [`load`] — `load_block`, `load_last`, and the `Absorb<2>` impl. +//! - [`store`] — `store_block` and the `Squeeze2` impl. + +pub(crate) mod load; +pub(crate) mod store; +pub(crate) mod wrappers; + +// Re-export `uint64x2_t` so callers (e.g. `neon.rs`) can keep +// referencing `crate::simd::arm64::uint64x2_t` exactly as before the +// split. +pub use wrappers::uint64x2_t; diff --git a/crates/algorithms/sha3/src/simd/arm64/load.rs b/crates/algorithms/sha3/src/simd/arm64/load.rs new file mode 100644 index 0000000000..7c9187387f --- /dev/null +++ b/crates/algorithms/sha3/src/simd/arm64/load.rs @@ -0,0 +1,86 @@ +//! Arm64 (NEON) block loads and the `Absorb<2>` impl. + +use libcrux_intrinsics::arm64::*; + +use crate::generic_keccak::KeccakState; +use crate::traits::{get_ij, set_ij, Absorb}; + +use super::wrappers::uint64x2_t; + +#[inline(always)] +pub(crate) fn load_block( + s: &mut [uint64x2_t; 25], + blocks: &[&[u8]; 2], + offset: usize, +) { + #[cfg(not(eurydice))] + debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0 && blocks[0].len() == blocks[1].len()); + for i in 0..RATE / 16 { + let start = offset + 16 * i; + let v0 = _vld1q_bytes_u64(&blocks[0][start..start + 16]); + let v1 = _vld1q_bytes_u64(&blocks[1][start..start + 16]); + let i0 = (2 * i) / 5; + let j0 = (2 * i) % 5; + let i1 = (2 * i + 1) / 5; + let j1 = (2 * i + 1) % 5; + set_ij( + s, + i0, + j0, + _veorq_u64(*get_ij(s, i0, j0), _vtrn1q_u64(v0, v1)), + ); + set_ij( + s, + i1, + j1, + _veorq_u64(*get_ij(s, i1, j1), _vtrn2q_u64(v0, v1)), + ); + } + if RATE % 16 != 0 { + let i = RATE / 8 - 1; + let mut u = [0u64; 2]; + let start = offset + RATE - 8; + u[0] = u64::from_le_bytes(blocks[0][start..start + 8].try_into().unwrap()); + u[1] = u64::from_le_bytes(blocks[1][start..start + 8].try_into().unwrap()); + let uvec = _vld1q_u64(&u); + set_ij(s, i / 5, i % 5, _veorq_u64(*get_ij(s, i / 5, i % 5), uvec)); + } +} + +#[inline(always)] +pub(crate) fn load_last( + state: &mut [uint64x2_t; 25], + blocks: &[&[u8]; 2], + offset: usize, + len: usize, +) { + #[cfg(not(eurydice))] + debug_assert!(offset + len <= blocks[0].len() && blocks[0].len() == blocks[1].len()); + + let mut buffer0 = [0u8; RATE]; + buffer0[0..len].copy_from_slice(&blocks[0][offset..offset + len]); + buffer0[len] = DELIMITER; + buffer0[RATE - 1] |= 0x80; + + let mut buffer1 = [0u8; RATE]; + buffer1[0..len].copy_from_slice(&blocks[1][offset..offset + len]); + buffer1[len] = DELIMITER; + buffer1[RATE - 1] |= 0x80; + + load_block::(state, &[&buffer0, &buffer1], 0); +} + +impl Absorb<2> for KeccakState<2, uint64x2_t> { + fn load_block(&mut self, input: &[&[u8]; 2], start: usize) { + load_block::(&mut self.st, input, start); + } + + fn load_last( + &mut self, + input: &[&[u8]; 2], + start: usize, + len: usize, + ) { + load_last::(&mut self.st, input, start, len); + } +} diff --git a/crates/algorithms/sha3/src/simd/arm64/store.rs b/crates/algorithms/sha3/src/simd/arm64/store.rs new file mode 100644 index 0000000000..9c1e6feb1b --- /dev/null +++ b/crates/algorithms/sha3/src/simd/arm64/store.rs @@ -0,0 +1,64 @@ +//! Arm64 (NEON) block stores and the `Squeeze2` impl. + +use libcrux_intrinsics::arm64::*; + +use crate::generic_keccak::KeccakState; +use crate::traits::{get_ij, Squeeze2}; + +use super::wrappers::uint64x2_t; + +#[inline(always)] +pub(crate) fn store_block( + s: &[uint64x2_t; 25], + out0: &mut [u8], + out1: &mut [u8], + start: usize, + len: usize, +) { + #[cfg(not(eurydice))] + debug_assert!(len <= RATE && start + len <= out0.len() && out0.len() == out1.len()); + for i in 0..len / 16 { + let i0 = (2 * i) / 5; + let j0 = (2 * i) % 5; + let i1 = (2 * i + 1) / 5; + let j1 = (2 * i + 1) % 5; + let v0 = _vtrn1q_u64(*get_ij(s, i0, j0), *get_ij(s, i1, j1)); + let v1 = _vtrn2q_u64(*get_ij(s, i0, j0), *get_ij(s, i1, j1)); + _vst1q_bytes_u64(&mut out0[start + 16 * i..start + 16 * (i + 1)], v0); + _vst1q_bytes_u64(&mut out1[start + 16 * i..start + 16 * (i + 1)], v1); + } + let remaining = len % 16; + if remaining > 8 { + let mut out0_tmp = [0u8; 16]; + let mut out1_tmp = [0u8; 16]; + let i = 2 * (len / 16); + let i0 = i / 5; + let j0 = i % 5; + let i1 = (i + 1) / 5; + let j1 = (i + 1) % 5; + let v0 = _vtrn1q_u64(*get_ij(s, i0, j0), *get_ij(s, i1, j1)); + let v1 = _vtrn2q_u64(*get_ij(s, i0, j0), *get_ij(s, i1, j1)); + _vst1q_bytes_u64(&mut out0_tmp, v0); + _vst1q_bytes_u64(&mut out1_tmp, v1); + out0[start + len - remaining..start + len].copy_from_slice(&out0_tmp[0..remaining]); + out1[start + len - remaining..start + len].copy_from_slice(&out1_tmp[0..remaining]); + } else if remaining > 0 { + let mut out01 = [0u8; 16]; + let i = 2 * (len / 16); + _vst1q_bytes_u64(&mut out01, *get_ij(s, i / 5, i % 5)); + out0[start + len - remaining..start + len].copy_from_slice(&out01[0..remaining]); + out1[start + len - remaining..start + len].copy_from_slice(&out01[8..8 + remaining]); + } +} + +impl Squeeze2 for KeccakState<2, uint64x2_t> { + fn squeeze2( + &self, + out0: &mut [u8], + out1: &mut [u8], + start: usize, + len: usize, + ) { + store_block::(&self.st, out0, out1, start, len); + } +} diff --git a/crates/algorithms/sha3/src/simd/arm64/wrappers.rs b/crates/algorithms/sha3/src/simd/arm64/wrappers.rs new file mode 100644 index 0000000000..67f0faebd7 --- /dev/null +++ b/crates/algorithms/sha3/src/simd/arm64/wrappers.rs @@ -0,0 +1,72 @@ +//! Arm64 (NEON) math wrappers, the `uint64x2_t` type alias, and the +//! `KeccakItem<2>` impl. + +use libcrux_intrinsics::arm64::*; + +use crate::traits::KeccakItem; + +#[allow(non_camel_case_types)] +pub type uint64x2_t = _uint64x2_t; + +#[inline(always)] +fn _veor5q_u64( + a: uint64x2_t, + b: uint64x2_t, + c: uint64x2_t, + d: uint64x2_t, + e: uint64x2_t, +) -> uint64x2_t { + _veor3q_u64(_veor3q_u64(a, b, c), d, e) +} + +#[inline(always)] +fn _vrax1q_u64(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { + libcrux_intrinsics::arm64::_vrax1q_u64(a, b) +} + +#[inline(always)] +fn _vxarq_u64(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t { + libcrux_intrinsics::arm64::_vxarq_u64::(a, b) +} + +#[inline(always)] +fn _vbcaxq_u64(a: uint64x2_t, b: uint64x2_t, c: uint64x2_t) -> uint64x2_t { + libcrux_intrinsics::arm64::_vbcaxq_u64(a, b, c) +} + +#[inline(always)] +fn _veorq_n_u64(a: uint64x2_t, c: u64) -> uint64x2_t { + let c = _vdupq_n_u64(c); + _veorq_u64(a, c) +} + +impl KeccakItem<2> for uint64x2_t { + #[inline(always)] + fn zero() -> Self { + _vdupq_n_u64(0) + } + #[inline(always)] + fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { + _veor5q_u64(a, b, c, d, e) + } + #[inline(always)] + fn rotate_left1_and_xor(a: Self, b: Self) -> Self { + _vrax1q_u64(a, b) + } + #[inline(always)] + fn xor_and_rotate(a: Self, b: Self) -> Self { + _vxarq_u64::(a, b) + } + #[inline(always)] + fn and_not_xor(a: Self, b: Self, c: Self) -> Self { + _vbcaxq_u64(a, b, c) + } + #[inline(always)] + fn xor_constant(a: Self, c: u64) -> Self { + _veorq_n_u64(a, c) + } + #[inline(always)] + fn xor(a: Self, b: Self) -> Self { + _veorq_u64(a, b) + } +} diff --git a/crates/algorithms/sha3/src/simd/avx2.rs b/crates/algorithms/sha3/src/simd/avx2.rs index 6a1cfa86ba..909c3086b6 100644 --- a/crates/algorithms/sha3/src/simd/avx2.rs +++ b/crates/algorithms/sha3/src/simd/avx2.rs @@ -1,259 +1,10 @@ -use libcrux_intrinsics::avx2::*; - -use crate::{generic_keccak::KeccakState, traits::*}; - -#[inline(always)] -fn rotate_left(x: Vec256) -> Vec256 { - #[cfg(not(eurydice))] - debug_assert!(LEFT + RIGHT == 64); - // This could be done more efficiently, if the shift values are multiples of 8. - // However, in SHA-3 this function is only called twice with such inputs (8/56). - mm256_xor_si256(mm256_slli_epi64::(x), mm256_srli_epi64::(x)) -} - -#[inline(always)] -fn _veor5q_u64(a: Vec256, b: Vec256, c: Vec256, d: Vec256, e: Vec256) -> Vec256 { - let ab = mm256_xor_si256(a, b); - let cd = mm256_xor_si256(c, d); - let abcd = mm256_xor_si256(ab, cd); - mm256_xor_si256(abcd, e) -} - -#[inline(always)] -fn _vrax1q_u64(a: Vec256, b: Vec256) -> Vec256 { - mm256_xor_si256(a, rotate_left::<1, 63>(b)) -} - -#[inline(always)] -fn _vxarq_u64(a: Vec256, b: Vec256) -> Vec256 { - let ab = mm256_xor_si256(a, b); - rotate_left::(ab) -} - -#[inline(always)] -fn _vbcaxq_u64(a: Vec256, b: Vec256, c: Vec256) -> Vec256 { - mm256_xor_si256(a, mm256_andnot_si256(c, b)) -} - -#[inline(always)] -fn _veorq_n_u64(a: Vec256, c: u64) -> Vec256 { - // Casting here is required, doesn't change the value. - let c = mm256_set1_epi64x(c as i64); - mm256_xor_si256(a, c) -} - -#[inline(always)] -pub(crate) fn load_block( - state: &mut [Vec256; 25], - blocks: &[&[u8]; 4], - offset: usize, -) { - #[cfg(not(eurydice))] - debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0 && (RATE % 32 == 8 || RATE % 32 == 16)); - for i in 0..RATE / 32 { - let start = offset + 32 * i; - let v0 = mm256_loadu_si256_u8(&blocks[0][start..start + 32]); - let v1 = mm256_loadu_si256_u8(&blocks[1][start..start + 32]); - let v2 = mm256_loadu_si256_u8(&blocks[2][start..start + 32]); - let v3 = mm256_loadu_si256_u8(&blocks[3][start..start + 32]); - - let v0l = mm256_unpacklo_epi64(v0, v1); // 0 0 2 2 - let v1h = mm256_unpackhi_epi64(v0, v1); // 1 1 3 3 - let v2l = mm256_unpacklo_epi64(v2, v3); // 0 0 2 2 - let v3h = mm256_unpackhi_epi64(v2, v3); // 1 1 3 3 - - let v0 = mm256_permute2x128_si256::<0x20>(v0l, v2l); // 0 0 0 0 - let v1 = mm256_permute2x128_si256::<0x20>(v1h, v3h); // 1 1 1 1 - let v2 = mm256_permute2x128_si256::<0x31>(v0l, v2l); // 2 2 2 2 - let v3 = mm256_permute2x128_si256::<0x31>(v1h, v3h); // 3 3 3 3 - - let i0 = (4 * i) / 5; - let j0 = (4 * i) % 5; - let i1 = (4 * i + 1) / 5; - let j1 = (4 * i + 1) % 5; - let i2 = (4 * i + 2) / 5; - let j2 = (4 * i + 2) % 5; - let i3 = (4 * i + 3) / 5; - let j3 = (4 * i + 3) % 5; - - set_ij(state, i0, j0, mm256_xor_si256(*get_ij(state, i0, j0), v0)); - set_ij(state, i1, j1, mm256_xor_si256(*get_ij(state, i1, j1), v1)); - set_ij(state, i2, j2, mm256_xor_si256(*get_ij(state, i2, j2), v2)); - set_ij(state, i3, j3, mm256_xor_si256(*get_ij(state, i3, j3), v3)); - } - - let rem = RATE % 32; // has to be 8 or 16 - let start = offset + 32 * (RATE / 32); - let mut u8s = [0u8; 32]; - u8s[0..8].copy_from_slice(&blocks[0][start..start + 8]); - u8s[8..16].copy_from_slice(&blocks[1][start..start + 8]); - u8s[16..24].copy_from_slice(&blocks[2][start..start + 8]); - u8s[24..32].copy_from_slice(&blocks[3][start..start + 8]); - let u = mm256_loadu_si256_u8(u8s.as_slice()); - let i = (4 * (RATE / 32)) / 5; - let j = (4 * (RATE / 32)) % 5; - set_ij(state, i, j, mm256_xor_si256(*get_ij(state, i, j), u)); - if rem == 16 { - let mut u8s = [0u8; 32]; - u8s[0..8].copy_from_slice(&blocks[0][start + 8..start + 16]); - u8s[8..16].copy_from_slice(&blocks[1][start + 8..start + 16]); - u8s[16..24].copy_from_slice(&blocks[2][start + 8..start + 16]); - u8s[24..32].copy_from_slice(&blocks[3][start + 8..start + 16]); - let u = mm256_loadu_si256_u8(u8s.as_slice()); - let i = (4 * (RATE / 32) + 1) / 5; - let j = (4 * (RATE / 32) + 1) % 5; - set_ij(state, i, j, mm256_xor_si256(*get_ij(state, i, j), u)); - } -} - -#[inline(always)] -pub(crate) fn load_last( - state: &mut [Vec256; 25], - blocks: &[&[u8]; 4], - start: usize, - len: usize, -) { - let mut buffers = [[0u8; RATE]; 4]; - for i in 0..4 { - buffers[i][0..len].copy_from_slice(&blocks[i][start..start + len]); - buffers[i][len] = DELIMITER; - buffers[i][RATE - 1] |= 0x80; - } - - load_block::( - state, - &[ - &buffers[0] as &[u8], - &buffers[1] as &[u8], - &buffers[2] as &[u8], - &buffers[3] as &[u8], - ], - 0, - ); -} - -#[inline(always)] -pub(crate) fn store_block( - s: &[Vec256; 25], - out0: &mut [u8], - out1: &mut [u8], - out2: &mut [u8], - out3: &mut [u8], - start: usize, - len: usize, -) { - let chunks = len / 32; - for i in 0..chunks { - let i0 = (4 * i) / 5; - let j0 = (4 * i) % 5; - let i1 = (4 * i + 1) / 5; - let j1 = (4 * i + 1) % 5; - let i2 = (4 * i + 2) / 5; - let j2 = (4 * i + 2) % 5; - let i3 = (4 * i + 3) / 5; - let j3 = (4 * i + 3) % 5; - - let v0l = mm256_permute2x128_si256::<0x20>(*get_ij(s, i0, j0), *get_ij(s, i2, j2)); - // 0 0 2 2 - let v1h = mm256_permute2x128_si256::<0x20>(*get_ij(s, i1, j1), *get_ij(s, i3, j3)); // 1 1 3 3 - let v2l = mm256_permute2x128_si256::<0x31>(*get_ij(s, i0, j0), *get_ij(s, i2, j2)); // 0 0 2 2 - let v3h = mm256_permute2x128_si256::<0x31>(*get_ij(s, i1, j1), *get_ij(s, i3, j3)); // 1 1 3 3 - - let v0 = mm256_unpacklo_epi64(v0l, v1h); // 0 1 2 3 - let v1 = mm256_unpackhi_epi64(v0l, v1h); // 0 1 2 3 - let v2 = mm256_unpacklo_epi64(v2l, v3h); // 0 1 2 3 - let v3 = mm256_unpackhi_epi64(v2l, v3h); // 0 1 2 3 - - mm256_storeu_si256_u8(&mut out0[start + 32 * i..start + 32 * (i + 1)], v0); - mm256_storeu_si256_u8(&mut out1[start + 32 * i..start + 32 * (i + 1)], v1); - mm256_storeu_si256_u8(&mut out2[start + 32 * i..start + 32 * (i + 1)], v2); - mm256_storeu_si256_u8(&mut out3[start + 32 * i..start + 32 * (i + 1)], v3); - } - - let rem = len % 32; - if rem > 0 { - let offset = start + 32 * chunks; - let mut u8s = [0u8; 32]; - let chunks8 = rem / 8; - for k in 0..chunks8 { - let i = (4 * chunks + k) / 5; - let j = (4 * chunks + k) % 5; - mm256_storeu_si256_u8(&mut u8s, *get_ij(s, i, j)); - out0[offset + 8 * k..offset + 8 * (k + 1)].copy_from_slice(&u8s[0..8]); - out1[offset + 8 * k..offset + 8 * (k + 1)].copy_from_slice(&u8s[8..16]); - out2[offset + 8 * k..offset + 8 * (k + 1)].copy_from_slice(&u8s[16..24]); - out3[offset + 8 * k..offset + 8 * (k + 1)].copy_from_slice(&u8s[24..32]); - } - let rem8 = rem % 8; - let offset_rem8 = offset + chunks8 * 8; - if rem8 > 0 { - let i = (4 * chunks + chunks8) / 5; - let j = (4 * chunks + chunks8) % 5; - mm256_storeu_si256_u8(&mut u8s, *get_ij(s, i, j)); - out0[offset_rem8..offset_rem8 + rem8].copy_from_slice(&u8s[0..rem8]); - out1[offset_rem8..offset_rem8 + rem8].copy_from_slice(&u8s[8..8 + rem8]); - out2[offset_rem8..offset_rem8 + rem8].copy_from_slice(&u8s[16..16 + rem8]); - out3[offset_rem8..offset_rem8 + rem8].copy_from_slice(&u8s[24..24 + rem8]); - } - } -} - -impl KeccakItem<4> for Vec256 { - #[inline(always)] - fn zero() -> Self { - mm256_set1_epi64x(0) - } - #[inline(always)] - fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { - _veor5q_u64(a, b, c, d, e) - } - #[inline(always)] - fn rotate_left1_and_xor(a: Self, b: Self) -> Self { - _vrax1q_u64(a, b) - } - #[inline(always)] - fn xor_and_rotate(a: Self, b: Self) -> Self { - _vxarq_u64::(a, b) - } - #[inline(always)] - fn and_not_xor(a: Self, b: Self, c: Self) -> Self { - _vbcaxq_u64(a, b, c) - } - #[inline(always)] - fn xor_constant(a: Self, c: u64) -> Self { - _veorq_n_u64(a, c) - } - #[inline(always)] - fn xor(a: Self, b: Self) -> Self { - mm256_xor_si256(a, b) - } -} - -impl Absorb<4> for KeccakState<4, Vec256> { - fn load_block(&mut self, input: &[&[u8]; 4], start: usize) { - load_block::(&mut self.st, input, start); - } - - fn load_last( - &mut self, - input: &[&[u8]; 4], - start: usize, - len: usize, - ) { - load_last::(&mut self.st, input, start, len) - } -} - -impl Squeeze4 for KeccakState<4, Vec256> { - fn squeeze4( - &self, - out0: &mut [u8], - out1: &mut [u8], - out2: &mut [u8], - out3: &mut [u8], - start: usize, - len: usize, - ) { - store_block::(&self.st, out0, out1, out2, out3, start, len) - } -} +//! AVX2 SIMD backend for SHA-3. +//! +//! Module-declaration shim; all bodies live in the submodules: +//! - [`wrappers`] — math wrappers and the `KeccakItem<4>` impl. +//! - [`load`] — `load_block`, `load_last`, and the `Absorb<4>` impl. +//! - [`store`] — `store_block` and the `Squeeze4` impl. + +pub(crate) mod load; +pub(crate) mod store; +pub(crate) mod wrappers; diff --git a/crates/algorithms/sha3/src/simd/avx2/load.rs b/crates/algorithms/sha3/src/simd/avx2/load.rs new file mode 100644 index 0000000000..f0dace2211 --- /dev/null +++ b/crates/algorithms/sha3/src/simd/avx2/load.rs @@ -0,0 +1,111 @@ +//! AVX2 block loads and the `Absorb<4>` impl. + +use libcrux_intrinsics::avx2::*; + +use crate::generic_keccak::KeccakState; +use crate::traits::{get_ij, set_ij, Absorb}; + +#[inline(always)] +pub(crate) fn load_block( + state: &mut [Vec256; 25], + blocks: &[&[u8]; 4], + offset: usize, +) { + #[cfg(not(eurydice))] + debug_assert!(RATE <= blocks[0].len() && RATE % 8 == 0 && (RATE % 32 == 8 || RATE % 32 == 16)); + for i in 0..RATE / 32 { + let start = offset + 32 * i; + let v0 = mm256_loadu_si256_u8(&blocks[0][start..start + 32]); + let v1 = mm256_loadu_si256_u8(&blocks[1][start..start + 32]); + let v2 = mm256_loadu_si256_u8(&blocks[2][start..start + 32]); + let v3 = mm256_loadu_si256_u8(&blocks[3][start..start + 32]); + + let v0l = mm256_unpacklo_epi64(v0, v1); // 0 0 2 2 + let v1h = mm256_unpackhi_epi64(v0, v1); // 1 1 3 3 + let v2l = mm256_unpacklo_epi64(v2, v3); // 0 0 2 2 + let v3h = mm256_unpackhi_epi64(v2, v3); // 1 1 3 3 + + let v0 = mm256_permute2x128_si256::<0x20>(v0l, v2l); // 0 0 0 0 + let v1 = mm256_permute2x128_si256::<0x20>(v1h, v3h); // 1 1 1 1 + let v2 = mm256_permute2x128_si256::<0x31>(v0l, v2l); // 2 2 2 2 + let v3 = mm256_permute2x128_si256::<0x31>(v1h, v3h); // 3 3 3 3 + + let i0 = (4 * i) / 5; + let j0 = (4 * i) % 5; + let i1 = (4 * i + 1) / 5; + let j1 = (4 * i + 1) % 5; + let i2 = (4 * i + 2) / 5; + let j2 = (4 * i + 2) % 5; + let i3 = (4 * i + 3) / 5; + let j3 = (4 * i + 3) % 5; + + set_ij(state, i0, j0, mm256_xor_si256(*get_ij(state, i0, j0), v0)); + set_ij(state, i1, j1, mm256_xor_si256(*get_ij(state, i1, j1), v1)); + set_ij(state, i2, j2, mm256_xor_si256(*get_ij(state, i2, j2), v2)); + set_ij(state, i3, j3, mm256_xor_si256(*get_ij(state, i3, j3), v3)); + } + + let rem = RATE % 32; // has to be 8 or 16 + let start = offset + 32 * (RATE / 32); + let mut u8s = [0u8; 32]; + u8s[0..8].copy_from_slice(&blocks[0][start..start + 8]); + u8s[8..16].copy_from_slice(&blocks[1][start..start + 8]); + u8s[16..24].copy_from_slice(&blocks[2][start..start + 8]); + u8s[24..32].copy_from_slice(&blocks[3][start..start + 8]); + let u = mm256_loadu_si256_u8(u8s.as_slice()); + let i = (4 * (RATE / 32)) / 5; + let j = (4 * (RATE / 32)) % 5; + set_ij(state, i, j, mm256_xor_si256(*get_ij(state, i, j), u)); + if rem == 16 { + let mut u8s = [0u8; 32]; + u8s[0..8].copy_from_slice(&blocks[0][start + 8..start + 16]); + u8s[8..16].copy_from_slice(&blocks[1][start + 8..start + 16]); + u8s[16..24].copy_from_slice(&blocks[2][start + 8..start + 16]); + u8s[24..32].copy_from_slice(&blocks[3][start + 8..start + 16]); + let u = mm256_loadu_si256_u8(u8s.as_slice()); + let i = (4 * (RATE / 32) + 1) / 5; + let j = (4 * (RATE / 32) + 1) % 5; + set_ij(state, i, j, mm256_xor_si256(*get_ij(state, i, j), u)); + } +} + +#[inline(always)] +pub(crate) fn load_last( + state: &mut [Vec256; 25], + blocks: &[&[u8]; 4], + start: usize, + len: usize, +) { + let mut buffers = [[0u8; RATE]; 4]; + for i in 0..4 { + buffers[i][0..len].copy_from_slice(&blocks[i][start..start + len]); + buffers[i][len] = DELIMITER; + buffers[i][RATE - 1] |= 0x80; + } + + load_block::( + state, + &[ + &buffers[0] as &[u8], + &buffers[1] as &[u8], + &buffers[2] as &[u8], + &buffers[3] as &[u8], + ], + 0, + ); +} + +impl Absorb<4> for KeccakState<4, Vec256> { + fn load_block(&mut self, input: &[&[u8]; 4], start: usize) { + load_block::(&mut self.st, input, start); + } + + fn load_last( + &mut self, + input: &[&[u8]; 4], + start: usize, + len: usize, + ) { + load_last::(&mut self.st, input, start, len) + } +} diff --git a/crates/algorithms/sha3/src/simd/avx2/store.rs b/crates/algorithms/sha3/src/simd/avx2/store.rs new file mode 100644 index 0000000000..6c33530999 --- /dev/null +++ b/crates/algorithms/sha3/src/simd/avx2/store.rs @@ -0,0 +1,86 @@ +//! AVX2 block stores and the `Squeeze4` impl. + +use libcrux_intrinsics::avx2::*; + +use crate::generic_keccak::KeccakState; +use crate::traits::{get_ij, Squeeze4}; + +#[inline(always)] +pub(crate) fn store_block( + s: &[Vec256; 25], + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], + start: usize, + len: usize, +) { + let chunks = len / 32; + for i in 0..chunks { + let i0 = (4 * i) / 5; + let j0 = (4 * i) % 5; + let i1 = (4 * i + 1) / 5; + let j1 = (4 * i + 1) % 5; + let i2 = (4 * i + 2) / 5; + let j2 = (4 * i + 2) % 5; + let i3 = (4 * i + 3) / 5; + let j3 = (4 * i + 3) % 5; + + let v0l = mm256_permute2x128_si256::<0x20>(*get_ij(s, i0, j0), *get_ij(s, i2, j2)); + // 0 0 2 2 + let v1h = mm256_permute2x128_si256::<0x20>(*get_ij(s, i1, j1), *get_ij(s, i3, j3)); // 1 1 3 3 + let v2l = mm256_permute2x128_si256::<0x31>(*get_ij(s, i0, j0), *get_ij(s, i2, j2)); // 0 0 2 2 + let v3h = mm256_permute2x128_si256::<0x31>(*get_ij(s, i1, j1), *get_ij(s, i3, j3)); // 1 1 3 3 + + let v0 = mm256_unpacklo_epi64(v0l, v1h); // 0 1 2 3 + let v1 = mm256_unpackhi_epi64(v0l, v1h); // 0 1 2 3 + let v2 = mm256_unpacklo_epi64(v2l, v3h); // 0 1 2 3 + let v3 = mm256_unpackhi_epi64(v2l, v3h); // 0 1 2 3 + + mm256_storeu_si256_u8(&mut out0[start + 32 * i..start + 32 * (i + 1)], v0); + mm256_storeu_si256_u8(&mut out1[start + 32 * i..start + 32 * (i + 1)], v1); + mm256_storeu_si256_u8(&mut out2[start + 32 * i..start + 32 * (i + 1)], v2); + mm256_storeu_si256_u8(&mut out3[start + 32 * i..start + 32 * (i + 1)], v3); + } + + let rem = len % 32; + if rem > 0 { + let offset = start + 32 * chunks; + let mut u8s = [0u8; 32]; + let chunks8 = rem / 8; + for k in 0..chunks8 { + let i = (4 * chunks + k) / 5; + let j = (4 * chunks + k) % 5; + mm256_storeu_si256_u8(&mut u8s, *get_ij(s, i, j)); + out0[offset + 8 * k..offset + 8 * (k + 1)].copy_from_slice(&u8s[0..8]); + out1[offset + 8 * k..offset + 8 * (k + 1)].copy_from_slice(&u8s[8..16]); + out2[offset + 8 * k..offset + 8 * (k + 1)].copy_from_slice(&u8s[16..24]); + out3[offset + 8 * k..offset + 8 * (k + 1)].copy_from_slice(&u8s[24..32]); + } + let rem8 = rem % 8; + let offset_rem8 = offset + chunks8 * 8; + if rem8 > 0 { + let i = (4 * chunks + chunks8) / 5; + let j = (4 * chunks + chunks8) % 5; + mm256_storeu_si256_u8(&mut u8s, *get_ij(s, i, j)); + out0[offset_rem8..offset_rem8 + rem8].copy_from_slice(&u8s[0..rem8]); + out1[offset_rem8..offset_rem8 + rem8].copy_from_slice(&u8s[8..8 + rem8]); + out2[offset_rem8..offset_rem8 + rem8].copy_from_slice(&u8s[16..16 + rem8]); + out3[offset_rem8..offset_rem8 + rem8].copy_from_slice(&u8s[24..24 + rem8]); + } + } +} + +impl Squeeze4 for KeccakState<4, Vec256> { + fn squeeze4( + &self, + out0: &mut [u8], + out1: &mut [u8], + out2: &mut [u8], + out3: &mut [u8], + start: usize, + len: usize, + ) { + store_block::(&self.st, out0, out1, out2, out3, start, len) + } +} diff --git a/crates/algorithms/sha3/src/simd/avx2/wrappers.rs b/crates/algorithms/sha3/src/simd/avx2/wrappers.rs new file mode 100644 index 0000000000..4a911be7fe --- /dev/null +++ b/crates/algorithms/sha3/src/simd/avx2/wrappers.rs @@ -0,0 +1,76 @@ +//! AVX2 math wrappers and the `KeccakItem<4>` impl. + +use libcrux_intrinsics::avx2::*; + +use crate::traits::KeccakItem; + +#[inline(always)] +fn rotate_left(x: Vec256) -> Vec256 { + #[cfg(not(eurydice))] + debug_assert!(LEFT + RIGHT == 64); + // This could be done more efficiently, if the shift values are multiples of 8. + // However, in SHA-3 this function is only called twice with such inputs (8/56). + mm256_xor_si256(mm256_slli_epi64::(x), mm256_srli_epi64::(x)) +} + +#[inline(always)] +fn _veor5q_u64(a: Vec256, b: Vec256, c: Vec256, d: Vec256, e: Vec256) -> Vec256 { + let ab = mm256_xor_si256(a, b); + let cd = mm256_xor_si256(c, d); + let abcd = mm256_xor_si256(ab, cd); + mm256_xor_si256(abcd, e) +} + +#[inline(always)] +fn _vrax1q_u64(a: Vec256, b: Vec256) -> Vec256 { + mm256_xor_si256(a, rotate_left::<1, 63>(b)) +} + +#[inline(always)] +fn _vxarq_u64(a: Vec256, b: Vec256) -> Vec256 { + let ab = mm256_xor_si256(a, b); + rotate_left::(ab) +} + +#[inline(always)] +fn _vbcaxq_u64(a: Vec256, b: Vec256, c: Vec256) -> Vec256 { + mm256_xor_si256(a, mm256_andnot_si256(c, b)) +} + +#[inline(always)] +fn _veorq_n_u64(a: Vec256, c: u64) -> Vec256 { + // Casting here is required, doesn't change the value. + let c = mm256_set1_epi64x(c as i64); + mm256_xor_si256(a, c) +} + +impl KeccakItem<4> for Vec256 { + #[inline(always)] + fn zero() -> Self { + mm256_set1_epi64x(0) + } + #[inline(always)] + fn xor5(a: Self, b: Self, c: Self, d: Self, e: Self) -> Self { + _veor5q_u64(a, b, c, d, e) + } + #[inline(always)] + fn rotate_left1_and_xor(a: Self, b: Self) -> Self { + _vrax1q_u64(a, b) + } + #[inline(always)] + fn xor_and_rotate(a: Self, b: Self) -> Self { + _vxarq_u64::(a, b) + } + #[inline(always)] + fn and_not_xor(a: Self, b: Self, c: Self) -> Self { + _vbcaxq_u64(a, b, c) + } + #[inline(always)] + fn xor_constant(a: Self, c: u64) -> Self { + _veorq_n_u64(a, c) + } + #[inline(always)] + fn xor(a: Self, b: Self) -> Self { + mm256_xor_si256(a, b) + } +}