diff --git a/crates/air/src/prove.rs b/crates/air/src/prove.rs index b97e5a91..0d84b3af 100644 --- a/crates/air/src/prove.rs +++ b/crates/air/src/prove.rs @@ -54,7 +54,7 @@ where columns_up_down_group_packed, air, &extra_data, - Some((zerocheck_challenges, None)), + Some(zerocheck_challenges), prover_state, virtual_column_statement .as_ref() diff --git a/crates/backend/sumcheck/src/lib.rs b/crates/backend/sumcheck/src/lib.rs index c3912353..75746400 100644 --- a/crates/backend/sumcheck/src/lib.rs +++ b/crates/backend/sumcheck/src/lib.rs @@ -1,5 +1,8 @@ #![cfg_attr(not(test), warn(unused_crate_dependencies))] +mod split_eq; +pub use split_eq::*; + mod prove; pub use prove::*; diff --git a/crates/backend/sumcheck/src/prove.rs b/crates/backend/sumcheck/src/prove.rs index a2acefda..2d99371c 100644 --- a/crates/backend/sumcheck/src/prove.rs +++ b/crates/backend/sumcheck/src/prove.rs @@ -11,7 +11,7 @@ pub fn sumcheck_prove<'a, EF, SC, M: Into>>( multilinears_f: M, computation: &SC, extra_data: &SC::ExtraData, - eq_factor: Option<(Vec, Option>)>, // (a, b, c ...), eq_poly(b, c, ...) + eq_factor: Option>, prover_state: &mut impl FSProver, sum: EF, store_intermediate_foldings: bool, @@ -39,7 +39,7 @@ pub fn sumcheck_fold_and_prove<'a, EF, SC, M: Into>>( prev_folding_factor: Option, computation: &SC, extra_data: &SC::ExtraData, - eq_factor: Option<(Vec, Option>)>, // (a, b, c ...), eq_poly(b, c, ...) + eq_factor: Option>, prover_state: &mut impl FSProver, sum: EF, store_intermediate_foldings: bool, @@ -88,7 +88,7 @@ pub fn sumcheck_prove_many_rounds<'a, EF, SC, M: Into>>( mut prev_folding_factor: Option, computation: &SC, extra_data: &SC::ExtraData, - mut eq_factor: Option<(Vec, Option>)>, // (a, b, c ...), eq_poly(b, c, ...) + mut eq_factor: Option>, prover_state: &mut impl FSProver, mut sum: EF, mut missing_mul_factors: Option, @@ -102,31 +102,16 @@ where SC::ExtraData: AlphaPowers, { let mut multilinears: MleGroup<'a, EF> = multilinears_f.into(); - - let mut eq_factor: Option<(Vec, MleOwned)> = eq_factor.take().map(|(eq_point, eq_mle)| { - let eq_mle = eq_mle.unwrap_or_else(|| { - let eval_eq_ext = eval_eq(&eq_point[1..]); - if multilinears.by_ref().is_packed() { - MleOwned::ExtensionPacked(pack_extension(&eval_eq_ext)) - } else { - MleOwned::Extension(eval_eq_ext) - } - }); - (eq_point, eq_mle) - }); - let mut n_vars = multilinears.by_ref().n_vars(); if prev_folding_factor.is_some() { n_vars -= 1; } - if let Some((eq_point, eq_mle)) = &eq_factor { + + let mut eq_factor_and_split: Option<(Vec, SplitEq)> = eq_factor.take().map(|eq_point| { assert_eq!(eq_point.len(), n_vars); - assert_eq!(eq_mle.by_ref().n_vars(), eq_point.len() - 1); - if eq_mle.by_ref().is_packed() && !multilinears.is_packed() { - assert!(eq_point.len() < packing_log_width::()); - multilinears = multilinears.by_ref().unpack().as_owned_or_clone().into(); - } - } + let split_eq = SplitEq::new(&eq_point[1..]); + (eq_point, split_eq) + }); let mut challenges = Vec::new(); for _ in 0..n_rounds { @@ -134,17 +119,14 @@ where if multilinears.by_ref().is_packed() && n_vars <= 1 + packing_log_width::() { // unpack multilinears = multilinears.by_ref().unpack().as_owned_or_clone().into(); - - if let Some((_, eq_mle)) = &mut eq_factor { - *eq_mle = eq_mle.by_ref().unpack().as_owned_or_clone(); - } + // SplitEq handles unpacking transparently via get_unpacked } let ps = compute_and_send_polynomial( &mut multilinears, prev_folding_factor, computation, - &eq_factor, + &eq_factor_and_split, extra_data, prover_state, sum, @@ -157,7 +139,7 @@ where prev_folding_factor = on_challenge_received( &mut multilinears, &mut n_vars, - &mut eq_factor, + &mut eq_factor_and_split, &mut sum, &mut missing_mul_factors, challenge, @@ -178,7 +160,7 @@ fn compute_and_send_polynomial<'a, EF, SC>( multilinears: &mut MleGroup<'a, EF>, prev_folding_factor: Option, computation: &SC, - eq_factor: &Option<(Vec, MleOwned)>, // (a, b, c ...), eq_poly(b, c, ...) + eq_factor_and_split: &Option<(Vec, SplitEq)>, extra_data: &SC::ExtraData, prover_state: &mut impl FSProver, sum: EF, @@ -196,8 +178,10 @@ where let computation_degree = computation.degree(); let sc_params = SumcheckComputeParams { - eq_mle: eq_factor.as_ref().map(|(_, eq_mle)| eq_mle), - first_eq_factor: eq_factor.as_ref().map(|(first_eq_factor, _)| first_eq_factor[0]), + split_eq: eq_factor_and_split.as_ref().map(|(_, split_eq)| split_eq), + first_eq_factor: eq_factor_and_split + .as_ref() + .map(|(first_eq_factor, _)| first_eq_factor[0]), computation, extra_data, missing_mul_factor, @@ -217,7 +201,7 @@ where None => sumcheck_compute(&multilinears.by_ref(), sc_params, computation_degree), }); - let p_at_1 = if let Some((eq_factor, _)) = eq_factor { + let p_at_1 = if let Some((eq_factor, _)) = eq_factor_and_split { (sum - (EF::ONE - eq_factor[0]) * p_evals[0]) / eq_factor[0] } else { sum - p_evals[0] @@ -232,7 +216,7 @@ where .collect::>(), ) .unwrap(); - let eq_alpha = eq_factor.as_ref().map(|(p, _)| p[0]); + let eq_alpha = eq_factor_and_split.as_ref().map(|(p, _)| p[0]); prover_state.add_sumcheck_polynomial(&poly.coeffs, eq_alpha); poly } @@ -241,7 +225,7 @@ where fn on_challenge_received<'a, EF: ExtensionField>>( multilinears: &mut MleGroup<'a, EF>, n_vars: &mut usize, - eq_factor: &mut Option<(Vec, MleOwned)>, // (a, b, c ...), eq_poly(b, c, ...) + eq_factor: &mut Option<(Vec, SplitEq)>, sum: &mut EF, missing_mul_factor: &mut Option, challenge: EF, @@ -253,7 +237,7 @@ fn on_challenge_received<'a, EF: ExtensionField>>( *sum = p.evaluate(challenge); *n_vars -= 1; - if let Some((eq_factor, eq_mle)) = eq_factor { + if let Some((eq_factor, split_eq)) = eq_factor { // Multiply sum by eq(α_i, r_i) since the polynomial doesn't include the eq linear factor let eq_eval = (EF::ONE - eq_factor[0]) * (EF::ONE - challenge) + eq_factor[0] * challenge; *sum *= eq_eval; @@ -262,7 +246,7 @@ fn on_challenge_received<'a, EF: ExtensionField>>( eq_eval * missing_mul_factor.unwrap_or(EF::ONE) / (EF::ONE - eq_factor.get(1).copied().unwrap_or_default()), ); eq_factor.remove(0); - eq_mle.truncate(eq_mle.by_ref().packed_len() / 2); + split_eq.truncate_half(); } if store_intermediate_foldings { diff --git a/crates/backend/sumcheck/src/quotient_computation.rs b/crates/backend/sumcheck/src/quotient_computation.rs index 5d9d03e2..456c04ad 100644 --- a/crates/backend/sumcheck/src/quotient_computation.rs +++ b/crates/backend/sumcheck/src/quotient_computation.rs @@ -1,14 +1,14 @@ use std::{ array, - ops::{Add, Mul}, + ops::{Add, Mul, MulAssign}, }; -use field::{Algebra, ExtensionField, Field}; +use field::{Algebra, ExtensionField, Field, PrimeCharacteristicRing}; use poly::*; use rayon::iter::IntoParallelRefIterator; use rayon::prelude::*; -use crate::{SumcheckComputation, sumcheck_quadratic}; +use crate::{SplitEq, SumcheckComputation, sumcheck_quadratic}; #[derive(Default, Debug)] pub struct GKRQuotientComputation; @@ -71,17 +71,21 @@ fn my_dot_product, A2: Copy>(a: &[A1], b: &[A2]) -> A1 { #[inline(always)] #[allow(clippy::too_many_arguments)] -fn compute_sumcheck_terms + Copy + Send + Sync, EF: Field>( - u0_left: F, - u0_right: F, - u1_left: F, - u1_right: F, - u2_left: F, - u2_right: F, - u3_left: F, - u3_right: F, - eq_val: F, -) -> (F, F, F, F) { +pub fn compute_sumcheck_terms( + u0_left: N, + u0_right: N, + u1_left: N, + u1_right: N, + u2_left: D, + u2_right: D, + u3_left: D, + u3_right: D, + eq_val: D, +) -> (D, D, D, D) +where + N: PrimeCharacteristicRing + Copy, + D: Algebra + Copy + MulAssign, +{ let (mut c0_term_single, mut c2_term_single) = sumcheck_quadratic(((&u2_left, &u2_right), (&u3_left, &u3_right))); c0_term_single *= eq_val; c2_term_single *= eq_val; @@ -97,16 +101,16 @@ fn compute_sumcheck_terms + Copy + Send + Sync, EF: Field>( } #[allow(clippy::too_many_arguments)] -fn finalize_polynomial + Copy + Send + Sync, EF: Field>( - c0_term_single: F, - c2_term_single: F, - c0_term_double: F, - c2_term_double: F, +pub fn finalize_polynomial + Copy + Send + Sync, EF: Field>( + c0_term_single: A, + c2_term_single: A, + c0_term_double: A, + c2_term_double: A, alpha: EF, first_eq_factor: EF, missing_mul_factor: EF, sum: EF, - decompose: impl Fn(F) -> Vec, + decompose: impl Fn(A) -> Vec, ) -> DensePolynomial { let c0 = c0_term_single * alpha + c0_term_double; let c2 = c2_term_single * alpha + c2_term_double; @@ -186,6 +190,84 @@ pub(crate) fn compute_gkr_quotient_sumcheck_polynomial + Copy + S ) } +#[allow(clippy::too_many_arguments, clippy::needless_range_loop)] +pub fn compute_gkr_quotient_sumcheck_polynomial_split_eq( + u0: &[N], + u1: &[N], + u2: &[EFPacking], + u3: &[EFPacking], + alpha: EF, + first_eq_factor: EF, + split_eq: &SplitEq, + missing_mul_factor: EF, + sum: EF, +) -> DensePolynomial +where + EF: ExtensionField>, + N: PrimeCharacteristicRing + Copy + Send + Sync, + EFPacking: Algebra + Algebra, +{ + type EP = EFPacking; + + let n = u0.len(); + let half = n / 2; + + let n_lo = split_eq.n_lo(); + let packed_hi = split_eq.packed_hi(); + let log_packed_hi = split_eq.log_packed_hi; + let eq_lo = &split_eq.eq_lo; + let eq_hi = &split_eq.eq_hi_packed; + + let zero = || (EP::::ZERO, EP::::ZERO, EP::::ZERO, EP::::ZERO); + let add = |a: (EP, EP, EP, EP), b: (EP, EP, EP, EP)| { + (a.0 + b.0, a.1 + b.1, a.2 + b.2, a.3 + b.3) + }; + + let (c0s, c2s, c0d, c2d) = (0..n_lo) + .into_par_iter() + .fold(zero, |mut acc, b_lo| { + let eq_lo_bc = as From>::from(eq_lo[b_lo]); + let base = b_lo << log_packed_hi; + let (mut l0, mut l1, mut l2, mut l3) = (EP::::ZERO, EP::::ZERO, EP::::ZERO, EP::::ZERO); + for k in 0..packed_hi { + let i = base + k; + let t = compute_sumcheck_terms( + u0[i], + u0[i + half], + u1[i], + u1[i + half], + u2[i], + u2[i + half], + u3[i], + u3[i + half], + eq_hi[k], + ); + l0 += t.0; + l1 += t.1; + l2 += t.2; + l3 += t.3; + } + acc.0 += l0 * eq_lo_bc; + acc.1 += l1 * eq_lo_bc; + acc.2 += l2 * eq_lo_bc; + acc.3 += l3 * eq_lo_bc; + acc + }) + .reduce(zero, add); + + finalize_polynomial( + c0s, + c2s, + c0d, + c2d, + alpha, + first_eq_factor, + missing_mul_factor, + sum, + crate::packing_decompose::, + ) +} + #[allow(clippy::too_many_arguments)] pub(crate) fn fold_and_compute_gkr_quotient_sumcheck_polynomial + Copy + Send + Sync, EF: Field>( prev_folding_factor: EF, @@ -279,3 +361,112 @@ pub(crate) fn fold_and_compute_gkr_quotient_sumcheck_polynomial + vec![folded_u0, folded_u1, folded_u2, folded_u3], ) } + +#[allow(clippy::too_many_arguments)] +#[allow(clippy::type_complexity)] +pub fn fold_and_compute_gkr_quotient_split_eq( + u0: &[N], + u1: &[N], + u2: &[EFPacking], + u3: &[EFPacking], + fold_num: impl Fn(&[N], usize, usize, usize) -> (EFPacking, EFPacking) + Sync, + fold_den: impl Fn(&[EFPacking], usize, usize, usize) -> (EFPacking, EFPacking) + Sync, + alpha: EF, + first_eq_factor: EF, + split_eq: &SplitEq, + missing_mul_factor: EF, + sum: EF, +) -> (DensePolynomial, Vec>>) +where + EF: ExtensionField>, + N: Copy + Send + Sync, + EFPacking: Algebra + Algebra, +{ + type EP = EFPacking; + + let n = u0.len(); + let half = n / 2; + let quarter = n / 4; + + let mut folded_u0 = unsafe { uninitialized_vec::>(half) }; + let mut folded_u1 = unsafe { uninitialized_vec::>(half) }; + let mut folded_u2 = unsafe { uninitialized_vec::>(half) }; + let mut folded_u3 = unsafe { uninitialized_vec::>(half) }; + + let zero = || (EP::::ZERO, EP::::ZERO, EP::::ZERO, EP::::ZERO); + let add = |a: (EP, EP, EP, EP), b: (EP, EP, EP, EP)| { + (a.0 + b.0, a.1 + b.1, a.2 + b.2, a.3 + b.3) + }; + + let packed_hi = split_eq.packed_hi(); + let log_packed_hi = split_eq.log_packed_hi; + let eq_lo = &split_eq.eq_lo; + let eq_hi = &split_eq.eq_hi_packed; + + let (c0s, c2s, c0d, c2d) = { + let (fl0, fr0) = folded_u0.split_at_mut(quarter); + let (fl1, fr1) = folded_u1.split_at_mut(quarter); + let (fl2, fr2) = folded_u2.split_at_mut(quarter); + let (fl3, fr3) = folded_u3.split_at_mut(quarter); + + fl0.par_chunks_mut(packed_hi) + .zip(fr0.par_chunks_mut(packed_hi)) + .zip(fl1.par_chunks_mut(packed_hi)) + .zip(fr1.par_chunks_mut(packed_hi)) + .zip(fl2.par_chunks_mut(packed_hi)) + .zip(fr2.par_chunks_mut(packed_hi)) + .zip(fl3.par_chunks_mut(packed_hi)) + .zip(fr3.par_chunks_mut(packed_hi)) + .enumerate() + .fold( + zero, + |mut acc, (b_lo, (((((((fl0, fr0), fl1), fr1), fl2), fr2), fl3), fr3))| { + let eq_lo_bc = as From>::from(eq_lo[b_lo]); + let base = b_lo << log_packed_hi; + let (mut l0, mut l1, mut l2, mut l3) = + (EP::::ZERO, EP::::ZERO, EP::::ZERO, EP::::ZERO); + for k in 0..packed_hi { + let i = base + k; + let (u0l, u0r) = fold_num(u0, i, half, quarter); + fl0[k] = u0l; + fr0[k] = u0r; + let (u1l, u1r) = fold_num(u1, i, half, quarter); + fl1[k] = u1l; + fr1[k] = u1r; + let (u2l, u2r) = fold_den(u2, i, half, quarter); + fl2[k] = u2l; + fr2[k] = u2r; + let (u3l, u3r) = fold_den(u3, i, half, quarter); + fl3[k] = u3l; + fr3[k] = u3r; + let t = compute_sumcheck_terms(u0l, u0r, u1l, u1r, u2l, u2r, u3l, u3r, eq_hi[k]); + l0 += t.0; + l1 += t.1; + l2 += t.2; + l3 += t.3; + } + acc.0 += l0 * eq_lo_bc; + acc.1 += l1 * eq_lo_bc; + acc.2 += l2 * eq_lo_bc; + acc.3 += l3 * eq_lo_bc; + acc + }, + ) + .reduce(zero, add) + }; + + ( + finalize_polynomial( + c0s, + c2s, + c0d, + c2d, + alpha, + first_eq_factor, + missing_mul_factor, + sum, + crate::packing_decompose::, + ), + vec![folded_u0, folded_u1, folded_u2, folded_u3], + ) +} diff --git a/crates/backend/sumcheck/src/sc_computation.rs b/crates/backend/sumcheck/src/sc_computation.rs index 249afa3e..8b9f8009 100644 --- a/crates/backend/sumcheck/src/sc_computation.rs +++ b/crates/backend/sumcheck/src/sc_computation.rs @@ -4,7 +4,7 @@ use field::*; use poly::*; use rayon::prelude::*; use std::any::TypeId; -use std::ops::{Add, AddAssign, Sub}; +use std::ops::{Add, AddAssign, Mul, MulAssign, Sub}; pub trait SumcheckComputation>>: Sync { type ExtraData: Send + Sync + 'static; @@ -169,7 +169,7 @@ fn handle_gkr_quotient<'a, EF: ExtensionField>, ED: AlphaPowers>( group: &MleGroupRef<'a, EF>, extra_data: &ED, first_eq_factor: EF, - eq_mle: &MleOwned, + split_eq: &SplitEq, missing_mul_factor: Option, sum: EF, ) -> Vec { @@ -177,29 +177,49 @@ fn handle_gkr_quotient<'a, EF: ExtensionField>, ED: AlphaPowers>( let mul_factor = missing_mul_factor.unwrap_or(EF::ONE); let poly = match group { - MleGroupRef::Extension(m) => compute_gkr_quotient_sumcheck_polynomial( - m[0], - m[1], - m[2], - m[3], - alpha, - first_eq_factor, - eq_mle.as_extension().unwrap(), - mul_factor, - sum, - identity_decompose, - ), - MleGroupRef::ExtensionPacked(m) => compute_gkr_quotient_sumcheck_polynomial( + MleGroupRef::Extension(m) => { + // Materialize eq for unpacked path (small table at this stage) + let eq_vals: Vec = (0..m[0].len() / 2).map(|i| split_eq.get_unpacked(i)).collect(); + compute_gkr_quotient_sumcheck_polynomial( + m[0], + m[1], + m[2], + m[3], + alpha, + first_eq_factor, + &eq_vals, + mul_factor, + sum, + identity_decompose, + ) + } + MleGroupRef::ExtensionPacked(m) if split_eq.is_remainder_mode() => { + let unpack = |s: &[EFPacking]| -> Vec { EFPacking::::to_ext_iter(s.iter().copied()).collect() }; + let (m0, m1, m2, m3) = (unpack(m[0]), unpack(m[1]), unpack(m[2]), unpack(m[3])); + let eq_vals: Vec = (0..m0.len() / 2).map(|i| split_eq.get_unpacked(i)).collect(); + compute_gkr_quotient_sumcheck_polynomial( + &m0, + &m1, + &m2, + &m3, + alpha, + first_eq_factor, + &eq_vals, + mul_factor, + sum, + identity_decompose, + ) + } + MleGroupRef::ExtensionPacked(m) => compute_gkr_quotient_sumcheck_polynomial_split_eq( m[0], m[1], m[2], m[3], alpha, first_eq_factor, - eq_mle.as_extension_packed().unwrap(), + split_eq, mul_factor, sum, - packing_decompose, ), _ => unimplemented!(), }; @@ -212,7 +232,7 @@ fn handle_gkr_quotient_with_fold<'a, EF: ExtensionField>, ED: AlphaPowers prev_folding_factor: EF, extra_data: &ED, first_eq_factor: EF, - eq_mle: &MleOwned, + split_eq: &SplitEq, missing_mul_factor: Option, sum: EF, ) -> (Vec, MleGroupOwned) { @@ -221,6 +241,8 @@ fn handle_gkr_quotient_with_fold<'a, EF: ExtensionField>, ED: AlphaPowers let (poly, folded_f) = match group { MleGroupRef::Extension(m) => { + // Materialize eq for the fold+compute path (small table, already halved) + let eq_vals: Vec = (0..m[0].len() / 4).map(|i| split_eq.get_unpacked(i)).collect(); let (poly, folded) = fold_and_compute_gkr_quotient_sumcheck_polynomial( prev_folding_factor, m[0], @@ -229,26 +251,51 @@ fn handle_gkr_quotient_with_fold<'a, EF: ExtensionField>, ED: AlphaPowers m[3], alpha, first_eq_factor, - eq_mle.as_extension().unwrap(), + &eq_vals, mul_factor, sum, identity_decompose, ); (poly, MleGroupOwned::Extension(folded)) } - MleGroupRef::ExtensionPacked(m) => { + MleGroupRef::ExtensionPacked(m) if split_eq.is_remainder_mode() => { + let unpack = |s: &[EFPacking]| -> Vec { EFPacking::::to_ext_iter(s.iter().copied()).collect() }; + let (m0, m1, m2, m3) = (unpack(m[0]), unpack(m[1]), unpack(m[2]), unpack(m[3])); + let eq_vals: Vec = (0..m0.len() / 4).map(|i| split_eq.get_unpacked(i)).collect(); let (poly, folded) = fold_and_compute_gkr_quotient_sumcheck_polynomial( prev_folding_factor, + &m0, + &m1, + &m2, + &m3, + alpha, + first_eq_factor, + &eq_vals, + mul_factor, + sum, + identity_decompose, + ); + (poly, MleGroupOwned::Extension(folded)) + } + MleGroupRef::ExtensionPacked(m) => { + let r = prev_folding_factor; + let fold_ext = |u: &[EFPacking], i: usize, half: usize, quarter: usize| { + let left = (u[i + half] - u[i]) * r + u[i]; + let right = (u[i + half + quarter] - u[i + quarter]) * r + u[i + quarter]; + (left, right) + }; + let (poly, folded) = fold_and_compute_gkr_quotient_split_eq( m[0], m[1], m[2], m[3], + fold_ext, + fold_ext, alpha, first_eq_factor, - eq_mle.as_extension_packed().unwrap(), + split_eq, mul_factor, sum, - packing_decompose, ); (poly, MleGroupOwned::ExtensionPacked(folded)) } @@ -257,9 +304,8 @@ fn handle_gkr_quotient_with_fold<'a, EF: ExtensionField>, ED: AlphaPowers (poly_to_evals(&poly), folded_f) } -#[derive(Debug)] pub struct SumcheckComputeParams<'a, EF: ExtensionField>, SC: SumcheckComputation> { - pub eq_mle: Option<&'a MleOwned>, + pub split_eq: Option<&'a SplitEq>, pub first_eq_factor: Option, pub computation: &'a SC, pub extra_data: &'a SC::ExtraData, @@ -277,7 +323,7 @@ where SC::ExtraData: AlphaPowers, { let SumcheckComputeParams { - eq_mle, + split_eq, first_eq_factor, computation, extra_data, @@ -293,7 +339,7 @@ where }; // Handle ProductComputation special case - if TypeId::of::() == TypeId::of::() && eq_mle.is_none() { + if TypeId::of::() == TypeId::of::() && split_eq.is_none() { assert!(missing_mul_factor.is_none()); assert!(extra_data.alpha_powers().is_empty()); assert_eq!(group.n_columns(), 2); @@ -302,23 +348,37 @@ where // Handle GKRQuotientComputation special case if TypeId::of::() == TypeId::of::() { - assert!(eq_mle.is_some()); + assert!(split_eq.is_some()); assert_eq!(group.n_columns(), 4); return handle_gkr_quotient( group, extra_data, first_eq_factor.unwrap(), - eq_mle.unwrap(), + split_eq.unwrap(), missing_mul_factor, sum, ); } match group { + MleGroupRef::ExtensionPacked(multilinears) if split_eq.is_some() => { + assert!(!split_eq.unwrap().is_remainder_mode()); + sumcheck_compute_with_split_eq( + multilinears, + degree, + split_eq.unwrap(), + computation, + extra_data, + missing_mul_factor, + packed_fold_size, + |sc, pf, ed| sc.eval_packed_extension(&pf, ed), + packing_unpack_sum, + ) + } MleGroupRef::ExtensionPacked(multilinears) => sumcheck_compute_core( multilinears, degree, - eq_mle.map(|e| e.as_extension_packed().unwrap()), + |i| split_eq.map(|seq| seq.get_packed(i)), computation, extra_data, missing_mul_factor, @@ -329,7 +389,7 @@ where MleGroupRef::BasePacked(multilinears) => sumcheck_compute_core( multilinears, degree, - eq_mle.map(|e| e.as_extension_packed().unwrap()), + |i| split_eq.map(|seq| seq.get_packed(i)), computation, extra_data, missing_mul_factor, @@ -340,7 +400,7 @@ where MleGroupRef::Base(multilinears) => sumcheck_compute_core( multilinears, degree, - eq_mle.map(|e| e.as_extension().unwrap()), + |i| split_eq.map(|seq| seq.get_unpacked(i)), computation, extra_data, missing_mul_factor, @@ -351,7 +411,7 @@ where MleGroupRef::Extension(multilinears) => sumcheck_compute_core( multilinears, degree, - eq_mle.map(|e| e.as_extension().unwrap()), + |i| split_eq.map(|seq| seq.get_unpacked(i)), computation, extra_data, missing_mul_factor, @@ -374,7 +434,7 @@ where SC::ExtraData: AlphaPowers, { let SumcheckComputeParams { - eq_mle, + split_eq, first_eq_factor, computation, extra_data, @@ -390,7 +450,7 @@ where }; // Handle ProductComputation special case - if TypeId::of::() == TypeId::of::() && eq_mle.is_none() { + if TypeId::of::() == TypeId::of::() && split_eq.is_none() { assert!(missing_mul_factor.is_none()); assert!(extra_data.alpha_powers().is_empty()); assert_eq!(group.n_columns(), 2); @@ -399,26 +459,43 @@ where // Handle GKRQuotientComputation special case if TypeId::of::() == TypeId::of::() { - assert!(eq_mle.is_some()); + assert!(split_eq.is_some()); assert_eq!(group.n_columns(), 4); return handle_gkr_quotient_with_fold( group, prev_folding_factor, extra_data, first_eq_factor.unwrap(), - eq_mle.unwrap(), + split_eq.unwrap(), missing_mul_factor, sum, ); } match group { + MleGroupRef::ExtensionPacked(multilinears) if split_eq.is_some() => { + assert!(!split_eq.unwrap().is_remainder_mode()); + let prev_folded_size = multilinears[0].len() / 2; + sumcheck_fold_and_compute_with_split_eq( + multilinears, + degree, + split_eq.unwrap(), + computation, + extra_data, + missing_mul_factor, + compute_fold_size, + |m, id| (m[id + prev_folded_size] - m[id]) * prev_folding_factor + m[id], + |sc, pf, ed| sc.eval_packed_extension(&pf, ed), + packing_unpack_sum, + MleGroupOwned::ExtensionPacked, + ) + } MleGroupRef::ExtensionPacked(multilinears) => { let prev_folded_size = multilinears[0].len() / 2; sumcheck_fold_and_compute_core( multilinears, degree, - eq_mle.map(|e| e.as_extension_packed().unwrap()), + |i| split_eq.map(|seq| seq.get_packed(i)), computation, extra_data, missing_mul_factor, @@ -435,7 +512,7 @@ where sumcheck_fold_and_compute_core( multilinears, degree, - eq_mle.map(|e| e.as_extension_packed().unwrap()), + |i| split_eq.map(|seq| seq.get_packed(i)), computation, extra_data, missing_mul_factor, @@ -451,7 +528,7 @@ where sumcheck_fold_and_compute_core( multilinears, degree, - eq_mle.map(|e| e.as_extension().unwrap()), + |i| split_eq.map(|seq| seq.get_unpacked(i)), computation, extra_data, missing_mul_factor, @@ -467,7 +544,7 @@ where sumcheck_fold_and_compute_core( multilinears, degree, - eq_mle.map(|e| e.as_extension().unwrap()), + |i| split_eq.map(|seq| seq.get_unpacked(i)), computation, extra_data, missing_mul_factor, @@ -485,7 +562,7 @@ where fn sumcheck_compute_core( multilinears: &[&[IF]], degree: usize, - eq_mle: Option<&[EFT]>, + eq_at: impl Fn(usize) -> Option + Sync + Send, computation: &SC, extra_data: &SC::ExtraData, missing_mul_factor: Option, @@ -496,11 +573,17 @@ fn sumcheck_compute_core( where EF: ExtensionField>, IF: Copy + Sub + Add + AddAssign + Send + Sync, - EFT: PrimeCharacteristicRing + Copy + Sub + Add + Send + Sync, + EFT: PrimeCharacteristicRing + + Copy + + Sub + + Add + + Send + + Sync + + Mul + + MulAssign, SC: SumcheckComputation, { - let compute_iteration = |i: usize| -> Vec { - let eq_mle_eval = eq_mle.map(|e| e[i]); + let compute_at = |i: usize, eq_val: Option| -> Vec { let mut rows = multilinears .iter() .map(|m| { @@ -514,7 +597,7 @@ where // z = 0 let point_0 = rows.iter().map(|row| row[0]).collect::>(); let mut eval_0 = eval_fn(computation, point_0, extra_data); - if let Some(eq) = eq_mle_eval { + if let Some(eq) = eq_val { eval_0 *= eq; } @@ -528,7 +611,7 @@ where } let point_f = rows.iter().map(|row| row[0] + row[2]).collect::>(); let mut eval = eval_fn(computation, point_f, extra_data); - if let Some(eq) = eq_mle_eval { + if let Some(eq) = eq_val { eval *= eq; } evals.push(eval); @@ -536,7 +619,7 @@ where evals }; - let sums = parallel_sum(fold_size, degree, compute_iteration); + let sums = parallel_sum(fold_size, degree, |i| compute_at(i, eq_at(i))); let unpacked_sums = sums.into_iter().map(&unpack_sum); build_evals(unpacked_sums, missing_mul_factor) } @@ -546,7 +629,7 @@ where fn sumcheck_fold_and_compute_core( multilinears: &[&[IF]], degree: usize, - eq_mle: Option<&[FT]>, + eq_at: impl Fn(usize) -> Option + Sync + Send, computation: &SC, extra_data: &SC::ExtraData, missing_mul_factor: Option, @@ -569,7 +652,7 @@ where .collect(); let compute_iteration = |i: usize| -> Vec { - let eq_mle_eval = eq_mle.map(|e| e[i]); + let eq_mle_eval = eq_at(i); let mut rows_f: Vec<[FT; 3]> = multilinears .iter() @@ -616,3 +699,172 @@ where let unpacked_sums = sums.into_iter().map(&unpack_sum); (build_evals(unpacked_sums, missing_mul_factor), wrap_f(folded_f)) } + +#[allow(clippy::too_many_arguments, clippy::needless_range_loop)] +fn sumcheck_compute_with_split_eq( + multilinears: &[&[EFPacking]], + degree: usize, + split_eq: &SplitEq, + computation: &SC, + extra_data: &SC::ExtraData, + missing_mul_factor: Option, + fold_size: usize, + eval_fn: impl Fn(&SC, Vec>, &SC::ExtraData) -> EFPacking + Sync + Send, + unpack_sum: impl Fn(EFPacking) -> EF, +) -> Vec +where + EF: ExtensionField>, + SC: SumcheckComputation, +{ + let n_lo = split_eq.n_lo(); + let packed_hi = split_eq.packed_hi(); + let log_packed_hi = split_eq.log_packed_hi; + let eq_lo = &split_eq.eq_lo; + let eq_hi = &split_eq.eq_hi_packed; + + let zero = || EFPacking::::zero_vec(degree); + let accumulate = |mut acc: Vec>, vals: Vec>| -> Vec> { + for (a, v) in acc.iter_mut().zip(vals.iter()) { + *a += *v; + } + acc + }; + + let sums: Vec> = (0..n_lo) + .into_par_iter() + .map(|b_lo| { + let eq_lo_bc = EFPacking::::from(eq_lo[b_lo]); + let base = b_lo << log_packed_hi; + let mut block_acc = zero(); + for k in 0..packed_hi { + let i = base + k; + let eq_val = eq_hi[k]; + + let mut rows = multilinears + .iter() + .map(|m| { + let lo = m[i]; + let hi = m[i + fold_size]; + let diff = hi - lo; + [lo, diff, diff] + }) + .collect::>(); + + // z = 0 + let p0 = rows.iter().map(|r| r[0]).collect(); + let mut e0 = eval_fn(computation, p0, extra_data); + e0 *= eq_val; + block_acc[0] += e0; + + // z = 2, 3, ... + for d in 1..degree { + for [_, diff, acc] in &mut rows { + *acc += *diff; + } + let pf = rows.iter().map(|r| r[0] + r[2]).collect(); + let mut ev = eval_fn(computation, pf, extra_data); + ev *= eq_val; + block_acc[d] += ev; + } + } + for a in &mut block_acc { + *a *= eq_lo_bc; + } + block_acc + }) + .reduce(zero, accumulate); + + let unpacked = sums.into_iter().map(&unpack_sum); + build_evals(unpacked, missing_mul_factor) +} + +#[allow(clippy::too_many_arguments, clippy::needless_range_loop)] +#[allow(clippy::type_complexity)] +fn sumcheck_fold_and_compute_with_split_eq( + multilinears: &[&[IF]], + degree: usize, + split_eq: &SplitEq, + computation: &SC, + extra_data: &SC::ExtraData, + missing_mul_factor: Option, + compute_fold_size: usize, + fold_f: impl Fn(&[IF], usize) -> EFPacking + Sync + Send, + eval_fn: impl Fn(&SC, Vec>, &SC::ExtraData) -> EFPacking + Sync + Send, + unpack_sum: impl Fn(EFPacking) -> EF, + wrap_f: impl FnOnce(Vec>>) -> MleGroupOwned, +) -> (Vec, MleGroupOwned) +where + EF: ExtensionField>, + IF: Copy + Send + Sync, + SC: SumcheckComputation, +{ + let prev_folded_size = 2 * compute_fold_size; + let folded_f: Vec>> = (0..multilinears.len()) + .map(|_| EFPacking::::zero_vec(prev_folded_size)) + .collect(); + + let n_lo = split_eq.n_lo(); + let packed_hi = split_eq.packed_hi(); + let log_packed_hi = split_eq.log_packed_hi; + let eq_lo = &split_eq.eq_lo; + let eq_hi = &split_eq.eq_hi_packed; + + let zero = || EFPacking::::zero_vec(degree); + let accumulate = |mut acc: Vec>, vals: Vec>| -> Vec> { + for (a, v) in acc.iter_mut().zip(vals.iter()) { + *a += *v; + } + acc + }; + + let sums: Vec> = (0..n_lo) + .into_par_iter() + .map(|b_lo| { + let eq_lo_bc = EFPacking::::from(eq_lo[b_lo]); + let base = b_lo << log_packed_hi; + let mut block_acc = zero(); + for k in 0..packed_hi { + let i = base + k; + let eq_val = eq_hi[k]; + + let mut rows_f: Vec<[EFPacking; 3]> = multilinears + .iter() + .enumerate() + .map(|(j, m)| { + let lo = fold_f(m, i); + let hi = fold_f(m, i + compute_fold_size); + unsafe { + let ptr = folded_f[j].as_ptr() as *mut EFPacking; + *ptr.add(i) = lo; + *ptr.add(i + compute_fold_size) = hi; + } + let diff = hi - lo; + [lo, diff, diff] + }) + .collect(); + + let p0 = rows_f.iter().map(|r| r[0]).collect(); + let mut e0 = eval_fn(computation, p0, extra_data); + e0 *= eq_val; + block_acc[0] += e0; + + for d in 1..degree { + for [_, diff, acc] in &mut rows_f { + *acc += *diff; + } + let pf = rows_f.iter().map(|r| r[0] + r[2]).collect(); + let mut ev = eval_fn(computation, pf, extra_data); + ev *= eq_val; + block_acc[d] += ev; + } + } + for a in &mut block_acc { + *a *= eq_lo_bc; + } + block_acc + }) + .reduce(zero, accumulate); + + let unpacked = sums.into_iter().map(&unpack_sum); + (build_evals(unpacked, missing_mul_factor), wrap_f(folded_f)) +} diff --git a/crates/backend/sumcheck/src/split_eq.rs b/crates/backend/sumcheck/src/split_eq.rs new file mode 100644 index 00000000..a2006a44 --- /dev/null +++ b/crates/backend/sumcheck/src/split_eq.rs @@ -0,0 +1,102 @@ +use field::{ExtensionField, PackedFieldExtension}; +use poly::*; + +pub struct SplitEq>> { + pub eq_lo: Vec, + pub eq_hi_packed: Vec>, + pub log_packed_hi: u32, // = log2(eq_hi_packed.len()), cached for bit-shift in get_packed + /// Unpacked remainder for when the packed table is empty or exhausted. + pub remainder: Vec, +} + +impl>> SplitEq { + pub fn new(eq_point: &[EF]) -> Self { + let n = eq_point.len(); + + if n <= packing_log_width::() * 2 { + return Self { + eq_lo: vec![EF::ONE], + eq_hi_packed: Vec::new(), + log_packed_hi: 0, + remainder: eval_eq(eq_point), + }; + } + + let hi_vars = (n / 2).max(packing_log_width::().max(1)); + let mid = n - hi_vars; + let eq_lo = eval_eq(&eq_point[..mid]); + let eq_hi_packed = eval_eq_packed(&eq_point[mid..]); + let log_packed_hi = eq_hi_packed.len().trailing_zeros(); + Self { + eq_lo, + eq_hi_packed, + log_packed_hi, + remainder: Vec::new(), + } + } + + #[inline] + pub fn is_remainder_mode(&self) -> bool { + !self.remainder.is_empty() || self.eq_hi_packed.is_empty() + } + + #[inline] + pub fn truncate_half(&mut self) { + if self.eq_lo.len() > 1 { + self.eq_lo.truncate(self.eq_lo.len() / 2); + } else if !self.remainder.is_empty() { + self.remainder.truncate(self.remainder.len() / 2); + } else if self.eq_hi_packed.len() > 1 { + let new_len = self.eq_hi_packed.len() / 2; + self.eq_hi_packed.truncate(new_len); + self.log_packed_hi = new_len.trailing_zeros(); + } else { + // eq_hi_packed has 0 or 1 element — unpack to remainder and halve + let mut unpacked: Vec = EFPacking::::to_ext_iter(self.eq_hi_packed.iter().copied()).collect(); + let scale = self.eq_lo[0]; + for v in &mut unpacked { + *v *= scale; + } + self.eq_lo[0] = EF::ONE; + unpacked.truncate(unpacked.len() / 2); + self.remainder = unpacked; + self.eq_hi_packed.clear(); + } + } + + #[inline] + pub fn n_lo(&self) -> usize { + self.eq_lo.len() + } + + #[inline] + pub fn packed_hi(&self) -> usize { + self.eq_hi_packed.len() + } + + #[inline(always)] + pub fn get_packed(&self, i: usize) -> EFPacking { + debug_assert!(!self.is_remainder_mode(), "get_packed called in remainder mode"); + let packed_hi = self.eq_hi_packed.len(); + if self.eq_lo.len() > 1 { + EFPacking::::from(self.eq_lo[i >> self.log_packed_hi]) * self.eq_hi_packed[i & (packed_hi - 1)] + } else { + self.eq_hi_packed[i] * self.eq_lo[0] + } + } + + #[inline(always)] + pub fn get_unpacked(&self, i: usize) -> EF { + if self.is_remainder_mode() { + if self.remainder.is_empty() { + EF::ONE + } else { + self.remainder[i] * self.eq_lo[0] + } + } else { + let width = packing_width::(); + let packed_val = self.get_packed(i / width); + EFPacking::::to_ext_iter([packed_val]).nth(i % width).unwrap() + } + } +} diff --git a/crates/sub_protocols/src/logup.rs b/crates/sub_protocols/src/logup.rs index 2027bad1..06d513df 100644 --- a/crates/sub_protocols/src/logup.rs +++ b/crates/sub_protocols/src/logup.rs @@ -48,8 +48,15 @@ pub fn prove_generic_logup( log_bytecode, &tables_log_heights_sorted.iter().cloned().collect(), ); - let mut numerators = EF::zero_vec(1 << total_gkr_n_vars); - let mut denominators = EF::zero_vec(1 << total_gkr_n_vars); + let mut numerators = F::zero_vec(1 << total_gkr_n_vars); + let width = packing_width::(); + let mut denominators_packed = EFPacking::::zero_vec((1 << total_gkr_n_vars) / width); + let c_packed = EFPacking::::from(c); + let alphas_packed: Vec> = alphas_eq_poly.iter().map(|a| EFPacking::::from(*a)).collect(); + let alpha_last = *alphas_eq_poly.last().unwrap(); + let memory_contrib = EFPacking::::from(alpha_last * F::from_usize(LOGUP_MEMORY_DOMAINSEP)); + let bytecode_contrib = EFPacking::::from(alpha_last * F::from_usize(LOGUP_BYTECODE_DOMAINSEP)); + let precompile_contrib = EFPacking::::from(alpha_last * F::from_usize(LOGUP_PRECOMPILE_DOMAINSEP)); let mut offset = 0; @@ -57,17 +64,22 @@ pub fn prove_generic_logup( assert_eq!(memory.len(), memory_acc.len()); numerators[offset..][..memory.len()] .par_iter_mut() - .zip(memory_acc) // TODO embedding overhead - .for_each(|(num, a)| *num = EF::from(-*a)); // Note the negative sign here - denominators[offset..][..memory.len()] + .zip(memory_acc) + .for_each(|(num, a)| *num = -*a); // Note the negative sign here + denominators_packed[offset / width..][..memory.len() / width] .par_iter_mut() - .zip(memory.par_iter().enumerate()) - .for_each(|(denom, (i, &mem_value))| { - *denom = c - finger_print( - F::from_usize(LOGUP_MEMORY_DOMAINSEP), - &[mem_value, F::from_usize(i)], - alphas_eq_poly, - ) + .enumerate() + .for_each(|(chunk_idx, denom_packed)| { + let base_i = chunk_idx * width; + *denom_packed = c_packed + - finger_print_packed::( + memory_contrib, + &[ + PFPacking::::from_fn(|w| memory[base_i + w]), + PFPacking::::from_fn(|w| F::from_usize(base_i + w)), + ], + &alphas_packed, + ); }); offset += memory.len(); @@ -75,27 +87,29 @@ pub fn prove_generic_logup( assert_eq!(1 << log_bytecode, bytecode_acc.len()); numerators[offset..][..bytecode_acc.len()] .par_iter_mut() - .zip(bytecode_acc) // TODO embedding overhead - .for_each(|(num, a)| *num = EF::from(-*a)); // Note the negative sign here - denominators[offset..][..1 << log_bytecode] - .par_iter_mut() - .zip( - bytecode_multilinear - .par_chunks_exact(N_INSTRUCTION_COLUMNS.next_power_of_two()) - .enumerate(), - ) - .for_each(|(denom, (i, instr))| { - let mut data = [F::ZERO; N_INSTRUCTION_COLUMNS + 1]; - data[..N_INSTRUCTION_COLUMNS].copy_from_slice(&instr[..N_INSTRUCTION_COLUMNS]); - data[N_INSTRUCTION_COLUMNS] = F::from_usize(i); - *denom = c - finger_print(F::from_usize(LOGUP_BYTECODE_DOMAINSEP), &data, alphas_eq_poly) - }); + .zip(bytecode_acc) + .for_each(|(num, a)| *num = -*a); // Note the negative sign here + { + let bytecode_stride = N_INSTRUCTION_COLUMNS.next_power_of_two(); + denominators_packed[offset / width..][..(1 << log_bytecode) / width] + .par_iter_mut() + .enumerate() + .for_each(|(chunk_idx, denom_packed)| { + let base_i = chunk_idx * width; + let mut data = [PFPacking::::ZERO; N_INSTRUCTION_COLUMNS + 1]; + for k in 0..N_INSTRUCTION_COLUMNS { + data[k] = PFPacking::::from_fn(|w| bytecode_multilinear[(base_i + w) * bytecode_stride + k]); + } + data[N_INSTRUCTION_COLUMNS] = PFPacking::::from_fn(|w| F::from_usize(base_i + w)); + *denom_packed = c_packed - finger_print_packed::(bytecode_contrib, &data, &alphas_packed); + }); + } let max_table_height = 1 << tables_log_heights_sorted[0].1; if 1 << log_bytecode < max_table_height { // padding - denominators[offset + (1 << log_bytecode)..offset + max_table_height] + denominators_packed[(offset + (1 << log_bytecode)) / width..(offset + max_table_height) / width] .par_iter_mut() - .for_each(|d| *d = EF::ONE); + .for_each(|d| *d = EFPacking::::ONE); } offset += max_table_height.max(1 << log_bytecode); // ... Rest of the tables: @@ -108,18 +122,19 @@ pub fn prove_generic_logup( let pc_column = &trace.columns[COL_PC]; let bytecode_columns = &trace.columns[N_RUNTIME_COLUMNS..][..N_INSTRUCTION_COLUMNS]; numerators[offset..][..1 << log_n_rows].par_iter_mut().for_each(|num| { - *num = EF::ONE; + *num = F::ONE; }); // TODO embedding overhead - denominators[offset..][..1 << log_n_rows] + denominators_packed[offset / width..][..(1 << log_n_rows) / width] .par_iter_mut() .enumerate() - .for_each(|(i, denom)| { - let mut data = [F::ZERO; N_INSTRUCTION_COLUMNS + 1]; - for j in 0..N_INSTRUCTION_COLUMNS { - data[j] = bytecode_columns[j][i]; + .for_each(|(chunk_idx, denom_packed)| { + let base_i = chunk_idx * width; + let mut data = [PFPacking::::ZERO; N_INSTRUCTION_COLUMNS + 1]; + for k in 0..N_INSTRUCTION_COLUMNS { + data[k] = PFPacking::::from_fn(|w| bytecode_columns[k][base_i + w]); } - data[N_INSTRUCTION_COLUMNS] = pc_column[i]; - *denom = c - finger_print(F::from_usize(LOGUP_BYTECODE_DOMAINSEP), &data, alphas_eq_poly) + data[N_INSTRUCTION_COLUMNS] = PFPacking::::from_fn(|w| pc_column[base_i + w]); + *denom_packed = c_packed - finger_print_packed::(bytecode_contrib, &data, &alphas_packed); }); offset += 1 << log_n_rows; } @@ -130,30 +145,33 @@ pub fn prove_generic_logup( .par_iter_mut() .zip(&trace.columns[bus.selector]) .for_each(|(num, selector)| { - *num = EF::from(match bus.direction { + *num = F::from(match bus.direction { BusDirection::Pull => -*selector, BusDirection::Push => *selector, }) }); // TODO embedding overhead - denominators[offset..][..1 << log_n_rows] - .par_iter_mut() - .enumerate() - .for_each(|(i, denom)| { - *denom = { - let mut bus_data = [F::ZERO; MAX_PRECOMPILE_BUS_WIDTH]; - for (j, entry) in bus.data.iter().enumerate() { + { + let bus_data_entries = &bus.data; + denominators_packed[offset / width..][..(1 << log_n_rows) / width] + .par_iter_mut() + .enumerate() + .for_each(|(chunk_idx, denom_packed)| { + let base_i = chunk_idx * width; + let mut bus_data = [PFPacking::::ZERO; MAX_PRECOMPILE_BUS_WIDTH]; + for (j, entry) in bus_data_entries.iter().enumerate() { bus_data[j] = match entry { - BusData::Column(col) => trace.columns[*col][i], - BusData::Constant(val) => F::from_usize(*val), + BusData::Column(col) => PFPacking::::from_fn(|w| trace.columns[*col][base_i + w]), + BusData::Constant(val) => PFPacking::::from(F::from_usize(*val)), }; } - c + finger_print( - F::from_usize(LOGUP_PRECOMPILE_DOMAINSEP), - &bus_data[..bus.data.len()], - alphas_eq_poly, - ) - } - }); + *denom_packed = c_packed + + finger_print_packed::( + precompile_contrib, + &bus_data[..bus_data_entries.len()], + &alphas_packed, + ); + }); + } offset += 1 << log_n_rows; @@ -164,23 +182,32 @@ pub fn prove_generic_logup( numerators[offset..][..col_values.len() << log_n_rows] .par_iter_mut() .for_each(|num| { - *num = EF::ONE; + *num = F::ONE; }); // TODO embedding overhead - denominators[offset..][..col_values.len() << log_n_rows] - .par_chunks_exact_mut(1 << log_n_rows) - .enumerate() - .for_each(|(i, denom_chunk)| { - let i_field = F::from_usize(i); - denom_chunk.par_iter_mut().enumerate().for_each(|(j, denom)| { - let index = col_index[j] + i_field; - let mem_value = col_values[i][j]; - *denom = c - finger_print( - F::from_usize(LOGUP_MEMORY_DOMAINSEP), - &[mem_value, index], - alphas_eq_poly, - ) + { + let packed_chunk_size = (1 << log_n_rows) / width; + denominators_packed[offset / width..][..col_values.len() * packed_chunk_size] + .par_chunks_exact_mut(packed_chunk_size) + .enumerate() + .for_each(|(i, denom_chunk)| { + let i_field = F::from_usize(i); + denom_chunk + .par_iter_mut() + .enumerate() + .for_each(|(chunk_idx, denom_packed)| { + let base_j = chunk_idx * width; + *denom_packed = c_packed + - finger_print_packed::( + memory_contrib, + &[ + PFPacking::::from_fn(|w| col_values[i][base_j + w]), + PFPacking::::from_fn(|w| col_index[base_j + w] + i_field), + ], + &alphas_packed, + ); + }); }); - }); + } offset += col_values.len() << log_n_rows; } } @@ -197,14 +224,15 @@ pub fn prove_generic_logup( .blue() ); - denominators[offset..].par_iter_mut().for_each(|d| *d = EF::ONE); // padding + denominators_packed[offset / width..] + .par_iter_mut() + .for_each(|d| *d = EFPacking::::ONE); // padding - // TODO pack directly - let numerators_packed = MleRef::Extension(&numerators).pack(); - let denominators_packed = MleRef::Extension(&denominators).pack(); + let numerators_packed = MleRef::Base(&numerators).pack(); + let denom_ref = MleRef::::ExtensionPacked(&denominators_packed); let (sum, claim_point_gkr, numerators_value, denominators_value) = - prove_gkr_quotient(prover_state, &numerators_packed.by_ref(), &denominators_packed.by_ref()); + prove_gkr_quotient(prover_state, &numerators_packed.by_ref(), &denom_ref); let _ = (numerators_value, denominators_value); // TODO use it to avoid some computation below @@ -269,7 +297,9 @@ pub fn prove_generic_logup( trace.columns[table.bus().selector].evaluate(&inner_point) * table.bus().direction.to_field_flag(); prover_state.add_extension_scalar(eval_on_selector); - let eval_on_data = (&denominators[offset..][..1 << log_n_rows]).evaluate(&inner_point); + let eval_on_data = + MleRef::::ExtensionPacked(&denominators_packed[offset / width..][..(1 << log_n_rows) / width]) + .evaluate(&inner_point); prover_state.add_extension_scalar(eval_on_data); bus_numerators_values.insert(*table, eval_on_selector); diff --git a/crates/sub_protocols/src/quotient_gkr.rs b/crates/sub_protocols/src/quotient_gkr.rs index 8dbccb61..f6d64e83 100644 --- a/crates/sub_protocols/src/quotient_gkr.rs +++ b/crates/sub_protocols/src/quotient_gkr.rs @@ -1,5 +1,6 @@ +use std::ops::Mul; + use backend::*; -use tracing::instrument; use crate::MIN_VARS_FOR_PACKING; @@ -7,7 +8,6 @@ use crate::MIN_VARS_FOR_PACKING; GKR to compute sum of fractions. */ -#[instrument(skip_all)] pub fn prove_gkr_quotient>>( prover_state: &mut impl FSProver, numerators: &MleRef<'_, EF>, @@ -59,33 +59,48 @@ fn prove_gkr_quotient_step>>( claim_point: &MultilinearPoint, claims: Vec, ) -> (MultilinearPoint, Vec) { - let prev_numerators_and_denominators_split = match (numerators.by_ref(), denominators.by_ref()) { - (MleRef::ExtensionPacked(numerators), MleRef::ExtensionPacked(denominators)) => { - let (left_nums, right_nums) = numerators.split_at(numerators.len() / 2); - let (left_dens, right_dens) = denominators.split_at(denominators.len() / 2); - MleGroupRef::ExtensionPacked(vec![left_nums, right_nums, left_dens, right_dens]) - } - (MleRef::Extension(numerators), MleRef::Extension(denominators)) => { - let (left_nums, right_nums) = numerators.split_at(numerators.len() / 2); - let (left_dens, right_dens) = denominators.split_at(denominators.len() / 2); - MleGroupRef::Extension(vec![left_nums, right_nums, left_dens, right_dens]) - } - _ => unreachable!(), - }; - let alpha = prover_state.sample(); - assert_eq!(claims.len(), 2); let sum = claims[0] + claims[1] * alpha; - let (mut next_point, inner_evals, _) = sumcheck_prove::( - prev_numerators_and_denominators_split, - &GKRQuotientComputation {}, - &alpha.powers().take(2).collect(), - Some((claim_point.0.clone(), None)), - prover_state, - sum, - false, - ); + let extra_data: Vec = alpha.powers().take(2).collect(); + + let (mut next_point, inner_evals) = match (numerators.by_ref(), denominators.by_ref()) { + (MleRef::BasePacked(nums), MleRef::ExtensionPacked(dens)) => { + prove_gkr_quotient_step_base_ext(prover_state, nums, dens, claim_point, &extra_data, sum) + } + _ => { + let ext_nums_unpacked: Vec; + let group = match (numerators.by_ref(), denominators.by_ref()) { + (MleRef::ExtensionPacked(numerators), MleRef::ExtensionPacked(denominators)) => { + let (ln, rn) = numerators.split_at(numerators.len() / 2); + let (ld, rd) = denominators.split_at(denominators.len() / 2); + MleGroupRef::ExtensionPacked(vec![ln, rn, ld, rd]) + } + (MleRef::Extension(numerators), MleRef::Extension(denominators)) => { + let (ln, rn) = numerators.split_at(numerators.len() / 2); + let (ld, rd) = denominators.split_at(denominators.len() / 2); + MleGroupRef::Extension(vec![ln, rn, ld, rd]) + } + (MleRef::Base(numerators), MleRef::Extension(denominators)) => { + ext_nums_unpacked = numerators.iter().map(|&x| EF::from(x)).collect(); + let (ln, rn) = ext_nums_unpacked.split_at(ext_nums_unpacked.len() / 2); + let (ld, rd) = denominators.split_at(denominators.len() / 2); + MleGroupRef::Extension(vec![ln, rn, ld, rd]) + } + _ => unreachable!(), + }; + let (point, evals, _) = sumcheck_prove::( + group, + &GKRQuotientComputation {}, + &extra_data, + Some(claim_point.0.clone()), + prover_state, + sum, + false, + ); + (point, evals) + } + }; prover_state.add_extension_scalars(&inner_evals); let beta = prover_state.sample(); @@ -100,22 +115,108 @@ fn prove_gkr_quotient_step>>( (next_point, next_claims) } +fn prove_gkr_quotient_step_base_ext>>( + prover_state: &mut impl FSProver, + nums: &[PFPacking], + dens: &[EFPacking], + claim_point: &MultilinearPoint, + extra_data: &[EF], + sum: EF, +) -> (MultilinearPoint, Vec) { + let eq_point = &claim_point.0; + let n_vars = eq_point.len(); + let alpha = extra_data[1]; + + let half = nums.len() / 2; + let (nl, nr) = nums.split_at(half); + let (dl, dr) = dens.split_at(half); + + let mut split_eq = SplitEq::new(&eq_point[1..]); + let poly_0 = + compute_gkr_quotient_sumcheck_polynomial_split_eq(nl, nr, dl, dr, alpha, eq_point[0], &split_eq, EF::ONE, sum); + prover_state.add_sumcheck_polynomial(&poly_0.coeffs, Some(eq_point[0])); + let challenge_0 = prover_state.sample(); + + let eq_eval_0 = (EF::ONE - eq_point[0]) * (EF::ONE - challenge_0) + eq_point[0] * challenge_0; + let sum_1 = poly_0.evaluate(challenge_0) * eq_eval_0; + let mmf_1 = eq_eval_0 / (EF::ONE - eq_point.get(1).copied().unwrap_or_default()); + + split_eq.truncate_half(); + let r = challenge_0; + let r_packed = EFPacking::::from(r); + let fold_base = |u: &[PFPacking], i: usize, half: usize, quarter: usize| { + let left = r_packed * (u[i + half] - u[i]) + u[i]; + let right = r_packed * (u[i + half + quarter] - u[i + quarter]) + u[i + quarter]; + (left, right) + }; + let fold_ext = |u: &[EFPacking], i: usize, half: usize, quarter: usize| { + let left = (u[i + half] - u[i]) * r + u[i]; + let right = (u[i + half + quarter] - u[i + quarter]) * r + u[i + quarter]; + (left, right) + }; + let (poly_1, folded) = fold_and_compute_gkr_quotient_split_eq( + nl, + nr, + dl, + dr, + fold_base, + fold_ext, + alpha, + eq_point[1], + &split_eq, + mmf_1, + sum_1, + ); + prover_state.add_sumcheck_polynomial(&poly_1.coeffs, Some(eq_point[1])); + let challenge_1 = prover_state.sample(); + + let eq_eval_1 = (EF::ONE - eq_point[1]) * (EF::ONE - challenge_1) + eq_point[1] * challenge_1; + let sum_2 = poly_1.evaluate(challenge_1) * eq_eval_1; + let mmf_2 = eq_eval_0 * eq_eval_1; + + let group = MleGroupOwned::ExtensionPacked(folded); + + let (remaining_point, final_group, _) = sumcheck_prove_many_rounds( + group, + Some(challenge_1), + &GKRQuotientComputation {}, + &extra_data.to_vec(), + Some(eq_point[2..].to_vec()), + prover_state, + sum_2, + Some(mmf_2), + n_vars - 2, + false, + 0, + ); + + let final_folds = final_group.as_extension().unwrap(); + let inner_evals: Vec = final_folds + .iter() + .map(|m| { + assert_eq!(m.len(), 1); + m[0] + }) + .collect(); + + let mut point = MultilinearPoint(vec![challenge_0, challenge_1]); + point.0.extend(remaining_point.0); + (point, inner_evals) +} + pub fn verify_gkr_quotient>>( verifier_state: &mut impl FSVerifier, n_vars: usize, ) -> Result<(EF, MultilinearPoint, EF, EF), ProofError> { let last_nums = verifier_state.next_extension_scalars_vec(2)?; let last_dens = verifier_state.next_extension_scalars_vec(2)?; - let quotient = last_nums[0] / last_dens[0] + last_nums[1] / last_dens[1]; - let mut point = MultilinearPoint(vec![verifier_state.sample()]); let mut claims_num = last_nums.evaluate(&point); let mut claims_den = last_dens.evaluate(&point); for i in 1..n_vars { (point, claims_num, claims_den) = verify_gkr_quotient_step(verifier_state, i, &point, claims_num, claims_den)?; } - Ok((quotient, point, claims_num, claims_den)) } @@ -127,12 +228,9 @@ fn verify_gkr_quotient_step>>( claims_den: EF, ) -> Result<(MultilinearPoint, EF, EF), ProofError> { let alpha = verifier_state.sample(); - let expected_sum = claims_num + alpha * claims_den; let postponed = sumcheck_verify(verifier_state, n_vars, 3, expected_sum, Some(&point.0))?; - let inner_evals = verifier_state.next_extension_scalars_vec(4)?; - if postponed.value != point.eq_poly_outside(&postponed.point) * GKRQuotientComputation::eval_extension( @@ -143,14 +241,11 @@ fn verify_gkr_quotient_step>>( { return Err(ProofError::InvalidProof); } - let beta = verifier_state.sample(); - let next_claims_numerators = (&inner_evals[..2]).evaluate(&MultilinearPoint(vec![beta])); let next_claims_denominators = (&inner_evals[2..]).evaluate(&MultilinearPoint(vec![beta])); let mut next_point = postponed.point.clone(); next_point.0.insert(0, beta); - Ok((next_point, next_claims_numerators, next_claims_denominators)) } @@ -159,28 +254,29 @@ fn sum_quotients>>( denominators: MleRef<'_, EF>, ) -> (MleOwned, MleOwned) { match (numerators, denominators) { - (MleRef::ExtensionPacked(numerators), MleRef::ExtensionPacked(denominators)) => { - let (new_numerators, new_denominators) = sum_quotients_2_by_2(numerators, denominators); - ( - MleOwned::ExtensionPacked(new_numerators), - MleOwned::ExtensionPacked(new_denominators), - ) + (MleRef::ExtensionPacked(n), MleRef::ExtensionPacked(d)) => { + let (nn, nd) = sum_quotients_2_by_2(n, d); + (MleOwned::ExtensionPacked(nn), MleOwned::ExtensionPacked(nd)) } - (MleRef::Extension(numerators), MleRef::Extension(denominators)) => { - let (new_numerators, new_denominators) = sum_quotients_2_by_2(numerators, denominators); - ( - MleOwned::Extension(new_numerators), - MleOwned::Extension(new_denominators), - ) + (MleRef::Extension(n), MleRef::Extension(d)) => { + let (nn, nd) = sum_quotients_2_by_2(n, d); + (MleOwned::Extension(nn), MleOwned::Extension(nd)) + } + (MleRef::BasePacked(n), MleRef::ExtensionPacked(d)) => { + let (nn, nd) = sum_quotients_2_by_2(n, d); + (MleOwned::ExtensionPacked(nn), MleOwned::ExtensionPacked(nd)) } _ => unreachable!(), } } -fn sum_quotients_2_by_2( - numerators: &[F], - denominators: &[F], -) -> (Vec, Vec) { + +fn sum_quotients_2_by_2(numerators: &[N], denominators: &[D]) -> (Vec, Vec) +where + N: Copy + Sync + Send, + D: PrimeCharacteristicRing + Sync + Send + Copy + Mul, +{ let n = numerators.len(); + assert_eq!(n, denominators.len()); let new_n = n / 2; let mut new_numerators = unsafe { uninitialized_vec(new_n) }; let mut new_denominators = unsafe { uninitialized_vec(new_n) }; @@ -189,26 +285,24 @@ fn sum_quotients_2_by_2( .zip(new_denominators.par_iter_mut()) .enumerate() .for_each(|(i, (num, den))| { - let my_numerators: [_; 2] = [numerators[i], numerators[i + new_n]]; - let my_denominators: [_; 2] = [denominators[i], denominators[i + new_n]]; - *num = my_numerators[0] * my_denominators[1] + my_numerators[1] * my_denominators[0]; - *den = my_denominators[0] * my_denominators[1]; + *num = denominators[i + new_n] * numerators[i] + denominators[i] * numerators[i + new_n]; + *den = denominators[i] * denominators[i + new_n]; }); (new_numerators, new_denominators) } #[cfg(test)] mod tests { - use std::time::Instant; - use super::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; + use std::time::Instant; use utils::{build_prover_state, build_verifier_state, init_tracing}; + type F = KoalaBear; type EF = QuinticExtensionFieldKB; - fn sum_all_quotients(nums: &[EF], den: &[EF]) -> EF { - nums.iter().zip(den.iter()).map(|(&n, &d)| n / d).sum() + fn sum_all_quotients(nums: &[F], den: &[EF]) -> EF { + nums.par_iter().zip(den).map(|(&n, &d)| EF::from(n) / d).sum() } #[test] @@ -218,8 +312,7 @@ mod tests { init_tracing(); let mut rng = StdRng::seed_from_u64(0); - - let numerators = (0..n).map(|_| rng.random()).collect::>(); + let numerators = (0..n).map(|_| rng.random()).collect::>(); let c: EF = rng.random(); let denominators_indexes = (0..n) .map(|_| PF::::from_usize(rng.random_range(..n))) @@ -228,20 +321,18 @@ mod tests { let real_quotient = sum_all_quotients(&numerators, &denominators); let mut prover_state = build_prover_state(); + let numerators = MleOwned::BasePacked(pack_extension(&numerators)); + let denominators = MleOwned::ExtensionPacked(pack_extension(&denominators)); + let time = Instant::now(); - let prover_statements = prove_gkr_quotient::( - &mut prover_state, - &MleRef::ExtensionPacked(&pack_extension(&numerators)), - &MleRef::ExtensionPacked(&pack_extension(&denominators)), - ); + let prover_statements = + prove_gkr_quotient::(&mut prover_state, &numerators.by_ref(), &denominators.by_ref()); println!("Proving time: {:?}", time.elapsed()); let mut verifier_state = build_verifier_state(prover_state).unwrap(); - let verifier_statements = verify_gkr_quotient::(&mut verifier_state, log_n).unwrap(); assert_eq!(&verifier_statements, &prover_statements); let (retrieved_quotient, claim_point, claim_num, claim_den) = verifier_statements; - assert_eq!(retrieved_quotient, real_quotient); assert_eq!(numerators.evaluate(&claim_point), claim_num); assert_eq!(denominators.evaluate(&claim_point), claim_den); diff --git a/crates/utils/src/multilinear.rs b/crates/utils/src/multilinear.rs index c74e5436..593e7680 100644 --- a/crates/utils/src/multilinear.rs +++ b/crates/utils/src/multilinear.rs @@ -83,6 +83,19 @@ pub fn finger_print>, EF: ExtensionField + *alphas_eq_poly.last().unwrap() * table } +#[inline(always)] +pub fn finger_print_packed>>( + table_contrib: EFPacking, + data: &[PFPacking], + alphas_packed: &[EFPacking], +) -> EFPacking { + let mut result = table_contrib; + for (alpha, d) in alphas_packed.iter().zip(data) { + result += *alpha * *d; + } + result +} + #[cfg(test)] mod tests { use rand::rngs::StdRng;