From 755ca64d1ec91b89fade549520fe9e96f64640ad Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 8 Sep 2025 21:47:47 +0800 Subject: [PATCH 01/91] generate septic sum layer witnesses --- ceno_zkvm/src/scheme.rs | 1 + ceno_zkvm/src/scheme/septic_curve.rs | 222 +++++++++++++++++++++++++++ ceno_zkvm/src/scheme/utils.rs | 72 +++++++++ 3 files changed, 295 insertions(+) create mode 100644 ceno_zkvm/src/scheme/septic_curve.rs diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 58a9aae89..a33b890e9 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -29,6 +29,7 @@ pub mod cpu; pub mod gpu; pub mod hal; pub mod prover; +pub mod septic_curve; pub mod utils; pub mod verifier; diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs new file mode 100644 index 000000000..8aeac68e7 --- /dev/null +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -0,0 +1,222 @@ +// The extension field and curve definition are adapted from +// https://github.com/succinctlabs/sp1/blob/v5.2.1/crates/stark/src/septic_curve.rs +use p3::field::Field; +use std::ops::{Add, Mul, Sub}; + +/// F[z] / (z^6 - z - 4) +/// +/// ```sage +/// # finite field F = GF(2^31 - 2^27 + 1) +/// p = 2^31 - 2^27 + 1 +/// F = GF(p) +/// +/// # polynomial ring over F +/// R. = PolynomialRing(F) +/// f = x^6 - x - 4 +/// +/// # check if f(x) is irreducible +/// print(f.is_irreducible()) +/// ``` +pub struct SexticExtension([F; 6]); + +/// F[z] / (z^7 - 2z - 5) +/// +/// ```sage +/// # finite field F = GF(2^31 - 2^27 + 1) +/// p = 2^31 - 2^27 + 1 +/// F = GF(p) +/// +/// # polynomial ring over F +/// R. = PolynomialRing(F) +/// f = x^7 - 2x - 5 +/// +/// # check if f(x) is irreducible +/// print(f.is_irreducible()) +/// ``` +#[derive(Clone, Debug, Default, PartialEq)] +pub struct SepticExtension(pub [F; 7]); + +impl SepticExtension { + pub fn is_zero(&self) -> bool { + self.0.iter().all(|c| *c == F::ZERO) + } + + pub fn inverse(&self) -> Option { + match self.is_zero() { + true => None, + false => { + todo!() + } + } + } + + pub fn square(&self) -> Self { + let mut result = [F::ZERO; 7]; + let two = F::from_canonical_u32(2); + let five = F::from_canonical_u32(5); + + // i < j + for i in 0..7 { + for j in (i + 1)..7 { + let term = two * self.0[i] * self.0[j]; + let mut index = i + j; + if index < 7 { + result[index] += term; + } else { + index -= 7; + // x^7 = 2x + 5 + result[index] += five * term; + result[index + 1] += two * term; + } + } + } + // i == j + result[0] += self.0[0] * self.0[0]; + result[2] += self.0[1] * self.0[1]; + result[4] += self.0[2] * self.0[2]; + result[6] += self.0[3] * self.0[3]; + // a4^2 * x^8 = a4^2 * (2x + 5)x = 5a4^2 * x + 2a4^2 * x^2 + let term = self.0[4] * self.0[4]; + result[1] += five * term; + result[2] += two * term; + // a5^2 * x^10 = a5^2 * (2x + 5)x^3 = 5a5^2 * x^3 + 2a5^2 * x^4 + let term = self.0[5] * self.0[5]; + result[3] += five * term; + result[4] += two * term; + // a6^2 * x^12 = a6^2 * (2x + 5)x^5 = 5a6^2 * x^5 + 2a6^2 * x^6 + let term = self.0[6] * self.0[6]; + result[5] += five * term; + result[6] += two * term; + + Self(result) + } +} + +impl From<[u32; 7]> for SepticExtension { + fn from(arr: [u32; 7]) -> Self { + let mut result = [F::ZERO; 7]; + for i in 0..7 { + result[i] = F::from_canonical_u32(arr[i]); + } + Self(result) + } +} + +impl Add<&Self> for SepticExtension { + type Output = Self; + + fn add(self, other: &Self) -> Self { + let mut result = [F::ZERO; 7]; + for i in 0..7 { + result[i] = self.0[i] + other.0[i]; + } + Self(result) + } +} + +impl Add for SepticExtension { + type Output = Self; + + fn add(self, other: Self) -> Self { + self.add(&other) + } +} + +impl Sub<&Self> for SepticExtension { + type Output = Self; + + fn sub(self, other: &Self) -> Self { + let mut result = [F::ZERO; 7]; + for i in 0..7 { + result[i] = self.0[i] - other.0[i]; + } + Self(result) + } +} + +impl Sub for SepticExtension { + type Output = Self; + + fn sub(self, other: Self) -> Self { + self.sub(&other) + } +} + +impl Mul<&Self> for SepticExtension { + type Output = Self; + + fn mul(self, other: &Self) -> Self { + let mut result = [F::ZERO; 7]; + let five = F::from_canonical_u32(5); + let two = F::from_canonical_u32(2); + for i in 0..7 { + for j in 0..7 { + let term = self.0[i] * other.0[j]; + let mut index = i + j; + if index < 7 { + result[index] += term; + } else { + index -= 7; + // x^7 = 2x + 5 + result[index] += five * term; + result[index + 1] += two * term; + } + } + } + Self(result) + } +} + +impl Mul for SepticExtension { + type Output = Self; + + fn mul(self, other: Self) -> Self { + self.mul(&other) + } +} + +/// A point on the short Weierstrass curve defined by +/// y^2 = x^3 + 2x + 26z^5 +/// over the extension field F[z] / (z^7 - 2z - 5). +/// +/// Note that +/// 1. its cofactor is 1 +/// 2. its order is a large prime number of 31x7 bits +#[derive(Clone, Debug, Default, PartialEq)] +pub struct SepticPoint { + pub x: SepticExtension, + pub y: SepticExtension, +} + +impl Add for SepticPoint { + type Output = Self; + + fn add(self, other: Self) -> Self { + assert!(other.x != self.x, "other = self or other = -self"); + let slope = (other.y - &self.y) * (other.x.clone() - &self.x).inverse().unwrap(); + let x = slope.square() - (self.x.clone() + other.x); + let y = slope * (x.clone() - self.x) - self.y; + + Self { x, y } + } +} + +#[cfg(test)] +mod tests { + use super::SepticExtension; + use p3::babybear::BabyBear; + + type F = BabyBear; + #[test] + fn test_septic_extension_arithmetic() { + // a = z, b = z^6 + z^5 + z^4 + let a: SepticExtension = SepticExtension::from([0, 1, 0, 0, 0, 0, 0]); + let b: SepticExtension = SepticExtension::from([0, 0, 0, 0, 1, 1, 1]); + + assert_eq!( + a * b, + // z^5 + z^6 + 2*z + 5 + SepticExtension::from([5, 2, 0, 0, 0, 1, 1]) + ) + } +} diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 194b77060..a3fd8773c 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -2,6 +2,7 @@ use crate::{ scheme::{ constants::MIN_PAR_SIZE, hal::{MainSumcheckProver, ProofInput, ProverDevice}, + septic_curve::{SepticExtension, SepticPoint}, }, structs::ComposedConstrainSystem, }; @@ -20,6 +21,7 @@ use multilinear_extensions::{ mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, util::ceil_log2, }; +use p3::matrix::{Matrix, dense::RowMajorMatrix}; use rayon::{ iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, @@ -297,6 +299,76 @@ pub(crate) fn infer_tower_product_witness( wit_layers } +pub fn log2_strict_usize(n: usize) -> usize { + assert!(n.is_power_of_two()); + n.trailing_zeros() as usize +} + +pub fn infer_septic_sum_witness( + p_mles: RowMajorMatrix, + q_mles: RowMajorMatrix, +) -> Vec>> { + assert!(p_mles.width() > 0); + let num_layers = log2_strict_usize(p_mles.height()); + + let mut layers = Vec::with_capacity(num_layers); + layers.push(vec![p_mles, q_mles]); + for i in (0..num_layers).rev() { + let last_layer = layers.last().unwrap(); + let (p, q) = (&last_layer[0], &last_layer[1]); + + let num_rows = p.height(); + let new_p = RowMajorMatrix::new( + (0..num_rows / 2) + .into_par_iter() + .flat_map_iter(|row| { + let p = p.row_slice(row); + let q = q.row_slice(row); + + let p1 = SepticPoint { + x: SepticExtension(std::array::from_fn(|i| p[i])), + y: SepticExtension(std::array::from_fn(|i| p[i + 7])), + }; + let p2 = SepticPoint { + x: SepticExtension(std::array::from_fn(|i| q[i])), + y: SepticExtension(std::array::from_fn(|i| q[i + 7])), + }; + let q = p1 + p2; + q.x.0.iter().chain(q.y.0.iter()).copied() + }) + .collect::>(), + 14, + ); + let new_q = RowMajorMatrix::new( + ((num_rows / 2)..num_rows) + .into_par_iter() + .flat_map_iter(|row| { + let p = p.row_slice(row); + let q = q.row_slice(row); + + let p1 = SepticPoint { + x: SepticExtension(std::array::from_fn(|i| p[i])), + y: SepticExtension(std::array::from_fn(|i| p[i + 7])), + }; + let p2 = SepticPoint { + x: SepticExtension(std::array::from_fn(|i| q[i])), + y: SepticExtension(std::array::from_fn(|i| q[i + 7])), + }; + let q = p1 + p2; + q.x.0.iter().chain(q.y.0.iter()).copied() + }) + .collect::>(), + 14, + ); + layers.push(vec![new_p, new_q]); + } + + layers + .iter() + .rev() + .map(|layer| layer.iter().map(|m| m.into_mles()).collect::>()) +} + pub fn build_main_witness< 'a, E: ExtensionField, From 3862bc55d42ee9a89f37d767817d20ebe4415815 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Wed, 10 Sep 2025 00:07:20 +0800 Subject: [PATCH 02/91] support ec add in tower verifier: wip --- ceno_zkvm/src/scheme/septic_curve.rs | 19 +++++++++++++- ceno_zkvm/src/scheme/verifier.rs | 38 ++++++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 8aeac68e7..6ad89e225 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1,7 +1,7 @@ // The extension field and curve definition are adapted from // https://github.com/succinctlabs/sp1/blob/v5.2.1/crates/stark/src/septic_curve.rs use p3::field::Field; -use std::ops::{Add, Mul, Sub}; +use std::ops::{Add, Deref, Mul, Sub}; /// F[z] / (z^6 - z - 4) /// @@ -36,6 +36,23 @@ pub struct SexticExtension([F; 6]); #[derive(Clone, Debug, Default, PartialEq)] pub struct SepticExtension(pub [F; 7]); +impl From<&[F]> for SepticExtension { + fn from(slice: &[F]) -> Self { + assert!(slice.len() == 7); + let mut arr = [F::default(); 7]; + arr.copy_from_slice(&slice[0..7]); + Self(arr) + } +} + +impl Deref for SepticExtension { + type Target = [F]; + + fn deref(&self) -> &[F] { + &self.0 + } +} + impl SepticExtension { pub fn is_zero(&self) -> bool { self.0.iter().all(|c| *c == F::ZERO) diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index f8c1c8a2a..bce4fcb8d 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -25,7 +25,10 @@ use witness::next_pow2_instance_padding; use crate::{ error::ZKVMError, - scheme::constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, + scheme::{ + constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, + septic_curve::SepticPoint, + }, structs::{ComposedConstrainSystem, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, utils::{ eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, @@ -734,8 +737,10 @@ pub type TowerVerifyResult = Result< impl TowerVerify { pub fn verify( + // TODO: unify prod/logup/ec_add prod_out_evals: Vec>, logup_out_evals: Vec>, + ecc_out_evals: Vec>, tower_proofs: &TowerProofs, num_variables: Vec, num_fanin: usize, @@ -755,6 +760,7 @@ impl TowerVerify { assert!(logup_out_evals.iter().all(|evals| { evals.len() == 4 // [p1, p2, q1, q2] })); + assert_eq!(ecc_out_evals.len(), 2); assert_eq!(num_variables.len(), num_prod_spec + num_logup_spec); let alpha_pows = get_challenge_pows( @@ -792,6 +798,31 @@ impl TowerVerify { ) }) .unzip::<_, _, Vec<_>, Vec<_>>(); + let ecc_eval = { + let SepticPoint { x, y } = &ecc_out_evals[0]; + let SepticPoint { x: x2, y: y2 } = &ecc_out_evals[1]; + + let xs = + x.0.iter() + .cloned() + .zip(x2.iter().cloned()) + .map(|(xi, x2i): (E::BaseField, E::BaseField)| { + vec![xi, x2i].into_mle().evaluate(&initial_rt) + }) + .collect_vec(); + + let ys = + y.0.iter() + .cloned() + .zip(y2.iter().cloned()) + .map(|(yi, y2i): (E::BaseField, E::BaseField)| { + vec![yi, y2i].into_mle().evaluate(&initial_rt) + }) + .collect_vec(); + + vec![xs, ys].concat() + }; + let initial_claim = izip!(&prod_spec_point_n_eval, &alpha_pows) .map(|(point_n_eval, alpha)| point_n_eval.eval * *alpha) .sum::() @@ -800,7 +831,10 @@ impl TowerVerify { &alpha_pows[num_prod_spec..] ) .map(|(point_n_eval, alpha)| point_n_eval.eval * *alpha) - .sum::(); + .sum::() + + izip!(ecc_eval, &alpha_pows[num_prod_spec + num_logup_spec * 2..]) + .map(|(eval, alpha)| eval * *alpha) + .sum::(); let max_num_variables = num_variables.iter().max().unwrap(); From 3ac7ffc0b10991299116d681e1d43eb67670d541 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 11 Sep 2025 01:13:36 +0800 Subject: [PATCH 03/91] ecc accumulation batched into tower verifier --- ceno_zkvm/src/scheme/verifier.rs | 102 ++++++++++++++++++++----------- 1 file changed, 68 insertions(+), 34 deletions(-) diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index bce4fcb8d..1de999583 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,6 +1,7 @@ use std::marker::PhantomData; use ff_ext::ExtensionField; +use p3::field::Field; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; @@ -27,7 +28,7 @@ use crate::{ error::ZKVMError, scheme::{ constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, - septic_curve::SepticPoint, + septic_curve::{SepticExtension, SepticPoint}, }, structs::{ComposedConstrainSystem, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, utils::{ @@ -736,6 +737,34 @@ pub type TowerVerifyResult = Result< >; impl TowerVerify { + fn get_ecc_eval( + p1: &SepticPoint, + p2: &SepticPoint, + rt: &[E], + ) -> SepticPoint { + let SepticPoint { x, y } = p1; + let SepticPoint { x: x2, y: y2 } = p2; + + let xs = + x.0.iter() + .cloned() + .zip(x2.iter().cloned()) + .map(|(xi, x2i)| vec![xi, x2i].into_mle().evaluate(rt)) + .collect_vec(); + + let ys = + y.0.iter() + .cloned() + .zip(y2.iter().cloned()) + .map(|(yi, y2i)| vec![yi, y2i].into_mle().evaluate(rt)) + .collect_vec(); + + SepticPoint { + x: xs.as_slice().into(), + y: ys.as_slice().into(), + } + } + pub fn verify( // TODO: unify prod/logup/ec_add prod_out_evals: Vec>, @@ -764,7 +793,7 @@ impl TowerVerify { assert_eq!(num_variables.len(), num_prod_spec + num_logup_spec); let alpha_pows = get_challenge_pows( - num_prod_spec + num_logup_spec * 2, /* logup occupy 2 sumcheck: numerator and denominator */ + num_prod_spec + num_logup_spec * 2 + 14, /* logup occupy 2 sumcheck: numerator and denominator */ transcript, ); let initial_rt: Point = transcript.sample_and_append_vec(b"product_sum", log2_num_fanin); @@ -798,31 +827,9 @@ impl TowerVerify { ) }) .unzip::<_, _, Vec<_>, Vec<_>>(); - let ecc_eval = { - let SepticPoint { x, y } = &ecc_out_evals[0]; - let SepticPoint { x: x2, y: y2 } = &ecc_out_evals[1]; - - let xs = - x.0.iter() - .cloned() - .zip(x2.iter().cloned()) - .map(|(xi, x2i): (E::BaseField, E::BaseField)| { - vec![xi, x2i].into_mle().evaluate(&initial_rt) - }) - .collect_vec(); - - let ys = - y.0.iter() - .cloned() - .zip(y2.iter().cloned()) - .map(|(yi, y2i): (E::BaseField, E::BaseField)| { - vec![yi, y2i].into_mle().evaluate(&initial_rt) - }) - .collect_vec(); - - vec![xs, ys].concat() - }; + let mut ecc_eval = Self::get_ecc_eval(&ecc_out_evals[0], &ecc_out_evals[1], &initial_rt); + // initial claim = \sum_j alpha^j * out_j[rt] let initial_claim = izip!(&prod_spec_point_n_eval, &alpha_pows) .map(|(point_n_eval, alpha)| point_n_eval.eval * *alpha) .sum::() @@ -831,10 +838,7 @@ impl TowerVerify { &alpha_pows[num_prod_spec..] ) .map(|(point_n_eval, alpha)| point_n_eval.eval * *alpha) - .sum::() - + izip!(ecc_eval, &alpha_pows[num_prod_spec + num_logup_spec * 2..]) - .map(|(eval, alpha)| eval * *alpha) - .sum::(); + .sum::(); let max_num_variables = num_variables.iter().max().unwrap(); @@ -863,12 +867,14 @@ impl TowerVerify { // check expected_evaluation let rt: Point = sumcheck_claim.point.iter().map(|c| c.elements).collect(); - let expected_evaluation: E = (0..num_prod_spec) + let eq = eq_eval(out_rt, &rt); + let mut expected_evaluation: E = (0..num_prod_spec) .zip(alpha_pows.iter()) .zip(num_variables.iter()) .map(|((spec_index, alpha), max_round)| { - eq_eval(out_rt, &rt) - * *alpha + // prod[b] = prod'[0,b] * prod'[1,b] + // prod[out_rt] = \sum_b eq(out_rt, b) * prod[b] = \sum_b eq(out_rt, b) * prod'[0,b] * prod'[1,b] + eq * *alpha * if round < *max_round-1 {tower_proofs.prod_specs_eval[spec_index][round].iter().copied().product()} else { E::ZERO } @@ -878,8 +884,12 @@ impl TowerVerify { .zip_eq(alpha_pows[num_prod_spec..].chunks(2)) .zip_eq(num_variables[num_prod_spec..].iter()) .map(|((spec_index, alpha), max_round)| { + // logup_q[b] = logup_q'[0,b] * logup_q'[1,b] + // logup_p[b] = logup_p'[0,b] * logup_q'[1,b] + logup_p'[1,b] * logup_q'[0,b] + // logup_p[out_rt] = \sum_b eq(out_rt, b) * (logup_p'[0,b] * logup_q'[1,b] + logup_p'[1,b] * logup_q'[0,b]) + // logup_q[out_rt] = \sum_b eq(out_rt, b) * logup_q'[0,b] * logup_q'[1,b] let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); - eq_eval(out_rt, &rt) * if round < *max_round-1 { + eq * if round < *max_round-1 { let evals = &tower_proofs.logup_specs_eval[spec_index][round]; let (p1, p2, q1, q2) = (evals[0], evals[1], evals[2], evals[3]); @@ -890,6 +900,23 @@ impl TowerVerify { } }) .sum::(); + // 0 = \sum_b eq(out_rt, b) * (ecc_x[b] + ecc_x'[0,b] + ecc_x'[1,b]) + // * (ecc_x'[1,b] - ecc_x'[0,b])^2 + // - (ecc_y'[1,b] - ecc_y'[0,b])^2 + let SepticPoint { x, y } = &ecc_eval; + let SepticPoint { x: x1, y: y1 } = &tower_proofs.ecc_evals[round][0]; + let SepticPoint { x: x2, y: y2 } = &tower_proofs.ecc_evals[round][1]; + + // TODO: avoid clone + let xs = (x.clone() + x1 + x2) * (x2.clone() - x1) * (x2.clone() - x1) + - (y2.clone() - y1) * (y2.clone() - y1); + + // 0 = (ecc_y + ecc_y'[0]) * (ecc_x'[1] - ecc_x'[0]) + // - (ecc_y'[1] - ecc_y'[0]) * (ecc_x'[0] - ecc_x) + let ys = (y.clone() + y1) * (x2.clone() - x1) - (y2.clone() - y1) * (x1.clone() - x); + expected_evaluation += izip!(xs.0.iter(), alpha_pows[num_prod_spec + num_logup_spec * 2..].iter().take(7)).map(|(&xi, &alpha)| eq * xi * alpha).sum::(); + expected_evaluation += izip!(ys.0.iter(), alpha_pows[num_prod_spec + num_logup_spec * 2..].iter().skip(7).take(7)).map(|(&yi, &alpha)| eq * yi * alpha).sum::(); + if expected_evaluation != sumcheck_claim.expected_evaluation { return Err(ZKVMError::VerifyError("mismatch tower evaluation".into())); } @@ -912,6 +939,7 @@ impl TowerVerify { .zip(next_alpha_pows.iter()) .zip(num_variables.iter()) .map(|((spec_index, alpha), max_round)| { + // prod'[rt,r_merge] = \sum_b eq(r_merge, b) * prod'[b,rt] if round < max_round -1 { // merged evaluation let evals = izip!( @@ -967,6 +995,12 @@ impl TowerVerify { } }) .sum::(); + // update ecc_eval + ecc_eval = Self::get_ecc_eval( + &tower_proofs.ecc_evals[round][0], + &tower_proofs.ecc_evals[round][1], + &rt_prime, + ); // sum evaluation from different specs let next_eval = next_prod_spec_evals + next_logup_spec_evals; Ok((PointAndEval { From c8f1013b242658dd9bdefedc249806b608180feb Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 11 Sep 2025 09:02:19 +0800 Subject: [PATCH 04/91] wip --- ceno_zkvm/src/scheme/septic_curve.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 6ad89e225..1ffc0572b 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1,6 +1,7 @@ // The extension field and curve definition are adapted from // https://github.com/succinctlabs/sp1/blob/v5.2.1/crates/stark/src/septic_curve.rs use p3::field::Field; +use serde::{Serialize, Deserialize}; use std::ops::{Add, Deref, Mul, Sub}; /// F[z] / (z^6 - z - 4) @@ -33,7 +34,7 @@ pub struct SexticExtension([F; 6]); /// # check if f(x) is irreducible /// print(f.is_irreducible()) /// ``` -#[derive(Clone, Debug, Default, PartialEq)] +#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] pub struct SepticExtension(pub [F; 7]); impl From<&[F]> for SepticExtension { @@ -197,9 +198,9 @@ impl Mul for SepticExtension { /// over the extension field F[z] / (z^7 - 2z - 5). /// /// Note that -/// 1. its cofactor is 1 -/// 2. its order is a large prime number of 31x7 bits -#[derive(Clone, Debug, Default, PartialEq)] +/// 1. The curve's cofactor is 1 +/// 2. The curve's order is a large prime number of 31x7 bits +#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] pub struct SepticPoint { pub x: SepticExtension, pub y: SepticExtension, From 0350253f93abbc14be0e5a5ded1b696029c79325 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 11 Sep 2025 14:49:39 +0800 Subject: [PATCH 05/91] simplify --- ceno_zkvm/src/scheme/septic_curve.rs | 40 ++++++++++++++++++++++------ ceno_zkvm/src/scheme/verifier.rs | 7 +++-- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 1ffc0572b..5eb992d3c 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1,7 +1,7 @@ // The extension field and curve definition are adapted from // https://github.com/succinctlabs/sp1/blob/v5.2.1/crates/stark/src/septic_curve.rs -use p3::field::Field; -use serde::{Serialize, Deserialize}; +use p3::field::{Field, FieldAlgebra}; +use serde::{Deserialize, Serialize}; use std::ops::{Add, Deref, Mul, Sub}; /// F[z] / (z^6 - z - 4) @@ -120,8 +120,8 @@ impl From<[u32; 7]> for SepticExtension { } } -impl Add<&Self> for SepticExtension { - type Output = Self; +impl Add<&Self> for SepticExtension { + type Output = SepticExtension; fn add(self, other: &Self) -> Self { let mut result = [F::ZERO; 7]; @@ -132,7 +132,19 @@ impl Add<&Self> for SepticExtension { } } -impl Add for SepticExtension { +impl Add for &SepticExtension { + type Output = SepticExtension; + + fn add(self, other: Self) -> SepticExtension { + let mut result = [F::ZERO; 7]; + for i in 0..7 { + result[i] = self.0[i] + other.0[i]; + } + SepticExtension(result) + } +} + +impl Add for SepticExtension { type Output = Self; fn add(self, other: Self) -> Self { @@ -140,8 +152,8 @@ impl Add for SepticExtension { } } -impl Sub<&Self> for SepticExtension { - type Output = Self; +impl Sub<&Self> for SepticExtension { + type Output = SepticExtension; fn sub(self, other: &Self) -> Self { let mut result = [F::ZERO; 7]; @@ -152,7 +164,19 @@ impl Sub<&Self> for SepticExtension { } } -impl Sub for SepticExtension { +impl Sub for &SepticExtension { + type Output = SepticExtension; + + fn sub(self, other: Self) -> SepticExtension { + let mut result = [F::ZERO; 7]; + for i in 0..7 { + result[i] = self.0[i] - other.0[i]; + } + SepticExtension(result) + } +} + +impl Sub for SepticExtension { type Output = Self; fn sub(self, other: Self) -> Self { diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 1de999583..3e3dfca09 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -907,13 +907,12 @@ impl TowerVerify { let SepticPoint { x: x1, y: y1 } = &tower_proofs.ecc_evals[round][0]; let SepticPoint { x: x2, y: y2 } = &tower_proofs.ecc_evals[round][1]; - // TODO: avoid clone - let xs = (x.clone() + x1 + x2) * (x2.clone() - x1) * (x2.clone() - x1) - - (y2.clone() - y1) * (y2.clone() - y1); + let xs = (x + x1 + x2) * (x2 - x1) * (x2 - x1) + - (y2 - y1) * (y2 - y1); // 0 = (ecc_y + ecc_y'[0]) * (ecc_x'[1] - ecc_x'[0]) // - (ecc_y'[1] - ecc_y'[0]) * (ecc_x'[0] - ecc_x) - let ys = (y.clone() + y1) * (x2.clone() - x1) - (y2.clone() - y1) * (x1.clone() - x); + let ys = (y + y1) * (x2 - x1) - (y2 - y1) * (x1 - x); expected_evaluation += izip!(xs.0.iter(), alpha_pows[num_prod_spec + num_logup_spec * 2..].iter().take(7)).map(|(&xi, &alpha)| eq * xi * alpha).sum::(); expected_evaluation += izip!(ys.0.iter(), alpha_pows[num_prod_spec + num_logup_spec * 2..].iter().skip(7).take(7)).map(|(&yi, &alpha)| eq * yi * alpha).sum::(); From 9dd7adee4213f60b8cebeb00760710ae451ca401 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sat, 13 Sep 2025 00:41:23 +0800 Subject: [PATCH 06/91] batch ec add into tower prover wip-1 --- ceno_zkvm/src/scheme/cpu/mod.rs | 87 ++++++++++++++++- ceno_zkvm/src/scheme/septic_curve.rs | 137 +++++++++++++++++++++++++++ 2 files changed, 222 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 04ba68306..25eaaaaa0 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -7,6 +7,7 @@ use crate::{ scheme::{ constants::{NUM_FANIN, NUM_FANIN_LOGUP}, hal::{DeviceProvingKey, MainSumcheckEvals, ProofInput, TowerProverSpec}, + septic_curve::{SepticExtension, SymbolicSepticExtension}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, masked_mle_split_to_chunks, wit_infer_by_expr, @@ -53,13 +54,20 @@ impl CpuTowerProver { pub fn create_proof<'a, E: ExtensionField, PCS: PolynomialCommitmentScheme>( prod_specs: Vec>>, logup_specs: Vec>>, + ecc_spec: Option>>, num_fanin: usize, transcript: &mut impl Transcript, ) -> (Point, TowerProofs) { #[derive(Debug, Clone)] enum GroupedMLE<'a, E: ExtensionField> { Prod((usize, Vec>)), // usize is the index in prod_specs - Logup((usize, Vec>)), /* usize is the index in logup_specs */ + Logup((usize, Vec>)), // usize is the index in logup_specs + EcAdd( + ( + Vec>, + Vec>, + ), + ), } // XXX to sumcheck batched product argument with logup, we limit num_product_fanin to 2 @@ -82,7 +90,8 @@ impl CpuTowerProver { let alpha_pows = get_challenge_pows( prod_specs_len + // logup occupy 2 sumcheck: numerator and denominator - logup_specs_len * 2, + logup_specs_len * 2 + + ecc_spec.as_ref().map_or(0, |_| 14), transcript, ); let initial_rt: Point = transcript.sample_and_append_vec(b"product_sum", log_num_fanin); @@ -112,6 +121,16 @@ impl CpuTowerProver { merge_spec_witness(&mut layer_witness, spec, i, GroupedMLE::Logup); } + if let Some(ecc_spec) = ecc_spec { + for i in 0..max_round_index { + layer_witness[i + 1].push(GroupedMLE::EcAdd(( + // TODO: avoid clone + ecc_spec.witness[i].clone(), + ecc_spec.witness[i + 1].clone(), + ))); + } + } + // skip(1) for output layer for (round, mut layer_witness) in layer_witness.into_iter().enumerate().skip(1) { // in first few round we just run on single thread @@ -190,6 +209,59 @@ impl CpuTowerProver { + alpha_denominator * q1 * q2), ); } + GroupedMLE::EcAdd((prev_layer, curr_layer)) => { + assert_eq!(curr_layer.len(), 2 * 14); // 3 points, each point has 14 polys + assert_eq!(prev_layer.len(), 2 * 14); + let (x1, rest) = curr_layer.split_at(7); + let (y1, rest) = rest.split_at(7); + let (x2, rest) = rest.split_at(7); + let (y2, _) = rest.split_at(7); + let (x3, y3) = prev_layer.split_at(7); + + let x1 = &SymbolicSepticExtension::new( + x1.into_iter() + .map(|x| expr_builder.lift(Either::Left(x))) + .collect(), + ); + let y1 = &SymbolicSepticExtension::new( + y1.into_iter() + .map(|y| expr_builder.lift(Either::Left(y))) + .collect(), + ); + let x2 = &SymbolicSepticExtension::new( + x2.into_iter() + .map(|x| expr_builder.lift(Either::Left(x))) + .collect(), + ); + let y2 = &SymbolicSepticExtension::new( + y2.into_iter() + .map(|y| expr_builder.lift(Either::Left(y))) + .collect(), + ); + let x3 = &SymbolicSepticExtension::new( + x3.into_iter() + .map(|x| expr_builder.lift(Either::Left(x))) + .collect(), + ); + let y3 = &SymbolicSepticExtension::new( + y3.into_iter() + .map(|y| expr_builder.lift(Either::Left(y))) + .collect(), + ); + + // 0 = eq * ((x3 + x1 + x2) * (x2 - x1)^2 - (y2 - y1)^2) + exprs.extend( + (((x3 + x1 + x2) * (x2 - x1) * (x2 - x1) - (y2 - y1) * (y2 - y1)) + * &eq_expr) + .to_exprs(), + ); + // 0 = eq * ((y3 + y1) * (x2 - x1) - (y2 - y1) * (x1 - x3)) + exprs.extend( + (((y3 + y1) * (x2 - x1) - (y2 - y1) * (x1 - x3)) + * &eq_expr) + .to_exprs(), + ); + } } } @@ -335,6 +407,8 @@ impl> TowerProver Mul for SepticExtension { } } +pub struct SymbolicSepticExtension(pub Vec>); + +impl Add for &SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn add(self, other: Self) -> Self::Output { + let res = self + .0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| a.clone() + b.clone()) + .collect(); + + SymbolicSepticExtension(res) + } +} + +impl Add<&Self> for SymbolicSepticExtension { + type Output = Self; + + fn add(self, other: &Self) -> Self { + (&self).add(other) + } +} + +impl Add for SymbolicSepticExtension { + type Output = Self; + + fn add(self, other: Self) -> Self { + (&self).add(&other) + } +} + +impl Sub for &SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn sub(self, other: Self) -> Self::Output { + let res = self + .0 + .iter() + .zip(other.0.iter()) + .map(|(a, b)| a.clone() - b.clone()) + .collect(); + + SymbolicSepticExtension(res) + } +} + +impl Sub<&Self> for SymbolicSepticExtension { + type Output = Self; + + fn sub(self, other: &Self) -> Self { + (&self).sub(other) + } +} + +impl Sub for SymbolicSepticExtension { + type Output = Self; + + fn sub(self, other: Self) -> Self { + (&self).sub(&other) + } +} + +impl Mul for &SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn mul(self, other: Self) -> Self::Output { + let mut result = vec![Expression::Constant(Either::Left(E::BaseField::ZERO)); 7]; + let five = Expression::Constant(Either::Left(E::BaseField::from_canonical_u32(5))); + let two = Expression::Constant(Either::Left(E::BaseField::from_canonical_u32(2))); + + for i in 0..7 { + for j in 0..7 { + let term = self.0[i].clone() * other.0[j].clone(); + let mut index = i + j; + if index < 7 { + result[index] += term; + } else { + index -= 7; + // x^7 = 2x + 5 + result[index] += five.clone() * term.clone(); + result[index + 1] += two.clone() * term.clone(); + } + } + } + SymbolicSepticExtension(result) + } +} + +impl Mul<&Self> for SymbolicSepticExtension { + type Output = Self; + + fn mul(self, other: &Self) -> Self { + (&self).mul(other) + } +} + +impl Mul for SymbolicSepticExtension { + type Output = Self; + + fn mul(self, other: Self) -> Self { + (&self).mul(&other) + } +} + +impl Mul<&Expression> for SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn mul(self, other: &Expression) -> Self::Output { + let res = self.0.iter().map(|a| a.clone() * other.clone()).collect(); + SymbolicSepticExtension(res) + } +} + +impl Mul> for SymbolicSepticExtension { + type Output = SymbolicSepticExtension; + + fn mul(self, other: Expression) -> Self::Output { + self.mul(&other) + } +} + +impl SymbolicSepticExtension { + pub fn new(exprs: Vec>) -> Self { + assert!(exprs.len() == 7); + Self(exprs) + } + + pub fn to_exprs(&self) -> Vec> { + self.0.clone() + } +} + /// A point on the short Weierstrass curve defined by /// y^2 = x^3 + 2x + 26z^5 /// over the extension field F[z] / (z^7 - 2z - 5). From f22d7aac18de254161e32d86041fbe799734b282 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 15 Sep 2025 23:48:53 +0800 Subject: [PATCH 07/91] revisit infer_septic_addition_witness --- ceno_zkvm/src/scheme/cpu/mod.rs | 8 +- ceno_zkvm/src/scheme/utils.rs | 223 ++++++++++++++++++++------------ 2 files changed, 147 insertions(+), 84 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 25eaaaaa0..b3779c55b 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -257,9 +257,7 @@ impl CpuTowerProver { ); // 0 = eq * ((y3 + y1) * (x2 - x1) - (y2 - y1) * (x1 - x3)) exprs.extend( - (((y3 + y1) * (x2 - x1) - (y2 - y1) * (x1 - x3)) - * &eq_expr) - .to_exprs(), + (((y3 + y1) * (x2 - x1) - (y2 - y1) * (x1 - x3)) * &eq_expr).to_exprs(), ); } } @@ -910,6 +908,10 @@ mod tests { #[test] fn test_ecc_tower_prover() { + // generate 1 product witness spec + + // generate 1 logup witness spec + // generate 1 ecc add witness } } diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index a3fd8773c..d1d5c033e 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -159,6 +159,11 @@ macro_rules! tower_mle_4 { }}; } +pub fn log2_strict_usize(n: usize) -> usize { + assert!(n.is_power_of_two()); + n.trailing_zeros() as usize +} + /// infer logup witness from last layer /// return is the ([p1,p2], [q1,q2]) for each layer pub(crate) fn infer_tower_logup_witness<'a, E: ExtensionField>( @@ -257,116 +262,172 @@ pub(crate) fn infer_tower_logup_witness<'a, E: ExtensionField>( .collect_vec() } -/// infer tower witness from last layer -pub(crate) fn infer_tower_product_witness( +/// Infer tower witness from input layer (layer 0 is the output layer and layer n is the input layer). +/// The relation between layer i and layer i+1 is as follows: +/// prod[i][b] = ∏_s prod[i+1][s,b] +/// where 2^s is the fanin of the product gate `num_product_fanin`. +pub fn infer_tower_product_witness( num_vars: usize, last_layer: Vec>, num_product_fanin: usize, ) -> Vec>> { + // sanity check assert!(last_layer.len() == num_product_fanin); - assert_eq!(num_product_fanin % 2, 0); - let log2_num_product_fanin = ceil_log2(num_product_fanin); - let mut wit_layers = - (0..(num_vars / log2_num_product_fanin) - 1).fold(vec![last_layer], |mut acc, _| { - let next_layer = acc.last().unwrap(); - let cur_len = next_layer[0].evaluations().len() / num_product_fanin; - let cur_layer: Vec> = (0..num_product_fanin) - .map(|index| { - let mut evaluations = vec![E::ONE; cur_len]; - next_layer.chunks_exact(2).for_each(|f| { - match (f[0].evaluations(), f[1].evaluations()) { - (FieldType::Ext(f1), FieldType::Ext(f2)) => { - let start: usize = index * cur_len; - (start..(start + cur_len)) + assert!(num_product_fanin.is_power_of_two()); + + let log2_num_product_fanin = log2_strict_usize(num_product_fanin); + assert!(num_vars % log2_num_product_fanin == 0); + assert!( + last_layer + .iter() + .all(|p| p.num_vars() == num_vars - log2_num_product_fanin) + ); + + let num_layers = num_vars / log2_num_product_fanin; + + let mut wit_layers = Vec::with_capacity(num_layers); + wit_layers.push(last_layer); + + for _ in (0..num_layers - 1).rev() { + let input_layer = wit_layers.last().unwrap(); + let output_len = input_layer[0].evaluations().len() / num_product_fanin; + + let output_layer: Vec> = (0..num_product_fanin) + .map(|index| { + // avoid the overhead of vector initialization + let mut evaluations: Vec = Vec::with_capacity(output_len); + unsafe { + // will be filled immediately + evaluations.set_len(output_len); + } + + input_layer.chunks_exact(2).enumerate().for_each(|(i, f)| { + match (f[0].evaluations(), f[1].evaluations()) { + (FieldType::Ext(f1), FieldType::Ext(f2)) => { + let start: usize = index * output_len; + + if i == 0 { + (start..(start + output_len)) .into_par_iter() .zip(evaluations.par_iter_mut()) .with_min_len(MIN_PAR_SIZE) - .map(|(index, evaluations)| { + .for_each(|(index, evaluations)| { + *evaluations = f1[index] * f2[index] + }); + } else { + (start..(start + output_len)) + .into_par_iter() + .zip(evaluations.par_iter_mut()) + .with_min_len(MIN_PAR_SIZE) + .for_each(|(index, evaluations)| { *evaluations *= f1[index] * f2[index] - }) - .collect() + }); } - _ => unreachable!("must be extension field"), } - }); - evaluations.into_mle() - }) - .collect_vec(); - acc.push(cur_layer); - acc - }); + _ => unreachable!("must be extension field"), + } + }); + evaluations.into_mle() + }) + .collect_vec(); + wit_layers.push(output_layer); + } + wit_layers.reverse(); - wit_layers -} -pub fn log2_strict_usize(n: usize) -> usize { - assert!(n.is_power_of_two()); - n.trailing_zeros() as usize + wit_layers } +/// Infer from input layer (layer 0) to the output layer (layer n) +/// Note that each layer has 2 * 7 * 2 multilinear polynomials. +/// +/// The relation between layer i and layer i+1 is as follows: +/// 0 = p[i][b] - (p[i+1][0,b] + p[i+1][1,b]) +/// pub fn infer_septic_sum_witness( - p_mles: RowMajorMatrix, - q_mles: RowMajorMatrix, + p_mles: Vec>, ) -> Vec>> { - assert!(p_mles.width() > 0); - let num_layers = log2_strict_usize(p_mles.height()); + assert_eq!(p_mles.len(), 2 * 7 * 2); + assert!(p_mles.iter().map(|p| p.num_vars()).all_equal()); + + // +1 as the input layer has 2*N points where N = 2^num_vars + // and the output layer has 2 points + let num_layers = p_mles[0].num_vars() + 1; let mut layers = Vec::with_capacity(num_layers); - layers.push(vec![p_mles, q_mles]); - for i in (0..num_layers).rev() { - let last_layer = layers.last().unwrap(); - let (p, q) = (&last_layer[0], &last_layer[1]); - - let num_rows = p.height(); - let new_p = RowMajorMatrix::new( - (0..num_rows / 2) - .into_par_iter() - .flat_map_iter(|row| { - let p = p.row_slice(row); - let q = q.row_slice(row); + layers.push(p_mles); - let p1 = SepticPoint { - x: SepticExtension(std::array::from_fn(|i| p[i])), - y: SepticExtension(std::array::from_fn(|i| p[i + 7])), - }; - let p2 = SepticPoint { - x: SepticExtension(std::array::from_fn(|i| q[i])), - y: SepticExtension(std::array::from_fn(|i| q[i + 7])), - }; - let q = p1 + p2; - q.x.0.iter().chain(q.y.0.iter()).copied() - }) - .collect::>(), - 14, - ); - let new_q = RowMajorMatrix::new( - ((num_rows / 2)..num_rows) + for _ in (0..num_layers-1).rev() { + let input_layer = layers.last().unwrap(); + let p = input_layer[0..14] + .iter() + .map(|mle| mle.get_base_field_vec()) + .collect_vec(); + let q = input_layer[14..28] + .iter() + .map(|mle| mle.get_base_field_vec()) + .collect_vec(); + + let output_len = p[0].len() / 2; + let mut outputs: Vec = Vec::with_capacity(28 * output_len); + unsafe { + // will be filled immediately + outputs.set_len(28 * output_len); + } + + (0..2).into_iter().for_each(|chunk| { + (0..output_len) .into_par_iter() - .flat_map_iter(|row| { - let p = p.row_slice(row); - let q = q.row_slice(row); + .with_min_len(MIN_PAR_SIZE) + .zip(outputs.par_chunks_mut(28)) + .for_each(|(idx, output)| { + let row = chunk * output_len + idx; + let offset = chunk * 14; let p1 = SepticPoint { - x: SepticExtension(std::array::from_fn(|i| p[i])), - y: SepticExtension(std::array::from_fn(|i| p[i + 7])), + x: SepticExtension(std::array::from_fn(|i| p[offset + i][row])), + y: SepticExtension(std::array::from_fn(|i| p[offset + i + 7][row])), }; let p2 = SepticPoint { - x: SepticExtension(std::array::from_fn(|i| q[i])), - y: SepticExtension(std::array::from_fn(|i| q[i + 7])), + x: SepticExtension(std::array::from_fn(|i| q[offset + i][row])), + y: SepticExtension(std::array::from_fn(|i| q[offset + i + 7][row])), }; - let q = p1 + p2; - q.x.0.iter().chain(q.y.0.iter()).copied() - }) - .collect::>(), - 14, - ); - layers.push(vec![new_p, new_q]); + + let p3 = p1 + p2; + + output[offset..] + .iter_mut() + .take(7) + .enumerate() + .for_each(|(i, out)| { + *out = p3.x.0[i]; + }); + output[offset..] + .iter_mut() + .skip(7) + .take(7) + .enumerate() + .for_each(|(i, out)| { + *out = p3.y.0[i]; + }); + }); + }); + + // transpose + let output_mles = (0..28) + .map(|i| { + (0..output_len) + .into_par_iter() + .map(|j| outputs[j * 28 + i]) + .collect::>() + .into_mle() + }) + .collect_vec(); + layers.push(output_mles); } + layers.reverse(); layers - .iter() - .rev() - .map(|layer| layer.iter().map(|m| m.into_mles()).collect::>()) } pub fn build_main_witness< From 7b366a15f938e5c11a362336fc98e404a664d0c2 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 16 Sep 2025 21:30:33 +0800 Subject: [PATCH 08/91] sample random point on curve --- ceno_zkvm/src/scheme/septic_curve.rs | 250 +++++++++++++++++++++++++-- 1 file changed, 236 insertions(+), 14 deletions(-) diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index f6a3ee35d..9394e1f68 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1,9 +1,10 @@ use either::Either; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, FromUniformBytes}; use multilinear_extensions::Expression; // The extension field and curve definition are adapted from // https://github.com/succinctlabs/sp1/blob/v5.2.1/crates/stark/src/septic_curve.rs use p3::field::{Field, FieldAlgebra}; +use rand::RngCore; use serde::{Deserialize, Serialize}; use std::ops::{Add, Deref, Mul, Sub}; @@ -62,11 +63,148 @@ impl SepticExtension { self.0.iter().all(|c| *c == F::ZERO) } + // returns z^{i*p} for i = 0..6 + // + // The sage script to compute z^{i*p} is as follows: + // ```sage + // p = 2^31 - 2^27 + 1 + // Fp = GF(p) + // R. = PolynomialRing(Fp) + // mod_poly = z^7 - 2*z - 5 + // Q = R.quotient(mod_poly) + // + // # compute z^(i*p) for i = 1..6 + // for k in range(1, 7): + // power = k * p + // z_power = Q(z)^power + // print(f"z^({k}*p) = {z_power}") + // ``` + fn z_pow_p(i: usize) -> Self { + match i { + 0 => [1, 0, 0, 0, 0, 0, 0].into(), + 1 => [ + 954599710, 1359279693, 566669999, 1982781815, 1735718361, 1174868538, 1120871770, + ] + .into(), + 2 => [ + 862825265, 597046311, 978840770, 1790138282, 1044777201, 835869808, 1342179023, + ] + .into(), + 3 => [ + 596273169, 658837454, 1515468261, 367059247, 781278880, 1544222616, 155490465, + ] + .into(), + 4 => [ + 557608863, 1173670028, 1749546888, 1086464137, 803900099, 1288818584, 1184677604, + ] + .into(), + 5 => [ + 763416381, 1252567168, 628856225, 1771903394, 650712211, 19417363, 57990258, + ] + .into(), + 6 => [ + 1734711039, 1749813853, 1227235221, 1707730636, 424560395, 1007029514, 498034669, + ] + .into(), + _ => unimplemented!("i should be in [0, 7]"), + } + } + + // returns z^{i*p^2} for i = 0..6 + // we can change the above sage script to compute z^{i*p^2} by replacing + // `power = k * p` with `power = k * p * p` + fn z_pow_p_square(i: usize) -> Self { + match i { + 0 => [1, 0, 0, 0, 0, 0, 0].into(), + 1 => [ + 1013489358, 1619071628, 304593143, 1949397349, 1564307636, 327761151, 415430835, + ] + .into(), + 2 => [ + 209824426, 1313900768, 38410482, 256593180, 1708830551, 1244995038, 1555324019, + ] + .into(), + 3 => [ + 1475628651, 777565847, 704492386, 1218528120, 1245363405, 475884575, 649166061, + ] + .into(), + 4 => [ + 550038364, 948935655, 68722023, 1251345762, 1692456177, 1177958698, 350232928, + ] + .into(), + 5 => [ + 882720258, 821925756, 199955840, 812002876, 1484951277, 1063138035, 491712810, + ] + .into(), + 6 => [ + 738287111, 1955364991, 552724293, 1175775744, 341623997, 1454022463, 408193320, + ] + .into(), + _ => unimplemented!("i should be in [0, 7]"), + } + } + + // returns self^p = (a0 + a1*z^p + ... + a6*z^(6p)) + pub fn frobenius(&self) -> Self { + Self::z_pow_p(0) * self.0[0] + + Self::z_pow_p(1) * self.0[1] + + Self::z_pow_p(2) * self.0[2] + + Self::z_pow_p(3) * self.0[3] + + Self::z_pow_p(4) * self.0[4] + + Self::z_pow_p(5) * self.0[5] + + Self::z_pow_p(6) * self.0[6] + } + + // returns self^(p^2) = (a0 + a1*z^(p^2) + ... + a6*z^(6*p^2)) + pub fn double_frobenius(&self) -> Self { + Self::z_pow_p_square(0) * self.0[0] + + Self::z_pow_p_square(1) * self.0[1] + + Self::z_pow_p_square(2) * self.0[2] + + Self::z_pow_p_square(3) * self.0[3] + + Self::z_pow_p_square(4) * self.0[4] + + Self::z_pow_p_square(5) * self.0[5] + + Self::z_pow_p_square(6) * self.0[6] + } + + // returns self^(p + p^2 + ... + p^6) + fn norm_sub(&self) -> Self { + let a = self.frobenius() * self.double_frobenius(); + let b = a.double_frobenius(); + let c = b.double_frobenius(); + + a * b * c + } + + // norm = self^(1 + p + ... + p^6) + // = self^((p^7-1)/(p-1)) + // it's a field element in F since norm^p = norm + fn norm(&self) -> F { + (self.norm_sub() * self).0[0] + } + + pub fn is_square(&self) -> bool { + // since a^((p^7 - 1)/2) = norm(a)^((p-1)/2) + // to test if self^((p^7 - 1) / 2) == 1? + // we can just test if norm(a)^((p-1)/2) == 1? + let power_digits = ((F::order() - 1u32) / 2u32).to_u64_digits(); + debug_assert!(power_digits.len() == 1); + let power = power_digits[0]; + + self.norm().exp_u64(power) == F::ONE + } + pub fn inverse(&self) -> Option { match self.is_zero() { true => None, false => { - todo!() + // since norm(a)^(-1) * a^(p + p^2 + ... + p^6) * a = 1 + // it's easy to see a^(-1) = norm(a)^(-1) * a^(p + p^2 + ... + p^6) + let x = self.norm_sub(); + let norm = (self * &x).0[0]; + // since self is not zero, norm is not zero + let norm_inv = norm.try_inverse().unwrap(); + + Some(x * norm_inv) } } } @@ -91,7 +229,7 @@ impl SepticExtension { } } } - // i == j + // i == j: i \in [0, 3] result[0] += self.0[0] * self.0[0]; result[2] += self.0[1] * self.0[1]; result[4] += self.0[2] * self.0[2]; @@ -111,6 +249,20 @@ impl SepticExtension { Self(result) } + + pub fn sqrt(&self) -> Option { + todo!() + } +} + +impl SepticExtension { + pub fn random(mut rng: impl RngCore) -> Self { + let mut arr = [F::ZERO; 7]; + for i in 0..7 { + arr[i] = F::random(&mut rng); + } + Self(arr) + } } impl From<[u32; 7]> for SepticExtension { @@ -187,10 +339,30 @@ impl Sub for SepticExtension { } } -impl Mul<&Self> for SepticExtension { - type Output = Self; +impl Mul for &SepticExtension { + type Output = SepticExtension; - fn mul(self, other: &Self) -> Self { + fn mul(self, other: F) -> Self::Output { + let mut result = [F::ZERO; 7]; + for i in 0..7 { + result[i] = self.0[i] * other; + } + SepticExtension(result) + } +} + +impl Mul for SepticExtension { + type Output = SepticExtension; + + fn mul(self, other: F) -> Self::Output { + (&self).mul(other) + } +} + +impl Mul for &SepticExtension { + type Output = SepticExtension; + + fn mul(self, other: Self) -> Self::Output { let mut result = [F::ZERO; 7]; let five = F::from_canonical_u32(5); let two = F::from_canonical_u32(2); @@ -208,7 +380,7 @@ impl Mul<&Self> for SepticExtension { } } } - Self(result) + SepticExtension(result) } } @@ -216,7 +388,15 @@ impl Mul for SepticExtension { type Output = Self; fn mul(self, other: Self) -> Self { - self.mul(&other) + (&self).mul(&other) + } +} + +impl Mul<&Self> for SepticExtension { + type Output = Self; + + fn mul(self, other: &Self) -> Self { + (&self).mul(other) } } @@ -380,10 +560,38 @@ impl Add for SepticPoint { } } +impl SepticPoint { + pub fn is_on_curve(&self) -> bool { + let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); + let a: F = F::from_canonical_u32(2); + + self.y.square() == self.x.square() * &self.x + (&self.x * a) + b + } +} + +impl SepticPoint { + pub fn random(mut rng: impl RngCore) -> Self { + let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); + let a: F = F::from_canonical_u32(2); + + loop { + let x = SepticExtension::random(&mut rng); + let y2 = x.square() * &x + (&x * a) + &b; + if y2.is_square() { + let y = y2.sqrt().unwrap(); + + return Self { x, y }; + } + } + } +} + #[cfg(test)] mod tests { use super::SepticExtension; - use p3::babybear::BabyBear; + use crate::scheme::septic_curve::SepticPoint; + use p3::{babybear::BabyBear, field::Field}; + use rand::thread_rng; type F = BabyBear; #[test] @@ -392,10 +600,24 @@ mod tests { let a: SepticExtension = SepticExtension::from([0, 1, 0, 0, 0, 0, 0]); let b: SepticExtension = SepticExtension::from([0, 0, 0, 0, 1, 1, 1]); - assert_eq!( - a * b, - // z^5 + z^6 + 2*z + 5 - SepticExtension::from([5, 2, 0, 0, 0, 1, 1]) - ) + let c = SepticExtension::from([5, 2, 0, 0, 0, 1, 1]); + assert_eq!(a * b, c); + + // a^(p^2) = (a^p)^p + assert_eq!(c.double_frobenius(), c.frobenius().frobenius()); + + // norm_sub(a) * a must be in F + let norm = c.norm_sub() * &c; + assert!(norm.0[1..7].iter().all(|x| x.is_zero())); + } + + #[test] + fn test_septic_curve_arithmetic() { + let mut rng = thread_rng(); + let p1 = SepticPoint::::random(&mut rng); + let p2 = SepticPoint::::random(&mut rng); + + let p3 = p1 + p2; + assert!(p3.is_on_curve()); } } From 58d9d22d1e46555d2ecd69127d5b3be19cd6ce0f Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Wed, 17 Sep 2025 22:17:01 +0800 Subject: [PATCH 09/91] add sqrt --- Cargo.lock | 1 + ceno_zkvm/Cargo.toml | 1 + ceno_zkvm/src/scheme/septic_curve.rs | 170 ++++++++++++++++++++++++++- 3 files changed, 166 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index de2220658..df24100fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -382,6 +382,7 @@ dependencies = [ "mpcs", "multilinear_extensions", "ndarray", + "num-bigint", "once_cell", "p3", "parse-size", diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 1b712320e..b38303a73 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -48,6 +48,7 @@ parse-size = "1.1" rand.workspace = true tempfile = "3.14" tiny-keccak.workspace = true +num-bigint = "0.4.6" [target.'cfg(unix)'.dependencies] tikv-jemalloc-ctl = { version = "0.6", features = ["stats"], optional = true } diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 9394e1f68..e5144c52d 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -3,10 +3,11 @@ use ff_ext::{ExtensionField, FromUniformBytes}; use multilinear_extensions::Expression; // The extension field and curve definition are adapted from // https://github.com/succinctlabs/sp1/blob/v5.2.1/crates/stark/src/septic_curve.rs +use num_bigint::BigUint; use p3::field::{Field, FieldAlgebra}; use rand::RngCore; use serde::{Deserialize, Serialize}; -use std::ops::{Add, Deref, Mul, Sub}; +use std::ops::{Add, Deref, Mul, MulAssign, Sub}; /// F[z] / (z^6 - z - 4) /// @@ -63,6 +64,16 @@ impl SepticExtension { self.0.iter().all(|c| *c == F::ZERO) } + pub fn zero() -> Self { + Self([F::ZERO; 7]) + } + + pub fn one() -> Self { + let mut arr = [F::ZERO; 7]; + arr[0] = F::ONE; + Self(arr) + } + // returns z^{i*p} for i = 0..6 // // The sage script to compute z^{i*p} is as follows: @@ -186,11 +197,11 @@ impl SepticExtension { // since a^((p^7 - 1)/2) = norm(a)^((p-1)/2) // to test if self^((p^7 - 1) / 2) == 1? // we can just test if norm(a)^((p-1)/2) == 1? - let power_digits = ((F::order() - 1u32) / 2u32).to_u64_digits(); - debug_assert!(power_digits.len() == 1); - let power = power_digits[0]; + let exp_digits = ((F::order() - 1u32) / 2u32).to_u64_digits(); + debug_assert!(exp_digits.len() == 1); + let exp = exp_digits[0]; - self.norm().exp_u64(power) == F::ONE + self.norm().exp_u64(exp) == F::ONE } pub fn inverse(&self) -> Option { @@ -250,8 +261,140 @@ impl SepticExtension { Self(result) } + pub fn pow(&self, exp: u64) -> Self { + let mut result = Self::one(); + let num_bits = 64 - exp.leading_zeros(); + for j in (0..num_bits).rev() { + result = result.square(); + if (exp >> j) & 1u64 == 1u64 { + result = result * self; + } + } + result + } + pub fn sqrt(&self) -> Option { - todo!() + // the algorithm is adapted from [Cipolla's algorithm](https://en.wikipedia.org/wiki/Cipolla%27s_algorithm + // the code is taken from https://github.com/succinctlabs/sp1/blob/dev/crates/stark/src/septic_extension.rs#L623 + let n = self.clone(); + + if n == Self::zero() || n == Self::one() { + return Some(n); + } + + // norm = n^(1 + p + ... + p^6) = n^(p^7-1)/(p-1) + let norm = n.norm(); + let exp = ((F::order() - 1u32) / 2u32).to_u64_digits()[0]; + // euler's criterion n^((p^7-1)/2) == 1 iff n is quadratic residue + if norm.exp_u64(exp) != F::ONE { + // it's not a square + return None; + }; + + // n_power = n^((p+1)/2) + let exp = ((F::order() + 1u32) / 2u32).to_u64_digits()[0]; + let n_power = self.pow(exp); + + // n^((p^2 + p)/2) + let mut n_frobenius = n_power.frobenius(); + let mut denominator = n_frobenius.clone(); + + // n^((p^4 + p^3)/2) + n_frobenius = n_frobenius.double_frobenius(); + denominator *= n_frobenius.clone(); + // n^((p^6 + p^5)/2) + n_frobenius = n_frobenius.double_frobenius(); + // d = n^((p^6 + p^5 + p^4 + p^3 + p^2 + p) / 2) + // d^2 * n = norm + denominator *= n_frobenius; + // d' = d*n + denominator *= n; + + let base = norm.inverse(); // norm^(-1) + let g = F::GENERATOR; + let mut a = F::ONE; + let mut non_residue = F::ONE - base; + let legendre_exp = (F::order() - 1u32) / 2u32; // (p-1)/2 + + // non_residue = a^2 - 1/norm + // find `a` such that non_residue is not a square in F + while non_residue.exp_u64(legendre_exp.to_u64_digits()[0]) == F::ONE { + a *= g; + non_residue = a.square() - base; + } + + // (p+1)/2 + let cipolla_exp = ((F::order() + 1u32) / 2u32).to_u64_digits()[0]; + // x = (a+i)^((p+1)/2) where a in Fp + // x^2 = (a+i) * (a+i)^p = (a+i)*(a-i) = a^2 - i^2 + // = a^2 - non_residue = 1/norm + // therefore, x is the square root of 1/norm + let mut x = QuadraticExtension::new(a, F::ONE, non_residue); + x = x.pow(cipolla_exp); + + // (x*d')^2 = x^2 * d^2 * n^2 = 1/norm * norm * n + Some(denominator * x.real) + } +} + +// a + bi where i^2 = non_residue +#[derive(Clone, Debug)] +pub struct QuadraticExtension { + pub real: F, + pub imag: F, + pub non_residue: F, +} + +impl QuadraticExtension { + pub fn new(real: F, imag: F, non_residue: F) -> Self { + Self { + real, + imag, + non_residue, + } + } + + pub fn square(&self) -> Self { + // (a + bi)^2 = (a^2 + b^2*i^2) + 2ab*i + let real = self.real * self.real + self.non_residue * self.imag * self.imag; + let mut imag = self.real * self.imag; + imag += imag; + + Self { + real, + imag, + non_residue: self.non_residue, + } + } + + pub fn mul(&self, other: &Self) -> Self { + // (a + bi)(c + di) = (ac + bd*i^2) + (ad + bc)i + let real = self.real * other.real + self.non_residue * self.imag * other.imag; + let imag = self.real * other.imag + self.imag * other.real; + + Self { + real, + imag, + non_residue: self.non_residue, + } + } + + pub fn pow(&self, exp: u64) -> Self { + let mut result = Self { + real: F::ONE, + imag: F::ZERO, + non_residue: self.non_residue, + }; + + let num_bits = 64 - exp.leading_zeros(); + for j in (0..num_bits).rev() { + result = result.square(); + if (exp >> j) & 1u64 == 1u64 { + result = result.mul(self); + } + } + + result } } @@ -400,6 +543,12 @@ impl Mul<&Self> for SepticExtension { } } +impl MulAssign for SepticExtension { + fn mul_assign(&mut self, other: Self) { + *self = (&*self).mul(&other); + } +} + pub struct SymbolicSepticExtension(pub Vec>); impl Add for &SymbolicSepticExtension { @@ -596,6 +745,7 @@ mod tests { type F = BabyBear; #[test] fn test_septic_extension_arithmetic() { + let mut rng = thread_rng(); // a = z, b = z^6 + z^5 + z^4 let a: SepticExtension = SepticExtension::from([0, 1, 0, 0, 0, 0, 0]); let b: SepticExtension = SepticExtension::from([0, 0, 0, 0, 1, 1, 1]); @@ -609,6 +759,14 @@ mod tests { // norm_sub(a) * a must be in F let norm = c.norm_sub() * &c; assert!(norm.0[1..7].iter().all(|x| x.is_zero())); + + let d: SepticExtension = SepticExtension::random(&mut rng); + let e = d.square(); + assert!(e.is_square()); + + let f = e.sqrt().unwrap(); + let zero = SepticExtension::zero(); + assert!(f == d || f == zero - d); } #[test] From 463a4312ce2f27bd51d8bd3ab72502577f78ef60 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Wed, 17 Sep 2025 22:24:07 +0800 Subject: [PATCH 10/91] fix ec add --- ceno_zkvm/src/scheme/septic_curve.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index e5144c52d..8ed9ea017 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -702,8 +702,8 @@ impl Add for SepticPoint { fn add(self, other: Self) -> Self { assert!(other.x != self.x, "other = self or other = -self"); let slope = (other.y - &self.y) * (other.x.clone() - &self.x).inverse().unwrap(); - let x = slope.square() - (self.x.clone() + other.x); - let y = slope * (x.clone() - self.x) - self.y; + let x = slope.square() - (&self.x + &other.x); + let y = slope * (self.x - &x) - self.y; Self { x, y } } From 56c7aa430ad91396bd44d86ff7300ad90cc7eb87 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 18 Sep 2025 20:20:41 +0800 Subject: [PATCH 11/91] infer_ec_sum_witness unit test --- ceno_zkvm/src/scheme/septic_curve.rs | 56 +++++++++++++-- ceno_zkvm/src/scheme/utils.rs | 103 +++++++++++++++++++-------- 2 files changed, 126 insertions(+), 33 deletions(-) diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 8ed9ea017..bddcd896d 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -3,11 +3,13 @@ use ff_ext::{ExtensionField, FromUniformBytes}; use multilinear_extensions::Expression; // The extension field and curve definition are adapted from // https://github.com/succinctlabs/sp1/blob/v5.2.1/crates/stark/src/septic_curve.rs -use num_bigint::BigUint; use p3::field::{Field, FieldAlgebra}; use rand::RngCore; use serde::{Deserialize, Serialize}; -use std::ops::{Add, Deref, Mul, MulAssign, Sub}; +use std::{ + iter::Sum, + ops::{Add, Deref, Mul, MulAssign, Sub}, +}; /// F[z] / (z^6 - z - 4) /// @@ -694,23 +696,63 @@ impl SymbolicSepticExtension { pub struct SepticPoint { pub x: SepticExtension, pub y: SepticExtension, + pub is_infinity: bool, +} + +impl SepticPoint { + pub fn double(&self) -> Self { + todo!() + } } impl Add for SepticPoint { type Output = Self; fn add(self, other: Self) -> Self { - assert!(other.x != self.x, "other = self or other = -self"); + if self.is_infinity { + return other; + } + + if other.is_infinity { + return self; + } + + if self.x == other.x { + if self.y == other.y { + return self.double(); + } else { + return Self { + x: SepticExtension::zero(), + y: SepticExtension::zero(), + is_infinity: true, + }; + } + } + let slope = (other.y - &self.y) * (other.x.clone() - &self.x).inverse().unwrap(); let x = slope.square() - (&self.x + &other.x); let y = slope * (self.x - &x) - self.y; - Self { x, y } + Self { + x, + y, + is_infinity: false, + } + } +} + +impl Sum for SepticPoint { + fn sum>(iter: I) -> Self { + iter.fold(Self::default(), |acc, p| acc + p) } } impl SepticPoint { pub fn is_on_curve(&self) -> bool { + if self.is_infinity && self.x.is_zero() && self.y.is_zero() { + return true; + } + let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); let a: F = F::from_canonical_u32(2); @@ -729,7 +771,11 @@ impl SepticPoint { if y2.is_square() { let y = y2.sqrt().unwrap(); - return Self { x, y }; + return Self { + x, + y, + is_infinity: false, + }; } } } diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index d1d5c033e..f1dd118bb 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -29,7 +29,7 @@ use rayon::{ }, prelude::ParallelSliceMut, }; -use std::{iter, sync::Arc}; +use std::{array::from_fn, iter, sync::Arc}; use witness::next_pow2_instance_padding; // first computes the masked mle'[j] = mle[j] if j < num_instance, else default @@ -343,7 +343,6 @@ pub fn infer_tower_product_witness( /// /// The relation between layer i and layer i+1 is as follows: /// 0 = p[i][b] - (p[i+1][0,b] + p[i+1][1,b]) -/// pub fn infer_septic_sum_witness( p_mles: Vec>, ) -> Vec>> { @@ -352,12 +351,12 @@ pub fn infer_septic_sum_witness( // +1 as the input layer has 2*N points where N = 2^num_vars // and the output layer has 2 points - let num_layers = p_mles[0].num_vars() + 1; + let num_layers = p_mles[0].num_vars() + 1; let mut layers = Vec::with_capacity(num_layers); layers.push(p_mles); - for _ in (0..num_layers-1).rev() { + for _ in (0..num_layers - 1).rev() { let input_layer = layers.last().unwrap(); let p = input_layer[0..14] .iter() @@ -385,31 +384,22 @@ pub fn infer_septic_sum_witness( let offset = chunk * 14; let p1 = SepticPoint { - x: SepticExtension(std::array::from_fn(|i| p[offset + i][row])), - y: SepticExtension(std::array::from_fn(|i| p[offset + i + 7][row])), + x: SepticExtension(from_fn(|i| p[i][row])), + y: SepticExtension(from_fn(|i| p[i + 7][row])), + is_infinity: false, }; let p2 = SepticPoint { - x: SepticExtension(std::array::from_fn(|i| q[offset + i][row])), - y: SepticExtension(std::array::from_fn(|i| q[offset + i + 7][row])), + x: SepticExtension(from_fn(|i| q[i][row])), + y: SepticExtension(from_fn(|i| q[i + 7][row])), + is_infinity: false, }; + // TODO: change to debug_assert + assert!(p1.is_on_curve() && p2.is_on_curve()); let p3 = p1 + p2; - output[offset..] - .iter_mut() - .take(7) - .enumerate() - .for_each(|(i, out)| { - *out = p3.x.0[i]; - }); - output[offset..] - .iter_mut() - .skip(7) - .take(7) - .enumerate() - .for_each(|(i, out)| { - *out = p3.y.0[i]; - }); + output[offset..offset + 7].clone_from_slice(&p3.x); + output[offset + 7..offset + 14].clone_from_slice(&p3.y); }); }); @@ -639,18 +629,22 @@ pub fn gkr_witness< #[cfg(test)] mod tests { - use ff_ext::{FieldInto, GoldilocksExt2}; + use ff_ext::{BabyBearExt4, FieldInto, GoldilocksExt2}; use itertools::Itertools; use multilinear_extensions::{ commutative_op_mle_pair, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, smart_slice::SmartSlice, - util::ceil_log2, + util::{ceil_log2, transpose}, }; - use p3::field::FieldAlgebra; + use p3::{babybear::BabyBear, field::FieldAlgebra}; - use crate::scheme::utils::{ - infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, + use crate::scheme::{ + septic_curve::{SepticExtension, SepticPoint}, + utils::{ + infer_septic_sum_witness, infer_tower_logup_witness, infer_tower_product_witness, + interleaving_mles_to_mles, + }, }; #[test] @@ -957,4 +951,57 @@ mod tests { ])) ); } + + #[test] + fn test_infer_septic_addition_witness() { + type F = BabyBear; + type E = BabyBearExt4; + + let n_points = 1 << 4; + let mut rng = rand::thread_rng(); + let points = (0..n_points) + .map(|_| SepticPoint::::random(&mut rng)) + .collect_vec(); + + // transform points to row major matrix + let trace = points + .chunks_exact(2) + .map(|points| { + points + .iter() + .flat_map(|p| p.x.0.iter().chain(p.y.0.iter()).copied()) + .collect_vec() + }) + .collect_vec(); + + let p_mles: Vec> = transpose(trace) + .into_iter() + .map(|v| v.into_mle()) + .collect_vec(); + + let layers = infer_septic_sum_witness(p_mles); + let output_layer = &layers[0]; + assert!(output_layer.iter().all(|mle| mle.num_vars() == 0)); + assert!(output_layer.len() == 28); + + // recover points from output layer + let output_points: Vec> = output_layer + .chunks_exact(14) + .map(|mles| { + mles.iter() + .map(|mle| mle.get_base_field_vec()[0]) + .collect_vec() + }) + .map(|chunk| SepticPoint { + x: SepticExtension(chunk[0..7].try_into().unwrap()), + y: SepticExtension(chunk[7..14].try_into().unwrap()), + is_infinity: false, + }) + .collect_vec(); + assert!(output_points.iter().all(|p| p.is_on_curve())); + + let point_acc: SepticPoint = output_points.into_iter().sum(); + let expected_acc: SepticPoint = points.into_iter().sum(); + assert_eq!(point_acc, expected_acc); + } } From c5c9c118a49c21df82bf3a2c1267d9525526b585 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Fri, 19 Sep 2025 16:45:26 +0800 Subject: [PATCH 12/91] return point at infinity for Default --- ceno_zkvm/src/scheme/septic_curve.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index bddcd896d..265b71ab7 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -692,7 +692,7 @@ impl SymbolicSepticExtension { /// Note that /// 1. The curve's cofactor is 1 /// 2. The curve's order is a large prime number of 31x7 bits -#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct SepticPoint { pub x: SepticExtension, pub y: SepticExtension, @@ -705,6 +705,16 @@ impl SepticPoint { } } +impl Default for SepticPoint { + fn default() -> Self { + Self { + x: SepticExtension::zero(), + y: SepticExtension::zero(), + is_infinity: true, + } + } +} + impl Add for SepticPoint { type Output = Self; @@ -721,6 +731,8 @@ impl Add for SepticPoint { if self.y == other.y { return self.double(); } else { + assert!((self.y + other.y).is_zero()); + return Self { x: SepticExtension::zero(), y: SepticExtension::zero(), From dc4445aa28266a6c80ab7f7e0a6b488756d4af14 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Fri, 19 Sep 2025 18:11:11 +0800 Subject: [PATCH 13/91] finish infer septic addition unit test --- ceno_zkvm/src/scheme/utils.rs | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index f1dd118bb..861919b61 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -393,8 +393,7 @@ pub fn infer_septic_sum_witness( y: SepticExtension(from_fn(|i| q[i + 7][row])), is_infinity: false, }; - // TODO: change to debug_assert - assert!(p1.is_on_curve() && p2.is_on_curve()); + debug_assert!(p1.is_on_curve() && p2.is_on_curve()); let p3 = p1 + p2; @@ -959,21 +958,25 @@ mod tests { let n_points = 1 << 4; let mut rng = rand::thread_rng(); + // sample n points let points = (0..n_points) .map(|_| SepticPoint::::random(&mut rng)) .collect_vec(); - + // transform points to row major matrix - let trace = points - .chunks_exact(2) - .map(|points| { - points + let trace = points[0..n_points / 2] + .iter() + .zip(points[n_points / 2..n_points].iter()) + .map(|(p, q)| { + [p, q] .iter() - .flat_map(|p| p.x.0.iter().chain(p.y.0.iter()).copied()) + .flat_map(|p| p.x.0.iter().chain(p.y.0.iter())) + .copied() .collect_vec() }) .collect_vec(); + // transpose row major matrix to column major matrix let p_mles: Vec> = transpose(trace) .into_iter() .map(|v| v.into_mle()) @@ -999,6 +1002,7 @@ mod tests { }) .collect_vec(); assert!(output_points.iter().all(|p| p.is_on_curve())); + assert_eq!(output_points.len(), 2); let point_acc: SepticPoint = output_points.into_iter().sum(); let expected_acc: SepticPoint = points.into_iter().sum(); From 7887206f998608dbe5712604913c769837223b4c Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 23 Sep 2025 00:09:27 +0800 Subject: [PATCH 14/91] adjust the ec add sumchecks --- ceno_zkvm/src/scheme/cpu/mod.rs | 128 ++++++++++++++++++++++++++++++-- 1 file changed, 121 insertions(+), 7 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index b3779c55b..f800a5dd8 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -48,6 +48,7 @@ pub type TowerRelationOutput = ( Vec>, Vec>, ); + pub struct CpuTowerProver; impl CpuTowerProver { @@ -81,6 +82,7 @@ impl CpuTowerProver { let max_round_index = prod_specs .iter() .chain(logup_specs.iter()) + .chain(ecc_spec.iter()) .map(|m| m.witness.len()) .max() .unwrap() @@ -141,7 +143,11 @@ impl CpuTowerProver { let mut witness_lk_expr = vec![vec![]; logup_specs_len]; let mut eq: MultilinearExtension = build_eq_x_r_vec(&out_rt).into_mle(); + let eq_len = eq.evaluations.len(); + let mut eq_prime = eq.get_ext_field_vec()[0..eq_len / 2].to_vec().into_mle(); + let eq_expr = expr_builder.lift(Either::Right(&mut eq)); + let eq_prime_expr = expr_builder.lift(Either::Right(&mut eq_prime)); // processing exprs for group_witness in layer_witness.iter_mut() { @@ -210,13 +216,22 @@ impl CpuTowerProver { ); } GroupedMLE::EcAdd((prev_layer, curr_layer)) => { - assert_eq!(curr_layer.len(), 2 * 14); // 3 points, each point has 14 polys - assert_eq!(prev_layer.len(), 2 * 14); + assert_eq!(curr_layer.len(), 3 * 14); // 3 points, each point has 14 polys + assert_eq!(prev_layer.len(), 3 * 14); let (x1, rest) = curr_layer.split_at(7); let (y1, rest) = rest.split_at(7); let (x2, rest) = rest.split_at(7); - let (y2, _) = rest.split_at(7); - let (x3, y3) = prev_layer.split_at(7); + let (y2, rest) = rest.split_at(7); + let (x3, y3) = rest.split_at(7); + + // x1'[b] = x3[0,b] + // y1'[b] = y3[0,b] + // x2'[b] = x3[1,b] + // y2'[b] = y3[1,b] + let (x1_prime, rest) = prev_layer.split_at_mut(7); + let (y1_prime, rest) = rest.split_at_mut(7); + let (x2_prime, rest) = rest.split_at_mut(7); + let (y2_prime, _) = rest.split_at_mut(7); let x1 = &SymbolicSepticExtension::new( x1.into_iter() @@ -248,17 +263,60 @@ impl CpuTowerProver { .map(|y| expr_builder.lift(Either::Left(y))) .collect(), ); + let x1_prime_expr = SymbolicSepticExtension::new( + x1_prime + .iter_mut() + .map(|x| expr_builder.lift(x.to_either())) + .collect(), + ); + let y1_prime_expr = SymbolicSepticExtension::new( + y1_prime + .iter_mut() + .map(|y| expr_builder.lift(y.to_either())) + .collect(), + ); + let x2_prime_expr = SymbolicSepticExtension::new( + x2_prime + .iter_mut() + .map(|x| expr_builder.lift(x.to_either())) + .collect(), + ); + let y2_prime_expr = SymbolicSepticExtension::new( + y2_prime + .iter_mut() + .map(|y| expr_builder.lift(y.to_either())) + .collect(), + ); - // 0 = eq * ((x3 + x1 + x2) * (x2 - x1)^2 - (y2 - y1)^2) + // layer i: x3', y3', x1', y1', x2', y2', each has `i` variables + // we copy the first half of x3 to x1', 2nd half to x2' and + // copy the first half of y3 to y1', 2nd half to y2'. + // + // x1'[b] = x3[0,b], y1'[b] = y3[0,b] + // x2'[b] = x3[1,b], y2'[b] = y3[1,b] + // + // layer i+1: x3, y3, x1, y1, x2, y2, each has `i+1` variables + // we requires the elliptic curve addition constraints hold at layer i+1. + // 1. 0 = \sum_b eq(rt,b) * ((x3 + x1 + x2) * (x2 - x1)^2 - (y2 - y1)^2) exprs.extend( (((x3 + x1 + x2) * (x2 - x1) * (x2 - x1) - (y2 - y1) * (y2 - y1)) * &eq_expr) .to_exprs(), ); - // 0 = eq * ((y3 + y1) * (x2 - x1) - (y2 - y1) * (x1 - x3)) + // 2. 0 = \sum_b eq(rt,b) * ((y3 + y1) * (x2 - x1) - (y2 - y1) * (x1 - x3)) exprs.extend( (((y3 + y1) * (x2 - x1) - (y2 - y1) * (x1 - x3)) * &eq_expr).to_exprs(), ); + + // with len = rt.len(), rt' = rt[0..len-1] + // x1'[rt'] = \sum_b' eq(rt',b') * x3[0,b'] + // y1'[rt'] = \sum_b' eq(rt',b') * y3[0,b'] + // x2'[rt'] = \sum_b' eq(rt',b') * x3[1,b'] + // y2'[rt'] = \sum_b' eq(rt',b') * y3[1,b'] + exprs.extend((x1_prime_expr * &eq_prime_expr).to_exprs()); + exprs.extend((y1_prime_expr * &eq_prime_expr).to_exprs()); + exprs.extend((x2_prime_expr * &eq_prime_expr).to_exprs()); + exprs.extend((y2_prime_expr * &eq_prime_expr).to_exprs()); } } } @@ -601,7 +659,8 @@ impl> TowerProver; + // generate 1 product witness spec // generate 1 logup witness spec + // if layer i has n variables, + // then layer i+1 has n-1 variables. // generate 1 ecc add witness + let ecc_spec: TowerProverSpec<'_, CpuBackend> = { + let n_points = 1 << 4; + let mut rng = rand::thread_rng(); + // sample n points + let points = (0..n_points) + .map(|_| SepticPoint::::random(&mut rng)) + .collect_vec(); + + // transform points to row major matrix + let trace = points[0..n_points / 2] + .iter() + .zip(points[n_points / 2..n_points].iter()) + .map(|(p, q)| { + [p, q] + .iter() + .flat_map(|p| p.x.0.iter().chain(p.y.0.iter())) + .copied() + .collect_vec() + }) + .collect_vec(); + + // transpose row major matrix to column major matrix + let p_mles: Vec> = transpose(trace) + .into_iter() + .map(|v| v.into_mle()) + .collect_vec(); + + crate::scheme::hal::TowerProverSpec { + witness: infer_septic_sum_witness(p_mles), + } + }; + let mut transcript = BasicTranscript::new(b"test"); + let prover = + CpuTowerProver::create_proof(vec![], vec![], Some(ecc_spec), 2, &mut transcript); } } From febb1f27811ef680b029e70850547899156def0a Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 23 Sep 2025 21:14:41 +0800 Subject: [PATCH 15/91] support jacobian coordinates --- ceno_zkvm/src/scheme/septic_curve.rs | 271 ++++++++++++++++++++++++++- 1 file changed, 268 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 265b71ab7..87ce33509 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -8,7 +8,7 @@ use rand::RngCore; use serde::{Deserialize, Serialize}; use std::{ iter::Sum, - ops::{Add, Deref, Mul, MulAssign, Sub}, + ops::{Add, Deref, Mul, MulAssign, Neg, Sub}, }; /// F[z] / (z^6 - z - 4) @@ -452,6 +452,18 @@ impl Add for SepticExtension { } } +impl Neg for SepticExtension { + type Output = Self; + + fn neg(self) -> Self { + let mut result = [F::ZERO; 7]; + for i in 0..7 { + result[i] = -self.0[i]; + } + Self(result) + } +} + impl Sub<&Self> for SepticExtension { type Output = SepticExtension; @@ -484,6 +496,25 @@ impl Sub for SepticExtension { } } +impl Add for &SepticExtension { + type Output = SepticExtension; + + fn add(self, other: F) -> Self::Output { + let mut result = self.clone(); + result.0[0] += other; + + result + } +} + +impl Add for SepticExtension { + type Output = SepticExtension; + + fn add(self, other: F) -> Self::Output { + (&self).add(other) + } +} + impl Mul for &SepticExtension { type Output = SepticExtension; @@ -701,7 +732,25 @@ pub struct SepticPoint { impl SepticPoint { pub fn double(&self) -> Self { - todo!() + let a = F::from_canonical_u32(2); + let three = F::from_canonical_u32(3); + let two = F::from_canonical_u32(2); + + let x1 = &self.x; + let y1 = &self.y; + let x1_sqr = x1.square(); + + // x3 = (3*x1^2 + a)^2 / (2*y1)^2 - x1 - x1 + let slope = (x1_sqr * three + a) * (y1 * two).inverse().unwrap(); + let x3 = slope.square() - x1 - x1; + // y3 = slope * (x1 - x3) - y1 + let y3 = slope * (x1 - &x3) - y1; + + Self { + x: x3, + y: y3, + is_infinity: false, + } } } @@ -715,6 +764,22 @@ impl Default for SepticPoint { } } +impl Neg for SepticPoint { + type Output = SepticPoint; + + fn neg(self) -> Self::Output { + if self.is_infinity { + return self; + } + + Self { + x: self.x, + y: -self.y, + is_infinity: false, + } + } +} + impl Add for SepticPoint { type Output = Self; @@ -770,6 +835,10 @@ impl SepticPoint { self.y.square() == self.x.square() * &self.x + (&self.x * a) + b } + + pub fn point_at_infinity() -> Self { + Self::default() + } } impl SepticPoint { @@ -793,10 +862,186 @@ impl SepticPoint { } } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct SepticJacobianPoint { + pub x: SepticExtension, + pub y: SepticExtension, + pub z: SepticExtension, +} + +impl From> for SepticJacobianPoint { + fn from(p: SepticPoint) -> Self { + if p.is_infinity { + Self::default() + } else { + Self { + x: p.x, + y: p.y, + z: SepticExtension::one(), + } + } + } +} + +impl Default for SepticJacobianPoint { + fn default() -> Self { + // return the point at infinity + Self { + x: SepticExtension::zero(), + y: SepticExtension::one(), + z: SepticExtension::zero(), + } + } +} + +impl SepticJacobianPoint { + pub fn point_at_infinity() -> Self { + Self::default() + } + + pub fn is_on_curve(&self) -> bool { + if self.z.is_zero() { + return self.x.is_zero() && !self.y.is_zero(); + } + + let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); + let a: F = F::from_canonical_u32(2); + + let z2 = self.z.square(); + let z4 = z2.square(); + let z6 = &z4 * &z2; + + // y^2 = x^3 + 2x*z^4 + b*z^6 + self.y.square() == self.x.square() * &self.x + (&self.x * a * z4) + (b * &z6) + } + + pub fn into_affine(self) -> SepticPoint { + if self.z.is_zero() { + return SepticPoint::point_at_infinity(); + } + + let z_inv = self.z.inverse().unwrap(); + let z_inv2 = z_inv.square(); + let z_inv3 = &z_inv2 * &z_inv; + + let x = &self.x * &z_inv2; + let y = &self.y * &z_inv3; + + SepticPoint { + x, + y, + is_infinity: false, + } + } +} + +impl Add for &SepticJacobianPoint { + type Output = SepticJacobianPoint; + + fn add(self, rhs: Self) -> Self::Output { + // https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#addition-add-2007-bl + if self.z.is_zero() { + return rhs.clone(); + } + + if rhs.z.is_zero() { + return self.clone(); + } + + let z1z1 = self.z.square(); + let z2z2 = rhs.z.square(); + + let u1 = &self.x * &z2z2; + let u2 = &rhs.x * &z1z1; + + let s1 = &self.y * &z2z2 * &rhs.z; + let s2 = &rhs.y * &z1z1 * &self.z; + + if u1 == u2 { + if s1 == s2 { + return self.double(); + } else { + return SepticJacobianPoint::point_at_infinity(); + } + } + + let two = F::from_canonical_u32(2); + let h = u2 - &u1; + let i = (&h * two).square(); + let j = &h * &i; + let r = (s2 - &s1) * two; + let v = u1 * &i; + + let x3 = r.square() - &j - &v * two; + let y3 = r * (v - &x3) - s1 * &j * two; + let z3 = (&self.z + &rhs.z).square() - &z1z1 - &z2z2; + let z3 = z3 * h; + + Self::Output { + x: x3, + y: y3, + z: z3, + } + } +} + +impl SepticJacobianPoint { + pub fn double(&self) -> Self { + // https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#doubling-dbl-2007-bl + + // y = 0 means self.order = 2 + if self.y.is_zero() { + return SepticJacobianPoint::point_at_infinity(); + } + + let two = F::from_canonical_u32(2); + let three = F::from_canonical_u32(3); + let eight = F::from_canonical_u32(8); + let a = F::from_canonical_u32(2); // The curve coefficient a + + // xx = x1^2 + let xx = self.x.square(); + + // yy = y1^2 + let yy = self.y.square(); + + // yyyy = yy^2 + let yyyy = yy.square(); + + // zz = z1^2 + let zz = self.z.square(); + + // S = 2*((x1 + y1^2)^2 - x1^2 - y1^4) + let s = (&self.x + &yy).square() - &xx - &yyyy; + let s = s * two; + + // M = 3*x1^2 + a*z1^4 + let m = &xx * three + zz.square() * a; + + // T = M^2 - 2*S + let t = m.square() - &s * two; + + // Y3 = M*(S-T)-8*y^4 + let y3 = m * (&s - &t) - &yyyy * eight; + + // X3 = T + let x3 = t; + + // Z3 = (y1+z1)^2 - y1^2 - z1^2 + let z3 = (&self.y + &self.z).square() - &yy - &zz; + + Self { + x: x3, + y: y3, + z: z3, + } + } +} + #[cfg(test)] mod tests { use super::SepticExtension; - use crate::scheme::septic_curve::SepticPoint; + use crate::scheme::septic_curve::{SepticJacobianPoint, SepticPoint}; use p3::{babybear::BabyBear, field::Field}; use rand::thread_rng; @@ -833,7 +1078,27 @@ mod tests { let p1 = SepticPoint::::random(&mut rng); let p2 = SepticPoint::::random(&mut rng); + let j1 = SepticJacobianPoint::from(p1.clone()); + let j2 = SepticJacobianPoint::from(p2.clone()); + let p3 = p1 + p2; + let j3 = &j1 + &j2; + + assert!(j1.is_on_curve()); + assert!(j2.is_on_curve()); + + assert!(j3.is_on_curve()); assert!(p3.is_on_curve()); + + assert_eq!(p3, j3.clone().into_affine()); + + // 2*p3 - p3 = p3 + let p4 = p3.double(); + assert_eq!((-p3.clone() + p4.clone()), p3); + + // 2*j3 = 2*p3 + let j4 = j3.double(); + assert!(j4.is_on_curve()); + assert_eq!(j4.into_affine(), p4); } } From a266248ed204df1bc0a65f1b635e4191d86a843f Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Wed, 24 Sep 2025 16:27:20 +0800 Subject: [PATCH 16/91] infer septic sum using jacobian coordinates --- ceno_zkvm/src/scheme/constants.rs | 2 + ceno_zkvm/src/scheme/septic_curve.rs | 22 +++++++- ceno_zkvm/src/scheme/utils.rs | 82 ++++++++++++++++------------ 3 files changed, 71 insertions(+), 35 deletions(-) diff --git a/ceno_zkvm/src/scheme/constants.rs b/ceno_zkvm/src/scheme/constants.rs index 3cc212e9f..901ce69ed 100644 --- a/ceno_zkvm/src/scheme/constants.rs +++ b/ceno_zkvm/src/scheme/constants.rs @@ -7,3 +7,5 @@ pub const NUM_FANIN_LOGUP: usize = 2; pub const MAX_NUM_VARIABLES: usize = 24; pub const DYNAMIC_RANGE_MAX_BITS: usize = 18; + +pub const SEPTIC_JACOBIAN_NUM_MLES: usize = 3 * 7; diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 87ce33509..b6ce61598 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -985,6 +985,14 @@ impl Add for &SepticJacobianPoint { } } +impl Add for SepticJacobianPoint { + type Output = SepticJacobianPoint; + + fn add(self, rhs: Self) -> Self::Output { + (&self).add(&rhs) + } +} + impl SepticJacobianPoint { pub fn double(&self) -> Self { // https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#doubling-dbl-2007-bl @@ -1038,6 +1046,18 @@ impl SepticJacobianPoint { } } +impl Sum for SepticJacobianPoint { + fn sum>(iter: I) -> Self { + iter.fold(Self::default(), |acc, p| acc + p) + } +} + +impl SepticJacobianPoint { + pub fn random(rng: impl RngCore) -> Self { + SepticPoint::random(rng).into() + } +} + #[cfg(test)] mod tests { use super::SepticExtension; @@ -1095,7 +1115,7 @@ mod tests { // 2*p3 - p3 = p3 let p4 = p3.double(); assert_eq!((-p3.clone() + p4.clone()), p3); - + // 2*j3 = 2*p3 let j4 = j3.double(); assert!(j4.is_on_curve()); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 861919b61..e48978c51 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -1,8 +1,8 @@ use crate::{ scheme::{ - constants::MIN_PAR_SIZE, + constants::{MIN_PAR_SIZE, SEPTIC_JACOBIAN_NUM_MLES}, hal::{MainSumcheckProver, ProofInput, ProverDevice}, - septic_curve::{SepticExtension, SepticPoint}, + septic_curve::{SepticExtension, SepticJacobianPoint, SepticPoint}, }, structs::ComposedConstrainSystem, }; @@ -338,76 +338,83 @@ pub fn infer_tower_product_witness( wit_layers } -/// Infer from input layer (layer 0) to the output layer (layer n) -/// Note that each layer has 2 * 7 * 2 multilinear polynomials. +/// Infer from input layer (layer n) to the output layer (layer 0) +/// Note that each layer has 3 * 7 * 2 multilinear polynomials since we use jacobian coordinates. /// /// The relation between layer i and layer i+1 is as follows: -/// 0 = p[i][b] - (p[i+1][0,b] + p[i+1][1,b]) +/// (x1', y1', z1')[b] = jacobian_add( (x1, y1, z1)[0,b], (x2, y2, z2)[1,b] ) +/// (x2', y2', z2')[b] = jacobian_add( (x3, y3, z3)[0,b], (x4, y4, z4)[1,b] ) +/// +/// TODO handle jacobian_add & jacobian_double at the same time pub fn infer_septic_sum_witness( p_mles: Vec>, ) -> Vec>> { - assert_eq!(p_mles.len(), 2 * 7 * 2); + assert_eq!(p_mles.len(), SEPTIC_JACOBIAN_NUM_MLES * 2); assert!(p_mles.iter().map(|p| p.num_vars()).all_equal()); // +1 as the input layer has 2*N points where N = 2^num_vars // and the output layer has 2 points let num_layers = p_mles[0].num_vars() + 1; + println!("{num_layers} layers in total"); let mut layers = Vec::with_capacity(num_layers); layers.push(p_mles); - for _ in (0..num_layers - 1).rev() { + for layer in (0..num_layers - 1).rev() { let input_layer = layers.last().unwrap(); - let p = input_layer[0..14] + let p = input_layer[0..SEPTIC_JACOBIAN_NUM_MLES] .iter() .map(|mle| mle.get_base_field_vec()) .collect_vec(); - let q = input_layer[14..28] + let q = input_layer[SEPTIC_JACOBIAN_NUM_MLES..] .iter() .map(|mle| mle.get_base_field_vec()) .collect_vec(); let output_len = p[0].len() / 2; - let mut outputs: Vec = Vec::with_capacity(28 * output_len); + let mut outputs: Vec = + Vec::with_capacity(SEPTIC_JACOBIAN_NUM_MLES * 2 * output_len); unsafe { // will be filled immediately - outputs.set_len(28 * output_len); + outputs.set_len(SEPTIC_JACOBIAN_NUM_MLES * 2 * output_len); } (0..2).into_iter().for_each(|chunk| { (0..output_len) .into_par_iter() .with_min_len(MIN_PAR_SIZE) - .zip(outputs.par_chunks_mut(28)) + .zip_eq(outputs.par_chunks_mut(SEPTIC_JACOBIAN_NUM_MLES * 2)) .for_each(|(idx, output)| { let row = chunk * output_len + idx; - let offset = chunk * 14; + let offset = chunk * SEPTIC_JACOBIAN_NUM_MLES; - let p1 = SepticPoint { + let p1 = SepticJacobianPoint { x: SepticExtension(from_fn(|i| p[i][row])), y: SepticExtension(from_fn(|i| p[i + 7][row])), - is_infinity: false, + z: SepticExtension(from_fn(|i| p[i + 14][row])), }; - let p2 = SepticPoint { + let p2 = SepticJacobianPoint { x: SepticExtension(from_fn(|i| q[i][row])), y: SepticExtension(from_fn(|i| q[i + 7][row])), - is_infinity: false, + z: SepticExtension(from_fn(|i| q[i + 14][row])), }; - debug_assert!(p1.is_on_curve() && p2.is_on_curve()); + assert!(p1.is_on_curve(), "{layer}, {row}"); + assert!(p2.is_on_curve(), "{layer}, {row}"); - let p3 = p1 + p2; + let p3 = &p1 + &p2; output[offset..offset + 7].clone_from_slice(&p3.x); output[offset + 7..offset + 14].clone_from_slice(&p3.y); + output[offset + 14..offset + 21].clone_from_slice(&p3.z); }); }); // transpose - let output_mles = (0..28) + let output_mles = (0..SEPTIC_JACOBIAN_NUM_MLES * 2) .map(|i| { (0..output_len) .into_par_iter() - .map(|j| outputs[j * 28 + i]) + .map(|j| outputs[j * SEPTIC_JACOBIAN_NUM_MLES * 2 + i]) .collect::>() .into_mle() }) @@ -639,7 +646,8 @@ mod tests { use p3::{babybear::BabyBear, field::FieldAlgebra}; use crate::scheme::{ - septic_curve::{SepticExtension, SepticPoint}, + constants::SEPTIC_JACOBIAN_NUM_MLES, + septic_curve::{SepticExtension, SepticJacobianPoint, SepticPoint}, utils::{ infer_septic_sum_witness, infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, @@ -956,13 +964,13 @@ mod tests { type F = BabyBear; type E = BabyBearExt4; - let n_points = 1 << 4; + let n_points = 1 << 6; let mut rng = rand::thread_rng(); // sample n points - let points = (0..n_points) - .map(|_| SepticPoint::::random(&mut rng)) + let points: Vec> = (0..n_points) + .map(|_| SepticJacobianPoint::::random(&mut rng)) .collect_vec(); - + // transform points to row major matrix let trace = points[0..n_points / 2] .iter() @@ -970,7 +978,7 @@ mod tests { .map(|(p, q)| { [p, q] .iter() - .flat_map(|p| p.x.0.iter().chain(p.y.0.iter())) + .flat_map(|p| p.x.0.iter().chain(p.y.0.iter()).chain(p.z.0.iter())) .copied() .collect_vec() }) @@ -985,27 +993,33 @@ mod tests { let layers = infer_septic_sum_witness(p_mles); let output_layer = &layers[0]; assert!(output_layer.iter().all(|mle| mle.num_vars() == 0)); - assert!(output_layer.len() == 28); + assert!(output_layer.len() == SEPTIC_JACOBIAN_NUM_MLES * 2); // recover points from output layer - let output_points: Vec> = output_layer - .chunks_exact(14) + let output_points: Vec> = output_layer + .chunks_exact(SEPTIC_JACOBIAN_NUM_MLES) .map(|mles| { mles.iter() .map(|mle| mle.get_base_field_vec()[0]) .collect_vec() }) - .map(|chunk| SepticPoint { + .map(|chunk| SepticJacobianPoint { x: SepticExtension(chunk[0..7].try_into().unwrap()), y: SepticExtension(chunk[7..14].try_into().unwrap()), - is_infinity: false, + z: SepticExtension(chunk[14..21].try_into().unwrap()), }) .collect_vec(); assert!(output_points.iter().all(|p| p.is_on_curve())); assert_eq!(output_points.len(), 2); - let point_acc: SepticPoint = output_points.into_iter().sum(); - let expected_acc: SepticPoint = points.into_iter().sum(); + let point_acc: SepticPoint = output_points + .into_iter() + .sum::>() + .into_affine(); + let expected_acc: SepticPoint = points + .into_iter() + .sum::>() + .into_affine(); assert_eq!(point_acc, expected_acc); } } From d22f76d9b4afabecf7d59bf6188620bc25a9ee0e Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Fri, 26 Sep 2025 00:28:22 +0800 Subject: [PATCH 17/91] update tower verifier for jacobian coordinates --- ceno_zkvm/src/scheme/verifier.rs | 150 ++++++++++++++++++++++--------- 1 file changed, 106 insertions(+), 44 deletions(-) diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 3e3dfca09..3f8a66bfa 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -27,8 +27,11 @@ use witness::next_pow2_instance_padding; use crate::{ error::ZKVMError, scheme::{ - constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, - septic_curve::{SepticExtension, SepticPoint}, + constants::{ + NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE, SEPTIC_EXTENSION_DEGREE, + SEPTIC_JACOBIAN_NUM_MLES, + }, + septic_curve::{SepticJacobianPoint}, }, structs::{ComposedConstrainSystem, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, utils::{ @@ -380,6 +383,7 @@ impl> ZKVMVerifier .chain(proof.w_out_evals.iter().cloned()) .collect_vec(), proof.lk_out_evals.clone(), + vec![], tower_proofs, vec![num_var_with_rotation; num_batched], num_product_fanin, @@ -512,6 +516,7 @@ impl> ZKVMVerifier .iter() .map(|eval| eval.to_vec()) .collect_vec(), + vec![], tower_proofs, expected_rounds, num_logup_fanin, @@ -738,12 +743,16 @@ pub type TowerVerifyResult = Result< impl TowerVerify { fn get_ecc_eval( - p1: &SepticPoint, - p2: &SepticPoint, + p1: &SepticJacobianPoint, + p2: &SepticJacobianPoint, rt: &[E], - ) -> SepticPoint { - let SepticPoint { x, y } = p1; - let SepticPoint { x: x2, y: y2 } = p2; + ) -> Vec> { + let SepticJacobianPoint { x, y, z } = p1; + let SepticJacobianPoint { + x: x2, + y: y2, + z: z2, + } = p2; let xs = x.0.iter() @@ -759,17 +768,25 @@ impl TowerVerify { .map(|(yi, y2i)| vec![yi, y2i].into_mle().evaluate(rt)) .collect_vec(); - SepticPoint { - x: xs.as_slice().into(), - y: ys.as_slice().into(), - } + let zs = + z.0.iter() + .cloned() + .zip(z2.iter().cloned()) + .map(|(zi, z2i)| vec![zi, z2i].into_mle().evaluate(rt)) + .collect_vec(); + + xs.into_iter() + .chain(ys.into_iter()) + .chain(zs.into_iter()) + .map(|eval| PointAndEval::new(rt.to_vec(), eval)) + .collect_vec() } pub fn verify( // TODO: unify prod/logup/ec_add prod_out_evals: Vec>, logup_out_evals: Vec>, - ecc_out_evals: Vec>, + ecc_out_evals: Vec>, tower_proofs: &TowerProofs, num_variables: Vec, num_fanin: usize, @@ -789,11 +806,13 @@ impl TowerVerify { assert!(logup_out_evals.iter().all(|evals| { evals.len() == 4 // [p1, p2, q1, q2] })); - assert_eq!(ecc_out_evals.len(), 2); + assert!(ecc_out_evals.len() == 0 || ecc_out_evals.len() == 2); assert_eq!(num_variables.len(), num_prod_spec + num_logup_spec); + let num_ecc = if ecc_out_evals.is_empty() { 0 } else { 1 }; + let alpha_pows = get_challenge_pows( - num_prod_spec + num_logup_spec * 2 + 14, /* logup occupy 2 sumcheck: numerator and denominator */ + num_prod_spec + num_logup_spec * 2 + num_ecc * SEPTIC_JACOBIAN_NUM_MLES, /* logup occupy 2 sumcheck: numerator and denominator */ transcript, ); let initial_rt: Point = transcript.sample_and_append_vec(b"product_sum", log2_num_fanin); @@ -801,6 +820,9 @@ impl TowerVerify { // out_j[rt] := (record_{j}[rt]) // out_j[rt] := (logup_p{j}[rt]) // out_j[rt] := (logup_q{j}[rt]) + // out_j[rt] := ecc_x{j}[rt] + // out_j[rt] := ecc_y{j}[rt] + // out_j[rt] := ecc_z{j}[rt] // bookkeeping records of latest (point, evaluation) of each layer // prod argument @@ -827,7 +849,11 @@ impl TowerVerify { ) }) .unzip::<_, _, Vec<_>, Vec<_>>(); - let mut ecc_eval = Self::get_ecc_eval(&ecc_out_evals[0], &ecc_out_evals[1], &initial_rt); + let mut ecc_eval = if num_ecc == 1 { + Self::get_ecc_eval(&ecc_out_evals[0], &ecc_out_evals[1], &initial_rt) + } else { + vec![] + }; // initial claim = \sum_j alpha^j * out_j[rt] let initial_claim = izip!(&prod_spec_point_n_eval, &alpha_pows) @@ -838,7 +864,10 @@ impl TowerVerify { &alpha_pows[num_prod_spec..] ) .map(|(point_n_eval, alpha)| point_n_eval.eval * *alpha) - .sum::(); + .sum::() + + izip!(&ecc_eval, &alpha_pows[num_prod_spec + num_logup_spec * 2..]) + .map(|(xi, alpha)| xi.eval * *alpha) + .sum::(); let max_num_variables = num_variables.iter().max().unwrap(); @@ -872,8 +901,8 @@ impl TowerVerify { .zip(alpha_pows.iter()) .zip(num_variables.iter()) .map(|((spec_index, alpha), max_round)| { - // prod[b] = prod'[0,b] * prod'[1,b] - // prod[out_rt] = \sum_b eq(out_rt, b) * prod[b] = \sum_b eq(out_rt, b) * prod'[0,b] * prod'[1,b] + // prod'[b] = prod[0,b] * prod[1,b] + // prod'[out_rt] = \sum_b eq(out_rt,b) * prod'[b] = \sum_b eq(out_rt,b) * prod[0,b] * prod[1,b] eq * *alpha * if round < *max_round-1 {tower_proofs.prod_specs_eval[spec_index][round].iter().copied().product()} else { E::ZERO @@ -884,10 +913,10 @@ impl TowerVerify { .zip_eq(alpha_pows[num_prod_spec..].chunks(2)) .zip_eq(num_variables[num_prod_spec..].iter()) .map(|((spec_index, alpha), max_round)| { - // logup_q[b] = logup_q'[0,b] * logup_q'[1,b] - // logup_p[b] = logup_p'[0,b] * logup_q'[1,b] + logup_p'[1,b] * logup_q'[0,b] - // logup_p[out_rt] = \sum_b eq(out_rt, b) * (logup_p'[0,b] * logup_q'[1,b] + logup_p'[1,b] * logup_q'[0,b]) - // logup_q[out_rt] = \sum_b eq(out_rt, b) * logup_q'[0,b] * logup_q'[1,b] + // logup_q'[b] = logup_q[0,b] * logup_q[1,b] + // logup_p'[b] = logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b] + // logup_p'[out_rt] = \sum_b eq(out_rt,b) * (logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b]) + // logup_q'[out_rt] = \sum_b eq(out_rt,b) * logup_q[0,b] * logup_q[1,b] let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); eq * if round < *max_round-1 { let evals = &tower_proofs.logup_specs_eval[spec_index][round]; @@ -900,21 +929,34 @@ impl TowerVerify { } }) .sum::(); - // 0 = \sum_b eq(out_rt, b) * (ecc_x[b] + ecc_x'[0,b] + ecc_x'[1,b]) - // * (ecc_x'[1,b] - ecc_x'[0,b])^2 - // - (ecc_y'[1,b] - ecc_y'[0,b])^2 - let SepticPoint { x, y } = &ecc_eval; - let SepticPoint { x: x1, y: y1 } = &tower_proofs.ecc_evals[round][0]; - let SepticPoint { x: x2, y: y2 } = &tower_proofs.ecc_evals[round][1]; - - let xs = (x + x1 + x2) * (x2 - x1) * (x2 - x1) - - (y2 - y1) * (y2 - y1); - - // 0 = (ecc_y + ecc_y'[0]) * (ecc_x'[1] - ecc_x'[0]) - // - (ecc_y'[1] - ecc_y'[0]) * (ecc_x'[0] - ecc_x) - let ys = (y + y1) * (x2 - x1) - (y2 - y1) * (x1 - x); - expected_evaluation += izip!(xs.0.iter(), alpha_pows[num_prod_spec + num_logup_spec * 2..].iter().take(7)).map(|(&xi, &alpha)| eq * xi * alpha).sum::(); - expected_evaluation += izip!(ys.0.iter(), alpha_pows[num_prod_spec + num_logup_spec * 2..].iter().skip(7).take(7)).map(|(&yi, &alpha)| eq * yi * alpha).sum::(); + + if num_ecc == 1 { + // (x', y', z')[b] = jacobian_add((x, y, z)[0,b], (x, y, z)[1,b]) + let degree = SEPTIC_EXTENSION_DEGREE; + let (x1, rest) = tower_proofs.ecc_evals[round].split_at(SEPTIC_EXTENSION_DEGREE); + let (y1, rest) = rest.split_at(SEPTIC_EXTENSION_DEGREE); + let (z1, rest) = rest.split_at(SEPTIC_EXTENSION_DEGREE); + let (x2, rest) = rest.split_at(SEPTIC_EXTENSION_DEGREE); + let (y2, rest) = rest.split_at(SEPTIC_EXTENSION_DEGREE); + let (z2, rest) = rest.split_at(SEPTIC_EXTENSION_DEGREE); + + // p1 and p2 are not valid ecc points + // we just want to use ecc addition formula as expression to get + // the expected evaluation for ecc sumcheck + let p1 = SepticJacobianPoint { + x: x1.into(), + y: y1.into(), + z: z1.into(), + }; + let p2 = SepticJacobianPoint { + x: x2.into(), + y: y2.into(), + z: z2.into(), + }; + + let SepticJacobianPoint { x, y, z } = p1 + p2; + expected_evaluation += izip!(x.0.iter().chain(y.0.iter()).chain(z.0.iter()), alpha_pows[num_prod_spec + num_logup_spec * 2..].iter()).map(|(&xi, &alpha)| eq * xi * alpha).sum::(); + } if expected_evaluation != sumcheck_claim.expected_evaluation { return Err(ZKVMError::VerifyError("mismatch tower evaluation".into())); @@ -994,14 +1036,34 @@ impl TowerVerify { } }) .sum::(); - // update ecc_eval - ecc_eval = Self::get_ecc_eval( - &tower_proofs.ecc_evals[round][0], - &tower_proofs.ecc_evals[round][1], - &rt_prime, - ); + // sum evaluation from different specs - let next_eval = next_prod_spec_evals + next_logup_spec_evals; + let mut next_eval = next_prod_spec_evals + next_logup_spec_evals; + + if num_ecc == 1 { + let next_round_expected_eval = if round < *max_num_variables - 1 { + tower_proofs.ecc_evals[round][..SEPTIC_JACOBIAN_NUM_MLES] + .iter() + .zip(tower_proofs.ecc_evals[round][SEPTIC_JACOBIAN_NUM_MLES..].iter()) + .zip(next_alpha_pows[num_prod_spec + num_logup_spec * 2..].iter()) + .zip(ecc_eval.iter_mut()) + .map(|(((a, b), alpha), point_and_eval)| { + let eval = izip!(vec![a, b].into_iter(), coeffs.iter()).map(|(a, b)| *a * *b).sum::(); + + point_and_eval.point = rt_prime.clone(); + point_and_eval.eval = eval; + + eval * *alpha + }) + .sum() + } else { + E::ZERO + }; + if next_round < *max_num_variables - 1 { + next_eval += next_round_expected_eval; + } + } + Ok((PointAndEval { point: rt_prime, eval: next_eval, From a662a49c8040c65af32a279f169039b87480aaea Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Fri, 26 Sep 2025 17:33:06 +0800 Subject: [PATCH 18/91] jacobian coordinates in tower's sumcheck --- ceno_zkvm/src/scheme/constants.rs | 3 +- ceno_zkvm/src/scheme/cpu/mod.rs | 218 +++++++++++++-------------- ceno_zkvm/src/scheme/septic_curve.rs | 1 + ceno_zkvm/src/scheme/verifier.rs | 3 +- 4 files changed, 111 insertions(+), 114 deletions(-) diff --git a/ceno_zkvm/src/scheme/constants.rs b/ceno_zkvm/src/scheme/constants.rs index 901ce69ed..feb3fd923 100644 --- a/ceno_zkvm/src/scheme/constants.rs +++ b/ceno_zkvm/src/scheme/constants.rs @@ -8,4 +8,5 @@ pub const MAX_NUM_VARIABLES: usize = 24; pub const DYNAMIC_RANGE_MAX_BITS: usize = 18; -pub const SEPTIC_JACOBIAN_NUM_MLES: usize = 3 * 7; +pub const SEPTIC_EXTENSION_DEGREE: usize = 7; +pub const SEPTIC_JACOBIAN_NUM_MLES: usize = 3 * SEPTIC_EXTENSION_DEGREE; diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index f800a5dd8..50d7e268b 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -5,7 +5,9 @@ use crate::{ circuit_builder::ConstraintSystem, error::ZKVMError, scheme::{ - constants::{NUM_FANIN, NUM_FANIN_LOGUP}, + constants::{ + NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE, SEPTIC_JACOBIAN_NUM_MLES, + }, hal::{DeviceProvingKey, MainSumcheckEvals, ProofInput, TowerProverSpec}, septic_curve::{SepticExtension, SymbolicSepticExtension}, utils::{ @@ -63,12 +65,7 @@ impl CpuTowerProver { enum GroupedMLE<'a, E: ExtensionField> { Prod((usize, Vec>)), // usize is the index in prod_specs Logup((usize, Vec>)), // usize is the index in logup_specs - EcAdd( - ( - Vec>, - Vec>, - ), - ), + EcAdd(Vec>), // 2 points, each point has 21 polys } // XXX to sumcheck batched product argument with logup, we limit num_product_fanin to 2 @@ -93,7 +90,7 @@ impl CpuTowerProver { prod_specs_len + // logup occupy 2 sumcheck: numerator and denominator logup_specs_len * 2 - + ecc_spec.as_ref().map_or(0, |_| 14), + + ecc_spec.as_ref().map_or(0, |_| SEPTIC_JACOBIAN_NUM_MLES), transcript, ); let initial_rt: Point = transcript.sample_and_append_vec(b"product_sum", log_num_fanin); @@ -124,13 +121,9 @@ impl CpuTowerProver { } if let Some(ecc_spec) = ecc_spec { - for i in 0..max_round_index { - layer_witness[i + 1].push(GroupedMLE::EcAdd(( - // TODO: avoid clone - ecc_spec.witness[i].clone(), - ecc_spec.witness[i + 1].clone(), - ))); - } + merge_spec_witness(&mut layer_witness, ecc_spec, 0, |(_, v)| { + GroupedMLE::EcAdd(v) + }); } // skip(1) for output layer @@ -143,11 +136,7 @@ impl CpuTowerProver { let mut witness_lk_expr = vec![vec![]; logup_specs_len]; let mut eq: MultilinearExtension = build_eq_x_r_vec(&out_rt).into_mle(); - let eq_len = eq.evaluations.len(); - let mut eq_prime = eq.get_ext_field_vec()[0..eq_len / 2].to_vec().into_mle(); - let eq_expr = expr_builder.lift(Either::Right(&mut eq)); - let eq_prime_expr = expr_builder.lift(Either::Right(&mut eq_prime)); // processing exprs for group_witness in layer_witness.iter_mut() { @@ -215,112 +204,89 @@ impl CpuTowerProver { + alpha_denominator * q1 * q2), ); } - GroupedMLE::EcAdd((prev_layer, curr_layer)) => { - assert_eq!(curr_layer.len(), 3 * 14); // 3 points, each point has 14 polys - assert_eq!(prev_layer.len(), 3 * 14); - let (x1, rest) = curr_layer.split_at(7); - let (y1, rest) = rest.split_at(7); - let (x2, rest) = rest.split_at(7); - let (y2, rest) = rest.split_at(7); - let (x3, y3) = rest.split_at(7); - - // x1'[b] = x3[0,b] - // y1'[b] = y3[0,b] - // x2'[b] = x3[1,b] - // y2'[b] = y3[1,b] - let (x1_prime, rest) = prev_layer.split_at_mut(7); - let (y1_prime, rest) = rest.split_at_mut(7); - let (x2_prime, rest) = rest.split_at_mut(7); - let (y2_prime, _) = rest.split_at_mut(7); + GroupedMLE::EcAdd(layer_polys) => { + assert_eq!(layer_polys.len(), 2 * SEPTIC_JACOBIAN_NUM_MLES); // 2 points, each point has 21 polys + + let (x1, rest) = layer_polys.split_at_mut(SEPTIC_EXTENSION_DEGREE); + let (y1, rest) = rest.split_at_mut(SEPTIC_EXTENSION_DEGREE); + let (z1, rest) = rest.split_at_mut(SEPTIC_EXTENSION_DEGREE); + let (x2, rest) = rest.split_at_mut(SEPTIC_EXTENSION_DEGREE); + let (y2, z2) = rest.split_at_mut(SEPTIC_EXTENSION_DEGREE); let x1 = &SymbolicSepticExtension::new( x1.into_iter() - .map(|x| expr_builder.lift(Either::Left(x))) + .map(|x| expr_builder.lift(x.to_either())) .collect(), ); let y1 = &SymbolicSepticExtension::new( y1.into_iter() - .map(|y| expr_builder.lift(Either::Left(y))) + .map(|y| expr_builder.lift(y.to_either())) + .collect(), + ); + let z1 = &SymbolicSepticExtension::new( + z1.into_iter() + .map(|z| expr_builder.lift(z.to_either())) .collect(), ); let x2 = &SymbolicSepticExtension::new( x2.into_iter() - .map(|x| expr_builder.lift(Either::Left(x))) + .map(|x| expr_builder.lift(x.to_either())) .collect(), ); let y2 = &SymbolicSepticExtension::new( y2.into_iter() - .map(|y| expr_builder.lift(Either::Left(y))) - .collect(), - ); - let x3 = &SymbolicSepticExtension::new( - x3.into_iter() - .map(|x| expr_builder.lift(Either::Left(x))) - .collect(), - ); - let y3 = &SymbolicSepticExtension::new( - y3.into_iter() - .map(|y| expr_builder.lift(Either::Left(y))) - .collect(), - ); - let x1_prime_expr = SymbolicSepticExtension::new( - x1_prime - .iter_mut() - .map(|x| expr_builder.lift(x.to_either())) - .collect(), - ); - let y1_prime_expr = SymbolicSepticExtension::new( - y1_prime - .iter_mut() .map(|y| expr_builder.lift(y.to_either())) .collect(), ); - let x2_prime_expr = SymbolicSepticExtension::new( - x2_prime - .iter_mut() - .map(|x| expr_builder.lift(x.to_either())) - .collect(), - ); - let y2_prime_expr = SymbolicSepticExtension::new( - y2_prime - .iter_mut() - .map(|y| expr_builder.lift(y.to_either())) + let z2 = &SymbolicSepticExtension::new( + z2.into_iter() + .map(|z| expr_builder.lift(z.to_either())) .collect(), ); - // layer i: x3', y3', x1', y1', x2', y2', each has `i` variables - // we copy the first half of x3 to x1', 2nd half to x2' and - // copy the first half of y3 to y1', 2nd half to y2'. - // - // x1'[b] = x3[0,b], y1'[b] = y3[0,b] - // x2'[b] = x3[1,b], y2'[b] = y3[1,b] - // - // layer i+1: x3, y3, x1, y1, x2, y2, each has `i+1` variables - // we requires the elliptic curve addition constraints hold at layer i+1. - // 1. 0 = \sum_b eq(rt,b) * ((x3 + x1 + x2) * (x2 - x1)^2 - (y2 - y1)^2) - exprs.extend( - (((x3 + x1 + x2) * (x2 - x1) * (x2 - x1) - (y2 - y1) * (y2 - y1)) - * &eq_expr) - .to_exprs(), - ); - // 2. 0 = \sum_b eq(rt,b) * ((y3 + y1) * (x2 - x1) - (y2 - y1) * (x1 - x3)) - exprs.extend( - (((y3 + y1) * (x2 - x1) - (y2 - y1) * (x1 - x3)) * &eq_expr).to_exprs(), - ); - - // with len = rt.len(), rt' = rt[0..len-1] - // x1'[rt'] = \sum_b' eq(rt',b') * x3[0,b'] - // y1'[rt'] = \sum_b' eq(rt',b') * y3[0,b'] - // x2'[rt'] = \sum_b' eq(rt',b') * x3[1,b'] - // y2'[rt'] = \sum_b' eq(rt',b') * y3[1,b'] - exprs.extend((x1_prime_expr * &eq_prime_expr).to_exprs()); - exprs.extend((y1_prime_expr * &eq_prime_expr).to_exprs()); - exprs.extend((x2_prime_expr * &eq_prime_expr).to_exprs()); - exprs.extend((y2_prime_expr * &eq_prime_expr).to_exprs()); + let two: Expression = 2.into(); + let four: Expression = 4.into(); + let z1_squared = z1 * z1; + println!("z1_squared: {:?}", z1_squared); + let z1_cubed = &z1_squared * z1; + let z2_squared = z2 * z2; + let z2_cubed = &z2_squared * z2; + + // U1 = X1*Z2^2, U2 = X2*Z1^2 + let u1 = x1 * &z2_squared; + let u2 = x2 * &z1_squared; + + // S1 = Y1*Z2^3, S2 = Y2*Z1^3 + let s1 = y1 * &z2_cubed; + let s2 = y2 * &z1_cubed; + + // H = U2-U1, R = S2-S1 + let h = u2 - &u1; + let h_squared = &h * &h; + let h_cubed = &h_squared * &h; + + let i = h_squared * &four; + let j = h_cubed * &four; + let r = (&s2 - &s1) * &two; + let v = &u1 * &i; + + // Check the formulas for X3, Y3, Z3 + // X3 = R^2 - J - 2*V + let x3 = &r * &r - j.clone() - v.clone() * &two; + // Y3 = R*(V - X3) - 2*S1*J + let y3 = r * (&v - &x3) - s1 * j * &two; + // Z3 = (Z1 + Z2)^2 - Z1Z1 - Z2Z2) * H + let z3 = z1 * z2 * h * &two; + // exprs.extend((x3 * &eq_expr).to_exprs()); + // exprs.extend((y3 * &eq_expr).to_exprs()); + // exprs.extend((z3 * &eq_expr).to_exprs()); } } } + for expr in exprs.iter() { + println!("expr: {:?}", expr); + } let wrap_batch_span = entered_span!("wrap_batch"); let (sumcheck_proofs, state) = IOPProverState::prove( expr_builder.to_virtual_polys(&[exprs.into_iter().sum()], &[]), @@ -976,8 +942,12 @@ mod tests { use transcript::BasicTranscript; use crate::scheme::{ - cpu::CpuTowerProver, hal::TowerProverSpec, septic_curve::SepticPoint, + constants::SEPTIC_JACOBIAN_NUM_MLES, + cpu::CpuTowerProver, + hal::TowerProverSpec, + septic_curve::{SepticExtension, SepticJacobianPoint, SepticPoint}, utils::infer_septic_sum_witness, + verifier::TowerVerify, }; #[test] @@ -986,19 +956,15 @@ mod tests { type F = BabyBear; type PCS = Basefold; - // generate 1 product witness spec - - // generate 1 logup witness spec - // if layer i has n variables, - // then layer i+1 has n-1 variables. + let log2_n = 6; + let n_points = 1 << log2_n; + let mut rng = rand::thread_rng(); // generate 1 ecc add witness let ecc_spec: TowerProverSpec<'_, CpuBackend> = { - let n_points = 1 << 4; - let mut rng = rand::thread_rng(); // sample n points let points = (0..n_points) - .map(|_| SepticPoint::::random(&mut rng)) + .map(|_| SepticJacobianPoint::::random(&mut rng)) .collect_vec(); // transform points to row major matrix @@ -1008,7 +974,7 @@ mod tests { .map(|(p, q)| { [p, q] .iter() - .flat_map(|p| p.x.0.iter().chain(p.y.0.iter())) + .flat_map(|p| p.x.0.iter().chain(p.y.0.iter()).chain(p.z.0.iter())) .copied() .collect_vec() }) @@ -1024,8 +990,38 @@ mod tests { witness: infer_septic_sum_witness(p_mles), } }; + let output_layer = &ecc_spec.witness[0]; + let ecc_out_evals: Vec> = output_layer + .chunks_exact(SEPTIC_JACOBIAN_NUM_MLES) + .map(|mles| { + mles.iter() + .map(|mle| mle.get_base_field_vec()[0]) + .collect_vec() + }) + .map(|chunk| SepticJacobianPoint { + x: SepticExtension(chunk[0..7].try_into().unwrap()), + y: SepticExtension(chunk[7..14].try_into().unwrap()), + z: SepticExtension(chunk[14..21].try_into().unwrap()), + }) + .collect_vec(); + let mut transcript = BasicTranscript::new(b"test"); - let prover = + println!("begin to create tower proof"); + let (_, tower_proof) = CpuTowerProver::create_proof(vec![], vec![], Some(ecc_spec), 2, &mut transcript); + + let mut transcript = BasicTranscript::new(b"test"); + assert!( + TowerVerify::verify( + vec![], + vec![], + ecc_out_evals, + &tower_proof, + vec![], + 2, + &mut transcript + ) + .is_ok() + ); } } diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index b6ce61598..4954a0d2a 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -582,6 +582,7 @@ impl MulAssign for SepticExtension { } } +#[derive(Clone, Debug)] pub struct SymbolicSepticExtension(pub Vec>); impl Add for &SymbolicSepticExtension { diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 3f8a66bfa..dd9c6a8bc 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -31,7 +31,7 @@ use crate::{ NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE, SEPTIC_EXTENSION_DEGREE, SEPTIC_JACOBIAN_NUM_MLES, }, - septic_curve::{SepticJacobianPoint}, + septic_curve::SepticJacobianPoint, }, structs::{ComposedConstrainSystem, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, utils::{ @@ -1063,7 +1063,6 @@ impl TowerVerify { next_eval += next_round_expected_eval; } } - Ok((PointAndEval { point: rt_prime, eval: next_eval, From 502b7b219389ee5c1102a9bc9df4da5652ff871e Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 29 Sep 2025 15:55:57 +0800 Subject: [PATCH 19/91] add quark prover for ecc addition --- ceno_zkvm/src/scheme/cpu/mod.rs | 249 +++++++++++++++++++++++--------- 1 file changed, 182 insertions(+), 67 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 50d7e268b..2434671ab 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -9,7 +9,7 @@ use crate::{ NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE, SEPTIC_JACOBIAN_NUM_MLES, }, hal::{DeviceProvingKey, MainSumcheckEvals, ProofInput, TowerProverSpec}, - septic_curve::{SepticExtension, SymbolicSepticExtension}, + septic_curve::{SepticExtension, SepticPoint, SymbolicSepticExtension}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, masked_mle_split_to_chunks, wit_infer_by_expr, @@ -27,17 +27,14 @@ use gkr_iop::{ use itertools::{Itertools, chain}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - Expression, Instance, WitnessId, - mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, - util::ceil_log2, - virtual_poly::build_eq_x_r_vec, - virtual_polys::VirtualPolynomialsBuilder, + mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, util::ceil_log2, virtual_poly::{build_eq_x_r_vec, eq_eval}, virtual_polys::VirtualPolynomialsBuilder, Expression, Instance, WitnessId }; +use p3::field::PackedValue; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; use std::{collections::BTreeMap, sync::Arc}; use sumcheck::{ macros::{entered_span, exit_span}, - structs::{IOPProverMessage, IOPProverState}, + structs::{IOPProof, IOPProverMessage, IOPProverState}, util::{get_challenge_pows, optimal_sumcheck_threads}, }; use transcript::Transcript; @@ -51,6 +48,131 @@ pub type TowerRelationOutput = ( Vec>, ); +pub struct EccQuarkProof { + pub zerocheck_proof: IOPProof, + + pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1] +} + +// implement the IOP proposed in [Quark paper](https://eprint.iacr.org/2020/1275.pdf) +// to accumulate N=2^n EC points into one EC point using affine coordinates +pub struct CpuEccProver; + +impl CpuEccProver { + pub fn create_ecc_proof<'a, E: ExtensionField>( + &self, + mut xs: Vec>, + mut ys: Vec>, + invs: Vec>, + transcript: &mut impl Transcript, + ) { + assert_eq!(xs.len(), SEPTIC_EXTENSION_DEGREE); + assert_eq!(ys.len(), SEPTIC_EXTENSION_DEGREE); + + let n = xs[0].num_vars() - 1; + let out_rt = transcript.sample_and_append_vec(b"ecc", n); + let num_threads = optimal_sumcheck_threads(out_rt.len()); + + let mut expr_builder = VirtualPolynomialsBuilder::new(num_threads, out_rt.len()); + + let mut eq: MultilinearExtension<'_, E> = build_eq_x_r_vec(&out_rt).into_mle(); + let eq_expr = expr_builder.lift((&mut eq).to_either()); + let mut exprs = vec![]; + + let filter_bj = |v: &[MultilinearExtension<'_, E>], j: usize| { + v.iter() + .map(|v| { + v.get_base_field_vec() + .iter() + .enumerate() + .filter(|(i, _)| *i % 2 == j) + .map(|(_, v)| v) + .cloned() + .collect_vec() + .into_mle() + }) + .collect_vec() + }; + // build x[b,0], x[b,1], y[b,0], y[b,1] + let mut x0 = filter_bj(&xs, 0); + let mut y0 = filter_bj(&ys, 0); + let mut x1 = filter_bj(&xs, 1); + let mut y1 = filter_bj(&ys, 1); + // build x[1,b], y[1,b], s[0,b] + let x3 = xs + .iter_mut() + .map(|x| x.as_view_slice_mut(2, 1)) + .collect_vec(); + let y3 = ys + .iter_mut() + .map(|x| x.as_view_slice_mut(2, 1)) + .collect_vec(); + let mut s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec(); + + let s = SymbolicSepticExtension::new( + s.iter_mut() + .map(|s| expr_builder.lift(s.to_either())) + .collect(), + ); + let x0 = SymbolicSepticExtension::new( + x0.iter_mut() + .map(|x| expr_builder.lift(x.to_either())) + .collect(), + ); + let y0 = SymbolicSepticExtension::new( + y0.iter_mut() + .map(|y| expr_builder.lift(y.to_either())) + .collect(), + ); + let x1 = SymbolicSepticExtension::new( + x1.iter_mut() + .map(|x| expr_builder.lift(x.to_either())) + .collect(), + ); + let y1 = SymbolicSepticExtension::new( + y1.iter_mut() + .map(|y| expr_builder.lift(y.to_either())) + .collect(), + ); + // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) + exprs.extend_from_slice( + (s.clone() * (&x0 - &x1) - (&y0 - &y1)) + .to_exprs() + .as_slice(), + ); + + // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] + + // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - y[b,0] - y[1,b] + + // reduced to s[0,rt], x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt] + + let (sumcheck_proofs, state) = IOPProverState::prove( + expr_builder + .to_virtual_polys(&[exprs.into_iter().sum::>() * eq_expr], &[]), + transcript, + ); + + let rt = state.collect_raw_challenges(); + let evals = state.get_mle_flatten_final_evaluations(); + + #[cfg(feature = "sanity-check")] + { + let s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec(); + assert_eq!(eq_eval(&out_rt, &rt), evals[0]); + assert_eq!(s[0].evaluate(&rt), evals[1]); + } + } +} + +pub struct EccVerifier; + +impl EccVerifier { + pub fn verify_ecc_proof>() { + todo!() + } +} + pub struct CpuTowerProver; impl CpuTowerProver { @@ -284,9 +406,6 @@ impl CpuTowerProver { } } - for expr in exprs.iter() { - println!("expr: {:?}", expr); - } let wrap_batch_span = entered_span!("wrap_batch"); let (sumcheck_proofs, state) = IOPProverState::prove( expr_builder.to_virtual_polys(&[exprs.into_iter().sum()], &[]), @@ -930,10 +1049,10 @@ where #[cfg(test)] mod tests { + use std::iter::repeat; + use ff_ext::BabyBearExt4; - use gkr_iop::cpu::CpuBackend; use itertools::Itertools; - use mpcs::{Basefold, BasefoldRSParams}; use multilinear_extensions::{ mle::{IntoMLE, MultilinearExtension}, util::transpose, @@ -942,86 +1061,82 @@ mod tests { use transcript::BasicTranscript; use crate::scheme::{ - constants::SEPTIC_JACOBIAN_NUM_MLES, - cpu::CpuTowerProver, - hal::TowerProverSpec, - septic_curve::{SepticExtension, SepticJacobianPoint, SepticPoint}, - utils::infer_septic_sum_witness, - verifier::TowerVerify, + constants::SEPTIC_EXTENSION_DEGREE, + cpu::CpuEccProver, + septic_curve::{SepticExtension, SepticPoint}, }; #[test] - fn test_ecc_tower_prover() { + fn test_ecc_quark_prover() { type E = BabyBearExt4; type F = BabyBear; - type PCS = Basefold; let log2_n = 6; let n_points = 1 << log2_n; let mut rng = rand::thread_rng(); // generate 1 ecc add witness - let ecc_spec: TowerProverSpec<'_, CpuBackend> = { - // sample n points - let points = (0..n_points) - .map(|_| SepticJacobianPoint::::random(&mut rng)) + let ecc_spec: Vec> = { + // sample N = 2^n points + let mut points = (0..n_points) + .map(|_| SepticPoint::::random(&mut rng)) .collect_vec(); + let mut s = Vec::with_capacity(n_points); + + for layer in (1..=log2_n).rev() { + let num_inputs = 1 << layer; + let inputs = &points[points.len() - num_inputs..]; + + s.extend(inputs.chunks_exact(2).map(|chunk| { + let p = &chunk[0]; + let q = &chunk[1]; + + (&p.y - &q.y) * (&p.x - &q.x).inverse().unwrap() + })); + + points.extend( + points[points.len() - num_inputs..] + .chunks_exact(2) + .map(|chunk| { + let p = chunk[0].clone(); + let q = chunk[1].clone(); + + p + q + }) + .collect_vec(), + ); + } + // padding to 2*N + s.extend(repeat(SepticExtension::zero()).take(n_points + 1)); + points.push(SepticPoint::point_at_infinity()); + + assert_eq!(s.len(), 2 * n_points); + assert_eq!(points.len(), 2 * n_points); // transform points to row major matrix - let trace = points[0..n_points / 2] + let trace = points .iter() - .zip(points[n_points / 2..n_points].iter()) - .map(|(p, q)| { - [p, q] - .iter() - .flat_map(|p| p.x.0.iter().chain(p.y.0.iter()).chain(p.z.0.iter())) + .zip_eq(s.iter()) + .map(|(p, s)| { + p.x.iter() + .chain(p.y.iter()) + .chain(s.iter()) .copied() .collect_vec() }) .collect_vec(); // transpose row major matrix to column major matrix - let p_mles: Vec> = transpose(trace) + transpose(trace) .into_iter() .map(|v| v.into_mle()) - .collect_vec(); - - crate::scheme::hal::TowerProverSpec { - witness: infer_septic_sum_witness(p_mles), - } + .collect_vec() }; - let output_layer = &ecc_spec.witness[0]; - let ecc_out_evals: Vec> = output_layer - .chunks_exact(SEPTIC_JACOBIAN_NUM_MLES) - .map(|mles| { - mles.iter() - .map(|mle| mle.get_base_field_vec()[0]) - .collect_vec() - }) - .map(|chunk| SepticJacobianPoint { - x: SepticExtension(chunk[0..7].try_into().unwrap()), - y: SepticExtension(chunk[7..14].try_into().unwrap()), - z: SepticExtension(chunk[14..21].try_into().unwrap()), - }) - .collect_vec(); + let (xs, rest) = ecc_spec.split_at(SEPTIC_EXTENSION_DEGREE); + let (ys, s) = rest.split_at(SEPTIC_EXTENSION_DEGREE); let mut transcript = BasicTranscript::new(b"test"); - println!("begin to create tower proof"); - let (_, tower_proof) = - CpuTowerProver::create_proof(vec![], vec![], Some(ecc_spec), 2, &mut transcript); - - let mut transcript = BasicTranscript::new(b"test"); - assert!( - TowerVerify::verify( - vec![], - vec![], - ecc_out_evals, - &tower_proof, - vec![], - 2, - &mut transcript - ) - .is_ok() - ); + let prover = CpuEccProver {}; + prover.create_ecc_proof(xs.to_vec(), ys.to_vec(), s.to_vec(), &mut transcript); } } From c3391bb0234056925d63f4cc9a92008a716e7fca Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 29 Sep 2025 16:25:10 +0800 Subject: [PATCH 20/91] sanity check on quark's zerocheck --- ceno_zkvm/src/scheme/cpu/mod.rs | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 2434671ab..d9d8a5cbb 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -27,9 +27,12 @@ use gkr_iop::{ use itertools::{Itertools, chain}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, util::ceil_log2, virtual_poly::{build_eq_x_r_vec, eq_eval}, virtual_polys::VirtualPolynomialsBuilder, Expression, Instance, WitnessId + Expression, Instance, WitnessId, + mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, + util::ceil_log2, + virtual_poly::{build_eq_x_r_vec, eq_eval}, + virtual_polys::VirtualPolynomialsBuilder, }; -use p3::field::PackedValue; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; use std::{collections::BTreeMap, sync::Arc}; use sumcheck::{ @@ -65,7 +68,7 @@ impl CpuEccProver { mut ys: Vec>, invs: Vec>, transcript: &mut impl Transcript, - ) { + ) -> EccQuarkProof { assert_eq!(xs.len(), SEPTIC_EXTENSION_DEGREE); assert_eq!(ys.len(), SEPTIC_EXTENSION_DEGREE); @@ -144,23 +147,33 @@ impl CpuEccProver { // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - y[b,0] - y[1,b] - - // reduced to s[0,rt], x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt] - - let (sumcheck_proofs, state) = IOPProverState::prove( + let (zerocheck_proof, state) = IOPProverState::prove( expr_builder .to_virtual_polys(&[exprs.into_iter().sum::>() * eq_expr], &[]), transcript, ); let rt = state.collect_raw_challenges(); + // TODO: fix this assertion + assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); let evals = state.get_mle_flatten_final_evaluations(); #[cfg(feature = "sanity-check")] { let s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec(); + let x0 = filter_bj(&xs, 0); + // check evaluations assert_eq!(eq_eval(&out_rt, &rt), evals[0]); - assert_eq!(s[0].evaluate(&rt), evals[1]); + for i in 0..SEPTIC_EXTENSION_DEGREE { + assert_eq!(s[i].evaluate(&rt), evals[1 + i]); + assert_eq!(x0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE + 1 + i]); + } + } + + // TODO: prove the validity of s[0,rt], x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt] + EccQuarkProof { + zerocheck_proof, + evals, } } } From ed67cd38ddeae41d6207192f3a761fc838118d4b Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 29 Sep 2025 16:52:31 +0800 Subject: [PATCH 21/91] delete ec addition using tower tree --- ceno_zkvm/src/scheme/cpu/mod.rs | 92 +------------------------------- ceno_zkvm/src/scheme/verifier.rs | 72 ++----------------------- 2 files changed, 5 insertions(+), 159 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index d9d8a5cbb..f79c968fe 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -192,7 +192,6 @@ impl CpuTowerProver { pub fn create_proof<'a, E: ExtensionField, PCS: PolynomialCommitmentScheme>( prod_specs: Vec>>, logup_specs: Vec>>, - ecc_spec: Option>>, num_fanin: usize, transcript: &mut impl Transcript, ) -> (Point, TowerProofs) { @@ -200,7 +199,6 @@ impl CpuTowerProver { enum GroupedMLE<'a, E: ExtensionField> { Prod((usize, Vec>)), // usize is the index in prod_specs Logup((usize, Vec>)), // usize is the index in logup_specs - EcAdd(Vec>), // 2 points, each point has 21 polys } // XXX to sumcheck batched product argument with logup, we limit num_product_fanin to 2 @@ -214,7 +212,6 @@ impl CpuTowerProver { let max_round_index = prod_specs .iter() .chain(logup_specs.iter()) - .chain(ecc_spec.iter()) .map(|m| m.witness.len()) .max() .unwrap() @@ -224,8 +221,7 @@ impl CpuTowerProver { let alpha_pows = get_challenge_pows( prod_specs_len + // logup occupy 2 sumcheck: numerator and denominator - logup_specs_len * 2 - + ecc_spec.as_ref().map_or(0, |_| SEPTIC_JACOBIAN_NUM_MLES), + logup_specs_len * 2, transcript, ); let initial_rt: Point = transcript.sample_and_append_vec(b"product_sum", log_num_fanin); @@ -255,12 +251,6 @@ impl CpuTowerProver { merge_spec_witness(&mut layer_witness, spec, i, GroupedMLE::Logup); } - if let Some(ecc_spec) = ecc_spec { - merge_spec_witness(&mut layer_witness, ecc_spec, 0, |(_, v)| { - GroupedMLE::EcAdd(v) - }); - } - // skip(1) for output layer for (round, mut layer_witness) in layer_witness.into_iter().enumerate().skip(1) { // in first few round we just run on single thread @@ -339,83 +329,6 @@ impl CpuTowerProver { + alpha_denominator * q1 * q2), ); } - GroupedMLE::EcAdd(layer_polys) => { - assert_eq!(layer_polys.len(), 2 * SEPTIC_JACOBIAN_NUM_MLES); // 2 points, each point has 21 polys - - let (x1, rest) = layer_polys.split_at_mut(SEPTIC_EXTENSION_DEGREE); - let (y1, rest) = rest.split_at_mut(SEPTIC_EXTENSION_DEGREE); - let (z1, rest) = rest.split_at_mut(SEPTIC_EXTENSION_DEGREE); - let (x2, rest) = rest.split_at_mut(SEPTIC_EXTENSION_DEGREE); - let (y2, z2) = rest.split_at_mut(SEPTIC_EXTENSION_DEGREE); - - let x1 = &SymbolicSepticExtension::new( - x1.into_iter() - .map(|x| expr_builder.lift(x.to_either())) - .collect(), - ); - let y1 = &SymbolicSepticExtension::new( - y1.into_iter() - .map(|y| expr_builder.lift(y.to_either())) - .collect(), - ); - let z1 = &SymbolicSepticExtension::new( - z1.into_iter() - .map(|z| expr_builder.lift(z.to_either())) - .collect(), - ); - let x2 = &SymbolicSepticExtension::new( - x2.into_iter() - .map(|x| expr_builder.lift(x.to_either())) - .collect(), - ); - let y2 = &SymbolicSepticExtension::new( - y2.into_iter() - .map(|y| expr_builder.lift(y.to_either())) - .collect(), - ); - let z2 = &SymbolicSepticExtension::new( - z2.into_iter() - .map(|z| expr_builder.lift(z.to_either())) - .collect(), - ); - - let two: Expression = 2.into(); - let four: Expression = 4.into(); - let z1_squared = z1 * z1; - println!("z1_squared: {:?}", z1_squared); - let z1_cubed = &z1_squared * z1; - let z2_squared = z2 * z2; - let z2_cubed = &z2_squared * z2; - - // U1 = X1*Z2^2, U2 = X2*Z1^2 - let u1 = x1 * &z2_squared; - let u2 = x2 * &z1_squared; - - // S1 = Y1*Z2^3, S2 = Y2*Z1^3 - let s1 = y1 * &z2_cubed; - let s2 = y2 * &z1_cubed; - - // H = U2-U1, R = S2-S1 - let h = u2 - &u1; - let h_squared = &h * &h; - let h_cubed = &h_squared * &h; - - let i = h_squared * &four; - let j = h_cubed * &four; - let r = (&s2 - &s1) * &two; - let v = &u1 * &i; - - // Check the formulas for X3, Y3, Z3 - // X3 = R^2 - J - 2*V - let x3 = &r * &r - j.clone() - v.clone() * &two; - // Y3 = R*(V - X3) - 2*S1*J - let y3 = r * (&v - &x3) - s1 * j * &two; - // Z3 = (Z1 + Z2)^2 - Z1Z1 - Z2Z2) * H - let z3 = z1 * z2 * h * &two; - // exprs.extend((x3 * &eq_expr).to_exprs()); - // exprs.extend((y3 * &eq_expr).to_exprs()); - // exprs.extend((z3 * &eq_expr).to_exprs()); - } } } @@ -757,8 +670,7 @@ impl> TowerProver> ZKVMVerifier .chain(proof.w_out_evals.iter().cloned()) .collect_vec(), proof.lk_out_evals.clone(), - vec![], tower_proofs, vec![num_var_with_rotation; num_batched], num_product_fanin, @@ -516,7 +515,6 @@ impl> ZKVMVerifier .iter() .map(|eval| eval.to_vec()) .collect_vec(), - vec![], tower_proofs, expected_rounds, num_logup_fanin, @@ -783,10 +781,8 @@ impl TowerVerify { } pub fn verify( - // TODO: unify prod/logup/ec_add prod_out_evals: Vec>, logup_out_evals: Vec>, - ecc_out_evals: Vec>, tower_proofs: &TowerProofs, num_variables: Vec, num_fanin: usize, @@ -806,13 +802,10 @@ impl TowerVerify { assert!(logup_out_evals.iter().all(|evals| { evals.len() == 4 // [p1, p2, q1, q2] })); - assert!(ecc_out_evals.len() == 0 || ecc_out_evals.len() == 2); assert_eq!(num_variables.len(), num_prod_spec + num_logup_spec); - let num_ecc = if ecc_out_evals.is_empty() { 0 } else { 1 }; - let alpha_pows = get_challenge_pows( - num_prod_spec + num_logup_spec * 2 + num_ecc * SEPTIC_JACOBIAN_NUM_MLES, /* logup occupy 2 sumcheck: numerator and denominator */ + num_prod_spec + num_logup_spec * 2, /* logup occupy 2 sumcheck: numerator and denominator */ transcript, ); let initial_rt: Point = transcript.sample_and_append_vec(b"product_sum", log2_num_fanin); @@ -849,11 +842,6 @@ impl TowerVerify { ) }) .unzip::<_, _, Vec<_>, Vec<_>>(); - let mut ecc_eval = if num_ecc == 1 { - Self::get_ecc_eval(&ecc_out_evals[0], &ecc_out_evals[1], &initial_rt) - } else { - vec![] - }; // initial claim = \sum_j alpha^j * out_j[rt] let initial_claim = izip!(&prod_spec_point_n_eval, &alpha_pows) @@ -864,10 +852,7 @@ impl TowerVerify { &alpha_pows[num_prod_spec..] ) .map(|(point_n_eval, alpha)| point_n_eval.eval * *alpha) - .sum::() - + izip!(&ecc_eval, &alpha_pows[num_prod_spec + num_logup_spec * 2..]) - .map(|(xi, alpha)| xi.eval * *alpha) - .sum::(); + .sum::(); let max_num_variables = num_variables.iter().max().unwrap(); @@ -930,34 +915,6 @@ impl TowerVerify { }) .sum::(); - if num_ecc == 1 { - // (x', y', z')[b] = jacobian_add((x, y, z)[0,b], (x, y, z)[1,b]) - let degree = SEPTIC_EXTENSION_DEGREE; - let (x1, rest) = tower_proofs.ecc_evals[round].split_at(SEPTIC_EXTENSION_DEGREE); - let (y1, rest) = rest.split_at(SEPTIC_EXTENSION_DEGREE); - let (z1, rest) = rest.split_at(SEPTIC_EXTENSION_DEGREE); - let (x2, rest) = rest.split_at(SEPTIC_EXTENSION_DEGREE); - let (y2, rest) = rest.split_at(SEPTIC_EXTENSION_DEGREE); - let (z2, rest) = rest.split_at(SEPTIC_EXTENSION_DEGREE); - - // p1 and p2 are not valid ecc points - // we just want to use ecc addition formula as expression to get - // the expected evaluation for ecc sumcheck - let p1 = SepticJacobianPoint { - x: x1.into(), - y: y1.into(), - z: z1.into(), - }; - let p2 = SepticJacobianPoint { - x: x2.into(), - y: y2.into(), - z: z2.into(), - }; - - let SepticJacobianPoint { x, y, z } = p1 + p2; - expected_evaluation += izip!(x.0.iter().chain(y.0.iter()).chain(z.0.iter()), alpha_pows[num_prod_spec + num_logup_spec * 2..].iter()).map(|(&xi, &alpha)| eq * xi * alpha).sum::(); - } - if expected_evaluation != sumcheck_claim.expected_evaluation { return Err(ZKVMError::VerifyError("mismatch tower evaluation".into())); } @@ -1038,31 +995,8 @@ impl TowerVerify { .sum::(); // sum evaluation from different specs - let mut next_eval = next_prod_spec_evals + next_logup_spec_evals; - - if num_ecc == 1 { - let next_round_expected_eval = if round < *max_num_variables - 1 { - tower_proofs.ecc_evals[round][..SEPTIC_JACOBIAN_NUM_MLES] - .iter() - .zip(tower_proofs.ecc_evals[round][SEPTIC_JACOBIAN_NUM_MLES..].iter()) - .zip(next_alpha_pows[num_prod_spec + num_logup_spec * 2..].iter()) - .zip(ecc_eval.iter_mut()) - .map(|(((a, b), alpha), point_and_eval)| { - let eval = izip!(vec![a, b].into_iter(), coeffs.iter()).map(|(a, b)| *a * *b).sum::(); + let next_eval = next_prod_spec_evals + next_logup_spec_evals; - point_and_eval.point = rt_prime.clone(); - point_and_eval.eval = eval; - - eval * *alpha - }) - .sum() - } else { - E::ZERO - }; - if next_round < *max_num_variables - 1 { - next_eval += next_round_expected_eval; - } - } Ok((PointAndEval { point: rt_prime, eval: next_eval, From be0629cd820252fab947c41fffce2660ade8638e Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 29 Sep 2025 16:55:37 +0800 Subject: [PATCH 22/91] delete --- ceno_zkvm/src/scheme/verifier.rs | 43 -------------------------------- 1 file changed, 43 deletions(-) diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index d216b0fb6..30842fab9 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -740,46 +740,6 @@ pub type TowerVerifyResult = Result< >; impl TowerVerify { - fn get_ecc_eval( - p1: &SepticJacobianPoint, - p2: &SepticJacobianPoint, - rt: &[E], - ) -> Vec> { - let SepticJacobianPoint { x, y, z } = p1; - let SepticJacobianPoint { - x: x2, - y: y2, - z: z2, - } = p2; - - let xs = - x.0.iter() - .cloned() - .zip(x2.iter().cloned()) - .map(|(xi, x2i)| vec![xi, x2i].into_mle().evaluate(rt)) - .collect_vec(); - - let ys = - y.0.iter() - .cloned() - .zip(y2.iter().cloned()) - .map(|(yi, y2i)| vec![yi, y2i].into_mle().evaluate(rt)) - .collect_vec(); - - let zs = - z.0.iter() - .cloned() - .zip(z2.iter().cloned()) - .map(|(zi, z2i)| vec![zi, z2i].into_mle().evaluate(rt)) - .collect_vec(); - - xs.into_iter() - .chain(ys.into_iter()) - .chain(zs.into_iter()) - .map(|eval| PointAndEval::new(rt.to_vec(), eval)) - .collect_vec() - } - pub fn verify( prod_out_evals: Vec>, logup_out_evals: Vec>, @@ -813,9 +773,6 @@ impl TowerVerify { // out_j[rt] := (record_{j}[rt]) // out_j[rt] := (logup_p{j}[rt]) // out_j[rt] := (logup_q{j}[rt]) - // out_j[rt] := ecc_x{j}[rt] - // out_j[rt] := ecc_y{j}[rt] - // out_j[rt] := ecc_z{j}[rt] // bookkeeping records of latest (point, evaluation) of each layer // prod argument From 74af34fa0913ec095c7ea4cd31b2ca4309696f43 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 29 Sep 2025 16:57:28 +0800 Subject: [PATCH 23/91] clean --- ceno_zkvm/src/scheme/cpu/mod.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index f79c968fe..3ab7aedf3 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -474,8 +474,6 @@ impl> TowerProver Date: Mon, 29 Sep 2025 20:43:46 +0800 Subject: [PATCH 24/91] add selector to turn off affine_add when b = 1...1 --- ceno_zkvm/src/scheme/cpu/mod.rs | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 3ab7aedf3..b93ffce54 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -23,6 +23,8 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, gkr::{self, Evaluation, GKRProof, GKRProverOutput, layer::LayerWitness}, hal::ProverBackend, + selector::SelectorType, + utils::eq_eval_less_or_equal_than, }; use itertools::{Itertools, chain}; use mpcs::{Point, PolynomialCommitmentScheme}; @@ -30,9 +32,10 @@ use multilinear_extensions::{ Expression, Instance, WitnessId, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, util::ceil_log2, - virtual_poly::{build_eq_x_r_vec, eq_eval}, + virtual_poly::{build_eq_x_r_vec}, virtual_polys::VirtualPolynomialsBuilder, }; +use p3::field::FieldAlgebra; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; use std::{collections::BTreeMap, sync::Arc}; use sumcheck::{ @@ -76,10 +79,16 @@ impl CpuEccProver { let out_rt = transcript.sample_and_append_vec(b"ecc", n); let num_threads = optimal_sumcheck_threads(out_rt.len()); + let alpha_pows = + transcript.sample_and_append_challenge_pows(SEPTIC_EXTENSION_DEGREE * 3, b"ecc_alpha"); + let mut expr_builder = VirtualPolynomialsBuilder::new(num_threads, out_rt.len()); - let mut eq: MultilinearExtension<'_, E> = build_eq_x_r_vec(&out_rt).into_mle(); - let eq_expr = expr_builder.lift((&mut eq).to_either()); + let sel = SelectorType::Prefix(E::BaseField::ZERO, 0.into()); + let num_instances = (1 << n) - 1; + let mut sel_mle: MultilinearExtension<'_, E> = sel.compute(&out_rt, num_instances).unwrap(); + let sel_expr = expr_builder.lift(sel_mle.to_either()); + let mut exprs = vec![]; let filter_bj = |v: &[MultilinearExtension<'_, E>], j: usize| { @@ -138,10 +147,12 @@ impl CpuEccProver { .collect(), ); // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) - exprs.extend_from_slice( + exprs.extend( (s.clone() * (&x0 - &x1) - (&y0 - &y1)) .to_exprs() - .as_slice(), + .into_iter() + .zip(alpha_pows.iter().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] @@ -149,12 +160,11 @@ impl CpuEccProver { // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - y[b,0] - y[1,b] let (zerocheck_proof, state) = IOPProverState::prove( expr_builder - .to_virtual_polys(&[exprs.into_iter().sum::>() * eq_expr], &[]), + .to_virtual_polys(&[exprs.into_iter().sum::>() * sel_expr], &[]), transcript, ); let rt = state.collect_raw_challenges(); - // TODO: fix this assertion assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); let evals = state.get_mle_flatten_final_evaluations(); @@ -163,7 +173,10 @@ impl CpuEccProver { let s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec(); let x0 = filter_bj(&xs, 0); // check evaluations - assert_eq!(eq_eval(&out_rt, &rt), evals[0]); + assert_eq!( + eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt), + evals[0] + ); for i in 0..SEPTIC_EXTENSION_DEGREE { assert_eq!(s[i].evaluate(&rt), evals[1 + i]); assert_eq!(x0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE + 1 + i]); From 6074b14780a85eb55a06921edd19c8082458cfe7 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 30 Sep 2025 11:49:00 +0800 Subject: [PATCH 25/91] add quark ecc verifier --- ceno_zkvm/src/scheme/cpu/mod.rs | 199 +++++++++++++++++++++++++++++--- 1 file changed, 184 insertions(+), 15 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index b93ffce54..4e0104b6e 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -32,15 +32,15 @@ use multilinear_extensions::{ Expression, Instance, WitnessId, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, util::ceil_log2, - virtual_poly::{build_eq_x_r_vec}, + virtual_poly::{VPAuxInfo, build_eq_x_r_vec}, virtual_polys::VirtualPolynomialsBuilder, }; use p3::field::FieldAlgebra; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; -use std::{collections::BTreeMap, sync::Arc}; +use std::{collections::BTreeMap, marker::PhantomData, ops::Deref, sync::Arc}; use sumcheck::{ macros::{entered_span, exit_span}, - structs::{IOPProof, IOPProverMessage, IOPProverState}, + structs::{IOPProof, IOPProverMessage, IOPProverState, IOPVerifierState}, util::{get_challenge_pows, optimal_sumcheck_threads}, }; use transcript::Transcript; @@ -56,7 +56,7 @@ pub type TowerRelationOutput = ( pub struct EccQuarkProof { pub zerocheck_proof: IOPProof, - + pub num_vars: usize, pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1] } @@ -65,6 +65,10 @@ pub struct EccQuarkProof { pub struct CpuEccProver; impl CpuEccProver { + pub fn new() -> Self { + Self {} + } + pub fn create_ecc_proof<'a, E: ExtensionField>( &self, mut xs: Vec>, @@ -111,11 +115,11 @@ impl CpuEccProver { let mut x1 = filter_bj(&xs, 1); let mut y1 = filter_bj(&ys, 1); // build x[1,b], y[1,b], s[0,b] - let x3 = xs + let mut x3 = xs .iter_mut() .map(|x| x.as_view_slice_mut(2, 1)) .collect_vec(); - let y3 = ys + let mut y3 = ys .iter_mut() .map(|x| x.as_view_slice_mut(2, 1)) .collect_vec(); @@ -146,7 +150,18 @@ impl CpuEccProver { .map(|y| expr_builder.lift(y.to_either())) .collect(), ); - // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) + let x3 = SymbolicSepticExtension::new( + x3.iter_mut() + .map(|x| expr_builder.lift(x.to_either())) + .collect(), + ); + let y3 = SymbolicSepticExtension::new( + y3.iter_mut() + .map(|y| expr_builder.lift(y.to_either())) + .collect(), + ); + // affine addition + // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) with b != (1,...,1) exprs.extend( (s.clone() * (&x0 - &x1) - (&y0 - &y1)) .to_exprs() @@ -155,9 +170,32 @@ impl CpuEccProver { .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); - // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] + // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] with b != (1,...,1) + exprs.extend( + ((&s * &s) - &x0 - &x1 - &x3) + .to_exprs() + .into_iter() + .zip( + alpha_pows[SEPTIC_EXTENSION_DEGREE..] + .iter() + .take(SEPTIC_EXTENSION_DEGREE), + ) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) with b != (1,...,1) + exprs.extend( + (s.clone() * (&x0 - &x3) - (&y0 + &y3)) + .to_exprs() + .into_iter() + .zip( + alpha_pows[SEPTIC_EXTENSION_DEGREE * 2..] + .iter() + .take(SEPTIC_EXTENSION_DEGREE), + ) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); - // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - y[b,0] - y[1,b] let (zerocheck_proof, state) = IOPProverState::prove( expr_builder .to_virtual_polys(&[exprs.into_iter().sum::>() * sel_expr], &[]), @@ -165,14 +203,26 @@ impl CpuEccProver { ); let rt = state.collect_raw_challenges(); - assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); let evals = state.get_mle_flatten_final_evaluations(); + assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); + // 7 for x[b,0], x[b,1], y[b,0], y[b,1], x[1,b], y[1,b], s[0,b] + assert_eq!(evals.len(), 1 + SEPTIC_EXTENSION_DEGREE * 7); + #[cfg(feature = "sanity-check")] { + use tracing_subscriber::filter; + let s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec(); let x0 = filter_bj(&xs, 0); + let y0 = filter_bj(&ys, 0); + let x1 = filter_bj(&xs, 1); + let y1 = filter_bj(&ys, 1); + let x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + let y3 = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec(); + // check evaluations + assert_eq!( eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt), evals[0] @@ -180,12 +230,33 @@ impl CpuEccProver { for i in 0..SEPTIC_EXTENSION_DEGREE { assert_eq!(s[i].evaluate(&rt), evals[1 + i]); assert_eq!(x0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE + 1 + i]); + assert_eq!( + y0[i].evaluate(&rt), + evals[SEPTIC_EXTENSION_DEGREE * 2 + 1 + i] + ); + assert_eq!( + x1[i].evaluate(&rt), + evals[SEPTIC_EXTENSION_DEGREE * 3 + 1 + i] + ); + assert_eq!( + y1[i].evaluate(&rt), + evals[SEPTIC_EXTENSION_DEGREE * 4 + 1 + i] + ); + assert_eq!( + x3[i].evaluate(&rt), + evals[SEPTIC_EXTENSION_DEGREE * 5 + 1 + i] + ); + assert_eq!( + y3[i].evaluate(&rt), + evals[SEPTIC_EXTENSION_DEGREE * 6 + 1 + i] + ); } } // TODO: prove the validity of s[0,rt], x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt] EccQuarkProof { zerocheck_proof, + num_vars: n, evals, } } @@ -194,8 +265,97 @@ impl CpuEccProver { pub struct EccVerifier; impl EccVerifier { - pub fn verify_ecc_proof>() { - todo!() + pub fn new() -> Self { + Self {} + } + + pub fn verify_ecc_proof( + &self, + proof: &EccQuarkProof, + transcript: &mut impl Transcript, + ) -> Result<(), ZKVMError> { + let out_rt = transcript.sample_and_append_vec(b"ecc", proof.num_vars); + let alpha_pows = + transcript.sample_and_append_challenge_pows(SEPTIC_EXTENSION_DEGREE * 3, b"ecc_alpha"); + + let sumcheck_claim = IOPVerifierState::verify( + E::ZERO, + &proof.zerocheck_proof, + &VPAuxInfo { + max_degree: 3, + max_num_variables: proof.num_vars, + phantom: PhantomData, + }, + transcript, + ); + + let s0: SepticExtension = proof.evals[1..][0..SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + let x0: SepticExtension = proof.evals[1..] + [SEPTIC_EXTENSION_DEGREE..2 * SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + let y0: SepticExtension = proof.evals[1..] + [2 * SEPTIC_EXTENSION_DEGREE..3 * SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + let x1: SepticExtension = proof.evals[1..] + [3 * SEPTIC_EXTENSION_DEGREE..4 * SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + let y1: SepticExtension = proof.evals[1..] + [4 * SEPTIC_EXTENSION_DEGREE..5 * SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + let x3: SepticExtension = proof.evals[1..] + [5 * SEPTIC_EXTENSION_DEGREE..6 * SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + let y3: SepticExtension = proof.evals[1..] + [6 * SEPTIC_EXTENSION_DEGREE..7 * SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + + let num_instances = (1 << proof.num_vars) - 1; + let rt = sumcheck_claim + .point + .iter() + .map(|c| c.elements.clone()) + .collect_vec(); + + // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) + // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] + // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) + let v1: SepticExtension = s0.clone() * (&x0 - &x1) - (&y0 - &y1); + let v2 = s0.square() - &x0 - &x1 - &x3; + let v3 = s0 * (&x0 - &x3) - (&y0 + &y3); + + let v: E = vec![v1, v2, v3] + .into_iter() + .enumerate() + .flat_map(|(i, v)| { + let start = i * SEPTIC_EXTENSION_DEGREE; + let end = (i + 1) * SEPTIC_EXTENSION_DEGREE; + v.0.into_iter() + .zip(alpha_pows[start..end].iter()) + .map(|(c, alpha)| c * *alpha) + }) + .sum(); + + let sel = eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt); + if sumcheck_claim.expected_evaluation != v * sel { + return Err(ZKVMError::VerifyError( + (format!( + "ecc zerocheck failed: mismatched evaluation, expected {}, got {}", + sumcheck_claim.expected_evaluation, + v * sel + )) + .into(), + )); + } + + Ok(()) } } @@ -998,7 +1158,7 @@ mod tests { use crate::scheme::{ constants::SEPTIC_EXTENSION_DEGREE, - cpu::CpuEccProver, + cpu::{CpuEccProver, EccVerifier}, septic_curve::{SepticExtension, SepticPoint}, }; @@ -1072,7 +1232,16 @@ mod tests { let (ys, s) = rest.split_at(SEPTIC_EXTENSION_DEGREE); let mut transcript = BasicTranscript::new(b"test"); - let prover = CpuEccProver {}; - prover.create_ecc_proof(xs.to_vec(), ys.to_vec(), s.to_vec(), &mut transcript); + let prover = CpuEccProver::new(); + let quark_proof = + prover.create_ecc_proof(xs.to_vec(), ys.to_vec(), s.to_vec(), &mut transcript); + + let mut transcript = BasicTranscript::new(b"test"); + let verifier = EccVerifier::new(); + assert!( + verifier + .verify_ecc_proof(&quark_proof, &mut transcript) + .is_ok() + ); } } From f3392ea19c36871c0226c1c7e55ce9930af1b731 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 30 Sep 2025 15:20:26 +0800 Subject: [PATCH 26/91] reorg --- ceno_zkvm/src/scheme/cpu/mod.rs | 151 ++++++--------------------- ceno_zkvm/src/scheme/septic_curve.rs | 18 ++++ ceno_zkvm/src/scheme/verifier.rs | 117 +++++++++++++++++++-- ceno_zkvm/src/structs.rs | 19 +++- 4 files changed, 177 insertions(+), 128 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 4e0104b6e..d06c6cd91 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -5,17 +5,15 @@ use crate::{ circuit_builder::ConstraintSystem, error::ZKVMError, scheme::{ - constants::{ - NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE, SEPTIC_JACOBIAN_NUM_MLES, - }, + constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE}, hal::{DeviceProvingKey, MainSumcheckEvals, ProofInput, TowerProverSpec}, - septic_curve::{SepticExtension, SepticPoint, SymbolicSepticExtension}, + septic_curve::{SepticPoint, SymbolicSepticExtension}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, masked_mle_split_to_chunks, wit_infer_by_expr, }, }, - structs::{ComposedConstrainSystem, PointAndEval, TowerProofs}, + structs::{ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs}, }; use either::Either; use ff_ext::ExtensionField; @@ -24,7 +22,6 @@ use gkr_iop::{ gkr::{self, Evaluation, GKRProof, GKRProverOutput, layer::LayerWitness}, hal::ProverBackend, selector::SelectorType, - utils::eq_eval_less_or_equal_than, }; use itertools::{Itertools, chain}; use mpcs::{Point, PolynomialCommitmentScheme}; @@ -32,20 +29,23 @@ use multilinear_extensions::{ Expression, Instance, WitnessId, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, util::ceil_log2, - virtual_poly::{VPAuxInfo, build_eq_x_r_vec}, + virtual_poly::build_eq_x_r_vec, virtual_polys::VirtualPolynomialsBuilder, }; use p3::field::FieldAlgebra; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; -use std::{collections::BTreeMap, marker::PhantomData, ops::Deref, sync::Arc}; +use std::{collections::BTreeMap, sync::Arc}; use sumcheck::{ macros::{entered_span, exit_span}, - structs::{IOPProof, IOPProverMessage, IOPProverState, IOPVerifierState}, + structs::{IOPProverMessage, IOPProverState}, util::{get_challenge_pows, optimal_sumcheck_threads}, }; use transcript::Transcript; use witness::next_pow2_instance_padding; +#[cfg(feature = "sanity-check")] +use {crate::scheme::septic_curve::SepticExtension, gkr_iop::utils::eq_eval_less_or_equal_than}; + pub type TowerRelationOutput = ( Point, TowerProofs, @@ -54,12 +54,6 @@ pub type TowerRelationOutput = ( Vec>, ); -pub struct EccQuarkProof { - pub zerocheck_proof: IOPProof, - pub num_vars: usize, - pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1] -} - // implement the IOP proposed in [Quark paper](https://eprint.iacr.org/2020/1275.pdf) // to accumulate N=2^n EC points into one EC point using affine coordinates pub struct CpuEccProver; @@ -74,6 +68,7 @@ impl CpuEccProver { mut xs: Vec>, mut ys: Vec>, invs: Vec>, + sum: SepticPoint, transcript: &mut impl Transcript, ) -> EccQuarkProof { assert_eq!(xs.len(), SEPTIC_EXTENSION_DEGREE); @@ -211,8 +206,6 @@ impl CpuEccProver { #[cfg(feature = "sanity-check")] { - use tracing_subscriber::filter; - let s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec(); let x0 = filter_bj(&xs, 0); let y0 = filter_bj(&ys, 0); @@ -220,9 +213,18 @@ impl CpuEccProver { let y1 = filter_bj(&ys, 1); let x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); let y3 = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec(); + let final_sum_x: SepticExtension = (x3.iter()) + .map(|x| x.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0] + .collect_vec() + .into(); + let final_sum_y: SepticExtension = (y3.iter()) + .map(|y| y.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0] + .collect_vec() + .into(); + let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y); + assert_eq!(final_sum, sum); // check evaluations - assert_eq!( eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt), evals[0] @@ -258,107 +260,11 @@ impl CpuEccProver { zerocheck_proof, num_vars: n, evals, + sum, } } } -pub struct EccVerifier; - -impl EccVerifier { - pub fn new() -> Self { - Self {} - } - - pub fn verify_ecc_proof( - &self, - proof: &EccQuarkProof, - transcript: &mut impl Transcript, - ) -> Result<(), ZKVMError> { - let out_rt = transcript.sample_and_append_vec(b"ecc", proof.num_vars); - let alpha_pows = - transcript.sample_and_append_challenge_pows(SEPTIC_EXTENSION_DEGREE * 3, b"ecc_alpha"); - - let sumcheck_claim = IOPVerifierState::verify( - E::ZERO, - &proof.zerocheck_proof, - &VPAuxInfo { - max_degree: 3, - max_num_variables: proof.num_vars, - phantom: PhantomData, - }, - transcript, - ); - - let s0: SepticExtension = proof.evals[1..][0..SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - let x0: SepticExtension = proof.evals[1..] - [SEPTIC_EXTENSION_DEGREE..2 * SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - let y0: SepticExtension = proof.evals[1..] - [2 * SEPTIC_EXTENSION_DEGREE..3 * SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - let x1: SepticExtension = proof.evals[1..] - [3 * SEPTIC_EXTENSION_DEGREE..4 * SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - let y1: SepticExtension = proof.evals[1..] - [4 * SEPTIC_EXTENSION_DEGREE..5 * SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - let x3: SepticExtension = proof.evals[1..] - [5 * SEPTIC_EXTENSION_DEGREE..6 * SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - let y3: SepticExtension = proof.evals[1..] - [6 * SEPTIC_EXTENSION_DEGREE..7 * SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - - let num_instances = (1 << proof.num_vars) - 1; - let rt = sumcheck_claim - .point - .iter() - .map(|c| c.elements.clone()) - .collect_vec(); - - // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) - // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] - // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) - let v1: SepticExtension = s0.clone() * (&x0 - &x1) - (&y0 - &y1); - let v2 = s0.square() - &x0 - &x1 - &x3; - let v3 = s0 * (&x0 - &x3) - (&y0 + &y3); - - let v: E = vec![v1, v2, v3] - .into_iter() - .enumerate() - .flat_map(|(i, v)| { - let start = i * SEPTIC_EXTENSION_DEGREE; - let end = (i + 1) * SEPTIC_EXTENSION_DEGREE; - v.0.into_iter() - .zip(alpha_pows[start..end].iter()) - .map(|(c, alpha)| c * *alpha) - }) - .sum(); - - let sel = eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt); - if sumcheck_claim.expected_evaluation != v * sel { - return Err(ZKVMError::VerifyError( - (format!( - "ecc zerocheck failed: mismatched evaluation, expected {}, got {}", - sumcheck_claim.expected_evaluation, - v * sel - )) - .into(), - )); - } - - Ok(()) - } -} - pub struct CpuTowerProver; impl CpuTowerProver { @@ -1158,8 +1064,9 @@ mod tests { use crate::scheme::{ constants::SEPTIC_EXTENSION_DEGREE, - cpu::{CpuEccProver, EccVerifier}, + cpu::CpuEccProver, septic_curve::{SepticExtension, SepticPoint}, + verifier::EccVerifier, }; #[test] @@ -1171,6 +1078,7 @@ mod tests { let n_points = 1 << log2_n; let mut rng = rand::thread_rng(); + let final_sum; // generate 1 ecc add witness let ecc_spec: Vec> = { // sample N = 2^n points @@ -1202,6 +1110,8 @@ mod tests { .collect_vec(), ); } + final_sum = points.last().cloned().unwrap(); + // padding to 2*N s.extend(repeat(SepticExtension::zero()).take(n_points + 1)); points.push(SepticPoint::point_at_infinity()); @@ -1233,8 +1143,13 @@ mod tests { let mut transcript = BasicTranscript::new(b"test"); let prover = CpuEccProver::new(); - let quark_proof = - prover.create_ecc_proof(xs.to_vec(), ys.to_vec(), s.to_vec(), &mut transcript); + let quark_proof = prover.create_ecc_proof( + xs.to_vec(), + ys.to_vec(), + s.to_vec(), + final_sum, + &mut transcript, + ); let mut transcript = BasicTranscript::new(b"test"); let verifier = EccVerifier::new(); diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 4954a0d2a..ed3030d23 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -53,6 +53,15 @@ impl From<&[F]> for SepticExtension { } } +impl From> for SepticExtension { + fn from(v: Vec) -> Self { + assert!(v.len() == 7); + let mut arr = [F::default(); 7]; + arr.copy_from_slice(&v[0..7]); + Self(arr) + } +} + impl Deref for SepticExtension { type Target = [F]; @@ -732,6 +741,15 @@ pub struct SepticPoint { } impl SepticPoint { + pub fn from_affine(x: SepticExtension, y: SepticExtension) -> Self { + let is_infinity = if x.is_zero() && y.is_zero() { + true + } else { + false + }; + + Self { x, y, is_infinity } + } pub fn double(&self) -> Self { let a = F::from_canonical_u32(2); let three = F::from_canonical_u32(3); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 30842fab9..abe0f6ec2 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,12 +1,11 @@ use std::marker::PhantomData; use ff_ext::ExtensionField; -use p3::field::Field; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use gkr_iop::gkr::GKRClaims; +use gkr_iop::{gkr::GKRClaims, utils::eq_eval_less_or_equal_than}; use itertools::{Itertools, chain, interleave, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ @@ -27,13 +26,13 @@ use witness::next_pow2_instance_padding; use crate::{ error::ZKVMError, scheme::{ - constants::{ - NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE, SEPTIC_EXTENSION_DEGREE, - SEPTIC_JACOBIAN_NUM_MLES, - }, - septic_curve::SepticJacobianPoint, + constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE, SEPTIC_EXTENSION_DEGREE}, + septic_curve::SepticExtension, + }, + structs::{ + ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VerifyingKey, + ZKVMVerifyingKey, }, - structs::{ComposedConstrainSystem, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, utils::{ eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, @@ -839,7 +838,7 @@ impl TowerVerify { // check expected_evaluation let rt: Point = sumcheck_claim.point.iter().map(|c| c.elements).collect(); let eq = eq_eval(out_rt, &rt); - let mut expected_evaluation: E = (0..num_prod_spec) + let expected_evaluation: E = (0..num_prod_spec) .zip(alpha_pows.iter()) .zip(num_variables.iter()) .map(|((spec_index, alpha), max_round)| { @@ -969,3 +968,103 @@ impl TowerVerify { )) } } + +pub struct EccVerifier; + +impl EccVerifier { + pub fn new() -> Self { + Self {} + } + + pub fn verify_ecc_proof( + &self, + proof: &EccQuarkProof, + transcript: &mut impl Transcript, + ) -> Result<(), ZKVMError> { + let out_rt = transcript.sample_and_append_vec(b"ecc", proof.num_vars); + let alpha_pows = + transcript.sample_and_append_challenge_pows(SEPTIC_EXTENSION_DEGREE * 3, b"ecc_alpha"); + + let sumcheck_claim = IOPVerifierState::verify( + E::ZERO, + &proof.zerocheck_proof, + &VPAuxInfo { + max_degree: 3, + max_num_variables: proof.num_vars, + phantom: PhantomData, + }, + transcript, + ); + + let s0: SepticExtension = proof.evals[1..][0..SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + let x0: SepticExtension = proof.evals[1..] + [SEPTIC_EXTENSION_DEGREE..2 * SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + let y0: SepticExtension = proof.evals[1..] + [2 * SEPTIC_EXTENSION_DEGREE..3 * SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + let x1: SepticExtension = proof.evals[1..] + [3 * SEPTIC_EXTENSION_DEGREE..4 * SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + let y1: SepticExtension = proof.evals[1..] + [4 * SEPTIC_EXTENSION_DEGREE..5 * SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + let x3: SepticExtension = proof.evals[1..] + [5 * SEPTIC_EXTENSION_DEGREE..6 * SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + let y3: SepticExtension = proof.evals[1..] + [6 * SEPTIC_EXTENSION_DEGREE..7 * SEPTIC_EXTENSION_DEGREE] + .try_into() + .unwrap(); + + let num_instances = (1 << proof.num_vars) - 1; + let rt = sumcheck_claim + .point + .iter() + .map(|c| c.elements.clone()) + .collect_vec(); + + // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) + // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] + // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) + // + // note that they are not septic extension field elements, + // we just want to reuse the multiply/add/sub formulas + let v1: SepticExtension = s0.clone() * (&x0 - &x1) - (&y0 - &y1); + let v2: SepticExtension = s0.square() - &x0 - &x1 - &x3; + let v3: SepticExtension = s0 * (&x0 - &x3) - (&y0 + &y3); + + let v: E = vec![v1, v2, v3] + .into_iter() + .enumerate() + .flat_map(|(i, v)| { + let start = i * SEPTIC_EXTENSION_DEGREE; + let end = (i + 1) * SEPTIC_EXTENSION_DEGREE; + v.0.into_iter() + .zip(alpha_pows[start..end].iter()) + .map(|(c, alpha)| c * *alpha) + }) + .sum(); + + let sel = eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt); + if sumcheck_claim.expected_evaluation != v * sel { + return Err(ZKVMError::VerifyError( + (format!( + "ecc zerocheck failed: mismatched evaluation, expected {}, got {}", + sumcheck_claim.expected_evaluation, + v * sel + )) + .into(), + )); + } + + Ok(()) + } +} diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index d1bed423f..fc828ecde 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -2,6 +2,7 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, instructions::Instruction, + scheme::septic_curve::SepticPoint, state::StateCircuit, tables::{RMMCollections, TableCircuit}, }; @@ -16,9 +17,25 @@ use std::{ collections::{BTreeMap, HashMap}, sync::Arc, }; -use sumcheck::structs::IOPProverMessage; +use sumcheck::structs::{IOPProof, IOPProverMessage}; use witness::RowMajorMatrix; +/// proof that the sum of N=2^n EC points is equal to `sum` +/// in one layer instead of GKR layered circuit approach +/// note that this one layer IOP borrowed ideas from +/// [Quark paper](https://eprint.iacr.org/2020/1275.pdf) +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct EccQuarkProof { + pub zerocheck_proof: IOPProof, + pub num_vars: usize, + pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1] + pub sum: SepticPoint, +} + #[derive(Clone, Serialize, Deserialize)] #[serde(bound( serialize = "E::BaseField: Serialize", From 06312d6d26fcbadccd36960ca761e9e80eaadbd0 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 30 Sep 2025 15:24:09 +0800 Subject: [PATCH 27/91] refine comments --- ceno_zkvm/src/scheme/cpu/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index d06c6cd91..388a18ac4 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -54,8 +54,8 @@ pub type TowerRelationOutput = ( Vec>, ); -// implement the IOP proposed in [Quark paper](https://eprint.iacr.org/2020/1275.pdf) -// to accumulate N=2^n EC points into one EC point using affine coordinates +// accumulate N=2^n EC points into one EC point using affine coordinates +// in one layer which borrows ideas from the [Quark paper](https://eprint.iacr.org/2020/1275.pdf) pub struct CpuEccProver; impl CpuEccProver { From 96a181f4c86706157949c39b98766102833d3794 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 30 Sep 2025 15:43:18 +0800 Subject: [PATCH 28/91] refine comments --- ceno_zkvm/src/scheme/cpu/mod.rs | 2 +- ceno_zkvm/src/structs.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 388a18ac4..1ab14e3f1 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -201,7 +201,7 @@ impl CpuEccProver { let evals = state.get_mle_flatten_final_evaluations(); assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); - // 7 for x[b,0], x[b,1], y[b,0], y[b,1], x[1,b], y[1,b], s[0,b] + // 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[0,rt] assert_eq!(evals.len(), 1 + SEPTIC_EXTENSION_DEGREE * 7); #[cfg(feature = "sanity-check")] diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index fc828ecde..c8d37346f 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -32,7 +32,7 @@ use witness::RowMajorMatrix; pub struct EccQuarkProof { pub zerocheck_proof: IOPProof, pub num_vars: usize, - pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1] + pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[0,rt], y[0,rt], s[0,rt] pub sum: SepticPoint, } From 37145699a7f670c9def63eb93905be552dfa819b Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 8 Oct 2025 17:45:38 +0800 Subject: [PATCH 29/91] add ram bus --- Cargo.lock | 22 +++++++++++----------- ceno_emul/src/tracer.rs | 14 ++++++++++++-- ceno_emul/src/vm_state.rs | 4 ++++ ceno_zkvm/src/e2e.rs | 30 +++++++++++++++++++++++++++++- 4 files changed, 56 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c0b55ae35..3bb631d1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1825,7 +1825,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=updates-for-precompiles#0f8ab8141aadd78c69a0a67ab6bd49399563e6b9" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" dependencies = [ "once_cell", "p3", @@ -2614,7 +2614,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=updates-for-precompiles#0f8ab8141aadd78c69a0a67ab6bd49399563e6b9" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" dependencies = [ "bincode", "clap", @@ -2638,7 +2638,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=updates-for-precompiles#0f8ab8141aadd78c69a0a67ab6bd49399563e6b9" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" dependencies = [ "either", "ff_ext", @@ -2959,7 +2959,7 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=updates-for-precompiles#0f8ab8141aadd78c69a0a67ab6bd49399563e6b9" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" dependencies = [ "p3-baby-bear", "p3-challenger", @@ -3368,7 +3368,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=updates-for-precompiles#0f8ab8141aadd78c69a0a67ab6bd49399563e6b9" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" dependencies = [ "ff_ext", "p3", @@ -4308,7 +4308,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=updates-for-precompiles#0f8ab8141aadd78c69a0a67ab6bd49399563e6b9" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" dependencies = [ "cfg-if", "dashu", @@ -4414,7 +4414,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=updates-for-precompiles#0f8ab8141aadd78c69a0a67ab6bd49399563e6b9" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" dependencies = [ "either", "ff_ext", @@ -4432,7 +4432,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=updates-for-precompiles#0f8ab8141aadd78c69a0a67ab6bd49399563e6b9" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" dependencies = [ "itertools 0.13.0", "p3", @@ -4827,7 +4827,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=updates-for-precompiles#0f8ab8141aadd78c69a0a67ab6bd49399563e6b9" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -5099,7 +5099,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=updates-for-precompiles#0f8ab8141aadd78c69a0a67ab6bd49399563e6b9" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" dependencies = [ "bincode", "clap", @@ -5386,7 +5386,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=updates-for-precompiles#0f8ab8141aadd78c69a0a67ab6bd49399563e6b9" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 8280e8351..5195bc85b 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -306,6 +306,7 @@ pub struct Tracer { // (start_addr -> (start_addr, end_addr, min_access_addr, max_access_addr)) mmio_min_max_access: Option>, latest_accesses: HashMap, + next_accesses: HashMap<(WordAddr, Cycle), Cycle>, } impl Default for Tracer { @@ -363,6 +364,7 @@ impl Tracer { ..StepRecord::default() }, latest_accesses: HashMap::new(), + next_accesses: HashMap::new(), } } @@ -471,9 +473,12 @@ impl Tracer { /// - Record the current instruction as the origin of the latest access. /// - Accesses within the same instruction are distinguished by `subcycle ∈ [0, 3]`. pub fn track_access(&mut self, addr: WordAddr, subcycle: Cycle) -> Cycle { - self.latest_accesses + let prev_cycle = self + .latest_accesses .insert(addr, self.record.cycle + subcycle) - .unwrap_or(0) + .unwrap_or(0); + self.next_accesses.insert((addr, prev_cycle), subcycle); + prev_cycle } /// Return all the addresses that were accessed and the cycle when they were last accessed. @@ -481,6 +486,11 @@ impl Tracer { &self.latest_accesses } + /// Return all the addresses that were accessed and the cycle when they were last accessed. + pub fn next_accesses(self) -> HashMap<(WordAddr, Cycle), Cycle> { + self.next_accesses + } + /// Return the cycle of the pending instruction (after the last completed step). pub fn cycle(&self) -> Cycle { self.record.cycle diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index 51057c2b0..eaac9d639 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -68,6 +68,10 @@ impl VMState { &self.tracer } + pub fn take_tracer(self) -> Tracer { + self.tracer + } + pub fn platform(&self) -> &Platform { &self.platform } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 226231c2b..cf0697fd9 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -16,7 +16,7 @@ use crate::{ tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig}, }; use ceno_emul::{ - Addr, ByteAddr, CENO_PLATFORM, EmuContext, InsnKind, IterAddresses, Platform, Program, + Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, IterAddresses, Platform, Program, StepRecord, Tracer, VMState, WORD_SIZE, WordAddr, host_utils::read_all_messages, }; use clap::ValueEnum; @@ -92,6 +92,31 @@ pub struct EmulationResult { pub all_records: Vec, pub final_mem_state: FinalMemState, pub pi: PublicValues, + pub ram_bus: RAMBus, +} + +pub struct RAMBus { + shard_id: usize, + num_shards: usize, + max_cycle: Cycle, + addr_future_accesses: HashMap<(WordAddr, Cycle), Cycle>, +} + +impl RAMBus { + pub fn new( + shard_id: usize, + num_shards: usize, + max_cycle: Cycle, + addr_future_accesses: HashMap<(WordAddr, Cycle), Cycle>, + ) -> Self { + RAMBus { + shard_id, + num_shards, + max_cycle, + addr_future_accesses, + } + } + pub fn send(&mut self) {} } pub fn emulate_program( @@ -270,10 +295,13 @@ pub fn emulate_program( ), ); + let ram_bus = RAMBus::new(0, 1, end_cycle, vm.take_tracer().next_accesses()); + EmulationResult { pi, exit_code, all_records, + ram_bus, final_mem_state: FinalMemState { reg: reg_final, io: io_final, From 0da1c229891f99790d3bbb71ff3c52910957f3a0 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 8 Oct 2025 21:30:11 +0800 Subject: [PATCH 30/91] wip add rambus impl --- ceno_zkvm/src/e2e.rs | 63 +++++++++++++++++----- ceno_zkvm/src/instructions.rs | 4 +- ceno_zkvm/src/instructions/riscv/rv32im.rs | 24 ++++++++- ceno_zkvm/src/structs.rs | 3 ++ 4 files changed, 79 insertions(+), 15 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index cf0697fd9..784def2e1 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -16,18 +16,22 @@ use crate::{ tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig}, }; use ceno_emul::{ - Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, IterAddresses, Platform, Program, - StepRecord, Tracer, VMState, WORD_SIZE, WordAddr, host_utils::read_all_messages, + Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, IterAddresses, MemOp, Platform, + Program, StepRecord, Tracer, VMState, WORD_SIZE, WordAddr, host_utils::read_all_messages, }; use clap::ValueEnum; +use either::Either; use ff_ext::ExtensionField; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; use gkr_iop::hal::ProverBackend; use itertools::{Itertools, MinMaxResult, chain}; -use mpcs::{PolynomialCommitmentScheme, SecurityLevel}; +use mpcs::{PolynomialCommitmentScheme, SecurityLevel, util::arithmetic::div_ceil}; +use multilinear_extensions::util::max_usable_threads; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serde::Serialize; use std::{ + borrow::Cow, collections::{BTreeSet, HashMap, HashSet}, sync::Arc, }; @@ -87,44 +91,78 @@ pub struct FullMemState { type InitMemState = FullMemState; type FinalMemState = FullMemState; -pub struct EmulationResult { +pub struct EmulationResult<'a> { pub exit_code: Option, pub all_records: Vec, pub final_mem_state: FinalMemState, pub pi: PublicValues, - pub ram_bus: RAMBus, + pub ram_bus: RAMBus<'a>, } -pub struct RAMBus { +pub struct RAMBus<'a> { shard_id: usize, num_shards: usize, max_cycle: Cycle, - addr_future_accesses: HashMap<(WordAddr, Cycle), Cycle>, + addr_future_accesses: Cow<'a, HashMap<(WordAddr, Cycle), Cycle>>, + thread_based_record_storage: Either>, &'a mut Vec>, + pub cur_shard_cycle_range: std::ops::Range, } -impl RAMBus { +impl<'a> RAMBus<'a> { pub fn new( shard_id: usize, num_shards: usize, max_cycle: Cycle, addr_future_accesses: HashMap<(WordAddr, Cycle), Cycle>, ) -> Self { + let max_threads = max_usable_threads(); + let max_insts = (max_cycle.div_ceil(4)); + let max_record_per_thread = max_insts.div_ceil(max_threads as u64); + // reserve larger max_record_per_thread even a shard just need smaller space + // TODO optimize mem usage + let thread_based_record_storage = (0..max_threads) + .into_par_iter() + .map(|_| Vec::with_capacity(max_record_per_thread as usize)) + .collect::>(); + let expected_cycles_per_shard = max_cycle.div_ceil(num_shards as u64) as usize; + let cur_shard_cycle_range = + (shard_id * expected_cycles_per_shard..(shard_id + 1) * expected_cycles_per_shard); RAMBus { shard_id, num_shards, max_cycle, - addr_future_accesses, + addr_future_accesses: Cow::Owned(addr_future_accesses), + thread_based_record_storage: Either::Left(thread_based_record_storage), + cur_shard_cycle_range, + } + } + + pub fn get_forked(&mut self) -> Vec { + match &mut self.thread_based_record_storage { + Either::Left(thread_based_record_storage) => thread_based_record_storage + .iter_mut() + .map(|v| RAMBus { + shard_id: self.shard_id, + num_shards: self.num_shards, + max_cycle: self.max_cycle, + addr_future_accesses: Cow::Borrowed(self.addr_future_accesses.as_ref()), + thread_based_record_storage: Either::Right(v), + cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), + }) + .collect_vec(), + Either::Right(_) => panic!("invalid type"), } } + pub fn send(&mut self) {} } -pub fn emulate_program( +pub fn emulate_program<'a>( program: Arc, max_steps: usize, init_mem_state: &InitMemState, platform: &Platform, -) -> EmulationResult { +) -> EmulationResult<'a> { let InitMemState { mem: mem_init, io: io_init, @@ -478,7 +516,7 @@ pub fn generate_fixed_traces( pub fn generate_witness( system_config: &ConstraintSystemConfig, - emul_result: EmulationResult, + mut emul_result: EmulationResult, program: &Program, ) -> ZKVMWitnesses { let mut zkvm_witness = ZKVMWitnesses::default(); @@ -487,6 +525,7 @@ pub fn generate_witness( .config .assign_opcode_circuit( &system_config.zkvm_cs, + &mut emul_result.ram_bus, &mut zkvm_witness, emul_result.all_records, ) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 4591c47e3..da4bb0d90 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -1,5 +1,5 @@ use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, structs::ProgramParams, + circuit_builder::CircuitBuilder, e2e::RAMBus, error::ZKVMError, structs::ProgramParams, tables::RMMCollections, witness::LkMultiplicity, }; use ceno_emul::StepRecord; @@ -101,6 +101,7 @@ pub trait Instruction { ) -> Result<(), ZKVMError>; fn assign_instances( + ram_bus: &mut RAMBus, config: &Self::InstructionConfig, num_witin: usize, num_structural_witin: usize, @@ -131,6 +132,7 @@ pub trait Instruction { let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); let raw_structual_witin_iter = raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); + let ram_bus_forks = ram_bus.get_forks(); raw_witin_iter .zip_eq(raw_structual_witin_iter) diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 7bf3a9d34..2a9d54dc0 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -9,6 +9,7 @@ use crate::instructions::riscv::lui::LuiInstruction; #[cfg(not(feature = "u16limb_circuit"))] use crate::tables::PowTableCircuit; use crate::{ + e2e::RAMBus, error::ZKVMError, instructions::{ Instruction, @@ -400,6 +401,7 @@ impl Rv32imConfig { pub fn assign_opcode_circuit( &self, cs: &ZKVMConstraintSystem, + ram_bus: &mut RAMBus, witness: &mut ZKVMWitnesses, steps: Vec, ) -> Result { @@ -452,6 +454,7 @@ impl Rv32imConfig { ($insn_kind:ident,$instruction:ty,$config:ident) => { witness.assign_opcode_circuit::<$instruction>( cs, + ram_bus, &self.$config, all_records.remove(&($insn_kind)).unwrap(), )?; @@ -511,30 +514,40 @@ impl Rv32imConfig { assign_opcode!(SB, SbInstruction, sb_config); // ecall / halt - witness.assign_opcode_circuit::>(cs, &self.halt_config, halt_records)?; + witness.assign_opcode_circuit::>( + cs, + ram_bus, + &self.halt_config, + halt_records, + )?; witness.assign_opcode_circuit::>( cs, + ram_bus, &self.keccak_config, keccak_records, )?; witness.assign_opcode_circuit::>>( cs, + ram_bus, &self.bn254_add_config, bn254_add_records, )?; witness.assign_opcode_circuit::>>( cs, + ram_bus, &self.bn254_double_config, bn254_double_records, )?; witness.assign_opcode_circuit::>>( cs, + ram_bus, &self.secp256k1_add_config, secp256k1_add_records, )?; witness .assign_opcode_circuit::>>( cs, + ram_bus, &self.secp256k1_double_config, secp256k1_double_records, )?; @@ -653,6 +666,7 @@ impl DummyExtraConfig { pub fn assign_opcode_circuit( &self, cs: &ZKVMConstraintSystem, + ram_bus: &mut RAMBus, witness: &mut ZKVMWitnesses, steps: GroupedSteps, ) -> Result<(), ZKVMError> { @@ -682,35 +696,41 @@ impl DummyExtraConfig { witness.assign_opcode_circuit::>( cs, + ram_bus, &self.secp256k1_decompress_config, secp256k1_decompress_steps, )?; witness.assign_opcode_circuit::>( cs, + ram_bus, &self.sha256_extend_config, sha256_extend_steps, )?; witness.assign_opcode_circuit::>( cs, + ram_bus, &self.bn254_fp_add_config, bn254_fp_add_steps, )?; witness.assign_opcode_circuit::>( cs, + ram_bus, &self.bn254_fp_mul_config, bn254_fp_mul_steps, )?; witness.assign_opcode_circuit::>( cs, + ram_bus, &self.bn254_fp2_add_config, bn254_fp2_add_steps, )?; witness.assign_opcode_circuit::>( cs, + ram_bus, &self.bn254_fp2_mul_config, bn254_fp2_mul_steps, )?; - witness.assign_opcode_circuit::>(cs, &self.ecall_config, other_steps)?; + witness.assign_opcode_circuit::>(cs, ram_bus, &self.ecall_config, other_steps)?; let _ = steps.remove(&INVALID); let keys: Vec<&InsnKind> = steps.keys().collect::>(); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index cd76d6fcd..08fcfcf87 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -1,5 +1,6 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::RAMBus, error::ZKVMError, instructions::Instruction, state::StateCircuit, @@ -310,6 +311,7 @@ impl ZKVMWitnesses { pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, + ram_bus: &mut RAMBus, config: &OC::InstructionConfig, records: Vec, ) -> Result<(), ZKVMError> { @@ -317,6 +319,7 @@ impl ZKVMWitnesses { let cs = cs.get_cs(&OC::name()).unwrap(); let (witness, logup_multiplicity) = OC::assign_instances( + ram_bus, config, cs.zkvm_v1_css.num_witin as usize, cs.zkvm_v1_css.num_structural_witin as usize, From d2e0d51615a64b83b3726c9de7621aa35dcd15b9 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Fri, 10 Oct 2025 16:27:25 +0800 Subject: [PATCH 31/91] add poseidon2 gadget --- ceno_zkvm/Cargo.toml | 2 + ceno_zkvm/src/gadgets/mod.rs | 2 + ceno_zkvm/src/gadgets/poseidon2.rs | 262 +++++++++++++++++++ ceno_zkvm/src/gadgets/poseidon2_constants.rs | 58 ++++ 4 files changed, 324 insertions(+) create mode 100644 ceno_zkvm/src/gadgets/poseidon2.rs create mode 100644 ceno_zkvm/src/gadgets/poseidon2_constants.rs diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index b38303a73..1cf14bd8f 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -30,6 +30,8 @@ sumcheck.workspace = true transcript.workspace = true whir.workspace = true witness.workspace = true +zkhash.workspace = true +lazy_static.workspace = true itertools.workspace = true ndarray.workspace = true diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index 6660a9339..7ee4652af 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -1,5 +1,7 @@ mod div; mod is_lt; +mod poseidon2; +mod poseidon2_constants; mod signed; mod signed_ext; mod signed_limbs; diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs new file mode 100644 index 000000000..62a759550 --- /dev/null +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -0,0 +1,262 @@ +// Poseidon2 over BabyBear field + +use std::{ + borrow::{Borrow, BorrowMut}, + iter::from_fn, + mem::transmute, +}; + +use ff_ext::{BabyBearExt4, ExtensionField}; +use gkr_iop::error::CircuitBuilderError; +use itertools::Itertools; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use num_bigint::BigUint; +use p3::{ + babybear::{BabyBear, BabyBearInternalLayerParameters}, + field::{Field, FieldAlgebra}, + monty_31::InternalLayerBaseParameters, + poseidon2::{MDSMat4, mds_light_permutation}, + poseidon2_air::{FullRound, PartialRound, Poseidon2Cols, SBox, num_cols}, +}; + +use crate::circuit_builder::CircuitBuilder; + +// copied from poseidon2-air/src/constants.rs +// as the original one cannot be accessed here +#[derive(Debug, Clone)] +pub(crate) struct RoundConstants< + F: Field, + const WIDTH: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> { + pub(crate) beginning_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], + pub(crate) partial_round_constants: [F; PARTIAL_ROUNDS], + pub(crate) ending_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], +} + +pub type Poseidon2BabyBearConfig = Poseidon2Config; +pub struct Poseidon2Config< + E: ExtensionField, + const STATE_WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> { + cols: Vec, + constants: RoundConstants, +} + +impl< + E: ExtensionField, + const STATE_WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> Poseidon2Config +{ + // constraints taken from poseidon2_air/src/air.rs + fn eval_sbox( + sbox: &SBox, SBOX_DEGREE, SBOX_REGISTERS>, + x: &mut Expression, + cb: &mut CircuitBuilder, + ) -> Result<(), CircuitBuilderError> { + *x = match (SBOX_DEGREE, SBOX_REGISTERS) { + (3, 0) => x.cube(), + (5, 0) => x.exp_const_u64::<5>(), + (7, 0) => x.exp_const_u64::<7>(), + (5, 1) => { + let committed_x3: Expression = sbox.0[0].clone(); + let x2: Expression = x.square(); + cb.require_zero( + || "x3 = x.cube()", + committed_x3.clone() - x2.clone() * x.clone(), + )?; + committed_x3 * x2 + } + (7, 1) => { + let committed_x3: Expression = sbox.0[0].clone(); + cb.require_zero(|| "x3 = x.cube()", committed_x3.clone() - x.cube())?; + committed_x3.square() * x.clone() + } + _ => panic!( + "Unexpected (SBOX_DEGREE, SBOX_REGISTERS) of ({}, {})", + SBOX_DEGREE, SBOX_REGISTERS + ), + }; + + Ok(()) + } + + fn eval_full_round( + state: &mut [Expression; STATE_WIDTH], + full_round: &FullRound, STATE_WIDTH, SBOX_DEGREE, SBOX_REGISTERS>, + round_constants: &[E::BaseField], + cb: &mut CircuitBuilder, + ) -> Result<(), CircuitBuilderError> { + for (i, (s, r)) in state.iter_mut().zip_eq(round_constants.iter()).enumerate() { + *s = s.clone() + r.expr(); + Self::eval_sbox(&full_round.sbox[i], s, cb)?; + } + Self::external_linear_layer(state); + for (state_i, post_i) in state.iter_mut().zip_eq(full_round.post.iter()) { + cb.require_zero(|| "post_i = state_i", state_i.clone() - post_i)?; + *state_i = post_i.clone(); + } + + Ok(()) + } + + fn eval_partial_round( + state: &mut [Expression; STATE_WIDTH], + partial_round: &PartialRound, STATE_WIDTH, SBOX_DEGREE, SBOX_REGISTERS>, + round_constant: &E::BaseField, + cb: &mut CircuitBuilder, + ) -> Result<(), CircuitBuilderError> { + state[0] = state[0].clone() + round_constant.expr(); + Self::eval_sbox(&partial_round.sbox, &mut state[0], cb)?; + + cb.require_zero( + || "state[0] = post_sbox", + state[0].clone() - partial_round.post_sbox.clone(), + )?; + state[0] = partial_round.post_sbox.clone(); + + Self::internal_linear_layer(state); + + Ok(()) + } + + fn external_linear_layer(state: &mut [Expression; STATE_WIDTH]) { + mds_light_permutation(state, &MDSMat4); + } + + fn internal_linear_layer(state: &mut [Expression; STATE_WIDTH]) { + let sum: Expression = state.iter().map(|s| s.get_monomial_form()).sum(); + // reduce to monomial form + let sum = sum.get_monomial_form(); + let babybear_prime = BigUint::from(0x7800_0001u32); + if E::BaseField::order() == babybear_prime { + // BabyBear + let diag_m1_matrix_bb = + &>:: + INTERNAL_DIAG_MONTY; + let diag_m1_matrix: &[E::BaseField; STATE_WIDTH] = + unsafe { transmute(diag_m1_matrix_bb) }; + for (input, diag_m1) in state.iter_mut().zip_eq(diag_m1_matrix) { + let updated = sum.clone() + Expression::from_f(*diag_m1) * input.clone(); + // reduce to monomial form + *input = updated.get_monomial_form(); + } + } else { + panic!("Unsupported field"); + } + } + + pub fn construct( + cb: &mut CircuitBuilder, + round_constants: RoundConstants< + E::BaseField, + STATE_WIDTH, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + ) -> Self { + let num_cols = + num_cols::( + ); + let cols = from_fn(|| Some(cb.create_witin(|| "poseidon2 col"))) + .take(num_cols) + .collect::>(); + println!("{num_cols}"); + let mut col_exprs = cols + .iter() + .map(|c| c.expr()) + .collect::>>(); + + let poseidon2_cols: &mut Poseidon2Cols< + Expression, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = col_exprs.as_mut_slice().borrow_mut(); + + // external linear layer + Self::external_linear_layer(&mut poseidon2_cols.inputs); + + // eval full round + for round in 0..HALF_FULL_ROUNDS { + Self::eval_full_round( + &mut poseidon2_cols.inputs, + &poseidon2_cols.beginning_full_rounds[round], + &round_constants.beginning_full_round_constants[round], + cb, + ) + .unwrap(); + } + + // eval partial round + for round in 0..PARTIAL_ROUNDS { + Self::eval_partial_round( + &mut poseidon2_cols.inputs, + &poseidon2_cols.partial_rounds[round], + &round_constants.partial_round_constants[round], + cb, + ) + .unwrap(); + } + + // TODO: after the last partial round, each state_i has ~STATE_WIDTH terms + // which will make the next full round to have many terms + + // eval full round + for round in 0..HALF_FULL_ROUNDS { + Self::eval_full_round( + &mut poseidon2_cols.inputs, + &poseidon2_cols.ending_full_rounds[round], + &round_constants.ending_full_round_constants[round], + cb, + ) + .unwrap(); + } + + Poseidon2Config { + cols, + constants: round_constants, + } + } + + pub fn assign_instance(&self, instance: &mut [E]) { + let poseidon2_cols: &Poseidon2Cols< + WitIn, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = self.cols.as_slice().borrow(); + } +} + +#[cfg(test)] +mod tests { + use crate::gadgets::{ + poseidon2::Poseidon2BabyBearConfig, poseidon2_constants::horizen_round_consts, + }; + use ff_ext::BabyBearExt4; + use gkr_iop::circuit_builder::{CircuitBuilder, ConstraintSystem}; + + type E = BabyBearExt4; + #[test] + fn test_poseidon2_gadget() { + let mut cs = ConstraintSystem::new(|| "poseidon2 gadget test"); + let mut cb = CircuitBuilder::::new(&mut cs); + + let poseidon2_constants = horizen_round_consts(); + let poseidon2_config = Poseidon2BabyBearConfig::construct(&mut cb, poseidon2_constants); + } +} diff --git a/ceno_zkvm/src/gadgets/poseidon2_constants.rs b/ceno_zkvm/src/gadgets/poseidon2_constants.rs new file mode 100644 index 000000000..cf807a56d --- /dev/null +++ b/ceno_zkvm/src/gadgets/poseidon2_constants.rs @@ -0,0 +1,58 @@ +// taken from openvm/crates/circuits/poseidon2-air/src/babybear.rs +use super::poseidon2::RoundConstants; +use lazy_static::lazy_static; +use p3::{babybear::BabyBear, field::FieldAlgebra}; +use std::array::from_fn; +use zkhash::{ + ark_ff::PrimeField as _, fields::babybear::FpBabyBear as HorizenBabyBear, + poseidon2::poseidon2_instance_babybear::RC16, +}; + +const BABY_BEAR_POSEIDON2_WIDTH: usize = 16; +const BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS: usize = 4; +const BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS: usize = 13; + +pub(crate) fn horizen_to_p3_babybear(horizen_babybear: HorizenBabyBear) -> BabyBear { + BabyBear::from_canonical_u64(horizen_babybear.into_bigint().0[0]) +} + +pub(crate) fn horizen_round_consts() -> RoundConstants< + BabyBear, + BABY_BEAR_POSEIDON2_WIDTH, + BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS, + BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS, +> { + let p3_rc16: Vec> = RC16 + .iter() + .map(|round| { + round + .iter() + .map(|babybear| horizen_to_p3_babybear(*babybear)) + .collect() + }) + .collect(); + let p_end = BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS + BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS; + + let beginning_full_round_constants: [[BabyBear; BABY_BEAR_POSEIDON2_WIDTH]; + BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS] = from_fn(|i| p3_rc16[i].clone().try_into().unwrap()); + let partial_round_constants: [BabyBear; BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS] = + from_fn(|i| p3_rc16[i + BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS][0]); + let ending_full_round_constants: [[BabyBear; BABY_BEAR_POSEIDON2_WIDTH]; + BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS] = + from_fn(|i| p3_rc16[i + p_end].clone().try_into().unwrap()); + + RoundConstants { + beginning_full_round_constants, + partial_round_constants, + ending_full_round_constants, + } +} + +lazy_static! { + pub static ref BABYBEAR_BEGIN_EXT_CONSTS: [[BabyBear; BABY_BEAR_POSEIDON2_WIDTH]; BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS] = + horizen_round_consts().beginning_full_round_constants; + pub static ref BABYBEAR_PARTIAL_CONSTS: [BabyBear; BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS] = + horizen_round_consts().partial_round_constants; + pub static ref BABYBEAR_END_EXT_CONSTS: [[BabyBear; BABY_BEAR_POSEIDON2_WIDTH]; BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS] = + horizen_round_consts().ending_full_round_constants; +} From 33b47ec3cecae39f510535b94c1faecb96f03523 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 10 Oct 2025 20:32:34 +0800 Subject: [PATCH 32/91] add shardcontext --- ceno_emul/src/lib.rs | 1 + ceno_emul/src/shards.rs | 14 ++ ceno_emul/src/tracer.rs | 8 +- ceno_zkvm/src/bin/e2e.rs | 18 +- ceno_zkvm/src/e2e.rs | 181 +++++++++++++++--- ceno_zkvm/src/instructions.rs | 42 ++-- ceno_zkvm/src/instructions/riscv/arith.rs | 8 +- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 2 + .../riscv/arith_imm/arith_imm_circuit_v2.rs | 4 +- ceno_zkvm/src/instructions/riscv/auipc.rs | 6 +- ceno_zkvm/src/instructions/riscv/b_insn.rs | 9 +- .../riscv/branch/branch_circuit.rs | 4 +- .../riscv/branch/branch_circuit_v2.rs | 4 +- .../src/instructions/riscv/branch/test.rs | 7 + ceno_zkvm/src/instructions/riscv/div.rs | 2 + .../instructions/riscv/div/div_circuit_v2.rs | 6 +- .../instructions/riscv/dummy/dummy_circuit.rs | 20 +- .../instructions/riscv/dummy/dummy_ecall.rs | 8 +- .../src/instructions/riscv/dummy/test.rs | 5 + .../src/instructions/riscv/ecall/halt.rs | 2 + .../src/instructions/riscv/ecall/keccak.rs | 10 +- .../riscv/ecall/weierstrass_add.rs | 12 +- .../riscv/ecall/weierstrass_double.rs | 17 +- .../src/instructions/riscv/ecall_base.rs | 28 ++- ceno_zkvm/src/instructions/riscv/i_insn.rs | 8 +- ceno_zkvm/src/instructions/riscv/im_insn.rs | 10 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 70 ++++++- ceno_zkvm/src/instructions/riscv/j_insn.rs | 6 +- .../src/instructions/riscv/jump/jal_v2.rs | 4 +- .../src/instructions/riscv/jump/jalr_v2.rs | 4 +- ceno_zkvm/src/instructions/riscv/jump/test.rs | 3 + .../instructions/riscv/logic/logic_circuit.rs | 7 +- .../src/instructions/riscv/logic/test.rs | 7 +- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 8 +- .../src/instructions/riscv/logic_imm/test.rs | 2 + ceno_zkvm/src/instructions/riscv/lui.rs | 6 +- .../src/instructions/riscv/memory/load_v2.rs | 4 +- .../src/instructions/riscv/memory/store_v2.rs | 4 +- .../src/instructions/riscv/memory/test.rs | 3 + ceno_zkvm/src/instructions/riscv/mulh.rs | 4 + .../riscv/mulh/mulh_circuit_v2.rs | 4 +- ceno_zkvm/src/instructions/riscv/r_insn.rs | 8 +- ceno_zkvm/src/instructions/riscv/rv32im.rs | 39 ++-- ceno_zkvm/src/instructions/riscv/s_insn.rs | 10 +- ceno_zkvm/src/instructions/riscv/shift.rs | 2 + .../riscv/shift/shift_circuit_v2.rs | 7 +- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 2 + ceno_zkvm/src/instructions/riscv/slt.rs | 2 + .../instructions/riscv/slt/slt_circuit_v2.rs | 6 +- ceno_zkvm/src/instructions/riscv/slti.rs | 2 + .../riscv/slti/slti_circuit_v2.rs | 6 +- ceno_zkvm/src/lib.rs | 1 + ceno_zkvm/src/precompiles/lookup_keccakf.rs | 7 +- .../weierstrass/weierstrass_add.rs | 7 +- .../weierstrass/weierstrass_double.rs | 9 +- ceno_zkvm/src/scheme/tests.rs | 20 +- ceno_zkvm/src/structs.rs | 6 +- gkr_iop/src/lib.rs | 5 +- 58 files changed, 555 insertions(+), 156 deletions(-) create mode 100644 ceno_emul/src/shards.rs diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 8f439d036..38bd6fcb2 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -45,3 +45,4 @@ pub mod utils; pub mod test_utils; pub mod host_utils; +pub mod shards; diff --git a/ceno_emul/src/shards.rs b/ceno_emul/src/shards.rs new file mode 100644 index 000000000..a8d06ab78 --- /dev/null +++ b/ceno_emul/src/shards.rs @@ -0,0 +1,14 @@ +pub struct Shards { + pub shard_id: usize, + pub num_shards: usize, +} + +impl Shards { + pub fn new(shard_id: usize, num_shards: usize) -> Self { + assert!(shard_id < num_shards); + Self { + shard_id, + num_shards, + } + } +} diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 5195bc85b..9dc9a0b12 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -473,11 +473,9 @@ impl Tracer { /// - Record the current instruction as the origin of the latest access. /// - Accesses within the same instruction are distinguished by `subcycle ∈ [0, 3]`. pub fn track_access(&mut self, addr: WordAddr, subcycle: Cycle) -> Cycle { - let prev_cycle = self - .latest_accesses - .insert(addr, self.record.cycle + subcycle) - .unwrap_or(0); - self.next_accesses.insert((addr, prev_cycle), subcycle); + let cur_cycle = self.record.cycle + subcycle; + let prev_cycle = self.latest_accesses.insert(addr, cur_cycle).unwrap_or(0); + self.next_accesses.insert((addr, prev_cycle), cur_cycle); prev_cycle } diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index c7ec2b310..a2c2ffde2 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -1,4 +1,4 @@ -use ceno_emul::{IterAddresses, Platform, Program, WORD_SIZE, Word}; +use ceno_emul::{IterAddresses, Platform, Program, WORD_SIZE, Word, shards::Shards}; use ceno_host::{CenoStdin, memory_from_file}; #[cfg(all(feature = "jemalloc", unix, not(test)))] use ceno_zkvm::print_allocated_bytes; @@ -108,6 +108,14 @@ struct Args { /// The security level to use. #[arg(short, long, value_enum, default_value_t = SecurityLevel::default())] security_level: SecurityLevel, + + // shard id + #[arg(long, default_value = "0")] + shard_id: u32, + + // number of total shards + #[arg(long, default_value = "1")] + num_shards: u32, } fn main() { @@ -240,6 +248,7 @@ fn main() { .unwrap_or_default(); let max_steps = args.max_steps.unwrap_or(usize::MAX); + let shards = Shards::new(args.shard_id as usize, args.num_shards as usize); match (args.pcs, args.field) { (PcsKind::Basefold, FieldType::Goldilocks) => { @@ -249,6 +258,7 @@ fn main() { prover, program, platform, + shards, &hints, &public_io, max_steps, @@ -264,6 +274,7 @@ fn main() { prover, program, platform, + shards, &hints, &public_io, max_steps, @@ -279,6 +290,7 @@ fn main() { prover, program, platform, + shards, &hints, &public_io, max_steps, @@ -294,6 +306,7 @@ fn main() { prover, program, platform, + shards, &hints, &public_io, max_steps, @@ -320,6 +333,7 @@ fn run_inner< pd: PD, program: Program, platform: Platform, + shards: Shards, hints: &[u32], public_io: &[u32], max_steps: usize, @@ -328,7 +342,7 @@ fn run_inner< checkpoint: Checkpoint, ) { let result = run_e2e_with_checkpoint::( - pd, program, platform, hints, public_io, max_steps, checkpoint, + pd, program, platform, shards, hints, public_io, max_steps, checkpoint, ); let zkvm_proof = result diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 784def2e1..89720025e 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -16,23 +16,25 @@ use crate::{ tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig}, }; use ceno_emul::{ - Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, IterAddresses, MemOp, Platform, - Program, StepRecord, Tracer, VMState, WORD_SIZE, WordAddr, host_utils::read_all_messages, + Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, IterAddresses, Platform, Program, + StepRecord, Tracer, VMState, WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, + shards::Shards, }; use clap::ValueEnum; use either::Either; use ff_ext::ExtensionField; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use gkr_iop::hal::ProverBackend; +use gkr_iop::{RAMType, hal::ProverBackend}; use itertools::{Itertools, MinMaxResult, chain}; -use mpcs::{PolynomialCommitmentScheme, SecurityLevel, util::arithmetic::div_ceil}; +use mpcs::{PolynomialCommitmentScheme, SecurityLevel}; use multilinear_extensions::util::max_usable_threads; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serde::Serialize; use std::{ borrow::Cow, - collections::{BTreeSet, HashMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, + mem, sync::Arc, }; use transcript::BasicTranscript as Transcript; @@ -96,52 +98,89 @@ pub struct EmulationResult<'a> { pub all_records: Vec, pub final_mem_state: FinalMemState, pub pi: PublicValues, - pub ram_bus: RAMBus<'a>, + pub shard_ctx: ShardContext<'a>, } -pub struct RAMBus<'a> { +pub enum RAMRecordType { + Read, + Write, +} + +pub struct RAMRecord { + ram_type: RAMRecordType, + id: u64, + addr: WordAddr, + prev_cycle: Cycle, + cycle: Cycle, + prev_value: Option, + value: Word, +} + +pub struct ShardContext<'a> { shard_id: usize, num_shards: usize, max_cycle: Cycle, addr_future_accesses: Cow<'a, HashMap<(WordAddr, Cycle), Cycle>>, - thread_based_record_storage: Either>, &'a mut Vec>, + thread_based_record_storage: Either< + Vec<[BTreeMap; mem::variant_count::()]>, + &'a mut [BTreeMap; mem::variant_count::()], + >, pub cur_shard_cycle_range: std::ops::Range, } -impl<'a> RAMBus<'a> { +impl<'a> Default for ShardContext<'a> { + fn default() -> Self { + let max_threads = max_usable_threads(); + let thread_based_record_storage = (0..max_threads) + .into_par_iter() + .map(|_| std::array::from_fn(|_| BTreeMap::new())) + .collect::>(); + Self { + shard_id: 0, + num_shards: 1, + max_cycle: Cycle::default(), + addr_future_accesses: Cow::Owned(HashMap::new()), + thread_based_record_storage: Either::Left(thread_based_record_storage), + cur_shard_cycle_range: 0..usize::MAX, + } + } +} + +impl<'a> ShardContext<'a> { pub fn new( shard_id: usize, num_shards: usize, - max_cycle: Cycle, + executed_instructions: usize, addr_future_accesses: HashMap<(WordAddr, Cycle), Cycle>, ) -> Self { let max_threads = max_usable_threads(); - let max_insts = (max_cycle.div_ceil(4)); - let max_record_per_thread = max_insts.div_ceil(max_threads as u64); - // reserve larger max_record_per_thread even a shard just need smaller space - // TODO optimize mem usage + // let max_record_per_thread = max_insts.div_ceil(max_threads as u64); + // TODO pre-reserve vector let thread_based_record_storage = (0..max_threads) .into_par_iter() - .map(|_| Vec::with_capacity(max_record_per_thread as usize)) + .map(|_| std::array::from_fn(|_| BTreeMap::new())) .collect::>(); - let expected_cycles_per_shard = max_cycle.div_ceil(num_shards as u64) as usize; - let cur_shard_cycle_range = - (shard_id * expected_cycles_per_shard..(shard_id + 1) * expected_cycles_per_shard); - RAMBus { + + let expected_inst_per_shard = executed_instructions.div_ceil(num_shards) as usize; + let max_cycle = (executed_instructions + 1) * 4; // cycle start from 4 + let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * 4).max(4) + ..((shard_id + 1) * expected_inst_per_shard * 4).min(max_cycle); + + ShardContext { shard_id, num_shards, - max_cycle, + max_cycle: max_cycle as Cycle, addr_future_accesses: Cow::Owned(addr_future_accesses), thread_based_record_storage: Either::Left(thread_based_record_storage), cur_shard_cycle_range, } } - pub fn get_forked(&mut self) -> Vec { + pub fn get_forked(&mut self) -> Vec> { match &mut self.thread_based_record_storage { Either::Left(thread_based_record_storage) => thread_based_record_storage .iter_mut() - .map(|v| RAMBus { + .map(|v| ShardContext { shard_id: self.shard_id, num_shards: self.num_shards, max_cycle: self.max_cycle, @@ -154,7 +193,64 @@ impl<'a> RAMBus<'a> { } } - pub fn send(&mut self) {} + #[inline(always)] + pub fn send( + &mut self, + ram_type: crate::structs::RAMType, + addr: WordAddr, + id: u64, + cycle: Cycle, + prev_cycle: Cycle, + value: Word, + prev_value: Option, + ) { + // check read from external mem bus + if prev_cycle < self.cur_shard_cycle_range.start as Cycle + && cycle >= self.cur_shard_cycle_range.start as Cycle + { + let ram_record = self + .thread_based_record_storage + .as_mut() + .right() + .expect("illegal type"); + ram_record[ram_type as usize].insert( + addr, + RAMRecord { + ram_type: RAMRecordType::Read, + id, + addr, + prev_cycle, + cycle, + prev_value, + value, + }, + ); + } + // check write to external mem bus + if let Some(future_touch_cycle) = self.addr_future_accesses.get(&(addr, cycle)) { + if *future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle + && cycle < self.cur_shard_cycle_range.end as Cycle + { + let ram_record = self + .thread_based_record_storage + .as_mut() + .right() + .expect("illegal type"); + ram_record[ram_type as usize].insert( + addr, + RAMRecord { + ram_type: RAMRecordType::Write, + id, + addr, + prev_cycle, + cycle, + prev_value, + value, + }, + ); + } + } + } } pub fn emulate_program<'a>( @@ -162,6 +258,7 @@ pub fn emulate_program<'a>( max_steps: usize, init_mem_state: &InitMemState, platform: &Platform, + shards: &Shards, ) -> EmulationResult<'a> { let InitMemState { mem: mem_init, @@ -333,13 +430,18 @@ pub fn emulate_program<'a>( ), ); - let ram_bus = RAMBus::new(0, 1, end_cycle, vm.take_tracer().next_accesses()); + let shard_ctx = ShardContext::new( + shards.shard_id, + shards.num_shards, + insts, + vm.take_tracer().next_accesses(), + ); EmulationResult { pi, exit_code, all_records, - ram_bus, + shard_ctx, final_mem_state: FinalMemState { reg: reg_final, io: io_final, @@ -525,14 +627,19 @@ pub fn generate_witness( .config .assign_opcode_circuit( &system_config.zkvm_cs, - &mut emul_result.ram_bus, + &mut emul_result.shard_ctx, &mut zkvm_witness, emul_result.all_records, ) .unwrap(); system_config .dummy_config - .assign_opcode_circuit(&system_config.zkvm_cs, &mut zkvm_witness, dummy_records) + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut emul_result.shard_ctx, + &mut zkvm_witness, + dummy_records, + ) .unwrap(); zkvm_witness.finalize_lk_multiplicities(); @@ -589,6 +696,7 @@ pub type IntermediateState = (Option>, Option { pub program: Arc, pub platform: Platform, + pub shards: Shards, pub static_addrs: Vec, pub pubio_len: usize, pub system_config: ConstraintSystemConfig, @@ -616,7 +724,11 @@ impl> E2ECheckpointResult< } /// Set up a program with the given platform -pub fn setup_program(program: Program, platform: Platform) -> E2EProgramCtx { +pub fn setup_program( + program: Program, + platform: Platform, + shards: Shards, +) -> E2EProgramCtx { let static_addrs = init_static_addrs(&program); let pubio_len = platform.public_io.iter_addresses().len(); let program_params = ProgramParams { @@ -641,6 +753,7 @@ pub fn setup_program(program: Program, platform: Platform) -> E2EProgramCtx { program: Arc::new(program), platform, + shards, static_addrs, pubio_len, system_config, @@ -733,13 +846,14 @@ pub fn run_e2e_with_checkpoint< device: PD, program: Program, platform: Platform, + shards: Shards, hints: &[u32], public_io: &[u32], max_steps: usize, checkpoint: Checkpoint, ) -> E2ECheckpointResult { let start = std::time::Instant::now(); - let ctx = setup_program::(program, platform); + let ctx = setup_program::(program, platform, shards); tracing::debug!("setup_program done in {:?}", start.elapsed()); // Keygen @@ -777,6 +891,7 @@ pub fn run_e2e_with_checkpoint< max_steps, &init_full_mem, &ctx.platform, + &ctx.shards, ); tracing::debug!("emulate done in {:?}", start.elapsed()); @@ -860,7 +975,13 @@ pub fn run_e2e_proof< is_mock_proving: bool, ) -> ZKVMProof { // Emulate program - let emul_result = emulate_program(ctx.program.clone(), max_steps, init_full_mem, &ctx.platform); + let emul_result = emulate_program( + ctx.program.clone(), + max_steps, + init_full_mem, + &ctx.platform, + &ctx.shards, + ); // clone pi before consuming let pi = emul_result.pi.clone(); diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index da4bb0d90..13a3ed22b 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -1,5 +1,5 @@ use crate::{ - circuit_builder::CircuitBuilder, e2e::RAMBus, error::ZKVMError, structs::ProgramParams, + circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, structs::ProgramParams, tables::RMMCollections, witness::LkMultiplicity, }; use ceno_emul::StepRecord; @@ -93,16 +93,17 @@ pub trait Instruction { } // assign single instance giving step from trace - fn assign_instance( + fn assign_instance<'a>( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext<'a>, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError>; fn assign_instances( - ram_bus: &mut RAMBus, config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, steps: Vec, @@ -132,23 +133,32 @@ pub trait Instruction { let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); let raw_structual_witin_iter = raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); - let ram_bus_forks = ram_bus.get_forks(); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(raw_structual_witin_iter) .zip_eq(steps.par_chunks(num_instance_per_batch)) - .flat_map(|((instances, structural_instance), steps)| { - let mut lk_multiplicity = lk_multiplicity.clone(); - instances - .chunks_mut(num_witin) - .zip_eq(structural_instance.chunks_mut(num_structural_witin)) - .zip_eq(steps) - .map(|((instance, structural_instance), step)| { - set_val!(structural_instance, selector_witin, E::BaseField::ONE); - Self::assign_instance(config, instance, &mut lk_multiplicity, step) - }) - .collect::>() - }) + .zip(shard_ctx_vec) + .flat_map( + |(((instances, structural_instance), steps), mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip_eq(structural_instance.chunks_mut(num_structural_witin)) + .zip_eq(steps) + .map(|((instance, structural_instance), step)| { + set_val!(structural_instance, selector_witin, E::BaseField::ONE); + Self::assign_instance( + config, + &mut shard_ctx, + instance, + &mut lk_multiplicity, + step, + ) + }) + .collect::>() + }, + ) .collect::>()?; raw_witin.padding_by_strategy(); diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index b73abcda4..a94024b4a 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -2,8 +2,8 @@ use std::marker::PhantomData; use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; use crate::{ - circuit_builder::CircuitBuilder, error::ZKVMError, instructions::Instruction, - structs::ProgramParams, uint::Value, witness::LkMultiplicity, + circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + instructions::Instruction, structs::ProgramParams, uint::Value, witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; @@ -87,13 +87,14 @@ impl Instruction for ArithInstruction::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs2_read = Value::new_unchecked(step.rs2().unwrap().value); config @@ -186,6 +187,7 @@ mod test { let insn_code = encode_rv32(I::INST_KIND, 2, 3, 4, 0); let (raw_witin, lkm) = ArithInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index a040681bc..4de4069d0 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -21,6 +21,7 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::constants::UInt}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -63,6 +64,7 @@ mod test { let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm); let (raw_witin, lkm) = AddiInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index f969a68b0..8ed175d58 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -70,6 +71,7 @@ impl Instruction for AddiInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -93,7 +95,7 @@ impl Instruction for AddiInstruction { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 7957f7003..3244c5d60 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -4,6 +4,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -142,13 +143,14 @@ impl Instruction for AuipcInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_written = split_to_u8(step.rd().unwrap().value.after); config.rd_written.assign_limbs(instance, &rd_written); @@ -189,6 +191,7 @@ mod tests { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{auipc::AuipcInstruction, constants::UInt}, @@ -239,6 +242,7 @@ mod tests { let insn_code = encode_rv32(InsnKind::AUIPC, 0, 0, 4, imm); let (raw_witin, lkm) = AuipcInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index 798902754..e84d6a1a2 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -5,6 +5,7 @@ use super::constants::PC_STEP_SIZE; use crate::{ chip_handler::{RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, tables::InsnRecord, @@ -12,7 +13,6 @@ use crate::{ }; use ff_ext::FieldInto; use multilinear_extensions::{Expression, ToExpr, WitIn}; - // Opcode: 1100011 // Funct3: // 000 BEQ @@ -89,12 +89,15 @@ impl BInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rs2.assign_instance(instance, lk_multiplicity, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rs2 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Immediate set_val!( diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs index 8aecd50f8..2c97a12ee 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs @@ -6,6 +6,7 @@ use ff_ext::ExtensionField; use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{IsEqualConfig, IsLtConfig, SignedLtConfig}, instructions::{ @@ -137,13 +138,14 @@ impl Instruction for BranchCircuit Result<(), ZKVMError> { config .b_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs1 = Value::new_unchecked(step.rs1().unwrap().value); let rs2 = Value::new_unchecked(step.rs2().unwrap().value); diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 94abb56d1..386d2c286 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, instructions::{ @@ -68,13 +69,14 @@ impl Instruction for BranchCircuit Result<(), ZKVMError> { config .b_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs1 = Value::new_unchecked(step.rs1().unwrap().value); let rs1_limbs = rs1.as_u16_limbs(); diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index aaf468127..82dbcffac 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -6,6 +6,7 @@ use ff_ext::{ExtensionField, GoldilocksExt2}; use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, error::ZKVMError, instructions::Instruction, scheme::mock_prover::{MOCK_PC_START, MockProver}, @@ -39,6 +40,7 @@ fn impl_opcode_beq(equal: bool) { let pc_offset = if equal { 8 } else { PC_STEP_SIZE }; let (raw_witin, lkm) = BeqInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -79,6 +81,7 @@ fn impl_opcode_bne(equal: bool) { let pc_offset = if equal { PC_STEP_SIZE } else { 8 }; let (raw_witin, lkm) = BneInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -122,6 +125,7 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { let insn_code = encode_rv32(InsnKind::BLTU, 2, 3, 0, -8); let (raw_witin, lkm) = BltuInstruction::assign_instances( &config, + &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -166,6 +170,7 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { let insn_code = encode_rv32(InsnKind::BGEU, 2, 3, 0, -8); let (raw_witin, lkm) = BgeuInstruction::assign_instances( &config, + &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -217,6 +222,7 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<() let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, -8); let (raw_witin, lkm) = BltInstruction::assign_instances( &config, + &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( @@ -268,6 +274,7 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<() let insn_code = encode_rv32(InsnKind::BGE, 2, 3, 0, -8); let (raw_witin, lkm) = BgeInstruction::assign_instances( &config, + &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 7ca30d2b8..966320407 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -53,6 +53,7 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{ @@ -179,6 +180,7 @@ mod test { // values assignment let ([raw_witin, _], lkm) = Insn::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index d2d2b78ee..f062ea949 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -12,6 +12,7 @@ use super::{ }; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{Instruction, riscv::constants::LIMB_BITS}, structs::ProgramParams, @@ -372,6 +373,7 @@ impl Instruction for ArithInstruction Instruction for ArithInstruction (true, true), diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index 7c98e2159..e2396942d 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -9,9 +9,9 @@ use super::super::{ insn_base::{ReadMEM, ReadRS1, ReadRS2, StateInOut, WriteMEM, WriteRD}, }; use crate::{ - chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, error::ZKVMError, - instructions::Instruction, structs::ProgramParams, tables::InsnRecord, uint::Value, - witness::LkMultiplicity, + chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, e2e::ShardContext, + error::ZKVMError, instructions::Instruction, structs::ProgramParams, tables::InsnRecord, + uint::Value, witness::LkMultiplicity, }; use ff_ext::FieldInto; use multilinear_extensions::{ToExpr, WitIn}; @@ -70,11 +70,12 @@ impl Instruction for DummyInstruction::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.assign_instance(instance, lk_multiplicity, step) + config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } } @@ -242,6 +243,7 @@ impl DummyConfig { pub(super) fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { @@ -253,19 +255,19 @@ impl DummyConfig { // Registers if let Some((rs1_op, rs1_read)) = &self.rs1 { - rs1_op.assign_instance(instance, lk_multiplicity, step)?; + rs1_op.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs1_val = Value::new_unchecked(step.rs1().expect("rs1 value").value); rs1_read.assign_value(instance, rs1_val); } if let Some((rs2_op, rs2_read)) = &self.rs2 { - rs2_op.assign_instance(instance, lk_multiplicity, step)?; + rs2_op.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs2_val = Value::new_unchecked(step.rs2().expect("rs2 value").value); rs2_read.assign_value(instance, rs2_val); } if let Some((rd_op, rd_written)) = &self.rd { - rd_op.assign_instance(instance, lk_multiplicity, step)?; + rd_op.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_val = Value::new_unchecked(step.rd().expect("rd value").value.after); rd_written.assign_value(instance, rd_val); @@ -284,10 +286,10 @@ impl DummyConfig { mem_after.assign_value(instance, Value::new(mem_op.value.after, lk_multiplicity)); } if let Some(mem_read) = &self.mem_read { - mem_read.assign_instance(instance, lk_multiplicity, step)?; + mem_read.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; } if let Some(mem_write) = &self.mem_write { - mem_write.assign_instance::(instance, lk_multiplicity, step)?; + mem_write.assign_instance::(instance, shard_ctx, lk_multiplicity, step)?; } let imm = InsnRecord::::imm_internal(&step.insn()).1; diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 69bdd1648..9cd5cb0f3 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -8,6 +8,7 @@ use super::{super::insn_base::WriteMEM, dummy_circuit::DummyConfig}; use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -84,6 +85,7 @@ impl Instruction for LargeEcallDummy fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -93,14 +95,14 @@ impl Instruction for LargeEcallDummy // Assign instruction. config .dummy_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; set_val!(instance, config.start_addr, u64::from(ops.mem_ops[0].addr)); // Assign registers. for ((value, writer), op) in config.reg_writes.iter().zip_eq(&ops.reg_ops) { value.assign_value(instance, Value::new_unchecked(op.value.after)); - writer.assign_op(instance, lk_multiplicity, step.cycle(), op)?; + writer.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op)?; } // Assign memory. @@ -112,7 +114,7 @@ impl Instruction for LargeEcallDummy .after .assign_value(instance, Value::new(op.value.after, lk_multiplicity)); set_val!(instance, addr, u64::from(op.addr)); - writer.assign_op(instance, lk_multiplicity, step.cycle(), op)?; + writer.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op)?; } Ok(()) diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index 6f7a89f73..c6f51d142 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -4,6 +4,7 @@ use ff_ext::GoldilocksExt2; use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{arith::AddOp, branch::BeqOp, ecall::EcallDummy}, @@ -34,6 +35,7 @@ fn test_dummy_ecall() { let insn_code = step.insn(); let (raw_witin, lkm) = EcallDummy::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![step], @@ -63,6 +65,7 @@ fn test_dummy_keccak() { let (step, program) = ceno_emul::test_utils::keccak_step(); let (raw_witin, lkm) = KeccakDummy::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![step], @@ -90,6 +93,7 @@ fn test_dummy_r() { let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); let (raw_witin, lkm) = AddDummy::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -125,6 +129,7 @@ fn test_dummy_b() { let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8); let (raw_witin, lkm) = BeqDummy::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_b_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index e14585727..bf38a67c4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -1,6 +1,7 @@ use crate::{ chip_handler::{RegisterChipOperations, general::PublicIOQuery}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, instructions::{ @@ -70,6 +71,7 @@ impl Instruction for HaltInstruction { fn assign_instance( config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index b0ac2a505..2d0e0c2fd 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -21,6 +21,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -156,6 +157,7 @@ impl Instruction for KeccakInstruction { fn assign_instance( _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, _instance: &mut [::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -165,6 +167,7 @@ impl Instruction for KeccakInstruction { fn assign_instances( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, steps: Vec, @@ -196,11 +199,13 @@ impl Instruction for KeccakInstruction { // each instance are composed of KECCAK_ROUNDS.next_power_of_two() let raw_witin_iter = raw_witin .par_batch_iter_mut(num_instance_per_batch * KECCAK_ROUNDS.next_power_of_two()); + let shard_ctx_vec = shard_ctx.get_forked(); // 1st pass: assign witness outside of gkr-iop scope raw_witin_iter .zip_eq(steps.par_chunks(num_instance_per_batch)) - .flat_map(|(instances, steps)| { + .zip(shard_ctx_vec) + .flat_map(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances @@ -222,6 +227,7 @@ impl Instruction for KeccakInstruction { config.ecall_id.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &WriteOp::new_register_op( @@ -238,6 +244,7 @@ impl Instruction for KeccakInstruction { )?; config.state_ptr.0.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &ops.reg_ops[0], @@ -246,6 +253,7 @@ impl Instruction for KeccakInstruction { for (writer, op) in config.mem_rw.iter().zip_eq(&ops.mem_ops) { writer.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), op, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 6d2a87fc5..bd2ca016a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -24,6 +24,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -207,6 +208,7 @@ impl Instruction fn assign_instance( _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, _instance: &mut [::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -216,6 +218,7 @@ impl Instruction fn assign_instances( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, steps: Vec, @@ -255,11 +258,13 @@ impl Instruction ); let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); // 1st pass: assign witness outside of gkr-iop scope raw_witin_iter .zip_eq(steps.par_chunks(num_instance_per_batch)) - .flat_map(|(instances, steps)| { + .zip(shard_ctx_vec) + .flat_map(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances @@ -273,6 +278,7 @@ impl Instruction config.ecall_id.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &WriteOp::new_register_op( @@ -289,6 +295,7 @@ impl Instruction )?; config.point_ptr_0.0.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &ops.reg_ops[0], @@ -301,12 +308,13 @@ impl Instruction )?; config.point_ptr_1.0.assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, step.cycle(), &ops.reg_ops[1], )?; for (writer, op) in config.mem_rw.iter().zip_eq(&ops.mem_ops) { - writer.assign_op(instance, &mut lk_multiplicity, step.cycle(), op)?; + writer.assign_op(instance, &mut shard_ctx, &mut lk_multiplicity, step.cycle(), op)?; } // fetch lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 0bd9736fb..1281eba7d 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -24,6 +24,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -179,6 +180,7 @@ impl Instruction::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -188,6 +190,7 @@ impl Instruction, @@ -227,11 +230,13 @@ impl Instruction Instruction Instruction OpFixedRS OpFixedRS IInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rd + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 5fa6cd501..567833b2f 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -7,6 +7,7 @@ use crate::{ witness::LkMultiplicity, }; +use crate::e2e::ShardContext; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use multilinear_extensions::{Expression, ToExpr}; @@ -67,14 +68,17 @@ impl IMInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rd + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; self.mem_read - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 43a72f739..03c654f98 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -10,8 +10,10 @@ use crate::{ RegisterChipOperations, RegisterExpr, }, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, + structs::RAMType, uint::Value, witness::{LkMultiplicity, set_val}, }; @@ -106,6 +108,7 @@ impl ReadRS1 { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { @@ -120,6 +123,15 @@ impl ReadRS1 { op.previous_cycle, step.cycle() + Tracer::SUBCYCLE_RS1, )?; + shard_ctx.send( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS1, + op.previous_cycle, + op.value, + None, + ); Ok(()) } @@ -160,6 +172,7 @@ impl ReadRS2 { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { @@ -175,6 +188,16 @@ impl ReadRS2 { step.cycle() + Tracer::SUBCYCLE_RS2, )?; + shard_ctx.send( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS2, + op.previous_cycle, + op.value, + None, + ); + Ok(()) } } @@ -216,16 +239,18 @@ impl WriteRD { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.rd().expect("rd op"); - self.assign_op(instance, lk_multiplicity, step.cycle(), &op) + self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), &op) } pub fn assign_op( &self, instance: &mut [E::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, cycle: Cycle, op: &WriteOp, @@ -246,6 +271,15 @@ impl WriteRD { op.previous_cycle, cycle + Tracer::SUBCYCLE_RD, )?; + shard_ctx.send( + RAMType::Register, + op.addr, + op.register_index() as u64, + cycle + Tracer::SUBCYCLE_RD, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); Ok(()) } @@ -284,24 +318,32 @@ impl ReadMEM { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { + let op = step.memory_op().unwrap(); // Memory state - set_val!( - instance, - self.prev_ts, - step.memory_op().unwrap().previous_cycle - ); + set_val!(instance, self.prev_ts, op.previous_cycle); // Memory read self.lt_cfg.assign_instance( instance, lk_multiplicity, - step.memory_op().unwrap().previous_cycle, + op.previous_cycle, step.cycle() + Tracer::SUBCYCLE_MEM, )?; + shard_ctx.send( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + step.cycle() + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + None, + ); + Ok(()) } } @@ -337,16 +379,18 @@ impl WriteMEM { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.memory_op().unwrap(); - self.assign_op(instance, lk_multiplicity, step.cycle(), &op) + self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), &op) } pub fn assign_op( &self, instance: &mut [F], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, cycle: Cycle, op: &WriteOp, @@ -360,6 +404,16 @@ impl WriteMEM { cycle + Tracer::SUBCYCLE_MEM, )?; + shard_ctx.send( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + cycle + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); + Ok(()) } } diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index 156aa1cd1..81a954893 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -4,13 +4,13 @@ use ff_ext::ExtensionField; use crate::{ chip_handler::{RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{StateInOut, WriteRD}, tables::InsnRecord, witness::LkMultiplicity, }; use multilinear_extensions::ToExpr; - // Opcode: 1101111 /// This config handles the common part of the J-type instruction (JAL): @@ -55,11 +55,13 @@ impl JInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { self.vm_state.assign_instance(instance, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; + self.rd + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch the instruction. lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 0f67be424..545adf275 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -4,6 +4,7 @@ use ff_ext::ExtensionField; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -88,13 +89,14 @@ impl Instruction for JalInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config .j_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_written = split_to_u8(step.rd().unwrap().value.after); config.rd_written.assign_limbs(instance, &rd_written); diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index bfec3a099..7f23ac9b6 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -5,6 +5,7 @@ use crate::{ Value, chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -135,6 +136,7 @@ impl Instruction for JalrInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, @@ -177,7 +179,7 @@ impl Instruction for JalrInstruction { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index 0b379f250..899e5a035 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -2,6 +2,7 @@ use super::{JalInstruction, JalrInstruction}; use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::constants::UInt}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -42,6 +43,7 @@ fn verify_test_opcode_jal(pc_offset: i32) { let insn_code = encode_rv32(InsnKind::JAL, 0, 0, 4, pc_offset); let (raw_witin, lkm) = JalInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_j_instruction( @@ -117,6 +119,7 @@ fn verify_test_opcode_jalr(rs1_read: Word, imm: i32) { let (raw_witin, lkm) = JalrInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index f761f6102..5a2d8e404 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -6,6 +6,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -53,6 +54,7 @@ impl Instruction for LogicInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -63,7 +65,7 @@ impl Instruction for LogicInstruction { step.rs2().unwrap().value as u64, ); - config.assign_instance(instance, lk_multiplicity, step) + config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } } @@ -106,11 +108,12 @@ impl LogicConfig { fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { self.r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rs1_read = split_to_u8(step.rs1().unwrap().value); self.rs1_read.assign_limbs(instance, &rs1_read); diff --git a/ceno_zkvm/src/instructions/riscv/logic/test.rs b/ceno_zkvm/src/instructions/riscv/logic/test.rs index dc01487d9..f68135c72 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/test.rs @@ -1,16 +1,16 @@ use ceno_emul::{Change, StepRecord, Word, encode_rv32}; use ff_ext::GoldilocksExt2; +use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::constants::UInt8}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, utils::split_to_u8, }; -use super::*; - const A: Word = 0xbead1010; const B: Word = 0xef552020; @@ -32,6 +32,7 @@ fn test_opcode_and() { let insn_code = encode_rv32(InsnKind::AND, 2, 3, 4, 0); let (raw_witin, lkm) = AndInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -74,6 +75,7 @@ fn test_opcode_or() { let insn_code = encode_rv32(InsnKind::OR, 2, 3, 4, 0); let (raw_witin, lkm) = OrInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -116,6 +118,7 @@ fn test_opcode_xor() { let insn_code = encode_rv32(InsnKind::XOR, 2, 3, 4, 0); let (raw_witin, lkm) = XorInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index c72f31efe..b48af7f5f 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -7,6 +7,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -94,6 +95,7 @@ impl Instruction for LogicInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, @@ -115,7 +117,7 @@ impl Instruction for LogicInstruction { imm_hi.into(), ); - config.assign_instance(instance, lkm, step) + config.assign_instance(instance, shard_ctx, lkm, step) } } @@ -163,11 +165,13 @@ impl LogicConfig { fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { let num_limbs = LIMB_BITS / 8; - self.i_insn.assign_instance(instance, lkm, step)?; + self.i_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1_read = split_to_u8(step.rs1().unwrap().value); self.rs1_read.assign_limbs(instance, &rs1_read); diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs index 23aa2d77c..68032fd41 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs @@ -4,6 +4,7 @@ use gkr_iop::circuit_builder::DebugIndex; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{ @@ -70,6 +71,7 @@ fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_w let insn_code = encode_rv32u(I::INST_KIND, 2, 0, 4, imm); let (raw_witin, lkm) = LogicInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 2cc280f04..198bafbc5 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -4,6 +4,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -88,13 +89,14 @@ impl Instruction for LuiInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_written = split_to_u8(step.rd().unwrap().value.after); for (val, witin) in izip!(rd_written.iter().skip(1), config.rd_written) { @@ -117,6 +119,7 @@ mod tests { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{constants::UInt, lui::LuiInstruction}, @@ -153,6 +156,7 @@ mod tests { let insn_code = encode_rv32(InsnKind::LUI, 0, 0, 4, imm); let (raw_witin, lkm) = LuiInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 1973e48ea..812e4020a 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -184,6 +185,7 @@ impl Instruction for LoadInstruction Instruction for LoadInstruction Instruction fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -147,7 +149,7 @@ impl Instruction let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config .s_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; config.rs1_read.assign_value(instance, rs1); config.rs2_read.assign_value(instance, rs2); set_val!(instance, config.imm, imm.1); diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index 90c5a0273..b2a04326b 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{ @@ -102,6 +103,7 @@ fn impl_opcode_store::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -217,6 +219,7 @@ mod test { let insn_code = encode_rv32(InsnKind::MULH, 2, 3, 4, 0); let (raw_witin, lkm) = MulhInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( @@ -300,6 +303,7 @@ mod test { let insn_code = encode_rv32(InsnKind::MULHSU, 2, 3, 4, 0); let (raw_witin, lkm) = MulhsuInstruction::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index c1853d7a8..a94f63e74 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -19,6 +19,7 @@ use multilinear_extensions::{Expression, ToExpr as _, WitIn}; use p3::field::{Field, FieldAlgebra}; use witness::set_val; +use crate::e2e::ShardContext; use itertools::Itertools; use std::{array, marker::PhantomData}; @@ -223,6 +224,7 @@ impl Instruction for MulhInstructionBas fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -241,7 +243,7 @@ impl Instruction for MulhInstructionBas // R-type instruction config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let (rd_high, rd_low, carry, rs1_ext, rs2_ext) = run_mulh::( I::INST_KIND, diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index 540ccaffe..a44e6757d 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -4,6 +4,7 @@ use ff_ext::ExtensionField; use crate::{ chip_handler::{RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, tables::InsnRecord, @@ -63,13 +64,14 @@ impl RInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rs2.assign_instance(instance, lk_multiplicity, step)?; - self.rd.assign_instance(instance, lk_multiplicity, step)?; + self.rs1.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rs2.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rd.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 2a9d54dc0..cc9810d45 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -9,7 +9,7 @@ use crate::instructions::riscv::lui::LuiInstruction; #[cfg(not(feature = "u16limb_circuit"))] use crate::tables::PowTableCircuit; use crate::{ - e2e::RAMBus, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -401,7 +401,7 @@ impl Rv32imConfig { pub fn assign_opcode_circuit( &self, cs: &ZKVMConstraintSystem, - ram_bus: &mut RAMBus, + shard_ctx: &mut ShardContext, witness: &mut ZKVMWitnesses, steps: Vec, ) -> Result { @@ -454,7 +454,7 @@ impl Rv32imConfig { ($insn_kind:ident,$instruction:ty,$config:ident) => { witness.assign_opcode_circuit::<$instruction>( cs, - ram_bus, + shard_ctx, &self.$config, all_records.remove(&($insn_kind)).unwrap(), )?; @@ -516,38 +516,38 @@ impl Rv32imConfig { // ecall / halt witness.assign_opcode_circuit::>( cs, - ram_bus, + shard_ctx, &self.halt_config, halt_records, )?; witness.assign_opcode_circuit::>( cs, - ram_bus, + shard_ctx, &self.keccak_config, keccak_records, )?; witness.assign_opcode_circuit::>>( cs, - ram_bus, + shard_ctx, &self.bn254_add_config, bn254_add_records, )?; witness.assign_opcode_circuit::>>( cs, - ram_bus, + shard_ctx, &self.bn254_double_config, bn254_double_records, )?; witness.assign_opcode_circuit::>>( cs, - ram_bus, + shard_ctx, &self.secp256k1_add_config, secp256k1_add_records, )?; witness .assign_opcode_circuit::>>( cs, - ram_bus, + shard_ctx, &self.secp256k1_double_config, secp256k1_double_records, )?; @@ -666,7 +666,7 @@ impl DummyExtraConfig { pub fn assign_opcode_circuit( &self, cs: &ZKVMConstraintSystem, - ram_bus: &mut RAMBus, + shard_ctx: &mut ShardContext, witness: &mut ZKVMWitnesses, steps: GroupedSteps, ) -> Result<(), ZKVMError> { @@ -696,41 +696,46 @@ impl DummyExtraConfig { witness.assign_opcode_circuit::>( cs, - ram_bus, + shard_ctx, &self.secp256k1_decompress_config, secp256k1_decompress_steps, )?; witness.assign_opcode_circuit::>( cs, - ram_bus, + shard_ctx, &self.sha256_extend_config, sha256_extend_steps, )?; witness.assign_opcode_circuit::>( cs, - ram_bus, + shard_ctx, &self.bn254_fp_add_config, bn254_fp_add_steps, )?; witness.assign_opcode_circuit::>( cs, - ram_bus, + shard_ctx, &self.bn254_fp_mul_config, bn254_fp_mul_steps, )?; witness.assign_opcode_circuit::>( cs, - ram_bus, + shard_ctx, &self.bn254_fp2_add_config, bn254_fp2_add_steps, )?; witness.assign_opcode_circuit::>( cs, - ram_bus, + shard_ctx, &self.bn254_fp2_mul_config, bn254_fp2_mul_steps, )?; - witness.assign_opcode_circuit::>(cs, ram_bus, &self.ecall_config, other_steps)?; + witness.assign_opcode_circuit::>( + cs, + shard_ctx, + &self.ecall_config, + other_steps, + )?; let _ = steps.remove(&INVALID); let keys: Vec<&InsnKind> = steps.keys().collect::>(); diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index f46cf4c5d..23a5ff810 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -1,6 +1,7 @@ use crate::{ chip_handler::{AddressExpr, MemoryExpr, RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, tables::InsnRecord, @@ -73,14 +74,17 @@ impl SInstructionConfig { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, lk_multiplicity, step)?; - self.rs2.assign_instance(instance, lk_multiplicity, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rs2 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; self.mem_write - .assign_instance::(instance, lk_multiplicity, step)?; + .assign_instance::(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 0c53f1a4c..d09b98c89 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -45,6 +45,7 @@ mod tests { use crate::utils::split_to_u8; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::RIVInstruction}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -173,6 +174,7 @@ mod tests { let (raw_witin, lkm) = ShiftLogicalInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 4e929670c..fac05279e 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,3 +1,4 @@ +use crate::e2e::ShardContext; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ instructions::{ @@ -321,6 +322,7 @@ impl Instruction for ShiftLogicalInstru fn assign_instance( config: &ShiftRTypeConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut crate::witness::LkMultiplicity, step: &ceno_emul::StepRecord, @@ -352,7 +354,7 @@ impl Instruction for ShiftLogicalInstru ); config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } @@ -419,6 +421,7 @@ impl Instruction for ShiftImmInstructio fn assign_instance( config: &ShiftImmConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut crate::witness::LkMultiplicity, step: &ceno_emul::StepRecord, @@ -449,7 +452,7 @@ impl Instruction for ShiftImmInstructio ); config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 4cf7ac155..1757a0fc7 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -43,6 +43,7 @@ mod test { use crate::utils::split_to_u8; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::RIVInstruction}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -170,6 +171,7 @@ mod test { let (raw_witin, lkm) = ShiftImmInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index 7b27617ad..3ba12bb39 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -38,6 +38,7 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{Instruction, riscv::constants::UInt}, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, @@ -72,6 +73,7 @@ mod test { let insn_code = encode_rv32(I::INST_KIND, 2, 3, 4, 0); let (raw_witin, lkm) = SetLessThanInstruction::<_, I>::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_r_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index 391dffb89..cd0b97ce4 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, instructions::{ @@ -75,11 +76,14 @@ impl Instruction for SetLessThanInstruc fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.r_insn.assign_instance(instance, lkm, step)?; + config + .r_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1 = step.rs1().unwrap().value; let rs2 = step.rs2().unwrap().value; diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 5802c4229..ff3a78043 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -35,6 +35,7 @@ mod test { use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, instructions::{ Instruction, riscv::{ @@ -185,6 +186,7 @@ mod test { let (raw_witin, lkm) = SetLessThanImmInstruction::::assign_instances( &config, + &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, vec![StepRecord::new_i_instruction( diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 1085561fb..914424247 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, instructions::{ @@ -92,11 +93,14 @@ impl Instruction for SetLessThanImmInst fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.i_insn.assign_instance(instance, lkm, step)?; + config + .i_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1 = step.rs1().unwrap().value; let rs1_value = Value::new_unchecked(rs1 as Word); diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 16e7ee821..a72c0ffe6 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -1,6 +1,7 @@ #![deny(clippy::cargo)] #![feature(box_patterns)] #![feature(stmt_expr_attributes)] +#![feature(variant_count)] pub mod error; pub mod instructions; diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 2fcd8de79..e1823bc70 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -40,6 +40,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, + e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{StateInOut, WriteMEM}, precompiles::{ @@ -1025,6 +1026,7 @@ pub fn run_faster_keccakf verify: bool, test_outputs: bool, ) -> Result, BackendError> { + let mut shard_ctx = ShardContext::default(); let num_instances = states.len(); let num_instances_rounds = num_instances * ROUNDS.next_power_of_two(); let log2_num_instance_rounds = ceil_log2(num_instances_rounds); @@ -1073,9 +1075,11 @@ pub fn run_faster_keccakf ); let raw_witin_iter = phase1_witness.par_batch_iter_mut(num_instance_per_batch * ROUNDS.next_power_of_two()); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(instances.par_chunks(num_instance_per_batch)) - .for_each(|(instances, steps)| { + .zip(shard_ctx_vec) + .for_each(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin as usize * ROUNDS.next_power_of_two()) @@ -1095,6 +1099,7 @@ pub fn run_faster_keccakf mem_config .assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, 10, &MemOp { diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index 365f8632e..f9c76fbf1 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -63,6 +63,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, + e2e::ShardContext, error::ZKVMError, gadgets::{FieldOperation, field_op::FieldOpCols}, instructions::riscv::insn_base::{StateInOut, WriteMEM}, @@ -561,6 +562,7 @@ pub fn run_weierstrass_add< verify: bool, test_outputs: bool, ) -> Result, BackendError> { + let mut shard_ctx = ShardContext::default(); let num_instances = points.len(); let log2_num_instance = ceil_log2(num_instances); let num_threads = optimal_sumcheck_threads(log2_num_instance); @@ -593,9 +595,11 @@ pub fn run_weierstrass_add< InstancePaddingStrategy::Default, ); let raw_witin_iter = phase1_witness.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(instances.par_chunks(num_instance_per_batch)) - .for_each(|(instances, steps)| { + .zip(shard_ctx_vec) + .for_each(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin as usize) @@ -612,6 +616,7 @@ pub fn run_weierstrass_add< mem_config .assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, 10, &MemOp { diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index 908ef2897..3922b6c22 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -64,6 +64,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, + e2e::ShardContext, error::ZKVMError, gadgets::{FieldOperation, field_op::FieldOpCols}, instructions::riscv::insn_base::{StateInOut, WriteMEM}, @@ -566,6 +567,7 @@ pub fn run_weierstrass_double< verify: bool, test_outputs: bool, ) -> Result, BackendError> { + let mut shard_ctx = ShardContext::default(); let num_instances = points.len(); let log2_num_instance = ceil_log2(num_instances); let num_threads = optimal_sumcheck_threads(log2_num_instance); @@ -595,9 +597,11 @@ pub fn run_weierstrass_double< InstancePaddingStrategy::Default, ); let raw_witin_iter = phase1_witness.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter - .zip(instances.par_chunks(num_instance_per_batch)) - .for_each(|(instances, steps)| { + .zip_eq(instances.par_chunks(num_instance_per_batch)) + .zip(shard_ctx_vec) + .for_each(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin as usize) @@ -616,6 +620,7 @@ pub fn run_weierstrass_double< mem_config .assign_op( instance, + &mut shard_ctx, &mut lk_multiplicity, 10, &MemOp { diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 60dff6a99..f7970b413 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -42,7 +42,7 @@ use super::{ utils::infer_tower_product_witness, verifier::{TowerVerify, ZKVMVerifier}, }; -use crate::tables::DynamicRangeTableCircuit; +use crate::{e2e::ShardContext, tables::DynamicRangeTableCircuit}; use itertools::Itertools; use mpcs::{ PolynomialCommitmentScheme, SecurityLevel, SecurityLevel::Conjecture100bits, WhirDefault, @@ -90,6 +90,7 @@ impl Instruction for Test fn assign_instance( config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, @@ -118,6 +119,7 @@ fn test_rw_lk_expression_combination() { let name = TestCircuit::::name(); let mut zkvm_cs = ZKVMConstraintSystem::default(); let config = zkvm_cs.register_opcode_circuit::>(); + let mut shard_ctx = ShardContext::default(); // generate fixed traces let mut zkvm_fixed_traces = ZKVMFixedTraces::default(); @@ -140,6 +142,7 @@ fn test_rw_lk_expression_combination() { zkvm_witness .assign_opcode_circuit::>( &zkvm_cs, + &mut shard_ctx, &config, vec![StepRecord::default(); num_instances], ) @@ -274,6 +277,7 @@ fn test_single_add_instance_e2e() { Pcs::setup(1 << MAX_NUM_VARIABLES, SecurityLevel::default()).expect("Basefold PCS setup"); let (pp, vp) = Pcs::trim((), 1 << MAX_NUM_VARIABLES).expect("Basefold trim"); let mut zkvm_cs = ZKVMConstraintSystem::default(); + let mut shard_ctx = ShardContext::default(); // opcode circuits let add_config = zkvm_cs.register_opcode_circuit::>(); let halt_config = zkvm_cs.register_opcode_circuit::>(); @@ -339,10 +343,20 @@ fn test_single_add_instance_e2e() { let mut zkvm_witness = ZKVMWitnesses::default(); // assign opcode circuits zkvm_witness - .assign_opcode_circuit::>(&zkvm_cs, &add_config, add_records) + .assign_opcode_circuit::>( + &zkvm_cs, + &mut shard_ctx, + &add_config, + add_records, + ) .unwrap(); zkvm_witness - .assign_opcode_circuit::>(&zkvm_cs, &halt_config, halt_records) + .assign_opcode_circuit::>( + &zkvm_cs, + &mut shard_ctx, + &halt_config, + halt_records, + ) .unwrap(); zkvm_witness.finalize_lk_multiplicities(); zkvm_witness diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 08fcfcf87..04842f349 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -1,6 +1,6 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - e2e::RAMBus, + e2e::ShardContext, error::ZKVMError, instructions::Instruction, state::StateCircuit, @@ -311,7 +311,7 @@ impl ZKVMWitnesses { pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, - ram_bus: &mut RAMBus, + shard_ctx: &mut ShardContext, config: &OC::InstructionConfig, records: Vec, ) -> Result<(), ZKVMError> { @@ -319,8 +319,8 @@ impl ZKVMWitnesses { let cs = cs.get_cs(&OC::name()).unwrap(); let (witness, logup_multiplicity) = OC::assign_instances( - ram_bus, config, + shard_ctx, cs.zkvm_v1_css.num_witin as usize, cs.zkvm_v1_css.num_structural_witin as usize, records, diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index fc69037ff..b1bffbb08 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -7,6 +7,7 @@ use either::Either; use ff_ext::ExtensionField; use multilinear_extensions::{Expression, impl_expr_from_unsigned, mle::ArcMultilinearExtension}; use std::marker::PhantomData; +use strum_macros::EnumIter; use transcript::Transcript; use witness::RowMajorMatrix; @@ -77,10 +78,10 @@ pub struct ProtocolVerifier, PCS>( PhantomData<(E, Trans, PCS)>, ); -#[derive(Clone, Debug, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Copy, EnumIter, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[repr(usize)] pub enum RAMType { - GlobalState, + GlobalState = 0, Register, Memory, } From 7248eec7c1c0794a6ede6e07e3ad25844d4ccaa3 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 13 Oct 2025 10:47:18 +0800 Subject: [PATCH 33/91] global chip wip --- ceno_zkvm/src/gadgets/mod.rs | 2 + ceno_zkvm/src/gadgets/poseidon2.rs | 29 +++++-- ceno_zkvm/src/instructions.rs | 1 + ceno_zkvm/src/instructions/global.rs | 118 +++++++++++++++++++++++++++ 4 files changed, 145 insertions(+), 5 deletions(-) create mode 100644 ceno_zkvm/src/instructions/global.rs diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index 7ee4652af..8e8e9e0a3 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -11,6 +11,8 @@ pub use gkr_iop::gadgets::{ AssertLtConfig, InnerLtConfig, IsEqualConfig, IsLtConfig, IsZeroConfig, cal_lt_diff, }; pub use is_lt::{AssertSignedLtConfig, SignedLtConfig}; +pub use poseidon2::{Poseidon2BabyBearConfig, Poseidon2Config}; +pub(crate) use poseidon2_constants::horizen_round_consts; pub use signed::Signed; pub use signed_ext::SignedExtendConfig; pub use signed_limbs::{UIntLimbsLT, UIntLimbsLTConfig}; diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index 62a759550..021513ac2 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -16,7 +16,7 @@ use p3::{ field::{Field, FieldAlgebra}, monty_31::InternalLayerBaseParameters, poseidon2::{MDSMat4, mds_light_permutation}, - poseidon2_air::{FullRound, PartialRound, Poseidon2Cols, SBox, num_cols}, + poseidon2_air::{FullRound, PartialRound, Poseidon2Cols, SBox, generate_trace_rows, num_cols}, }; use crate::circuit_builder::CircuitBuilder; @@ -78,6 +78,10 @@ impl< } (7, 1) => { let committed_x3: Expression = sbox.0[0].clone(); + // TODO: avoid x^3 as x may have ~STATE_WIDTH terms after the linear layer + // we can allocate one more column to store x^2 (which has ~STATE_WIDTH^2 terms) + // then x^3 = x * x^2 + // but this will increase the number of columns (by FULL_ROUNDS * STATE_WIDTH + PARTIAL_ROUNDS) cb.require_zero(|| "x3 = x.cube()", committed_x3.clone() - x.cube())?; committed_x3.square() * x.clone() } @@ -170,7 +174,6 @@ impl< let cols = from_fn(|| Some(cb.create_witin(|| "poseidon2 col"))) .take(num_cols) .collect::>(); - println!("{num_cols}"); let mut col_exprs = cols .iter() .map(|c| c.expr()) @@ -230,16 +233,32 @@ impl< } } - pub fn assign_instance(&self, instance: &mut [E]) { + pub fn inputs(&self) -> Vec> { + let col_exprs = self.cols.iter().map(|c| c.expr()).collect::>(); + let poseidon2_cols: &Poseidon2Cols< - WitIn, + Expression, STATE_WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS, - > = self.cols.as_slice().borrow(); + > = col_exprs.as_slice().borrow(); + + poseidon2_cols.inputs.to_vec() } + + // pub fn assign_instance(&self, input: &[E; STATE_WIDTH]) { + // generate_trace_rows(inputs, constants) + // let poseidon2_cols: &Poseidon2Cols< + // WitIn, + // STATE_WIDTH, + // SBOX_DEGREE, + // SBOX_REGISTERS, + // HALF_FULL_ROUNDS, + // PARTIAL_ROUNDS, + // > = self.cols.as_slice().borrow(); + // } } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 4591c47e3..d546eaa04 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -19,6 +19,7 @@ use rayon::{ }; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_val}; +pub mod global; pub mod riscv; pub trait Instruction { diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs new file mode 100644 index 000000000..949596da5 --- /dev/null +++ b/ceno_zkvm/src/instructions/global.rs @@ -0,0 +1,118 @@ +use crate::gadgets::{Poseidon2BabyBearConfig, horizen_round_consts}; +use ff_ext::{BabyBearExt4, ExtensionField}; +use gkr_iop::{circuit_builder::CircuitBuilder, error::CircuitBuilderError}; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use p3::field::FieldAlgebra; + +use crate::{ + instructions::{Instruction, riscv::constants::UInt}, + scheme::constants::SEPTIC_EXTENSION_DEGREE, +}; + +// opcode circuit + mem init/final table + mem local chip: consistency RAMType::Register / Memory + +// mem local <-> global +// precompile <-> global +pub struct GlobalConfig { + addr: WitIn, + ram_type: WitIn, + value: UInt, + shard: WitIn, + clk: WitIn, + is_write: WitIn, + x: Vec, + y: Vec, + poseidon2: Poseidon2BabyBearConfig, +} + +impl GlobalConfig { + pub fn config(cb: &mut CircuitBuilder) -> Result { + let x = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| cb.create_witin(|| format!("x{}", i))) + .collect(); + let y = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| cb.create_witin(|| format!("y{}", i))) + .collect(); + let addr = cb.create_witin(|| "addr"); + let ram_type = cb.create_witin(|| "ram_type"); + let value = UInt::new(|| "value", cb)?; + let shard = cb.create_witin(|| "shard"); + let clk = cb.create_witin(|| "clk"); + let is_write = cb.create_witin(|| "is_write"); + + let rc = horizen_round_consts(); + let cb: &mut CircuitBuilder<'_, BabyBearExt4> = unsafe { std::mem::transmute(cb) }; + let hasher = Poseidon2BabyBearConfig::construct(cb, rc); + + let mut input = vec![]; + input.push(addr.expr()); + input.push(ram_type.expr()); + // memory expr has same number of limbs as register expr + input.extend(value.memory_expr()); + input.push(shard.expr()); + input.push(clk.expr()); + + for (input_expr, hasher_input) in input.into_iter().zip(hasher.inputs().into_iter()) { + // TODO: replace with cb.require_equal() + cb.require_zero(|| "poseidon2 input", input_expr - hasher_input); + } + + // TODO: enforce x = poseidon2([addr, ram_type, value[0], value[1], shard, clk]) + // TODO: enforce \sum_i (xi, yi) = ecc_sum + // TODO: output ecc_sum as public values + + // TODO: enforce is_write is boolean + // TODO: enforce y < p/2 if is_write = 1 + // enforce p/2 <= y < p if is_write = 0 + + Ok(GlobalConfig { + x, + y, + addr, + ram_type, + value, + shard, + clk, + is_write, + poseidon2: hasher, + }) + } +} + +// This chip is used to manage read/write into a global set +// shared among multiple shards +pub struct GlobalChip {} + +impl Instruction for GlobalChip { + type InstructionConfig = GlobalConfig; + + fn name() -> String { + "Global".to_string() + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + _param: &crate::structs::ProgramParams, + ) -> Result { + let config = GlobalConfig::config(cb)?; + + Ok(config) + } + + fn assign_instance( + config: &Self::InstructionConfig, + instance: &mut [::BaseField], + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + todo!() + } +} + +#[cfg(test)] +mod tests { + #[test] + fn test_global_chip() { + // Test the GlobalChip functionality here + } +} From 912747eb8799b09be467e920273a82b57ffbb270 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 13 Oct 2025 10:53:08 +0800 Subject: [PATCH 34/91] upgrade gkr-backend to v1.0.0-alpha.10 --- Cargo.lock | 479 +++++++++++++++++++++++++++++++++++++++++++++++++++-- Cargo.toml | 20 ++- 2 files changed, 480 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index df24100fd..9e0580cc2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,6 +104,76 @@ dependencies = [ "backtrace", ] +[[package]] +name = "ark-ff" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec847af850f44ad29048935519032c33da8aa03340876d351dfab5660d2966ba" +dependencies = [ + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std", + "derivative", + "digest", + "itertools 0.10.5", + "num-bigint", + "num-traits", + "paste", + "rustc_version", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ed4aa4fe255d0bc6d79373f7e31d2ea147bcf486cba1be5ba7ea85abdb92348" +dependencies = [ + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-ff-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7abe79b0e4288889c4574159ab790824d0033b9fdcb2a112a3182fac2e514565" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "ark-serialize" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" +dependencies = [ + "ark-std", + "digest", + "num-bigint", +] + +[[package]] +name = "ark-std" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94893f1e0c6eeab764ade8dc4c0db24caf4fe7cbbaafc0eba0a9030f447b5185" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + [[package]] name = "arrayvec" version = "0.7.6" @@ -189,6 +259,60 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "blake2b_simd" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06e903a20b159e944f91ec8499fe1e55651480c541ea0a584f5d967c49ad9d99" +dependencies = [ + "arrayref", + "arrayvec", + "constant_time_eq", +] + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bls12_381" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3c196a77437e7cc2fb515ce413a6401291578b5afc8ecb29a3c7ab957f05941" +dependencies = [ + "ff 0.12.1", + "group 0.12.1", + "pairing", + "rand_core", + "subtle", +] + [[package]] name = "bumpalo" version = "3.17.0" @@ -379,6 +503,7 @@ dependencies = [ "gkr_iop", "glob", "itertools 0.13.0", + "lazy_static", "mpcs", "multilinear_extensions", "ndarray", @@ -406,6 +531,7 @@ dependencies = [ "transcript", "whir", "witness", + "zkhash", ] [[package]] @@ -500,6 +626,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "cpp_demangle" version = "0.4.4" @@ -509,6 +641,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "criterion" version = "0.5.1" @@ -576,6 +717,16 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "csv" version = "1.3.1" @@ -678,6 +829,17 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "derive_builder" version = "0.20.2" @@ -730,6 +892,17 @@ dependencies = [ "unicode-xid", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + [[package]] name = "dirs-next" version = "2.0.0" @@ -822,10 +995,32 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "ff" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d013fc25338cc558c5c2cfbad646908fb23591e2404481826742b651c9af7160" +dependencies = [ + "bitvec", + "rand_core", + "subtle", +] + +[[package]] +name = "ff" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" +dependencies = [ + "bitvec", + "rand_core", + "subtle", +] + [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.6#3a9e040bdbdf0059ed432b9d8a93a29171200e83" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "once_cell", "p3", @@ -860,12 +1055,28 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "gcd" version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d758ba1b47b00caf47f24925c0074ecb20d6dfcffe7f6d53395c0465674841a" +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "generic_static" version = "0.2.0" @@ -961,6 +1172,29 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" +[[package]] +name = "group" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfbfb3a6cfbd390d5c9564ab283a0349b9b9fcd46a706c1eb10e0db70bfbac7" +dependencies = [ + "ff 0.12.1", + "memuse", + "rand_core", + "subtle", +] + +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff 0.13.1", + "rand_core", + "subtle", +] + [[package]] name = "half" version = "2.6.0" @@ -971,6 +1205,29 @@ dependencies = [ "crunchy", ] +[[package]] +name = "halo2" +version = "0.1.0-beta.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a23c779b38253fe1538102da44ad5bd5378495a61d2c4ee18d64eaa61ae5995" +dependencies = [ + "halo2_proofs", +] + +[[package]] +name = "halo2_proofs" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e925780549adee8364c7f2b685c753f6f3df23bde520c67416e93bf615933760" +dependencies = [ + "blake2b_simd", + "ff 0.12.1", + "group 0.12.1", + "pasta_curves 0.4.1", + "rand_core", + "rayon", +] + [[package]] name = "hashbrown" version = "0.15.3" @@ -989,6 +1246,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbd780fe5cc30f81464441920d82ac8740e2e46b29a6fad543ddd075229ce37e" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "hex-conservative" version = "0.2.1" @@ -1240,6 +1503,29 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jubjub" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a575df5f985fe1cd5b2b05664ff6accfc46559032b954529fd225a2168d27b0f" +dependencies = [ + "bitvec", + "bls12_381", + "ff 0.12.1", + "group 0.12.1", + "rand_core", + "subtle", +] + +[[package]] +name = "keccak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" +dependencies = [ + "cpufeatures", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -1361,6 +1647,12 @@ dependencies = [ "libc", ] +[[package]] +name = "memuse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d97bbf43eb4f088f8ca469930cde17fa036207c9a5e02ccc5107c4e8b17c964" + [[package]] name = "miniz_oxide" version = "0.8.8" @@ -1373,7 +1665,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.6#3a9e040bdbdf0059ed432b9d8a93a29171200e83" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "bincode", "clap", @@ -1397,7 +1689,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.6#3a9e040bdbdf0059ed432b9d8a93a29171200e83" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "either", "ff_ext", @@ -1602,8 +1894,9 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.6#3a9e040bdbdf0059ed432b9d8a93a29171200e83" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ + "p3-air", "p3-baby-bear", "p3-challenger", "p3-commit", @@ -1615,12 +1908,23 @@ dependencies = [ "p3-maybe-rayon", "p3-mds", "p3-merkle-tree", + "p3-monty-31", "p3-poseidon", "p3-poseidon2", + "p3-poseidon2-air", "p3-symmetric", "p3-util", ] +[[package]] +name = "p3-air" +version = "0.1.0" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" +dependencies = [ + "p3-field", + "p3-matrix", +] + [[package]] name = "p3-baby-bear" version = "0.1.0" @@ -1836,6 +2140,22 @@ dependencies = [ "rand", ] +[[package]] +name = "p3-poseidon2-air" +version = "0.1.0" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" +dependencies = [ + "p3-air", + "p3-field", + "p3-matrix", + "p3-maybe-rayon", + "p3-poseidon2", + "p3-util", + "rand", + "tikv-jemallocator", + "tracing", +] + [[package]] name = "p3-symmetric" version = "0.1.0" @@ -1854,6 +2174,15 @@ dependencies = [ "serde", ] +[[package]] +name = "pairing" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135590d8bdba2b31346f9cd1fb2a912329f5135e832a4f422942eb6ead8b6b3b" +dependencies = [ + "group 0.12.1", +] + [[package]] name = "parking_lot" version = "0.12.3" @@ -1883,6 +2212,36 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "487f2ccd1e17ce8c1bfab3a65c89525af41cfad4c8659021a1e9a2aacd73b89b" +[[package]] +name = "pasta_curves" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc65faf8e7313b4b1fbaa9f7ca917a0eed499a9663be71477f87993604341d8" +dependencies = [ + "blake2b_simd", + "ff 0.12.1", + "group 0.12.1", + "lazy_static", + "rand", + "static_assertions", + "subtle", +] + +[[package]] +name = "pasta_curves" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e57598f73cc7e1b2ac63c79c517b31a0877cd7c402cdcaa311b5208de7a095" +dependencies = [ + "blake2b_simd", + "ff 0.13.1", + "group 0.13.0", + "lazy_static", + "rand", + "static_assertions", + "subtle", +] + [[package]] name = "paste" version = "1.0.15" @@ -1953,7 +2312,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.6#3a9e040bdbdf0059ed432b9d8a93a29171200e83" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "ff_ext", "p3", @@ -2100,6 +2459,12 @@ version = "5.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rancor" version = "0.1.0" @@ -2444,6 +2809,27 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -2483,6 +2869,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "str_stack" version = "0.1.0" @@ -2542,7 +2934,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.6#3a9e040bdbdf0059ed432b9d8a93a29171200e83" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "either", "ff_ext", @@ -2560,7 +2952,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.6#3a9e040bdbdf0059ed432b9d8a93a29171200e83" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "itertools 0.13.0", "p3", @@ -2626,6 +3018,12 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "tempfile" version = "3.19.1" @@ -2901,7 +3299,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.6#3a9e040bdbdf0059ed432b9d8a93a29171200e83" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -2919,6 +3317,12 @@ dependencies = [ "strength_reduce", ] +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unarray" version = "0.1.4" @@ -3149,7 +3553,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.6#3a9e040bdbdf0059ed432b9d8a93a29171200e83" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "bincode", "clap", @@ -3294,7 +3698,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.6#3a9e040bdbdf0059ed432b9d8a93a29171200e83" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "ff_ext", "multilinear_extensions", @@ -3316,6 +3720,15 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "yoke" version = "0.7.5" @@ -3401,6 +3814,26 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "zerovec" version = "0.10.4" @@ -3422,3 +3855,29 @@ dependencies = [ "quote", "syn 2.0.101", ] + +[[package]] +name = "zkhash" +version = "0.2.0" +source = "git+https://github.com/HorizenLabs/poseidon2.git?rev=bb476b9#bb476b9ca38198cf5092487283c8b8c5d4317c4e" +dependencies = [ + "ark-ff", + "ark-std", + "bitvec", + "blake2", + "bls12_381", + "byteorder", + "cfg-if", + "group 0.12.1", + "group 0.13.0", + "halo2", + "hex", + "jubjub", + "lazy_static", + "pasta_curves 0.5.1", + "rand", + "serde", + "sha2", + "sha3", + "subtle", +] diff --git a/Cargo.toml b/Cargo.toml index 87da2fd72..47f2e3a82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,15 +20,15 @@ repository = "https://github.com/scroll-tech/ceno" version = "0.1.0" [workspace.dependencies] -ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", rev = "v1.0.0-alpha.6" } -mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", rev = "v1.0.0-alpha.6" } -multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", rev = "v1.0.0-alpha.6" } -p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", rev = "v1.0.0-alpha.6" } -poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", rev = "v1.0.0-alpha.6" } -sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", rev = "v1.0.0-alpha.6" } -transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", rev = "v1.0.0-alpha.6" } -whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", rev = "v1.0.0-alpha.6" } -witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", rev = "v1.0.0-alpha.6" } +ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", rev = "v1.0.0-alpha.10" } +mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", rev = "v1.0.0-alpha.10" } +multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", rev = "v1.0.0-alpha.10" } +p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", rev = "v1.0.0-alpha.10" } +poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", rev = "v1.0.0-alpha.10" } +sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", rev = "v1.0.0-alpha.10" } +transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", rev = "v1.0.0-alpha.10" } +whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", rev = "v1.0.0-alpha.10" } +witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", rev = "v1.0.0-alpha.10" } anyhow = { version = "1.0", default-features = false } bincode = "1" @@ -63,6 +63,8 @@ tracing = { version = "0.1", features = [ tracing-forest = { version = "0.1.6" } tracing-subscriber = { version = "0.3", features = ["env-filter"] } uint = "0.8" +zkhash = { git = "https://github.com/HorizenLabs/poseidon2.git", rev = "bb476b9" } +lazy_static = "1.5.0" ceno_gpu = { path = "utils/cuda_hal", package = "cuda_hal" } From 098afb9906028ff84149404863cb5623d08925d3 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 13 Oct 2025 13:03:12 +0800 Subject: [PATCH 35/91] chore --- Cargo.lock | 118 ++++++++++++++++------------------------------------- Cargo.toml | 2 +- 2 files changed, 36 insertions(+), 84 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5f36c930d..5691f8624 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -872,7 +872,7 @@ dependencies = [ "ceno_zkvm", "clap", "console", - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "ff_ext", "get_dir", "gkr_iop", "mpcs", @@ -951,9 +951,9 @@ dependencies = [ "anyhow", "ceno_rt", "elf", - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "ff_ext", "itertools 0.13.0", - "multilinear_extensions 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "multilinear_extensions", "num-derive", "num-traits", "rrs-succinct", @@ -1011,7 +1011,7 @@ dependencies = [ "cudarc", "derive", "either", - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "ff_ext", "generic-array 1.2.0", "generic_static", "gkr_iop", @@ -1019,12 +1019,12 @@ dependencies = [ "itertools 0.13.0", "lazy_static", "mpcs", - "multilinear_extensions 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "multilinear_extensions", "ndarray", "num", "num-bigint", "once_cell", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "p3", "parse-size", "pprof2", "prettytable-rs", @@ -1352,11 +1352,11 @@ name = "cuda_hal" version = "0.1.0" dependencies = [ "anyhow", - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "ff_ext", "itertools 0.13.0", "mpcs", - "multilinear_extensions 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "multilinear_extensions", + "p3", "rand 0.8.5", "rayon", "sumcheck", @@ -1882,18 +1882,7 @@ version = "0.1.0" source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "once_cell", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", - "rand_core 0.6.4", - "serde", -] - -[[package]] -name = "ff_ext" -version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" -dependencies = [ - "once_cell", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9)", + "p3", "rand_core 0.6.4", "serde", ] @@ -2042,12 +2031,12 @@ dependencies = [ "cuda_hal", "cudarc", "either", - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "ff_ext", "itertools 0.13.0", "mpcs", - "multilinear_extensions 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "multilinear_extensions", "once_cell", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "p3", "rand 0.8.5", "rayon", "serde", @@ -2738,11 +2727,11 @@ source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10 dependencies = [ "bincode", "clap", - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "ff_ext", "itertools 0.13.0", - "multilinear_extensions 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "multilinear_extensions", "num-integer", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "p3", "rand 0.8.5", "rand_chacha 0.3.1", "rayon", @@ -2761,24 +2750,9 @@ version = "0.1.0" source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "either", - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", - "itertools 0.13.0", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", - "rand 0.8.5", - "rayon", - "serde", - "tracing", -] - -[[package]] -name = "multilinear_extensions" -version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" -dependencies = [ - "either", - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9)", + "ff_ext", "itertools 0.13.0", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9)", + "p3", "rand 0.8.5", "rayon", "serde", @@ -3116,28 +3090,6 @@ dependencies = [ "p3-util", ] -[[package]] -name = "p3" -version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" -dependencies = [ - "p3-baby-bear", - "p3-challenger", - "p3-commit", - "p3-dft", - "p3-field", - "p3-fri", - "p3-goldilocks", - "p3-matrix", - "p3-maybe-rayon", - "p3-mds", - "p3-merkle-tree", - "p3-poseidon", - "p3-poseidon2", - "p3-symmetric", - "p3-util", -] - [[package]] name = "p3-air" version = "0.1.0" @@ -3594,8 +3546,8 @@ name = "poseidon" version = "0.1.0" source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "ff_ext", + "p3", "serde", ] @@ -4532,16 +4484,16 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9#44e4aa4456b084481a9aef1b7ee5f829221d5a0d" +source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "cfg-if", "dashu", "elliptic-curve", - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9)", + "ff_ext", "generic-array 1.2.0", "itertools 0.13.0", "k256", - "multilinear_extensions 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.9)", + "multilinear_extensions", "num", "p256", "p3-field", @@ -4641,10 +4593,10 @@ version = "0.1.0" source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "either", - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "ff_ext", "itertools 0.13.0", - "multilinear_extensions 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "multilinear_extensions", + "p3", "rayon", "serde", "sumcheck_macro", @@ -4659,7 +4611,7 @@ version = "0.1.0" source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ "itertools 0.13.0", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "p3", "proc-macro2", "quote", "rand 0.8.5", @@ -5053,9 +5005,9 @@ name = "transcript" version = "0.1.0" source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "ff_ext", "itertools 0.13.0", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "p3", "poseidon", ] @@ -5328,10 +5280,10 @@ dependencies = [ "bincode", "clap", "derive_more 1.0.0", - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "ff_ext", "itertools 0.14.0", - "multilinear_extensions 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "multilinear_extensions", + "p3", "rand 0.8.5", "rand_chacha 0.3.1", "rayon", @@ -5612,9 +5564,9 @@ name = "witness" version = "0.1.0" source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" dependencies = [ - "ff_ext 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", - "multilinear_extensions 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", - "p3 0.1.0 (git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10)", + "ff_ext", + "multilinear_extensions", + "p3", "rand 0.8.5", "rayon", "tracing", diff --git a/Cargo.toml b/Cargo.toml index c4f3d6f7e..cd39987af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", rev = "v1.0.0-alpha.10" } p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", rev = "v1.0.0-alpha.10" } poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", rev = "v1.0.0-alpha.10" } -sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", rev = "v1.0.0-alpha.9" } +sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", rev = "v1.0.0-alpha.10" } sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", rev = "v1.0.0-alpha.10" } transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", rev = "v1.0.0-alpha.10" } whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", rev = "v1.0.0-alpha.10" } From d0d2471de3eb8c9a3d2cc2fd41cc4218bca6dd94 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 13 Oct 2025 17:11:24 +0800 Subject: [PATCH 36/91] separate init/final/ramchip --- ceno_zkvm/src/chip_handler/general.rs | 17 +- ceno_zkvm/src/e2e.rs | 90 ++-- ceno_zkvm/src/instructions/riscv/constants.rs | 2 + .../riscv/ecall/weierstrass_add.rs | 8 +- ceno_zkvm/src/instructions/riscv/r_insn.rs | 9 +- ceno_zkvm/src/tables/ram/ram_impl.rs | 437 +++++++++++++++++- gkr_iop/src/circuit_builder.rs | 74 ++- 7 files changed, 574 insertions(+), 63 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index e1ace19d0..513c4d98b 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -4,8 +4,8 @@ use gkr_iop::{error::CircuitBuilderError, tables::LookupTable}; use crate::{ circuit_builder::CircuitBuilder, instructions::riscv::constants::{ - END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX, - UINT_LIMBS, + END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, + MEM_BUS_WITH_READ_IDX, MEM_BUS_WITH_WRITE_IDX, PUBLIC_IO_IDX, UINT_LIMBS, }, tables::InsnRecord, }; @@ -22,6 +22,9 @@ pub trait PublicIOQuery { fn query_end_pc(&mut self) -> Result; fn query_end_cycle(&mut self) -> Result; fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; + + fn query_mem_bus_with_read(&mut self) -> Result; + fn query_mem_bus_with_write(&mut self) -> Result; } impl<'a, E: ExtensionField> InstFetch for CircuitBuilder<'a, E> { @@ -60,6 +63,16 @@ impl<'a, E: ExtensionField> PublicIOQuery for CircuitBuilder<'a, E> { self.cs.query_instance(|| "end_cycle", END_CYCLE_IDX) } + fn query_mem_bus_with_read(&mut self) -> Result { + self.cs + .query_instance(|| "mem_bus_with_read", MEM_BUS_WITH_READ_IDX) + } + + fn query_mem_bus_with_write(&mut self) -> Result { + self.cs + .query_instance(|| "mem_bus_with_write", MEM_BUS_WITH_WRITE_IDX) + } + fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError> { Ok([ self.cs.query_instance(|| "public_io_low", PUBLIC_IO_IDX)?, diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 89720025e..786ac2d2a 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -101,19 +101,13 @@ pub struct EmulationResult<'a> { pub shard_ctx: ShardContext<'a>, } -pub enum RAMRecordType { - Read, - Write, -} - pub struct RAMRecord { - ram_type: RAMRecordType, - id: u64, - addr: WordAddr, - prev_cycle: Cycle, - cycle: Cycle, - prev_value: Option, - value: Word, + pub id: u64, + pub addr: WordAddr, + pub prev_cycle: Cycle, + pub cycle: Cycle, + pub prev_value: Option, + pub value: Word, } pub struct ShardContext<'a> { @@ -121,7 +115,11 @@ pub struct ShardContext<'a> { num_shards: usize, max_cycle: Cycle, addr_future_accesses: Cow<'a, HashMap<(WordAddr, Cycle), Cycle>>, - thread_based_record_storage: Either< + read_thread_based_record_storage: Either< + Vec<[BTreeMap; mem::variant_count::()]>, + &'a mut [BTreeMap; mem::variant_count::()], + >, + write_thread_based_record_storage: Either< Vec<[BTreeMap; mem::variant_count::()]>, &'a mut [BTreeMap; mem::variant_count::()], >, @@ -131,16 +129,23 @@ pub struct ShardContext<'a> { impl<'a> Default for ShardContext<'a> { fn default() -> Self { let max_threads = max_usable_threads(); - let thread_based_record_storage = (0..max_threads) - .into_par_iter() - .map(|_| std::array::from_fn(|_| BTreeMap::new())) - .collect::>(); Self { shard_id: 0, num_shards: 1, max_cycle: Cycle::default(), addr_future_accesses: Cow::Owned(HashMap::new()), - thread_based_record_storage: Either::Left(thread_based_record_storage), + read_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| std::array::from_fn(|_| BTreeMap::new())) + .collect::>(), + ), + write_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| std::array::from_fn(|_| BTreeMap::new())) + .collect::>(), + ), cur_shard_cycle_range: 0..usize::MAX, } } @@ -155,12 +160,6 @@ impl<'a> ShardContext<'a> { ) -> Self { let max_threads = max_usable_threads(); // let max_record_per_thread = max_insts.div_ceil(max_threads as u64); - // TODO pre-reserve vector - let thread_based_record_storage = (0..max_threads) - .into_par_iter() - .map(|_| std::array::from_fn(|_| BTreeMap::new())) - .collect::>(); - let expected_inst_per_shard = executed_instructions.div_ceil(num_shards) as usize; let max_cycle = (executed_instructions + 1) * 4; // cycle start from 4 let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * 4).max(4) @@ -171,25 +170,46 @@ impl<'a> ShardContext<'a> { num_shards, max_cycle: max_cycle as Cycle, addr_future_accesses: Cow::Owned(addr_future_accesses), - thread_based_record_storage: Either::Left(thread_based_record_storage), + // TODO with_capacity optimisation + read_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| std::array::from_fn(|_| BTreeMap::new())) + .collect::>(), + ), + // TODO with_capacity optimisation + write_thread_based_record_storage: Either::Left( + (0..max_threads) + .into_par_iter() + .map(|_| std::array::from_fn(|_| BTreeMap::new())) + .collect::>(), + ), cur_shard_cycle_range, } } pub fn get_forked(&mut self) -> Vec> { - match &mut self.thread_based_record_storage { - Either::Left(thread_based_record_storage) => thread_based_record_storage + match ( + &mut self.read_thread_based_record_storage, + &mut self.write_thread_based_record_storage, + ) { + ( + Either::Left(read_thread_based_record_storage), + Either::Left(write_thread_based_record_storage), + ) => read_thread_based_record_storage .iter_mut() - .map(|v| ShardContext { + .zip(write_thread_based_record_storage.iter_mut()) + .map(|(read, write)| ShardContext { shard_id: self.shard_id, num_shards: self.num_shards, max_cycle: self.max_cycle, addr_future_accesses: Cow::Borrowed(self.addr_future_accesses.as_ref()), - thread_based_record_storage: Either::Right(v), + read_thread_based_record_storage: Either::Right(read), + write_thread_based_record_storage: Either::Right(write), cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), }) .collect_vec(), - Either::Right(_) => panic!("invalid type"), + _ => panic!("invalid type"), } } @@ -206,17 +226,16 @@ impl<'a> ShardContext<'a> { ) { // check read from external mem bus if prev_cycle < self.cur_shard_cycle_range.start as Cycle - && cycle >= self.cur_shard_cycle_range.start as Cycle + && self.cur_shard_cycle_range.contains(&(cycle as usize)) { let ram_record = self - .thread_based_record_storage + .read_thread_based_record_storage .as_mut() .right() .expect("illegal type"); ram_record[ram_type as usize].insert( addr, RAMRecord { - ram_type: RAMRecordType::Read, id, addr, prev_cycle, @@ -229,17 +248,16 @@ impl<'a> ShardContext<'a> { // check write to external mem bus if let Some(future_touch_cycle) = self.addr_future_accesses.get(&(addr, cycle)) { if *future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle - && cycle < self.cur_shard_cycle_range.end as Cycle + && self.cur_shard_cycle_range.contains(&(cycle as usize)) { let ram_record = self - .thread_based_record_storage + .write_thread_based_record_storage .as_mut() .right() .expect("illegal type"); ram_record[ram_type as usize].insert( addr, RAMRecord { - ram_type: RAMRecordType::Write, id, addr, prev_cycle, diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 1992f4fa3..f471528a6 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -10,6 +10,8 @@ pub const INIT_CYCLE_IDX: usize = 3; pub const END_PC_IDX: usize = 4; pub const END_CYCLE_IDX: usize = 5; pub const PUBLIC_IO_IDX: usize = 6; +pub const MEM_BUS_WITH_READ_IDX: usize = 7; +pub const MEM_BUS_WITH_WRITE_IDX: usize = 8; pub const LIMB_BITS: usize = 16; pub const LIMB_MASK: u32 = 0xFFFF; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index bd2ca016a..03a27a47a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -314,7 +314,13 @@ impl Instruction &ops.reg_ops[1], )?; for (writer, op) in config.mem_rw.iter().zip_eq(&ops.mem_ops) { - writer.assign_op(instance, &mut shard_ctx, &mut lk_multiplicity, step.cycle(), op)?; + writer.assign_op( + instance, + &mut shard_ctx, + &mut lk_multiplicity, + step.cycle(), + op, + )?; } // fetch lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index a44e6757d..1d559a941 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -69,9 +69,12 @@ impl RInstructionConfig { step: &StepRecord, ) -> Result<(), ZKVMError> { self.vm_state.assign_instance(instance, step)?; - self.rs1.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; - self.rs2.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; - self.rd.assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rs2 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + self.rd + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index f92dc37cc..0e16d4500 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -1,15 +1,22 @@ -use std::{marker::PhantomData, sync::Arc}; - use ceno_emul::{Addr, Cycle, WORD_SIZE}; +use either::Either; use ff_ext::{ExtensionField, SmallField}; use gkr_iop::error::CircuitBuilderError; use itertools::Itertools; -use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use witness::{InstancePaddingStrategy, RowMajorMatrix, set_fixed_val, set_val}; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; +use std::{marker::PhantomData, ops::Neg, sync::Arc}; +use witness::{ + InstancePaddingStrategy, RowMajorMatrix, next_pow2_instance_padding, set_fixed_val, set_val, +}; +use super::{ + MemInitRecord, + ram_circuit::{DynVolatileRamTable, MemFinalRecord, NonVolatileTable}, +}; use crate::{ chip_handler::general::PublicIOQuery, circuit_builder::{CircuitBuilder, SetTableSpec}, + e2e::RAMRecord, instructions::riscv::constants::{LIMB_BITS, LIMB_MASK}, structs::ProgramParams, }; @@ -17,11 +24,8 @@ use ff_ext::FieldInto; use multilinear_extensions::{ Expression, Fixed, StructuralWitIn, StructuralWitInType, ToExpr, WitIn, }; - -use super::{ - MemInitRecord, - ram_circuit::{DynVolatileRamTable, MemFinalRecord, NonVolatileTable}, -}; +use p3::field::FieldAlgebra; +use rayon::prelude::{ParallelSlice, ParallelSliceMut}; /// define a non-volatile memory with init value #[derive(Clone, Debug)] @@ -443,6 +447,421 @@ impl DynVolatileRamTableConfig } } +/// volatile with all init value as 0 +/// dynamic address as witin, relied on augment of knowledge to prove address form +#[derive(Clone, Debug)] +pub struct DynVolatileRamTableInitConfig { + addr: StructuralWitIn, + + phantom: PhantomData, + params: ProgramParams, +} + +impl DynVolatileRamTableInitConfig { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + let max_len = DVRAM::max_len(params); + let addr = cb.create_structural_witin( + || "addr", + StructuralWitInType::EqualDistanceSequence { + max_len, + offset: DVRAM::offset_addr(params), + multi_factor: WORD_SIZE, + descending: DVRAM::DESCENDING, + }, + ); + + assert!(DVRAM::ZERO_INIT); + + let init_expr = vec![Expression::ZERO; DVRAM::V_LIMBS]; + + let init_table = [ + vec![(DVRAM::RAM_TYPE as usize).into()], + vec![addr.expr()], + init_expr, + vec![Expression::ZERO], // Initial cycle. + ] + .concat(); + + cb.w_table_record( + || "init_table", + DVRAM::RAM_TYPE, + SetTableSpec { + len: None, + structural_witins: vec![addr], + }, + init_table, + )?; + + Ok(Self { + addr, + phantom: PhantomData, + params: params.clone(), + }) + } + + /// TODO consider taking RowMajorMatrix as argument to save allocations. + pub fn assign_instances( + &self, + num_witin: usize, + num_structural_witin: usize, + final_mem: &[MemFinalRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + assert!(final_mem.len() <= DVRAM::max_len(&self.params)); + assert!(DVRAM::max_len(&self.params).is_power_of_two()); + + let params = self.params.clone(); + let addr_id = self.addr.id as u64; + let addr_padding_fn = move |row: u64, col: u64| { + assert_eq!(col, addr_id); + DVRAM::addr(¶ms, row as usize) as u64 + }; + + let mut witness = + RowMajorMatrix::::new(final_mem.len(), num_witin, InstancePaddingStrategy::Default); + let mut structural_witness = RowMajorMatrix::::new( + final_mem.len(), + num_structural_witin, + InstancePaddingStrategy::Custom(Arc::new(addr_padding_fn)), + ); + + witness + .par_rows_mut() + .zip(structural_witness.par_rows_mut()) + .zip(final_mem) + .enumerate() + .for_each(|(i, ((row, structural_row), rec))| { + assert_eq!( + rec.addr, + DVRAM::addr(&self.params, i), + "rec.addr {:x} != expected {:x}", + rec.addr, + DVRAM::addr(&self.params, i), + ); + set_val!(structural_row, self.addr, rec.addr as u64); + }); + + structural_witness.padding_by_strategy(); + Ok([witness, structural_witness]) + } +} + +/// volatile with all init value as 0 +/// dynamic address as witin, relied on augment of knowledge to prove address form +#[derive(Clone, Debug)] +pub struct DynVolatileRamTableFinalConfig { + // addr is subset and could be any form + // TODO check soundness issue + addr_subset: WitIn, + + final_v: Vec, + final_cycle: WitIn, + + phantom: PhantomData, + params: ProgramParams, +} + +impl DynVolatileRamTableFinalConfig { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + let addr_subset = cb.create_witin(|| format!("addr_subset")); + + let final_v = (0..DVRAM::V_LIMBS) + .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) + .collect::>(); + let final_cycle = cb.create_witin(|| "final_cycle"); + + let final_expr = final_v.iter().map(|v| v.expr()).collect_vec(); + let final_table = [ + // a v t + vec![(DVRAM::RAM_TYPE as usize).into()], + vec![addr_subset.expr()], + final_expr, + vec![final_cycle.expr()], + ] + .concat(); + cb.r_table_record( + || "final_table", + DVRAM::RAM_TYPE, + SetTableSpec { + len: None, + structural_witins: vec![], + }, + final_table, + )?; + + Ok(Self { + addr_subset, + final_v, + final_cycle, + phantom: PhantomData, + params: params.clone(), + }) + } + + /// TODO consider taking RowMajorMatrix as argument to save allocations. + pub fn assign_instances( + &self, + num_witin: usize, + num_structural_witin: usize, + final_mem: &[MemFinalRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + assert_eq!(num_structural_witin, 0); + assert!(final_mem.len() <= DVRAM::max_len(&self.params)); + assert!(DVRAM::max_len(&self.params).is_power_of_two()); + + let mut witness = + RowMajorMatrix::::new(final_mem.len(), num_witin, InstancePaddingStrategy::Default); + + witness + .par_rows_mut() + .zip(final_mem) + .enumerate() + .for_each(|(i, (row, rec))| { + assert_eq!( + rec.addr, + DVRAM::addr(&self.params, i), + "rec.addr {:x} != expected {:x}", + rec.addr, + DVRAM::addr(&self.params, i), + ); + + if self.final_v.len() == 1 { + // Assign value directly. + set_val!(row, self.final_v[0], rec.value as u64); + } else { + // Assign value limbs. + self.final_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }); + } + set_val!(row, self.final_cycle, rec.cycle); + + set_val!(row, self.addr_subset, rec.addr as u64); + }); + + Ok([witness, RowMajorMatrix::empty()]) + } +} + +/// volatile with all init value as 0 +/// dynamic address as witin, relied on augment of knowledge to prove address form +#[derive(Clone, Debug)] +pub struct DynVolatileRAMBusConfig { + addr_subset: WitIn, + + sel_read: StructuralWitIn, + sel_write: StructuralWitIn, + local_write_v: Vec, + local_read_v: Vec, + local_read_cycle: WitIn, + + phantom: PhantomData, + params: ProgramParams, +} + +impl DynVolatileRAMBusConfig { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + let one = Expression::Constant(Either::Left(E::BaseField::ONE)); + let mem_bus_with_read = cb.query_mem_bus_with_read()?; + let mem_bus_with_write = cb.query_mem_bus_with_write()?; + let addr_subset = cb.create_witin(|| "addr_subset"); + // TODO add new selector to support sel_rw + let sel_read = cb.create_structural_witin( + || "sel_read", + StructuralWitInType::EqualDistanceSequence { + max_len: 0, + offset: DVRAM::offset_addr(params), + multi_factor: WORD_SIZE, + descending: DVRAM::DESCENDING, + }, + ); + let sel_write = cb.create_structural_witin( + || "sel_write", + StructuralWitInType::EqualDistanceSequence { + max_len: 0, + offset: DVRAM::offset_addr(params), + multi_factor: WORD_SIZE, + descending: DVRAM::DESCENDING, + }, + ); + + // local write + let local_write_v = (0..DVRAM::V_LIMBS) + .map(|i| cb.create_witin(|| format!("local_write_v_limb_{i}"))) + .collect::>(); + let local_write_v_expr = local_write_v.iter().map(|v| v.expr()).collect_vec(); + + // local read + let local_read_v = (0..DVRAM::V_LIMBS) + .map(|i| cb.create_witin(|| format!("local_read_v_limb_{i}"))) + .collect::>(); + let local_read_v_expr: Vec> = + local_read_v.iter().map(|v| v.expr()).collect_vec(); + let local_read_cycle = cb.create_witin(|| "local_read_cycle"); + + // TODO global write + // TODO global read + + // constraints + // read from global, write to local + // W_{local} = mem_bus_with_read * (sel_read * local_write_record + (1 - sel_read) * ONE) + (1 - mem_bus_with_read) * ONE + let local_raw_write_record = [ + vec![(DVRAM::RAM_TYPE as usize).into()], + vec![addr_subset.expr()], + local_write_v_expr.clone(), + vec![Expression::ZERO], // mem bus local init cycle always 0. + ] + .concat(); + let local_write_record = cb.rlc_chip_record(local_raw_write_record.clone()); + let local_write = mem_bus_with_read.expr() + * (sel_read.expr() * local_write_record + (one.clone() - sel_read.expr()).expr()) + + (one.clone() - mem_bus_with_read.expr()); + cb.w_table_rlc_record( + || "local_write_record", + DVRAM::RAM_TYPE, + SetTableSpec { + len: None, + structural_witins: vec![sel_read], + }, + local_raw_write_record, + local_write, + )?; + // TODO R_{global} = mem_bus_with_read * (sel_read * global_read + (1-sel_read) * EC_INFINITY) + (1 - mem_bus_with_read) * EC_INFINITY + + // write to global, read from local + // R_{local} = mem_bus_with_write * (sel_write * local_read_record + (1 - sel_write) * ONE) + (1 - mem_bus_with_write) * ONE + let local_raw_read_record = [ + vec![(DVRAM::RAM_TYPE as usize).into()], + vec![addr_subset.expr()], + local_read_v_expr.clone(), + vec![local_read_cycle.expr()], + ] + .concat(); + let local_read_record = cb.rlc_chip_record(local_raw_read_record.clone()); + let local_read: Expression = mem_bus_with_write.expr() + * (sel_write.expr() * local_read_record + (one.clone() - sel_write.expr())) + + (one.clone() - mem_bus_with_write.expr()); + cb.r_table_rlc_record( + || "local_read_record", + DVRAM::RAM_TYPE, + SetTableSpec { + len: None, + structural_witins: vec![sel_write], + }, + local_raw_read_record, + local_read, + )?; + // TODO W_{local} = mem_bus_with_write * (sel_write * global_write + (1 - sel_write) * EC_INFINITY) + (1 - mem_bus_with_write) * EC_INFINITY + + Ok(Self { + addr_subset, + sel_write, + sel_read, + local_write_v, + local_read_v, + local_read_cycle, + phantom: PhantomData, + params: params.clone(), + }) + } + + /// TODO consider taking RowMajorMatrix as argument to save allocations. + pub fn assign_instances( + &self, + num_witin: usize, + num_structural_witin: usize, + global_read_mem: &[RAMRecord], + global_write_mem: &[RAMRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + assert!(global_read_mem.len() <= DVRAM::max_len(&self.params)); + assert!(DVRAM::max_len(&self.params).is_power_of_two()); + let witness_length = { + let max_len = global_read_mem.len().max(global_write_mem.len()); + // first half write, second half read + next_pow2_instance_padding(max_len) * 2 + }; + + let mut witness = + RowMajorMatrix::::new(witness_length, num_witin, InstancePaddingStrategy::Default); + let witness_mid = witness.values.len() / 2; + let mut structural_witness = RowMajorMatrix::::new( + witness_length, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + let (witness_write, witness_read) = witness.values.split_at_mut(witness_mid); + let structural_witness_mid = structural_witness.values.len() / 2; + let (structural_witness_write, structural_witness_read) = structural_witness + .values + .split_at_mut(structural_witness_mid); + + rayon::join( + // global write, local read + || { + witness_write + .par_chunks_mut(num_witin) + .zip(structural_witness_write.par_chunks_mut(num_structural_witin)) + .zip(global_write_mem) + .enumerate() + .for_each(|(i, ((row, structural_row), rec))| { + if self.local_read_v.len() == 1 { + // Assign value directly. + set_val!(row, self.local_read_v[0], rec.value as u64); + } else { + // Assign value limbs. + self.local_read_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }); + } + set_val!(row, self.local_read_cycle, rec.cycle); + + set_val!(row, self.addr_subset, rec.addr.baddr().0 as u64); + set_val!(structural_row, self.sel_write, 1u64); + + // TODO assign W_{global} + }); + }, + // global read, local write + || { + witness_read + .par_chunks_mut(num_witin) + .zip(structural_witness_read.par_chunks_mut(num_structural_witin)) + .zip(global_read_mem) + .enumerate() + .for_each(|(i, ((row, structural_row), rec))| { + if self.local_write_v.len() == 1 { + // Assign value directly. + set_val!(row, self.local_write_v[0], rec.value as u64); + } else { + // Assign value limbs. + self.local_write_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }); + } + set_val!(row, self.addr_subset, rec.addr.baddr().0 as u64); + set_val!(structural_row, self.sel_read, 1u64); + + // TODO assign R_{global} + }); + }, + ); + + structural_witness.padding_by_strategy(); + Ok([witness, structural_witness]) + } +} + #[cfg(test)] mod tests { use std::iter::successors; diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index e4129bfe8..2da0271ff 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -329,12 +329,21 @@ impl ConstraintSystem { N: FnOnce() -> NR, { let rlc_record = self.rlc_chip_record(record.clone()); - assert_eq!( - rlc_record.degree(), - 1, - "rlc record degree {} != 1", - rlc_record.degree() - ); + self.r_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record) + } + + pub fn r_table_rlc_record( + &mut self, + name_fn: N, + ram_type: RAMType, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { self.r_table_expressions.push(SetTableExpression { expr: rlc_record, table_spec, @@ -358,12 +367,21 @@ impl ConstraintSystem { N: FnOnce() -> NR, { let rlc_record = self.rlc_chip_record(record.clone()); - assert_eq!( - rlc_record.degree(), - 1, - "rlc record degree {} != 1", - rlc_record.degree() - ); + self.w_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record) + } + + pub fn w_table_rlc_record( + &mut self, + name_fn: N, + ram_type: RAMType, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { self.w_table_expressions.push(SetTableExpression { expr: rlc_record, table_spec, @@ -579,6 +597,22 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { .r_table_record(name_fn, ram_type, table_spec, record) } + pub fn r_table_rlc_record( + &mut self, + name_fn: N, + ram_type: RAMType, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .r_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record) + } + pub fn w_table_record( &mut self, name_fn: N, @@ -594,6 +628,22 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { .w_table_record(name_fn, ram_type, table_spec, record) } + pub fn w_table_rlc_record( + &mut self, + name_fn: N, + ram_type: RAMType, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .w_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record) + } + pub fn read_record( &mut self, name_fn: N, From 6e2f8d696136296ad8103b158516b498ddf0d90a Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 13 Oct 2025 17:38:41 +0800 Subject: [PATCH 37/91] make Instruction stateful --- ceno_zkvm/src/instructions.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index d546eaa04..af83d0695 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -33,15 +33,17 @@ pub trait Instruction { /// construct circuit and manipulate circuit builder, then return the respective config fn construct_circuit( + &self, circuit_builder: &mut CircuitBuilder, param: &ProgramParams, ) -> Result; fn build_gkr_iop_circuit( + &self, cb: &mut CircuitBuilder, param: &ProgramParams, ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { - let config = Self::construct_circuit(cb, param)?; + let config = self.construct_circuit(cb, param)?; let w_len = cb.cs.w_expressions.len(); let r_len = cb.cs.r_expressions.len(); let lk_len = cb.cs.lk_expressions.len(); From a0c7b91a2f3a0576a538732ad4051903df611fff Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 13 Oct 2025 19:18:30 +0800 Subject: [PATCH 38/91] local ram circuit --- ceno_zkvm/src/tables/ram/ram_impl.rs | 47 ++++++++++++++++++---------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 0e16d4500..87968b97b 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -555,6 +555,7 @@ pub struct DynVolatileRamTableFinalConfig, final_cycle: WitIn, @@ -568,15 +569,26 @@ impl DynVolatileRamTableFinalC cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { - let addr_subset = cb.create_witin(|| format!("addr_subset")); + let addr_subset = cb.create_witin(|| "addr_subset"); + + let sel = cb.create_structural_witin( + || "sel", + StructuralWitInType::EqualDistanceSequence { + max_len: 0, + offset: DVRAM::offset_addr(params), + multi_factor: WORD_SIZE, + descending: DVRAM::DESCENDING, + }, + ); let final_v = (0..DVRAM::V_LIMBS) .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) .collect::>(); let final_cycle = cb.create_witin(|| "final_cycle"); + // R_{local} = sel * rlc_final_table + (1 - sel) * ONE let final_expr = final_v.iter().map(|v| v.expr()).collect_vec(); - let final_table = [ + let raw_final_table = [ // a v t vec![(DVRAM::RAM_TYPE as usize).into()], vec![addr_subset.expr()], @@ -584,18 +596,22 @@ impl DynVolatileRamTableFinalC vec![final_cycle.expr()], ] .concat(); - cb.r_table_record( + let final_table_expr = sel.expr() * cb.rlc_chip_record(raw_final_table.clone()) + + (Expression::Constant(Either::Left(E::BaseField::ONE)) - sel.expr()); + cb.r_table_rlc_record( || "final_table", DVRAM::RAM_TYPE, SetTableSpec { len: None, - structural_witins: vec![], + structural_witins: vec![sel], }, - final_table, + raw_final_table, + final_table_expr, )?; Ok(Self { addr_subset, + sel, final_v, final_cycle, phantom: PhantomData, @@ -610,26 +626,24 @@ impl DynVolatileRamTableFinalC num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - assert_eq!(num_structural_witin, 0); + assert_eq!(num_structural_witin, 1); assert!(final_mem.len() <= DVRAM::max_len(&self.params)); assert!(DVRAM::max_len(&self.params).is_power_of_two()); let mut witness = RowMajorMatrix::::new(final_mem.len(), num_witin, InstancePaddingStrategy::Default); + let mut structural_witness = RowMajorMatrix::::new( + final_mem.len(), + num_structural_witin, + InstancePaddingStrategy::Default, + ); witness .par_rows_mut() + .zip(structural_witness.par_rows_mut()) .zip(final_mem) .enumerate() - .for_each(|(i, (row, rec))| { - assert_eq!( - rec.addr, - DVRAM::addr(&self.params, i), - "rec.addr {:x} != expected {:x}", - rec.addr, - DVRAM::addr(&self.params, i), - ); - + .for_each(|(i, ((row, structural_row), rec))| { if self.final_v.len() == 1 { // Assign value directly. set_val!(row, self.final_v[0], rec.value as u64); @@ -643,9 +657,10 @@ impl DynVolatileRamTableFinalC set_val!(row, self.final_cycle, rec.cycle); set_val!(row, self.addr_subset, rec.addr as u64); + set_val!(row, self.sel, 1u64); }); - Ok([witness, RowMajorMatrix::empty()]) + Ok([witness, structural_witness]) } } From 4963848d7e2510e7f6158780a39b3e9cfb4f9737 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 13 Oct 2025 19:34:05 +0800 Subject: [PATCH 39/91] add shard info to public io --- ceno_emul/src/shards.rs | 8 ++++++++ ceno_zkvm/src/chip_handler/general.rs | 8 ++++++-- ceno_zkvm/src/e2e.rs | 3 +++ ceno_zkvm/src/instructions/riscv/constants.rs | 3 ++- ceno_zkvm/src/scheme.rs | 12 ++++++++++++ 5 files changed, 31 insertions(+), 3 deletions(-) diff --git a/ceno_emul/src/shards.rs b/ceno_emul/src/shards.rs index a8d06ab78..fd34baf85 100644 --- a/ceno_emul/src/shards.rs +++ b/ceno_emul/src/shards.rs @@ -11,4 +11,12 @@ impl Shards { num_shards, } } + + pub fn is_first_shard(&self) -> bool { + self.shard_id == 0 + } + + pub fn is_last_shard(&self) -> bool { + self.shard_id == self.num_shards - 1 + } } diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 513c4d98b..805b60baf 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -4,7 +4,7 @@ use gkr_iop::{error::CircuitBuilderError, tables::LookupTable}; use crate::{ circuit_builder::CircuitBuilder, instructions::riscv::constants::{ - END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, + END_CYCLE_IDX, END_PC_IDX, END_SHARD_ID_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, MEM_BUS_WITH_READ_IDX, MEM_BUS_WITH_WRITE_IDX, PUBLIC_IO_IDX, UINT_LIMBS, }, tables::InsnRecord, @@ -22,7 +22,7 @@ pub trait PublicIOQuery { fn query_end_pc(&mut self) -> Result; fn query_end_cycle(&mut self) -> Result; fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; - + fn query_shard_id(&mut self) -> Result; fn query_mem_bus_with_read(&mut self) -> Result; fn query_mem_bus_with_write(&mut self) -> Result; } @@ -63,6 +63,10 @@ impl<'a, E: ExtensionField> PublicIOQuery for CircuitBuilder<'a, E> { self.cs.query_instance(|| "end_cycle", END_CYCLE_IDX) } + fn query_shard_id(&mut self) -> Result { + self.cs.query_instance(|| "shard_id", END_SHARD_ID_IDX) + } + fn query_mem_bus_with_read(&mut self) -> Result { self.cs .query_instance(|| "mem_bus_with_read", MEM_BUS_WITH_READ_IDX) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 786ac2d2a..c0330558f 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -334,6 +334,9 @@ pub fn emulate_program<'a>( Tracer::SUBCYCLES_PER_INSN, vm.get_pc().into(), end_cycle, + shards.shard_id as u32, + !shards.is_first_shard(), // first shard disable global read + !shards.is_last_shard(), // last shard disable global write io_init.iter().map(|rec| rec.value).collect_vec(), ); diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index f471528a6..17316f956 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -9,9 +9,10 @@ pub const INIT_PC_IDX: usize = 2; pub const INIT_CYCLE_IDX: usize = 3; pub const END_PC_IDX: usize = 4; pub const END_CYCLE_IDX: usize = 5; -pub const PUBLIC_IO_IDX: usize = 6; +pub const END_SHARD_ID_IDX: usize = 6; pub const MEM_BUS_WITH_READ_IDX: usize = 7; pub const MEM_BUS_WITH_WRITE_IDX: usize = 8; +pub const PUBLIC_IO_IDX: usize = 9; pub const LIMB_BITS: usize = 16; pub const LIMB_MASK: u32 = 0xFFFF; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 58a9aae89..f2f81c096 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -72,6 +72,9 @@ pub struct PublicValues { init_cycle: u64, end_pc: u32, end_cycle: u64, + shard_id: u32, + mem_bus_with_read: bool, + mem_bus_with_write: bool, public_io: Vec, } @@ -82,6 +85,9 @@ impl PublicValues { init_cycle: u64, end_pc: u32, end_cycle: u64, + shard_id: u32, + mem_bus_with_read: bool, + mem_bus_with_write: bool, public_io: Vec, ) -> Self { Self { @@ -90,6 +96,9 @@ impl PublicValues { init_cycle, end_pc, end_cycle, + shard_id, + mem_bus_with_read, + mem_bus_with_write, public_io, } } @@ -103,6 +112,9 @@ impl PublicValues { vec![E::BaseField::from_canonical_u64(self.init_cycle)], vec![E::BaseField::from_canonical_u32(self.end_pc)], vec![E::BaseField::from_canonical_u64(self.end_cycle)], + vec![E::BaseField::from_canonical_u32(self.shard_id)], + vec![E::BaseField::from_bool(self.mem_bus_with_read)], + vec![E::BaseField::from_bool(self.mem_bus_with_write)], ] .into_iter() .chain( From 160291e1fb06c303ea0712eac773559144a37918 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 13 Oct 2025 23:25:40 +0800 Subject: [PATCH 40/91] make all opcode circuits to be stateful --- ceno_emul/src/syscalls/bn254/bn254_fptower.rs | 5 ++ ceno_emul/src/syscalls/keccak_permute.rs | 1 + ceno_emul/src/syscalls/secp256k1.rs | 5 ++ ceno_emul/src/syscalls/sha256.rs | 1 + ceno_zkvm/src/chip_handler/general.rs | 27 ++++++- ceno_zkvm/src/gadgets/mod.rs | 1 + ceno_zkvm/src/gadgets/poseidon2.rs | 6 +- ceno_zkvm/src/instructions/global.rs | 80 ++++++++++++------- ceno_zkvm/src/instructions/riscv.rs | 2 +- ceno_zkvm/src/instructions/riscv/arith.rs | 16 ++-- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 8 +- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 2 + ceno_zkvm/src/instructions/riscv/auipc.rs | 6 +- ceno_zkvm/src/instructions/riscv/branch.rs | 11 +++ .../riscv/branch/branch_circuit.rs | 4 +- .../riscv/branch/branch_circuit_v2.rs | 2 + .../src/instructions/riscv/branch/test.rs | 22 ++--- ceno_zkvm/src/instructions/riscv/constants.rs | 1 + ceno_zkvm/src/instructions/riscv/div.rs | 15 +++- .../instructions/riscv/div/div_circuit_v2.rs | 2 + .../instructions/riscv/dummy/dummy_circuit.rs | 2 + .../instructions/riscv/dummy/dummy_ecall.rs | 2 + .../src/instructions/riscv/dummy/test.rs | 12 ++- ceno_zkvm/src/instructions/riscv/ecall.rs | 2 + .../src/instructions/riscv/ecall/halt.rs | 2 + .../src/instructions/riscv/ecall/keccak.rs | 3 + .../riscv/ecall/weierstrass_add.rs | 3 + .../riscv/ecall/weierstrass_decompress.rs | 3 + .../riscv/ecall/weierstrass_double.rs | 3 + .../src/instructions/riscv/jump/jal_v2.rs | 2 + ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 1 + .../src/instructions/riscv/jump/jalr_v2.rs | 2 + ceno_zkvm/src/instructions/riscv/jump/test.rs | 6 +- ceno_zkvm/src/instructions/riscv/logic.rs | 6 ++ .../instructions/riscv/logic/logic_circuit.rs | 2 + .../src/instructions/riscv/logic/test.rs | 9 ++- ceno_zkvm/src/instructions/riscv/logic_imm.rs | 6 ++ .../riscv/logic_imm/logic_imm_circuit_v2.rs | 2 + .../src/instructions/riscv/logic_imm/test.rs | 13 +-- ceno_zkvm/src/instructions/riscv/lui.rs | 6 +- ceno_zkvm/src/instructions/riscv/memory.rs | 15 ++++ .../src/instructions/riscv/memory/load_v2.rs | 4 +- .../src/instructions/riscv/memory/store_v2.rs | 2 + .../src/instructions/riscv/memory/test.rs | 18 ++++- ceno_zkvm/src/instructions/riscv/mulh.rs | 32 ++++---- .../riscv/mulh/mulh_circuit_v2.rs | 4 +- ceno_zkvm/src/instructions/riscv/shift.rs | 14 ++-- .../riscv/shift/shift_circuit_v2.rs | 6 +- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 12 ++- ceno_zkvm/src/instructions/riscv/slt.rs | 10 ++- .../instructions/riscv/slt/slt_circuit_v2.rs | 2 + ceno_zkvm/src/instructions/riscv/slti.rs | 12 +-- .../riscv/slti/slti_circuit_v2.rs | 4 +- ceno_zkvm/src/instructions/riscv/test.rs | 6 +- ceno_zkvm/src/scheme/tests.rs | 3 + ceno_zkvm/src/scheme/utils.rs | 3 +- ceno_zkvm/src/structs.rs | 10 ++- gkr_iop/src/circuit_builder.rs | 31 +++++++ 58 files changed, 365 insertions(+), 127 deletions(-) diff --git a/ceno_emul/src/syscalls/bn254/bn254_fptower.rs b/ceno_emul/src/syscalls/bn254/bn254_fptower.rs index 0e7c21db6..5e52d5fad 100644 --- a/ceno_emul/src/syscalls/bn254/bn254_fptower.rs +++ b/ceno_emul/src/syscalls/bn254/bn254_fptower.rs @@ -11,7 +11,9 @@ use crate::{ use super::types::{BN254_FP_WORDS, BN254_FP2_WORDS}; +#[derive(Default)] pub struct Bn254FpAddSpec; + impl SyscallSpec for Bn254FpAddSpec { const NAME: &'static str = "BN254_FP_ADD"; @@ -20,6 +22,7 @@ impl SyscallSpec for Bn254FpAddSpec { const CODE: u32 = ceno_rt::syscalls::BN254_FP_ADD; } +#[derive(Default)] pub struct Bn254Fp2AddSpec; impl SyscallSpec for Bn254Fp2AddSpec { const NAME: &'static str = "BN254_FP2_ADD"; @@ -29,6 +32,7 @@ impl SyscallSpec for Bn254Fp2AddSpec { const CODE: u32 = ceno_rt::syscalls::BN254_FP2_ADD; } +#[derive(Default)] pub struct Bn254FpMulSpec; impl SyscallSpec for Bn254FpMulSpec { const NAME: &'static str = "BN254_FP_MUL"; @@ -38,6 +42,7 @@ impl SyscallSpec for Bn254FpMulSpec { const CODE: u32 = ceno_rt::syscalls::BN254_FP_MUL; } +#[derive(Default)] pub struct Bn254Fp2MulSpec; impl SyscallSpec for Bn254Fp2MulSpec { const NAME: &'static str = "BN254_FP2_MUL"; diff --git a/ceno_emul/src/syscalls/keccak_permute.rs b/ceno_emul/src/syscalls/keccak_permute.rs index 31757ea38..022bc9597 100644 --- a/ceno_emul/src/syscalls/keccak_permute.rs +++ b/ceno_emul/src/syscalls/keccak_permute.rs @@ -8,6 +8,7 @@ use super::{SyscallEffects, SyscallSpec, SyscallWitness}; const KECCAK_CELLS: usize = 25; // u64 cells pub const KECCAK_WORDS: usize = KECCAK_CELLS * 2; // u32 words +#[derive(Default)] pub struct KeccakSpec; impl SyscallSpec for KeccakSpec { diff --git a/ceno_emul/src/syscalls/secp256k1.rs b/ceno_emul/src/syscalls/secp256k1.rs index 2e0e89506..59d98df31 100644 --- a/ceno_emul/src/syscalls/secp256k1.rs +++ b/ceno_emul/src/syscalls/secp256k1.rs @@ -5,8 +5,13 @@ use std::iter; use super::{SyscallEffects, SyscallSpec, SyscallWitness}; +#[derive(Default)] pub struct Secp256k1AddSpec; + +#[derive(Default)] pub struct Secp256k1DoubleSpec; + +#[derive(Default)] pub struct Secp256k1DecompressSpec; impl SyscallSpec for Secp256k1AddSpec { diff --git a/ceno_emul/src/syscalls/sha256.rs b/ceno_emul/src/syscalls/sha256.rs index 08bdf0fb5..42e1e0f60 100644 --- a/ceno_emul/src/syscalls/sha256.rs +++ b/ceno_emul/src/syscalls/sha256.rs @@ -4,6 +4,7 @@ use super::{SyscallEffects, SyscallSpec, SyscallWitness}; pub const SHA_EXTEND_WORDS: usize = 64; // u64 cells +#[derive(Default)] pub struct Sha256ExtendSpec; impl SyscallSpec for Sha256ExtendSpec { diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index e1ace19d0..a86d82dd0 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -4,9 +4,10 @@ use gkr_iop::{error::CircuitBuilderError, tables::LookupTable}; use crate::{ circuit_builder::CircuitBuilder, instructions::riscv::constants::{ - END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX, - UINT_LIMBS, + END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, GLOBAL_RW_SUM_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, + PUBLIC_IO_IDX, UINT_LIMBS, }, + scheme::constants::SEPTIC_EXTENSION_DEGREE, tables::InsnRecord, }; use multilinear_extensions::{Expression, Instance}; @@ -21,6 +22,7 @@ pub trait PublicIOQuery { fn query_init_cycle(&mut self) -> Result; fn query_end_pc(&mut self) -> Result; fn query_end_cycle(&mut self) -> Result; + fn query_global_rw_sum(&mut self) -> Result, CircuitBuilderError>; fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; } @@ -67,4 +69,25 @@ impl<'a, E: ExtensionField> PublicIOQuery for CircuitBuilder<'a, E> { .query_instance(|| "public_io_high", PUBLIC_IO_IDX + 1)?, ]) } + + fn query_global_rw_sum(&mut self) -> Result, CircuitBuilderError> { + let x = (0..SEPTIC_EXTENSION_DEGREE) + .into_iter() + .map(|i| { + self.cs + .query_instance(|| format!("global_rw_sum_x_{}", i), GLOBAL_RW_SUM_IDX + i) + }) + .collect::, CircuitBuilderError>>()?; + let y = (0..SEPTIC_EXTENSION_DEGREE) + .into_iter() + .map(|i| { + self.cs.query_instance( + || format!("global_rw_sum_y_{}", i), + GLOBAL_RW_SUM_IDX + SEPTIC_EXTENSION_DEGREE + i, + ) + }) + .collect::, CircuitBuilderError>>()?; + + Ok([x, y].concat()) + } } diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index fe8dda7b4..d0d8ed67d 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -15,6 +15,7 @@ pub use gkr_iop::gadgets::{ AssertLtConfig, InnerLtConfig, IsEqualConfig, IsLtConfig, IsZeroConfig, cal_lt_diff, }; pub use is_lt::{AssertSignedLtConfig, SignedLtConfig}; +pub(crate) use poseidon2::RoundConstants; pub use poseidon2::{Poseidon2BabyBearConfig, Poseidon2Config}; pub(crate) use poseidon2_constants::horizen_round_consts; pub use signed::Signed; diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index 021513ac2..c713c332c 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -12,7 +12,7 @@ use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; use num_bigint::BigUint; use p3::{ - babybear::{BabyBear, BabyBearInternalLayerParameters}, + babybear::{BabyBearInternalLayerParameters}, field::{Field, FieldAlgebra}, monty_31::InternalLayerBaseParameters, poseidon2::{MDSMat4, mds_light_permutation}, @@ -248,6 +248,10 @@ impl< poseidon2_cols.inputs.to_vec() } + pub fn output(&self) -> Vec> { + todo!() + } + // pub fn assign_instance(&self, input: &[E; STATE_WIDTH]) { // generate_trace_rows(inputs, constants) // let poseidon2_cols: &Poseidon2Cols< diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 949596da5..2c943fb7e 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -1,7 +1,14 @@ -use crate::gadgets::{Poseidon2BabyBearConfig, horizen_round_consts}; -use ff_ext::{BabyBearExt4, ExtensionField}; -use gkr_iop::{circuit_builder::CircuitBuilder, error::CircuitBuilderError}; -use multilinear_extensions::{Expression, ToExpr, WitIn}; +use std::iter::repeat; + +use crate::{ + chip_handler::general::PublicIOQuery, + gadgets::{Poseidon2Config, RoundConstants}, +}; +use ff_ext::ExtensionField; +use gkr_iop::{ + circuit_builder::CircuitBuilder, error::CircuitBuilderError, +}; +use multilinear_extensions::{ToExpr, WitIn}; use p3::field::FieldAlgebra; use crate::{ @@ -9,10 +16,12 @@ use crate::{ scheme::constants::SEPTIC_EXTENSION_DEGREE, }; -// opcode circuit + mem init/final table + mem local chip: consistency RAMType::Register / Memory - -// mem local <-> global -// precompile <-> global +// opcode circuit + mem init/final table + global chip: +// have read/write consistency for RAMType::Register +// and RAMType::Memory +// +// global chip: read from and write into a global set shared +// among multiple shards pub struct GlobalConfig { addr: WitIn, ram_type: WitIn, @@ -22,15 +31,19 @@ pub struct GlobalConfig { is_write: WitIn, x: Vec, y: Vec, - poseidon2: Poseidon2BabyBearConfig, + poseidon2: Poseidon2Config, } impl GlobalConfig { - pub fn config(cb: &mut CircuitBuilder) -> Result { - let x = (0..SEPTIC_EXTENSION_DEGREE) + // TODO: make `WIDTH`, `HALF_FULL_ROUNDS`, `PARTIAL_ROUNDS` generic parameters + pub fn configure( + cb: &mut CircuitBuilder, + rc: RoundConstants, + ) -> Result { + let x: Vec = (0..SEPTIC_EXTENSION_DEGREE) .map(|i| cb.create_witin(|| format!("x{}", i))) .collect(); - let y = (0..SEPTIC_EXTENSION_DEGREE) + let y: Vec = (0..SEPTIC_EXTENSION_DEGREE) .map(|i| cb.create_witin(|| format!("y{}", i))) .collect(); let addr = cb.create_witin(|| "addr"); @@ -40,9 +53,8 @@ impl GlobalConfig { let clk = cb.create_witin(|| "clk"); let is_write = cb.create_witin(|| "is_write"); - let rc = horizen_round_consts(); - let cb: &mut CircuitBuilder<'_, BabyBearExt4> = unsafe { std::mem::transmute(cb) }; - let hasher = Poseidon2BabyBearConfig::construct(cb, rc); + // TODO: support other field + let hasher = Poseidon2Config::construct(cb, rc); let mut input = vec![]; input.push(addr.expr()); @@ -51,15 +63,24 @@ impl GlobalConfig { input.extend(value.memory_expr()); input.push(shard.expr()); input.push(clk.expr()); + input.extend(repeat(E::BaseField::ZERO.expr()).take(16 - 6)); + + // enforces final_sum = \sum_i (x_i, y_i) using ecc quark protocol + let final_sum = cb.query_global_rw_sum()?; + cb.ec_sum( + x.iter().map(|xi| xi.expr()).collect::>(), + y.iter().map(|yi| yi.expr()).collect::>(), + final_sum.into_iter().map(|x| x.expr()).collect::>(), + ); + // enforces x = poseidon2([addr, ram_type, value[0], value[1], shard, clk, 0]) for (input_expr, hasher_input) in input.into_iter().zip(hasher.inputs().into_iter()) { // TODO: replace with cb.require_equal() - cb.require_zero(|| "poseidon2 input", input_expr - hasher_input); + cb.require_zero(|| "poseidon2 input", input_expr - hasher_input)?; + } + for (xi, hasher_output) in x.iter().zip(hasher.output().into_iter()) { + cb.require_zero(|| "poseidon2 output", xi.expr() - hasher_output)?; } - - // TODO: enforce x = poseidon2([addr, ram_type, value[0], value[1], shard, clk]) - // TODO: enforce \sum_i (xi, yi) = ecc_sum - // TODO: output ecc_sum as public values // TODO: enforce is_write is boolean // TODO: enforce y < p/2 if is_write = 1 @@ -79,11 +100,13 @@ impl GlobalConfig { } } -// This chip is used to manage read/write into a global set +// This chip is used to manage read/write into a global set // shared among multiple shards -pub struct GlobalChip {} +pub struct GlobalChip { + rc: RoundConstants, +} -impl Instruction for GlobalChip { +impl Instruction for GlobalChip { type InstructionConfig = GlobalConfig; fn name() -> String { @@ -91,19 +114,20 @@ impl Instruction for GlobalChip { } fn construct_circuit( + &self, cb: &mut CircuitBuilder, _param: &crate::structs::ProgramParams, ) -> Result { - let config = GlobalConfig::config(cb)?; + let config = GlobalConfig::configure(cb, self.rc.clone())?; Ok(config) } fn assign_instance( - config: &Self::InstructionConfig, - instance: &mut [::BaseField], - lk_multiplicity: &mut crate::witness::LkMultiplicity, - step: &ceno_emul::StepRecord, + _config: &Self::InstructionConfig, + _instance: &mut [::BaseField], + _lk_multiplicity: &mut crate::witness::LkMultiplicity, + _step: &ceno_emul::StepRecord, ) -> Result<(), crate::error::ZKVMError> { todo!() } diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 69c656148..7b86cfca5 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -44,7 +44,7 @@ mod test; #[cfg(test)] mod test_utils; -pub trait RIVInstruction { +pub trait RIVInstruction: Default { const INST_KIND: InsnKind; } diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index b73abcda4..a71147f18 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -18,15 +18,20 @@ pub struct ArithConfig { rd_written: UInt, } -pub struct ArithInstruction(PhantomData<(E, I)>); +#[derive(Default)] +pub struct ArithInstruction(PhantomData<(E, I)>); +#[derive(Default)] pub struct AddOp; + impl RIVInstruction for AddOp { const INST_KIND: InsnKind = InsnKind::ADD; } pub type AddInstruction = ArithInstruction; +#[derive(Default)] pub struct SubOp; + impl RIVInstruction for SubOp { const INST_KIND: InsnKind = InsnKind::SUB; } @@ -40,6 +45,7 @@ impl Instruction for ArithInstruction, _params: &ProgramParams, ) -> Result { @@ -163,15 +169,11 @@ mod test { fn verify(name: &'static str, rs1: u32, rs2: u32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = ArithInstruction::::default(); let config = cb .namespace( || format!("{:?}_({name})", I::INST_KIND), - |cb| { - Ok(ArithInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - )) - }, + |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), ) .unwrap() .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index a040681bc..974c77fd9 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -3,6 +3,8 @@ mod arith_imm_circuit; #[cfg(feature = "u16limb_circuit")] mod arith_imm_circuit_v2; +use ff_ext::ExtensionField; + #[cfg(feature = "u16limb_circuit")] pub use crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::AddiInstruction; @@ -11,7 +13,7 @@ pub use crate::instructions::riscv::arith_imm::arith_imm_circuit::AddiInstructio use super::RIVInstruction; -impl RIVInstruction for AddiInstruction { +impl RIVInstruction for AddiInstruction { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::ADDI; } @@ -48,12 +50,12 @@ mod test { fn test_opcode_addi_internal(rs1: u32, rd: u32, imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = AddiInstruction::::default(); let config = cb .namespace( || "addi", |cb| { - let config = - AddiInstruction::::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index f969a68b0..cbadc807d 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -17,6 +17,7 @@ use p3::field::FieldAlgebra; use std::marker::PhantomData; use witness::set_val; +#[derive(Default)] pub struct AddiInstruction(PhantomData); pub struct InstructionConfig { @@ -37,6 +38,7 @@ impl Instruction for AddiInstruction { } fn construct_circuit( + &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 7957f7003..5304016e0 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -32,6 +32,7 @@ pub struct AuipcConfig { pub rd_written: UInt8, } +#[derive(Default)] pub struct AuipcInstruction(PhantomData); impl Instruction for AuipcInstruction { @@ -42,6 +43,7 @@ impl Instruction for AuipcInstruction { } fn construct_circuit( + &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { @@ -224,12 +226,12 @@ mod tests { fn test_opcode_auipc(rd: u32, imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = AuipcInstruction::default(); let config = cb .namespace( || "auipc", |cb| { - let config = - AuipcInstruction::::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/branch.rs b/ceno_zkvm/src/instructions/riscv/branch.rs index dc2c8c9e6..082c3f897 100644 --- a/ceno_zkvm/src/instructions/riscv/branch.rs +++ b/ceno_zkvm/src/instructions/riscv/branch.rs @@ -6,19 +6,25 @@ mod branch_circuit_v2; #[cfg(test)] mod test; +#[derive(Default)] pub struct BeqOp; + impl RIVInstruction for BeqOp { const INST_KIND: InsnKind = InsnKind::BEQ; } pub type BeqInstruction = branch_circuit::BranchCircuit; +#[derive(Default)] pub struct BneOp; + impl RIVInstruction for BneOp { const INST_KIND: InsnKind = InsnKind::BNE; } pub type BneInstruction = branch_circuit::BranchCircuit; +#[derive(Default)] pub struct BltuOp; + impl RIVInstruction for BltuOp { const INST_KIND: InsnKind = InsnKind::BLTU; } @@ -27,7 +33,9 @@ pub type BltuInstruction = branch_circuit_v2::BranchCircuit; #[cfg(not(feature = "u16limb_circuit"))] pub type BltuInstruction = branch_circuit::BranchCircuit; +#[derive(Default)] pub struct BgeuOp; + impl RIVInstruction for BgeuOp { const INST_KIND: InsnKind = InsnKind::BGEU; } @@ -36,7 +44,9 @@ pub type BgeuInstruction = branch_circuit_v2::BranchCircuit; #[cfg(not(feature = "u16limb_circuit"))] pub type BgeuInstruction = branch_circuit::BranchCircuit; +#[derive(Default)] pub struct BltOp; + impl RIVInstruction for BltOp { const INST_KIND: InsnKind = InsnKind::BLT; } @@ -45,6 +55,7 @@ pub type BltInstruction = branch_circuit_v2::BranchCircuit; #[cfg(not(feature = "u16limb_circuit"))] pub type BltInstruction = branch_circuit::BranchCircuit; +#[derive(Default)] pub struct BgeOp; impl RIVInstruction for BgeOp { const INST_KIND: InsnKind = InsnKind::BGE; diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs index 8aecd50f8..efbe64d5a 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs @@ -22,7 +22,8 @@ use crate::{ use multilinear_extensions::Expression; pub use p3::field::FieldAlgebra; -pub struct BranchCircuit(PhantomData<(E, I)>); +#[derive(Default)] +pub struct BranchCircuit(PhantomData<(E, I)>); pub struct BranchConfig { pub b_insn: BInstructionConfig, @@ -41,6 +42,7 @@ impl Instruction for BranchCircuit; fn construct_circuit( + &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 94abb56d1..3d4ed92de 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -15,6 +15,7 @@ use ff_ext::ExtensionField; use multilinear_extensions::Expression; use std::marker::PhantomData; +#[derive(Default)] pub struct BranchCircuit(PhantomData<(E, I)>); pub struct BranchConfig { @@ -34,6 +35,7 @@ impl Instruction for BranchCircuit, _param: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index aaf468127..84d61553e 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -24,11 +24,12 @@ fn test_opcode_beq() { fn impl_opcode_beq(equal: bool) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = BeqInstruction::default(); let config = cb .namespace( || "beq", |cb| { - let config = BeqInstruction::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -64,11 +65,12 @@ fn test_opcode_bne() { fn impl_opcode_bne(equal: bool) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = BneInstruction::default(); let config = cb .namespace( || "bne", |cb| { - let config = BneInstruction::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -110,8 +112,8 @@ fn test_bltu_circuit() -> Result<(), ZKVMError> { fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { let mut cs = ConstraintSystem::new(|| "riscv"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let config = - BltuInstruction::construct_circuit(&mut circuit_builder, &ProgramParams::default())?; + let inst = BltuInstruction::default(); + let config = inst.construct_circuit(&mut circuit_builder, &ProgramParams::default())?; let pc_after = if taken { ByteAddr(MOCK_PC_START.0 - 8) @@ -154,8 +156,8 @@ fn test_bgeu_circuit() -> Result<(), ZKVMError> { fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { let mut cs = ConstraintSystem::new(|| "riscv"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let config = - BgeuInstruction::construct_circuit(&mut circuit_builder, &ProgramParams::default())?; + let inst = BgeuInstruction::default(); + let config = inst.construct_circuit(&mut circuit_builder, &ProgramParams::default())?; let pc_after = if taken { ByteAddr(MOCK_PC_START.0 - 8) @@ -205,8 +207,8 @@ fn test_blt_circuit() -> Result<(), ZKVMError> { fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { let mut cs = ConstraintSystem::new(|| "riscv"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let config = - BltInstruction::construct_circuit(&mut circuit_builder, &ProgramParams::default())?; + let inst = BltInstruction::default(); + let config = inst.construct_circuit(&mut circuit_builder, &ProgramParams::default())?; let pc_after = if taken { ByteAddr(MOCK_PC_START.0 - 8) @@ -256,8 +258,8 @@ fn test_bge_circuit() -> Result<(), ZKVMError> { fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { let mut cs = ConstraintSystem::new(|| "riscv"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let config = - BgeInstruction::construct_circuit(&mut circuit_builder, &ProgramParams::default())?; + let inst = BgeInstruction::default(); + let config = inst.construct_circuit(&mut circuit_builder, &ProgramParams::default())?; let pc_after = if taken { ByteAddr(MOCK_PC_START.0 - 8) diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 1992f4fa3..a9ab6abcc 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -10,6 +10,7 @@ pub const INIT_CYCLE_IDX: usize = 3; pub const END_PC_IDX: usize = 4; pub const END_CYCLE_IDX: usize = 5; pub const PUBLIC_IO_IDX: usize = 6; +pub const GLOBAL_RW_SUM_IDX: usize = PUBLIC_IO_IDX + 2; pub const LIMB_BITS: usize = 16; pub const LIMB_MASK: u32 = 0xFFFF; diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 7ca30d2b8..9fb0695bc 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -7,7 +7,9 @@ mod div_circuit_v2; use super::RIVInstruction; +#[derive(Default)] pub struct DivuOp; + impl RIVInstruction for DivuOp { const INST_KIND: InsnKind = InsnKind::DIVU; } @@ -16,7 +18,9 @@ pub type DivuInstruction = div_circuit_v2::ArithInstruction; #[cfg(not(feature = "u16limb_circuit"))] pub type DivuInstruction = div_circuit::ArithInstruction; +#[derive(Default)] pub struct RemuOp; + impl RIVInstruction for RemuOp { const INST_KIND: InsnKind = InsnKind::REMU; } @@ -25,7 +29,9 @@ pub type RemuInstruction = div_circuit_v2::ArithInstruction; #[cfg(not(feature = "u16limb_circuit"))] pub type RemuInstruction = div_circuit::ArithInstruction; +#[derive(Default)] pub struct RemOp; + impl RIVInstruction for RemOp { const INST_KIND: InsnKind = InsnKind::REM; } @@ -34,7 +40,9 @@ pub type RemInstruction = div_circuit_v2::ArithInstruction; #[cfg(not(feature = "u16limb_circuit"))] pub type RemInstruction = div_circuit::ArithInstruction; +#[derive(Default)] pub struct DivOp; + impl RIVInstruction for DivOp { const INST_KIND: InsnKind = InsnKind::DIV; } @@ -158,7 +166,7 @@ mod test { const INSN_KIND: InsnKind = InsnKind::REMU; } - fn verify + TestInstance>( + fn verify + TestInstance + Default>( name: &str, dividend: >::NumType, divisor: >::NumType, @@ -167,10 +175,11 @@ mod test { ) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = Insn::default(); let config = cb .namespace( || format!("{}_({})", Insn::name(), name), - |cb| Ok(Insn::construct_circuit(cb, &ProgramParams::default())), + |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), ) .unwrap() .unwrap(); @@ -220,7 +229,7 @@ mod test { } // shortcut to verify given pair produces correct output - fn verify_positive + TestInstance>( + fn verify_positive + TestInstance + Default>( name: &str, dividend: >::NumType, divisor: >::NumType, diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index d2d2b78ee..d11330f27 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -43,6 +43,7 @@ pub struct DivRemConfig { lt_diff: WitIn, } +#[derive(Default)] pub struct ArithInstruction(PhantomData<(E, I)>); impl Instruction for ArithInstruction { @@ -53,6 +54,7 @@ impl Instruction for ArithInstruction, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index 7c98e2159..04e59cc96 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -20,6 +20,7 @@ use p3::field::FieldAlgebra; use witness::set_val; /// DummyInstruction can handle any instruction and produce its side-effects. +#[derive(Default)] pub struct DummyInstruction(PhantomData<(E, I)>); impl Instruction for DummyInstruction { @@ -30,6 +31,7 @@ impl Instruction for DummyInstruction, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 69bdd1648..662f5a0e1 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -24,6 +24,7 @@ use witness::set_val; /// including multiple memory operations. /// /// Unsafe: The content is not constrained. +#[derive(Default)] pub struct LargeEcallDummy(PhantomData<(E, S)>); impl Instruction for LargeEcallDummy { @@ -33,6 +34,7 @@ impl Instruction for LargeEcallDummy S::NAME.to_owned() } fn construct_circuit( + &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index 6f7a89f73..2d8a24ac3 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -19,11 +19,12 @@ type BeqDummy = DummyInstruction; fn test_dummy_ecall() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = EcallDummy::default(); let config = cb .namespace( || "ecall_dummy", |cb| { - let config = EcallDummy::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -49,11 +50,12 @@ fn test_dummy_keccak() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = KeccakDummy::default(); let config = cb .namespace( || "keccak_dummy", |cb| { - let config = KeccakDummy::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -76,11 +78,12 @@ fn test_dummy_keccak() { fn test_dummy_r() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = AddDummy::default(); let config = cb .namespace( || "add_dummy", |cb| { - let config = AddDummy::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -111,11 +114,12 @@ fn test_dummy_r() { fn test_dummy_b() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = BeqDummy::default(); let config = cb .namespace( || "beq_dummy", |cb| { - let config = BeqDummy::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/ecall.rs b/ceno_zkvm/src/instructions/riscv/ecall.rs index a25bbeeb6..ba3a9d00e 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall.rs @@ -14,7 +14,9 @@ pub use halt::HaltInstruction; use super::{RIVInstruction, dummy::DummyInstruction}; +#[derive(Default)] pub struct EcallOp; + impl RIVInstruction for EcallOp { const INST_KIND: InsnKind = InsnKind::ECALL; } diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index e14585727..e5709adc0 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -26,6 +26,7 @@ pub struct HaltConfig { lt_x10_cfg: AssertLtConfig, } +#[derive(Default)] pub struct HaltInstruction(PhantomData); impl Instruction for HaltInstruction { @@ -36,6 +37,7 @@ impl Instruction for HaltInstruction { } fn construct_circuit( + &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index b0ac2a505..57bd13897 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -49,6 +49,7 @@ pub struct EcallKeccakConfig { } /// KeccakInstruction can handle any instruction and produce its side-effects. +#[derive(Default)] pub struct KeccakInstruction(PhantomData); impl Instruction for KeccakInstruction { @@ -59,6 +60,7 @@ impl Instruction for KeccakInstruction { } fn construct_circuit( + &self, _circuit_builder: &mut CircuitBuilder, _param: &ProgramParams, ) -> Result { @@ -66,6 +68,7 @@ impl Instruction for KeccakInstruction { } fn build_gkr_iop_circuit( + &self, cb: &mut CircuitBuilder, _param: &ProgramParams, ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 6365cfcd2..e2fa19e7a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -52,6 +52,7 @@ pub struct EcallWeierstrassAddAssignConfig } /// WeierstrassAddAssignInstruction can handle any instruction and produce its side-effects. +#[derive(Default)] pub struct WeierstrassAddAssignInstruction(PhantomData<(E, EC)>); impl Instruction @@ -64,6 +65,7 @@ impl Instruction } fn construct_circuit( + &self, _circuit_builder: &mut CircuitBuilder, _param: &ProgramParams, ) -> Result { @@ -71,6 +73,7 @@ impl Instruction } fn build_gkr_iop_circuit( + &self, cb: &mut CircuitBuilder, _param: &ProgramParams, ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index 6003f9794..7d094e612 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -59,6 +59,7 @@ pub struct EcallWeierstrassDecompressConfig(PhantomData<(E, EC)>); impl Instruction @@ -71,6 +72,7 @@ impl Instruction, _param: &ProgramParams, ) -> Result { @@ -78,6 +80,7 @@ impl Instruction, _param: &ProgramParams, ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index aa8e18972..210fe81c9 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -54,6 +54,7 @@ pub struct EcallWeierstrassDoubleAssignConfig< } /// WeierstrassDoubleAssignInstruction can handle any instruction and produce its side-effects. +#[derive(Default)] pub struct WeierstrassDoubleAssignInstruction(PhantomData<(E, EC)>); impl Instruction @@ -66,6 +67,7 @@ impl Instruction, _param: &ProgramParams, ) -> Result { @@ -73,6 +75,7 @@ impl Instruction, _param: &ProgramParams, ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 0f67be424..9a6830e6f 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -26,6 +26,7 @@ pub struct JalConfig { pub rd_written: UInt8, } +#[derive(Default)] pub struct JalInstruction(PhantomData); /// JAL instruction circuit @@ -47,6 +48,7 @@ impl Instruction for JalInstruction { } fn construct_circuit( + &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index f1ba94aa7..ba6f4d2d9 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -28,6 +28,7 @@ pub struct JalrConfig { pub rd_written: UInt, } +#[derive(Default)] pub struct JalrInstruction(PhantomData); /// JALR instruction circuit diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index bfec3a099..f3ff2990a 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -33,6 +33,7 @@ pub struct JalrConfig { pub rd_high: WitIn, } +#[derive(Default)] pub struct JalrInstruction(PhantomData); /// JALR instruction circuit @@ -47,6 +48,7 @@ impl Instruction for JalrInstruction { } fn construct_circuit( + &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index 0b379f250..7de183fee 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -27,11 +27,12 @@ fn test_opcode_jal() { fn verify_test_opcode_jal(pc_offset: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = JalInstruction::default(); let config = cb .namespace( || "jal", |cb| { - let config = JalInstruction::::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -86,11 +87,12 @@ fn test_opcode_jalr() { fn verify_test_opcode_jalr(rs1_read: Word, imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = JalrInstruction::default(); let config = cb .namespace( || "jalr", |cb| { - let config = JalrInstruction::::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/logic.rs b/ceno_zkvm/src/instructions/riscv/logic.rs index 9ac2cd4c1..8749231ed 100644 --- a/ceno_zkvm/src/instructions/riscv/logic.rs +++ b/ceno_zkvm/src/instructions/riscv/logic.rs @@ -7,21 +7,27 @@ mod test; use ceno_emul::InsnKind; +#[derive(Default)] pub struct AndOp; + impl LogicOp for AndOp { const INST_KIND: InsnKind = InsnKind::AND; type OpsTable = AndTable; } pub type AndInstruction = LogicInstruction; +#[derive(Default)] pub struct OrOp; + impl LogicOp for OrOp { const INST_KIND: InsnKind = InsnKind::OR; type OpsTable = OrTable; } pub type OrInstruction = LogicInstruction; +#[derive(Default)] pub struct XorOp; + impl LogicOp for XorOp { const INST_KIND: InsnKind = InsnKind::XOR; type OpsTable = XorTable; diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index f761f6102..c57a20b8e 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -24,6 +24,7 @@ pub trait LogicOp { } /// The Instruction circuit for a given LogicOp. +#[derive(Default)] pub struct LogicInstruction(PhantomData<(E, I)>); impl Instruction for LogicInstruction { @@ -34,6 +35,7 @@ impl Instruction for LogicInstruction { } fn construct_circuit( + &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/logic/test.rs b/ceno_zkvm/src/instructions/riscv/logic/test.rs index dc01487d9..87de3f9a1 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/test.rs @@ -18,11 +18,12 @@ const B: Word = 0xef552020; fn test_opcode_and() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = AndInstruction::default(); let config = cb .namespace( || "and", |cb| { - let config = AndInstruction::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -60,11 +61,12 @@ fn test_opcode_and() { fn test_opcode_or() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = OrInstruction::default(); let config = cb .namespace( || "or", |cb| { - let config = OrInstruction::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -102,11 +104,12 @@ fn test_opcode_or() { fn test_opcode_xor() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = XorInstruction::default(); let config = cb .namespace( || "xor", |cb| { - let config = XorInstruction::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm.rs b/ceno_zkvm/src/instructions/riscv/logic_imm.rs index a4b46edcc..1e9dc25e1 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm.rs @@ -24,21 +24,27 @@ use gkr_iop::tables::ops::{AndTable, OrTable, XorTable}; use ceno_emul::InsnKind; use gkr_iop::tables::OpsTable; +#[derive(Default)] pub struct AndiOp; + impl LogicOp for AndiOp { const INST_KIND: InsnKind = InsnKind::ANDI; type OpsTable = AndTable; } pub type AndiInstruction = LogicInstruction; +#[derive(Default)] pub struct OriOp; + impl LogicOp for OriOp { const INST_KIND: InsnKind = InsnKind::ORI; type OpsTable = OrTable; } pub type OriInstruction = LogicInstruction; +#[derive(Default)] pub struct XoriOp; + impl LogicOp for XoriOp { const INST_KIND: InsnKind = InsnKind::XORI; type OpsTable = XorTable; diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index c72f31efe..6f98e6c74 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -26,6 +26,7 @@ use ceno_emul::{InsnKind, StepRecord}; use multilinear_extensions::ToExpr; /// The Instruction circuit for a given LogicOp. +#[derive(Default)] pub struct LogicInstruction(PhantomData<(E, I)>); impl Instruction for LogicInstruction { @@ -36,6 +37,7 @@ impl Instruction for LogicInstruction { } fn construct_circuit( + &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs index 23aa2d77c..33b5f8065 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs @@ -42,9 +42,15 @@ fn test_opcode_xori() { verify::("negative imm", TEST, NEG, TEST ^ NEG); } -fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_written: u32) { +fn verify( + name: &'static str, + rs1_read: u32, + imm: u32, + expected_rd_written: u32, +) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = LogicInstruction::::default(); let (prefix, rd_written) = match I::INST_KIND { InsnKind::ANDI => ("ANDI", rs1_read & imm), @@ -57,10 +63,7 @@ fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_w .namespace( || format!("{prefix}_({name})"), |cb| { - let config = LogicInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - ); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 2cc280f04..c495cfb04 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -29,6 +29,7 @@ pub struct LuiConfig { pub rd_written: [WitIn; UINT_BYTE_LIMBS - 1], } +#[derive(Default)] pub struct LuiInstruction(PhantomData); impl Instruction for LuiInstruction { @@ -39,6 +40,7 @@ impl Instruction for LuiInstruction { } fn construct_circuit( + &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { @@ -138,12 +140,12 @@ mod tests { fn test_opcode_lui(rd: u32, imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = LuiInstruction::default(); let config = cb .namespace( || "lui", |cb| { - let config = - LuiInstruction::::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs index bb29491f7..612058667 100644 --- a/ceno_zkvm/src/instructions/riscv/memory.rs +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -24,6 +24,7 @@ pub use crate::instructions::riscv::memory::store_v2::StoreInstruction; use ceno_emul::InsnKind; +#[derive(Default)] pub struct LwOp; impl RIVInstruction for LwOp { @@ -32,43 +33,57 @@ impl RIVInstruction for LwOp { pub type LwInstruction = LoadInstruction; +#[derive(Default)] pub struct LhOp; + impl RIVInstruction for LhOp { const INST_KIND: InsnKind = InsnKind::LH; } pub type LhInstruction = LoadInstruction; +#[derive(Default)] pub struct LhuOp; + impl RIVInstruction for LhuOp { const INST_KIND: InsnKind = InsnKind::LHU; } pub type LhuInstruction = LoadInstruction; +#[derive(Default)] pub struct LbOp; + impl RIVInstruction for LbOp { const INST_KIND: InsnKind = InsnKind::LB; } pub type LbInstruction = LoadInstruction; +#[derive(Default)] pub struct LbuOp; + impl RIVInstruction for LbuOp { const INST_KIND: InsnKind = InsnKind::LBU; } pub type LbuInstruction = LoadInstruction; +#[derive(Default)] pub struct SWOp; + impl RIVInstruction for SWOp { const INST_KIND: InsnKind = InsnKind::SW; } pub type SwInstruction = StoreInstruction; +#[derive(Default)] pub struct SHOp; + impl RIVInstruction for SHOp { const INST_KIND: InsnKind = InsnKind::SH; } pub type ShInstruction = StoreInstruction; +#[derive(Default)] pub struct SBOp; + impl RIVInstruction for SBOp { const INST_KIND: InsnKind = InsnKind::SB; } diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 1973e48ea..4a008f009 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -37,7 +37,8 @@ pub struct LoadConfig { signed_extend_config: Option>, } -pub struct LoadInstruction(PhantomData<(E, I)>); +#[derive(Default)] +pub struct LoadInstruction(PhantomData<(E, I)>); impl Instruction for LoadInstruction { type InstructionConfig = LoadConfig; @@ -47,6 +48,7 @@ impl Instruction for LoadInstruction, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index f07968d19..ce85ced04 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -35,6 +35,7 @@ pub struct StoreConfig { next_memory_value: Option>, } +#[derive(Default)] pub struct StoreInstruction(PhantomData<(E, I)>); impl Instruction @@ -47,6 +48,7 @@ impl Instruction } fn construct_circuit( + &self, circuit_builder: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index 90c5a0273..031d27398 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -75,14 +75,21 @@ fn load(mem_value: Word, insn: InsnKind, shift: u32) -> Word { } } -fn impl_opcode_store>(imm: i32) { +fn impl_opcode_store< + E: ExtensionField + Hash, + I: RIVInstruction, + Inst: Instruction + Default, +>( + imm: i32, +) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = Inst::default(); let config = cb .namespace( || Inst::name(), |cb| { - let config = Inst::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -137,14 +144,17 @@ fn impl_opcode_store>(imm: i32) { +fn impl_opcode_load + Default>( + imm: i32, +) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = Inst::default(); let config = cb .namespace( || Inst::name(), |cb| { - let config = Inst::construct_circuit(cb, &ProgramParams::default()); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/mulh.rs b/ceno_zkvm/src/instructions/riscv/mulh.rs index dd8e3b3f5..0ed1ddd71 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh.rs @@ -11,26 +11,34 @@ use mulh_circuit::MulhInstructionBase; #[cfg(feature = "u16limb_circuit")] use mulh_circuit_v2::MulhInstructionBase; +#[derive(Default)] pub struct MulOp; + impl RIVInstruction for MulOp { const INST_KIND: InsnKind = InsnKind::MUL; } pub type MulInstruction = MulhInstructionBase; +#[derive(Default)] pub struct MulhOp; + impl RIVInstruction for MulhOp { const INST_KIND: InsnKind = InsnKind::MULH; } pub type MulhInstruction = MulhInstructionBase; +#[derive(Default)] pub struct MulhuOp; + impl RIVInstruction for MulhuOp { const INST_KIND: InsnKind = InsnKind::MULHU; } pub type MulhuInstruction = MulhInstructionBase; +#[derive(Default)] pub struct MulhsuOp; + impl RIVInstruction for MulhsuOp { const INST_KIND: InsnKind = InsnKind::MULHSU; } @@ -111,15 +119,11 @@ mod test { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = MulhInstructionBase::::default(); let config = cb .namespace( || format!("{:?}_({name})", I::INST_KIND), - |cb| { - Ok(MulhInstructionBase::::construct_circuit( - cb, - &ProgramParams::default(), - )) - }, + |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), ) .unwrap() .unwrap(); @@ -198,15 +202,11 @@ mod test { fn verify_mulh(rs1: i32, rs2: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = MulhInstruction::::default(); let config = cb .namespace( || "mulh", - |cb| { - Ok(MulhInstruction::construct_circuit( - cb, - &ProgramParams::default(), - )) - }, + |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), ) .unwrap() .unwrap(); @@ -281,15 +281,11 @@ mod test { fn verify_mulhsu(rs1: i32, rs2: u32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = MulhsuInstruction::::default(); let config = cb .namespace( || "mulhsu", - |cb| { - Ok(MulhsuInstruction::construct_circuit( - cb, - &ProgramParams::default(), - )) - }, + |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), ) .unwrap() .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index c1853d7a8..6b3d5c13c 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -22,7 +22,8 @@ use witness::set_val; use itertools::Itertools; use std::{array, marker::PhantomData}; -pub struct MulhInstructionBase(PhantomData<(E, I)>); +#[derive(Default)] +pub struct MulhInstructionBase(PhantomData<(E, I)>); pub struct MulhConfig { rs1_read: UInt, @@ -43,6 +44,7 @@ impl Instruction for MulhInstructionBas } fn construct_circuit( + &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 0c53f1a4c..841c4e97d 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -11,19 +11,25 @@ use crate::instructions::riscv::shift::shift_circuit::ShiftLogicalInstruction; #[cfg(feature = "u16limb_circuit")] use crate::instructions::riscv::shift::shift_circuit_v2::ShiftLogicalInstruction; +#[derive(Default)] pub struct SllOp; + impl RIVInstruction for SllOp { const INST_KIND: InsnKind = InsnKind::SLL; } pub type SllInstruction = ShiftLogicalInstruction; +#[derive(Default)] pub struct SrlOp; + impl RIVInstruction for SrlOp { const INST_KIND: InsnKind = InsnKind::SRL; } pub type SrlInstruction = ShiftLogicalInstruction; +#[derive(Default)] pub struct SraOp; + impl RIVInstruction for SraOp { const INST_KIND: InsnKind = InsnKind::SRA; } @@ -122,6 +128,7 @@ mod tests { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = ShiftLogicalInstruction::::default(); let shift = rs2_read & 0b11111; let (prefix, insn_code, rd_written) = match I::INST_KIND { InsnKind::SLL => ( @@ -145,12 +152,7 @@ mod tests { let config = cb .namespace( || format!("{prefix}_({name})"), - |cb| { - Ok(ShiftLogicalInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - )) - }, + |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), ) .unwrap() .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 4e929670c..c7915ca74 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -271,6 +271,7 @@ pub struct ShiftRTypeConfig { r_insn: RInstructionConfig, } +#[derive(Default)] pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); impl Instruction for ShiftLogicalInstruction { @@ -281,6 +282,7 @@ impl Instruction for ShiftLogicalInstru } fn construct_circuit( + &self, circuit_builder: &mut crate::circuit_builder::CircuitBuilder, _params: &ProgramParams, ) -> Result { @@ -366,7 +368,8 @@ pub struct ShiftImmConfig { imm: WitIn, } -pub struct ShiftImmInstruction(PhantomData<(E, I)>); +#[derive(Default)] +pub struct ShiftImmInstruction(PhantomData<(E, I)>); impl Instruction for ShiftImmInstruction { type InstructionConfig = ShiftImmConfig; @@ -376,6 +379,7 @@ impl Instruction for ShiftImmInstructio } fn construct_circuit( + &self, circuit_builder: &mut crate::circuit_builder::CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 4cf7ac155..a3d43eedf 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -9,19 +9,25 @@ use crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmInstruction; #[cfg(not(feature = "u16limb_circuit"))] use crate::instructions::riscv::shift_imm::shift_imm_circuit::ShiftImmInstruction; +#[derive(Default)] pub struct SlliOp; + impl RIVInstruction for SlliOp { const INST_KIND: InsnKind = InsnKind::SLLI; } pub type SlliInstruction = ShiftImmInstruction; +#[derive(Default)] pub struct SraiOp; + impl RIVInstruction for SraiOp { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRAI; } pub type SraiInstruction = ShiftImmInstruction; +#[derive(Default)] pub struct SrliOp; + impl RIVInstruction for SrliOp { const INST_KIND: ceno_emul::InsnKind = InsnKind::SRLI; } @@ -142,10 +148,8 @@ mod test { .namespace( || format!("{prefix}_({name})"), |cb| { - let config = ShiftImmInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - ); + let inst = ShiftImmInstruction::::default(); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index 7b27617ad..b01779742 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -7,7 +7,9 @@ use ceno_emul::InsnKind; use super::RIVInstruction; +#[derive(Default)] pub struct SltOp; + impl RIVInstruction for SltOp { const INST_KIND: InsnKind = InsnKind::SLT; } @@ -16,7 +18,9 @@ pub type SltInstruction = slt_circuit_v2::SetLessThanInstruction; #[cfg(not(feature = "u16limb_circuit"))] pub type SltInstruction = slt_circuit::SetLessThanInstruction; +#[derive(Default)] pub struct SltuOp; + impl RIVInstruction for SltuOp { const INST_KIND: InsnKind = InsnKind::SLTU; } @@ -55,14 +59,12 @@ mod test { ) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); + let inst = SetLessThanInstruction::::default(); let config = cb .namespace( || format!("{}/{name}", I::INST_KIND), |cb| { - let config = SetLessThanInstruction::<_, I>::construct_circuit( - cb, - &ProgramParams::default(), - ); + let config = inst.construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index 391dffb89..16050e733 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -14,6 +14,7 @@ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use std::marker::PhantomData; +#[derive(Default)] pub struct SetLessThanInstruction(PhantomData<(E, I)>); /// This config handles R-Instructions that represent registers values as 2 * u16. @@ -35,6 +36,7 @@ impl Instruction for SetLessThanInstruc } fn construct_circuit( + &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 5802c4229..b7ed0ffab 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -12,13 +12,17 @@ use crate::instructions::riscv::slti::slti_circuit::SetLessThanImmInstruction; use super::RIVInstruction; +#[derive(Default)] pub struct SltiOp; + impl RIVInstruction for SltiOp { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SLTI; } pub type SltiInstruction = SetLessThanImmInstruction; +#[derive(Default)] pub struct SltiuOp; + impl RIVInstruction for SltiuOp { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SLTIU; } @@ -169,16 +173,12 @@ mod test { let mut cb = CircuitBuilder::new(&mut cs); let insn_code = encode_rv32(I::INST_KIND, 2, 0, 4, imm); + let inst = SetLessThanImmInstruction::::default(); let config = cb .namespace( || format!("{:?}_({name})", I::INST_KIND), - |cb| { - Ok(SetLessThanImmInstruction::::construct_circuit( - cb, - &ProgramParams::default(), - )) - }, + |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), ) .unwrap() .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 1085561fb..1e1b1c9b7 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -36,7 +36,8 @@ pub struct SetLessThanImmConfig { uint_lt_config: UIntLimbsLTConfig, } -pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); +#[derive(Default)] +pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); impl Instruction for SetLessThanImmInstruction { type InstructionConfig = SetLessThanImmConfig; @@ -46,6 +47,7 @@ impl Instruction for SetLessThanImmInst } fn construct_circuit( + &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/test.rs b/ceno_zkvm/src/instructions/riscv/test.rs index 47c0ba178..41a3877fe 100644 --- a/ceno_zkvm/src/instructions/riscv/test.rs +++ b/ceno_zkvm/src/instructions/riscv/test.rs @@ -17,13 +17,15 @@ fn test_multiple_opcode() { let params = ProgramParams::default(); let mut cs = ConstraintSystem::new(|| "riscv"); + let add_inst = AddInstruction::::default(); let _add_config = cs.namespace( || "add", - |cs| AddInstruction::construct_circuit(&mut CircuitBuilder::::new(cs), ¶ms), + |cs| add_inst.construct_circuit(&mut CircuitBuilder::::new(cs), ¶ms), ); + let sub_inst = SubInstruction::::default(); let _sub_config = cs.namespace( || "sub", - |cs| SubInstruction::construct_circuit(&mut CircuitBuilder::::new(cs), ¶ms), + |cs| sub_inst.construct_circuit(&mut CircuitBuilder::::new(cs), ¶ms), ); let param = Pcs::setup(1 << 10, SecurityLevel::default()).unwrap(); let (_, _) = Pcs::trim(param, 1 << 10).unwrap(); diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 60dff6a99..4f3404fd9 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -55,6 +55,8 @@ use transcript::{BasicTranscript, Transcript}; struct TestConfig { pub(crate) reg_id: WitIn, } + +#[derive(Default)] struct TestCircuit { phantom: PhantomData, } @@ -67,6 +69,7 @@ impl Instruction for Test } fn construct_circuit( + &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index e48978c51..52887f9a2 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -2,7 +2,7 @@ use crate::{ scheme::{ constants::{MIN_PAR_SIZE, SEPTIC_JACOBIAN_NUM_MLES}, hal::{MainSumcheckProver, ProofInput, ProverDevice}, - septic_curve::{SepticExtension, SepticJacobianPoint, SepticPoint}, + septic_curve::{SepticExtension, SepticJacobianPoint}, }, structs::ComposedConstrainSystem, }; @@ -21,7 +21,6 @@ use multilinear_extensions::{ mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, util::ceil_log2, }; -use p3::matrix::{Matrix, dense::RowMajorMatrix}; use rayon::{ iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 0586c9c2a..0e6c085a9 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -198,11 +198,15 @@ impl ZKVMConstraintSystem { } } - pub fn register_opcode_circuit>(&mut self) -> OC::InstructionConfig { + pub fn register_opcode_circuit + Default>( + &mut self, + ) -> OC::InstructionConfig { let mut cs = ConstraintSystem::new(|| format!("riscv_opcode/{}", OC::name())); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let (config, gkr_iop_circuit) = - OC::build_gkr_iop_circuit(&mut circuit_builder, &self.params).unwrap(); + let op_circuit = OC::default(); + let (config, gkr_iop_circuit) = op_circuit + .build_gkr_iop_circuit(&mut circuit_builder, &self.params) + .unwrap(); let cs = ComposedConstrainSystem { zkvm_v1_css: cs, gkr_circuit: Some(gkr_iop_circuit), diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index e4129bfe8..eb1cb4d03 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -103,6 +103,9 @@ pub struct ConstraintSystem { pub instance_name_map: HashMap, + pub ec_point_exprs: Vec>, + pub ec_final_sum: Vec>, + pub r_selector: Option>, pub r_expressions: Vec>, pub r_expressions_namespace_map: Vec, @@ -167,6 +170,8 @@ impl ConstraintSystem { fixed_namespace_map: vec![], ns: NameSpace::new(root_name_fn), instance_name_map: HashMap::new(), + ec_final_sum: vec![], + ec_point_exprs: vec![], r_selector: None, r_expressions: vec![], r_expressions_namespace_map: vec![], @@ -405,6 +410,23 @@ impl ConstraintSystem { Ok(()) } + pub fn ec_sum( + &mut self, + xs: Vec>, + ys: Vec>, + final_sum: Vec>, + ) { + assert_eq!(xs.len(), 7); + assert_eq!(ys.len(), 7); + assert_eq!(final_sum.len(), 7 * 2); + + assert_eq!(self.ec_point_exprs.len(), 0); + self.ec_point_exprs.extend(xs.into_iter()); + self.ec_point_exprs.extend(ys.into_iter()); + + self.ec_final_sum = final_sum; + } + pub fn require_zero, N: FnOnce() -> NR>( &mut self, name_fn: N, @@ -624,6 +646,15 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.rlc_chip_record(records) } + pub fn ec_sum( + &mut self, + xs: Vec>, + ys: Vec>, + final_sum: Vec>, + ) { + self.cs.ec_sum(xs, ys, final_sum); + } + pub fn create_bit(&mut self, name_fn: N) -> Result where NR: Into, From 88601697643add92a1594ec2c076783778091b8c Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 13 Oct 2025 23:26:49 +0800 Subject: [PATCH 41/91] fmt --- ceno_zkvm/src/gadgets/poseidon2.rs | 2 +- ceno_zkvm/src/instructions/global.rs | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index c713c332c..f70049bcb 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -12,7 +12,7 @@ use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; use num_bigint::BigUint; use p3::{ - babybear::{BabyBearInternalLayerParameters}, + babybear::BabyBearInternalLayerParameters, field::{Field, FieldAlgebra}, monty_31::InternalLayerBaseParameters, poseidon2::{MDSMat4, mds_light_permutation}, diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 2c943fb7e..35db317ea 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -5,9 +5,7 @@ use crate::{ gadgets::{Poseidon2Config, RoundConstants}, }; use ff_ext::ExtensionField; -use gkr_iop::{ - circuit_builder::CircuitBuilder, error::CircuitBuilderError, -}; +use gkr_iop::{circuit_builder::CircuitBuilder, error::CircuitBuilderError}; use multilinear_extensions::{ToExpr, WitIn}; use p3::field::FieldAlgebra; From c542201bb91b75a8cdeb9ac977d5ffa980ffa168 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 14 Oct 2025 13:08:35 +0800 Subject: [PATCH 42/91] wip config as trait --- .../src/instructions/riscv/rv32im/mmu.rs | 9 +- ceno_zkvm/src/tables/ram.rs | 4 +- ceno_zkvm/src/tables/ram/ram_circuit.rs | 91 +++++++++++-- ceno_zkvm/src/tables/ram/ram_impl.rs | 125 ++++++++++++++++-- 4 files changed, 204 insertions(+), 25 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index d8c032c7b..bf85accd8 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -9,11 +9,16 @@ use crate::{ structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ HeapCircuit, HintsCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, - PubIOTable, RegTable, RegTableCircuit, StackCircuit, StaticMemCircuit, StaticMemTable, - TableCircuit, + PubIOTable, RegTable, RegTableCircuit, RegTableInitCircuit, StackCircuit, StaticMemCircuit, + StaticMemTable, TableCircuit, }, }; +pub struct RegConfigs { + pub reg_init_config: as TableCircuit>::TableConfig, + pub reg_mem_bus: as TableCircuit>::TableConfig, +} + pub struct MmuConfig { /// Initialization of registers. pub reg_config: as TableCircuit>::TableConfig, diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index e34ce1dcc..5839d7125 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -8,6 +8,7 @@ use crate::{ mod ram_circuit; mod ram_impl; +use crate::tables::ram::ram_circuit::NonVolatileInitRamCircuit; pub use ram_circuit::{DynVolatileRamTable, MemFinalRecord, MemInitRecord, NonVolatileTable}; #[derive(Clone)] @@ -108,7 +109,8 @@ impl NonVolatileTable for RegTable { } } -pub type RegTableCircuit = NonVolatileRamCircuit; +// pub type RegTableCircuit = NonVolatileRamCircuit; +pub type RegTableInitCircuit = NonVolatileInitRamCircuit; #[derive(Clone)] pub struct StaticMemTable; diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 0a8b6bf97..633670649 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -1,17 +1,19 @@ use std::{collections::HashMap, marker::PhantomData}; -use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word}; -use ff_ext::ExtensionField; -use witness::{InstancePaddingStrategy, RowMajorMatrix}; - use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, structs::{ProgramParams, RAMType}, tables::{RMMCollections, TableCircuit}, }; +use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word}; +use ff_ext::{ExtensionField, SmallField}; +use gkr_iop::error::CircuitBuilderError; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; -use super::ram_impl::{DynVolatileRamTableConfig, NonVolatileTableConfig, PubIOTableConfig}; +use super::ram_impl::{ + DynVolatileRamTableConfig, NonVolatileInitTableConfig, NonVolatileTableConfig, PubIOTableConfig, +}; #[derive(Clone, Debug)] pub struct MemInitRecord { @@ -104,6 +106,55 @@ impl TableCirc } } +/// NonVolatileRamCircuit initializes and finalizes memory +/// - at fixed addresses, +/// - with fixed initial content, +/// - with witnessed final content that the program wrote, if WRITABLE, +/// - or final content equal to initial content, if not WRITABLE. +pub struct NonVolatileInitRamCircuit(PhantomData<(E, R)>); + +impl TableCircuit + for NonVolatileInitRamCircuit +{ + type TableConfig = NonVolatileInitTableConfig; + type FixedInput = [MemInitRecord]; + type WitnessInput = [MemFinalRecord]; + + fn name() -> String { + format!("RAM_{:?}_{}", NVRAM::RAM_TYPE, NVRAM::name()) + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + Ok(cb.namespace( + || Self::name(), + |cb| Self::TableConfig::construct_circuit(cb, params), + )?) + } + + fn generate_fixed_traces( + config: &Self::TableConfig, + num_fixed: usize, + init_v: &Self::FixedInput, + ) -> RowMajorMatrix { + // assume returned table is well-formed include padding + config.gen_init_state(num_fixed, init_v) + } + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + num_structural_witin: usize, + _multiplicity: &[HashMap], + final_v: &Self::WitnessInput, + ) -> Result, ZKVMError> { + // assume returned table is well-formed include padding + Ok(config.assign_instances(num_witin, num_structural_witin, final_v)?) + } +} + /// PubIORamCircuit initializes and finalizes memory /// - at fixed addresses, /// - with content from the public input of proofs. @@ -189,6 +240,20 @@ pub trait DynVolatileRamTable { } } +pub trait DynVolatileRamTableConfigTrait: Sized + Send + Sync { + type Output: Sized + Send + Sync; + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result; + fn assign_instances( + &self, + num_witin: usize, + num_structural_witin: usize, + final_mem: &[MemFinalRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError>; +} + /// DynVolatileRamCircuit initializes and finalizes memory /// - at witnessed addresses, in a contiguous range chosen by the prover, /// - with zeros as initial content if ZERO_INIT, @@ -197,12 +262,15 @@ pub trait DynVolatileRamTable { /// If not ZERO_INIT: /// - The initial content is an unconstrained prover hint. /// - The final content is equal to this initial content. -pub struct DynVolatileRamCircuit(PhantomData<(E, R)>); +pub struct DynVolatileRamCircuit(PhantomData<(E, R, C)>); -impl TableCircuit - for DynVolatileRamCircuit +impl< + E: ExtensionField, + DVRAM: DynVolatileRamTable + Send + Sync + Clone, + C: DynVolatileRamTableConfigTrait, +> TableCircuit for DynVolatileRamCircuit { - type TableConfig = DynVolatileRamTableConfig; + type TableConfig = C::Output; type FixedInput = (); type WitnessInput = [MemFinalRecord]; @@ -214,10 +282,7 @@ impl TableC cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { - Ok(cb.namespace( - || Self::name(), - |cb| Self::TableConfig::construct_circuit(cb, params), - )?) + Ok(cb.namespace(|| Self::name(), |cb| C::construct_circuit(cb, params))?) } fn generate_fixed_traces( diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 87968b97b..657def000 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -19,6 +19,7 @@ use crate::{ e2e::RAMRecord, instructions::riscv::constants::{LIMB_BITS, LIMB_MASK}, structs::ProgramParams, + tables::ram::ram_circuit::DynVolatileRamTableConfigTrait, }; use ff_ext::FieldInto; use multilinear_extensions::{ @@ -186,6 +187,105 @@ impl NonVolatileTableConfig { + init_v: Vec, + addr: Fixed, + + phantom: PhantomData, + params: ProgramParams, +} + +impl NonVolatileInitTableConfig { + pub fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + let init_v = (0..NVRAM::V_LIMBS) + .map(|i| cb.create_fixed(|| format!("init_v_limb_{i}"))) + .collect_vec(); + let addr = cb.create_fixed(|| "addr"); + + let init_table = [ + vec![(NVRAM::RAM_TYPE as usize).into()], + vec![Expression::Fixed(addr)], + init_v.iter().map(|v| v.expr()).collect_vec(), + vec![Expression::ZERO], // Initial cycle. + ] + .concat(); + + cb.w_table_record( + || "init_table", + NVRAM::RAM_TYPE, + SetTableSpec { + len: Some(NVRAM::len(params)), + structural_witins: vec![], + }, + init_table, + )?; + + Ok(Self { + init_v, + addr, + phantom: PhantomData, + params: params.clone(), + }) + } + + pub fn gen_init_state( + &self, + num_fixed: usize, + init_mem: &[MemInitRecord], + ) -> RowMajorMatrix { + assert!( + NVRAM::len(&self.params).is_power_of_two(), + "{} len {} must be a power of 2", + NVRAM::name(), + NVRAM::len(&self.params) + ); + + let mut init_table = RowMajorMatrix::::new( + NVRAM::len(&self.params), + num_fixed, + InstancePaddingStrategy::Default, + ); + assert_eq!(init_table.num_padding_instances(), 0); + + init_table + .par_rows_mut() + .zip_eq(init_mem) + .for_each(|(row, rec)| { + if self.init_v.len() == 1 { + // Assign value directly. + set_fixed_val!(row, self.init_v[0], (rec.value as u64).into_f()); + } else { + // Assign value limbs. + self.init_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_fixed_val!(row, limb, (val as u64).into_f()); + }); + } + set_fixed_val!(row, self.addr, (rec.addr as u64).into_f()); + }); + + init_table + } + + /// TODO consider taking RowMajorMatrix as argument to save allocations. + pub fn assign_instances( + &self, + num_witin: usize, + num_structural_witin: usize, + _final_mem: &[MemFinalRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + assert_eq!(num_structural_witin, 0); + assert!(_final_mem.is_empty()); + + Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]) + } +} + /// define public io /// init value set by instance #[derive(Clone, Debug)] @@ -315,8 +415,11 @@ pub struct DynVolatileRamTableConfig DynVolatileRamTableConfig { - pub fn construct_circuit( +impl DynVolatileRamTableConfigTrait + for DynVolatileRamTableConfig +{ + type Output = DynVolatileRamTableConfig; + fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { @@ -389,7 +492,7 @@ impl DynVolatileRamTableConfig } /// TODO consider taking RowMajorMatrix as argument to save allocations. - pub fn assign_instances( + fn assign_instances( &self, num_witin: usize, num_structural_witin: usize, @@ -457,8 +560,10 @@ pub struct DynVolatileRamTableInitConfig DynVolatileRamTableInitConfig { - pub fn construct_circuit( +impl DynVolatileRamTableConfigTrait + for DynVolatileRamTableInitConfig +{ + fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { @@ -503,7 +608,7 @@ impl DynVolatileRamTableInitCo } /// TODO consider taking RowMajorMatrix as argument to save allocations. - pub fn assign_instances( + fn assign_instances( &self, num_witin: usize, num_structural_witin: usize, @@ -564,8 +669,10 @@ pub struct DynVolatileRamTableFinalConfig DynVolatileRamTableFinalConfig { - pub fn construct_circuit( +impl DynVolatileRamTableConfigTrait + for DynVolatileRamTableFinalConfig +{ + fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { @@ -620,7 +727,7 @@ impl DynVolatileRamTableFinalC } /// TODO consider taking RowMajorMatrix as argument to save allocations. - pub fn assign_instances( + fn assign_instances( &self, num_witin: usize, num_structural_witin: usize, From 706aa08ea67e1cabccf1a7cd708617ddd10f7df3 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 14 Oct 2025 14:17:58 +0800 Subject: [PATCH 43/91] global chip unit test wip --- ceno_zkvm/src/instructions/global.rs | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 35db317ea..049802be7 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -3,7 +3,9 @@ use std::iter::repeat; use crate::{ chip_handler::general::PublicIOQuery, gadgets::{Poseidon2Config, RoundConstants}, + witness::LkMultiplicity, }; +use ceno_emul::StepRecord; use ff_ext::ExtensionField; use gkr_iop::{circuit_builder::CircuitBuilder, error::CircuitBuilderError}; use multilinear_extensions::{ToExpr, WitIn}; @@ -123,10 +125,16 @@ impl Instruction for GlobalChip { fn assign_instance( _config: &Self::InstructionConfig, - _instance: &mut [::BaseField], - _lk_multiplicity: &mut crate::witness::LkMultiplicity, - _step: &ceno_emul::StepRecord, + _instance: &mut [E::BaseField], + _lk_multiplicity: &mut LkMultiplicity, + _step: &StepRecord, ) -> Result<(), crate::error::ZKVMError> { + // assign (x, y) + + // assign [addr, ram_type, value, shard, clk, is_write] + + // assign poseidon2 hasher + todo!() } } @@ -136,5 +144,13 @@ mod tests { #[test] fn test_global_chip() { // Test the GlobalChip functionality here + + // init global chip with horizen_rc_consts + + // create a bunch of random memory read/write records + + // assign witness + + // create chip proof for global chip } } From b84f74e4de3a85b055066996a219ac365e96518b Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 14 Oct 2025 14:50:58 +0800 Subject: [PATCH 44/91] separate circuit into init/final --- .../src/instructions/riscv/rv32im/mmu.rs | 1 + ceno_zkvm/src/tables/ram.rs | 29 +- ceno_zkvm/src/tables/ram/ram_circuit.rs | 92 ++----- ceno_zkvm/src/tables/ram/ram_impl.rs | 253 +++++++++++++----- 4 files changed, 244 insertions(+), 131 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index bf85accd8..5cf537f3c 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -16,6 +16,7 @@ use crate::{ pub struct RegConfigs { pub reg_init_config: as TableCircuit>::TableConfig, + pub reg_final_config: as TableCircuit>::TableConfig, pub reg_mem_bus: as TableCircuit>::TableConfig, } diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index 5839d7125..416ce83ff 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -8,7 +8,10 @@ use crate::{ mod ram_circuit; mod ram_impl; -use crate::tables::ram::ram_circuit::NonVolatileInitRamCircuit; +use crate::tables::ram::ram_impl::{ + DynVolatileRamTableConfig, DynVolatileRamTableFinalConfig, DynVolatileRamTableInitConfig, + NonVolatileFinalTableConfig, NonVolatileInitTableConfig, +}; pub use ram_circuit::{DynVolatileRamTable, MemFinalRecord, MemInitRecord, NonVolatileTable}; #[derive(Clone)] @@ -33,7 +36,10 @@ impl DynVolatileRamTable for HeapTable { } } -pub type HeapCircuit = DynVolatileRamCircuit; +pub type HeapInitCircuit = + DynVolatileRamCircuit>; +pub type HeapFinalCircuit = + DynVolatileRamCircuit>; #[derive(Clone)] pub struct StackTable; @@ -67,7 +73,10 @@ impl DynVolatileRamTable for StackTable { } } -pub type StackCircuit = DynVolatileRamCircuit; +pub type StackInitCircuit = + DynVolatileRamCircuit>; +pub type StackFinalCircuit = + DynVolatileRamCircuit>; #[derive(Clone)] pub struct HintsTable; @@ -89,7 +98,8 @@ impl DynVolatileRamTable for HintsTable { "HintsTable" } } -pub type HintsCircuit = DynVolatileRamCircuit; +pub type HintsCircuit = + DynVolatileRamCircuit>; /// RegTable, fix size without offset #[derive(Clone)] @@ -109,8 +119,10 @@ impl NonVolatileTable for RegTable { } } -// pub type RegTableCircuit = NonVolatileRamCircuit; -pub type RegTableInitCircuit = NonVolatileInitRamCircuit; +pub type RegTableInitCircuit = + NonVolatileRamCircuit>; +pub type RegTableFinalCircuit = + NonVolatileRamCircuit>; #[derive(Clone)] pub struct StaticMemTable; @@ -129,7 +141,10 @@ impl NonVolatileTable for StaticMemTable { } } -pub type StaticMemCircuit = NonVolatileRamCircuit; +pub type StaticMemInitCircuit = + NonVolatileRamCircuit>; +pub type StaticMemFinalCircuit = + NonVolatileRamCircuit>; #[derive(Clone)] pub struct PubIOTable; diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 633670649..854548c95 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -12,7 +12,8 @@ use gkr_iop::error::CircuitBuilderError; use witness::{InstancePaddingStrategy, RowMajorMatrix}; use super::ram_impl::{ - DynVolatileRamTableConfig, NonVolatileInitTableConfig, NonVolatileTableConfig, PubIOTableConfig, + DynVolatileRamTableConfig, NonVolatileInitTableConfig, NonVolatileTableConfig, + NonVolatileTableConfigTrait, PubIOTableConfig, }; #[derive(Clone, Debug)] @@ -62,12 +63,15 @@ pub trait NonVolatileTable { /// - with fixed initial content, /// - with witnessed final content that the program wrote, if WRITABLE, /// - or final content equal to initial content, if not WRITABLE. -pub struct NonVolatileRamCircuit(PhantomData<(E, R)>); +pub struct NonVolatileRamCircuit(PhantomData<(E, R, C)>); -impl TableCircuit - for NonVolatileRamCircuit +impl< + E: ExtensionField, + NVRAM: NonVolatileTable + Send + Sync + Clone, + C: NonVolatileTableConfigTrait, +> TableCircuit for NonVolatileRamCircuit { - type TableConfig = NonVolatileTableConfig; + type TableConfig = C::Config; type FixedInput = [MemInitRecord]; type WitnessInput = [MemFinalRecord]; @@ -79,10 +83,7 @@ impl TableCirc cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { - Ok(cb.namespace( - || Self::name(), - |cb| Self::TableConfig::construct_circuit(cb, params), - )?) + Ok(cb.namespace(|| Self::name(), |cb| C::construct_circuit(cb, params))?) } fn generate_fixed_traces( @@ -91,7 +92,7 @@ impl TableCirc init_v: &Self::FixedInput, ) -> RowMajorMatrix { // assume returned table is well-formed include padding - config.gen_init_state(num_fixed, init_v) + C::gen_init_state(config, num_fixed, init_v) } fn assign_instances( @@ -102,57 +103,13 @@ impl TableCirc final_v: &Self::WitnessInput, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding - Ok(config.assign_instances(num_witin, num_structural_witin, final_v)?) - } -} - -/// NonVolatileRamCircuit initializes and finalizes memory -/// - at fixed addresses, -/// - with fixed initial content, -/// - with witnessed final content that the program wrote, if WRITABLE, -/// - or final content equal to initial content, if not WRITABLE. -pub struct NonVolatileInitRamCircuit(PhantomData<(E, R)>); - -impl TableCircuit - for NonVolatileInitRamCircuit -{ - type TableConfig = NonVolatileInitTableConfig; - type FixedInput = [MemInitRecord]; - type WitnessInput = [MemFinalRecord]; - - fn name() -> String { - format!("RAM_{:?}_{}", NVRAM::RAM_TYPE, NVRAM::name()) - } - - fn construct_circuit( - cb: &mut CircuitBuilder, - params: &ProgramParams, - ) -> Result { - Ok(cb.namespace( - || Self::name(), - |cb| Self::TableConfig::construct_circuit(cb, params), + Ok(C::assign_instances( + config, + num_witin, + num_structural_witin, + final_v, )?) } - - fn generate_fixed_traces( - config: &Self::TableConfig, - num_fixed: usize, - init_v: &Self::FixedInput, - ) -> RowMajorMatrix { - // assume returned table is well-formed include padding - config.gen_init_state(num_fixed, init_v) - } - - fn assign_instances( - config: &Self::TableConfig, - num_witin: usize, - num_structural_witin: usize, - _multiplicity: &[HashMap], - final_v: &Self::WitnessInput, - ) -> Result, ZKVMError> { - // assume returned table is well-formed include padding - Ok(config.assign_instances(num_witin, num_structural_witin, final_v)?) - } } /// PubIORamCircuit initializes and finalizes memory @@ -241,13 +198,13 @@ pub trait DynVolatileRamTable { } pub trait DynVolatileRamTableConfigTrait: Sized + Send + Sync { - type Output: Sized + Send + Sync; + type Config: Sized + Send + Sync; fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, - ) -> Result; + ) -> Result; fn assign_instances( - &self, + config: &Self::Config, num_witin: usize, num_structural_witin: usize, final_mem: &[MemFinalRecord], @@ -270,7 +227,7 @@ impl< C: DynVolatileRamTableConfigTrait, > TableCircuit for DynVolatileRamCircuit { - type TableConfig = C::Output; + type TableConfig = C::Config; type FixedInput = (); type WitnessInput = [MemFinalRecord]; @@ -301,6 +258,13 @@ impl< final_v: &Self::WitnessInput, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding - Ok(config.assign_instances(num_witin, num_structural_witin, final_v)?) + Ok( + >::assign_instances( + config, + num_witin, + num_structural_witin, + final_v, + )?, + ) } } diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 657def000..8ef622539 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -28,6 +28,28 @@ use multilinear_extensions::{ use p3::field::FieldAlgebra; use rayon::prelude::{ParallelSlice, ParallelSliceMut}; +pub trait NonVolatileTableConfigTrait: Sized + Send + Sync { + type Config: Sized + Send + Sync; + + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result; + + fn gen_init_state( + config: &Self::Config, + num_fixed: usize, + init_mem: &[MemInitRecord], + ) -> RowMajorMatrix; + + fn assign_instances( + config: &Self::Config, + num_witin: usize, + num_structural_witin: usize, + final_mem: &[MemFinalRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError>; +} + /// define a non-volatile memory with init value #[derive(Clone, Debug)] pub struct NonVolatileTableConfig { @@ -41,8 +63,11 @@ pub struct NonVolatileTableConfig params: ProgramParams, } -impl NonVolatileTableConfig { - pub fn construct_circuit( +impl NonVolatileTableConfigTrait + for NonVolatileTableConfig +{ + type Config = NonVolatileTableConfig; + fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { @@ -111,20 +136,20 @@ impl NonVolatileTableConfig( - &self, + fn gen_init_state( + config: &Self::Config, num_fixed: usize, init_mem: &[MemInitRecord], ) -> RowMajorMatrix { assert!( - NVRAM::len(&self.params).is_power_of_two(), + NVRAM::len(&config.params).is_power_of_two(), "{} len {} must be a power of 2", NVRAM::name(), - NVRAM::len(&self.params) + NVRAM::len(&config.params) ); let mut init_table = RowMajorMatrix::::new( - NVRAM::len(&self.params), + NVRAM::len(&config.params), num_fixed, InstancePaddingStrategy::Default, ); @@ -134,32 +159,32 @@ impl NonVolatileTableConfig> (l * LIMB_BITS)) & LIMB_MASK; set_fixed_val!(row, limb, (val as u64).into_f()); }); } - set_fixed_val!(row, self.addr, (rec.addr as u64).into_f()); + set_fixed_val!(row, config.addr, (rec.addr as u64).into_f()); }); init_table } /// TODO consider taking RowMajorMatrix as argument to save allocations. - pub fn assign_instances( - &self, + fn assign_instances( + config: &Self::Config, num_witin: usize, num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { assert_eq!(num_structural_witin, 0); let mut final_table = RowMajorMatrix::::new( - NVRAM::len(&self.params), + NVRAM::len(&config.params), num_witin, InstancePaddingStrategy::Default, ); @@ -168,7 +193,7 @@ impl NonVolatileTableConfig NonVolatileTableConfig NonVolatileInitTableConfig { - pub fn construct_circuit( +impl NonVolatileTableConfigTrait + for NonVolatileInitTableConfig +{ + type Config = NonVolatileInitTableConfig; + + fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { + assert!(NVRAM::WRITABLE); let init_v = (0..NVRAM::V_LIMBS) .map(|i| cb.create_fixed(|| format!("init_v_limb_{i}"))) .collect_vec(); @@ -233,20 +263,20 @@ impl NonVolatileInitTableConfig( - &self, + fn gen_init_state( + config: &Self::Config, num_fixed: usize, init_mem: &[MemInitRecord], ) -> RowMajorMatrix { assert!( - NVRAM::len(&self.params).is_power_of_two(), + NVRAM::len(&config.params).is_power_of_two(), "{} len {} must be a power of 2", NVRAM::name(), - NVRAM::len(&self.params) + NVRAM::len(&config.params) ); let mut init_table = RowMajorMatrix::::new( - NVRAM::len(&self.params), + NVRAM::len(&config.params), num_fixed, InstancePaddingStrategy::Default, ); @@ -256,26 +286,26 @@ impl NonVolatileInitTableConfig> (l * LIMB_BITS)) & LIMB_MASK; set_fixed_val!(row, limb, (val as u64).into_f()); }); } - set_fixed_val!(row, self.addr, (rec.addr as u64).into_f()); + set_fixed_val!(row, config.addr, (rec.addr as u64).into_f()); }); init_table } /// TODO consider taking RowMajorMatrix as argument to save allocations. - pub fn assign_instances( - &self, - num_witin: usize, + fn assign_instances( + _config: &Self::Config, + _num_witin: usize, num_structural_witin: usize, _final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { @@ -286,6 +316,106 @@ impl NonVolatileInitTableConfig { + addr: WitIn, + + final_v: Vec, + final_cycle: WitIn, + + phantom: PhantomData, + params: ProgramParams, +} + +impl NonVolatileTableConfigTrait + for NonVolatileFinalTableConfig +{ + type Config = NonVolatileFinalTableConfig; + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + assert!(NVRAM::WRITABLE); + let addr = cb.create_witin(|| "addr"); + + let final_cycle = cb.create_witin(|| "final_cycle"); + let final_v = (0..NVRAM::V_LIMBS) + .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) + .collect::>(); + + let final_table = [ + // a v t + vec![(NVRAM::RAM_TYPE as usize).into()], + vec![addr.expr()], + final_v.iter().map(|v| v.expr()).collect_vec(), + vec![final_cycle.expr()], + ] + .concat(); + + cb.r_table_record( + || "final_table", + NVRAM::RAM_TYPE, + SetTableSpec { + len: Some(NVRAM::len(params)), + structural_witins: vec![], + }, + final_table, + )?; + + Ok(Self { + final_v, + addr, + final_cycle, + phantom: PhantomData, + params: params.clone(), + }) + } + + fn gen_init_state( + config: &Self::Config, + num_fixed: usize, + init_mem: &[MemInitRecord], + ) -> RowMajorMatrix { + RowMajorMatrix::empty() + } + + /// TODO consider taking RowMajorMatrix as argument to save allocations. + fn assign_instances( + config: &Self::Config, + num_witin: usize, + num_structural_witin: usize, + final_mem: &[MemFinalRecord], + ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + assert_eq!(num_structural_witin, 0); + let mut final_table = RowMajorMatrix::::new( + NVRAM::len(&config.params), + num_witin, + InstancePaddingStrategy::Default, + ); + + final_table + .par_rows_mut() + .zip_eq(final_mem) + .for_each(|(row, rec)| { + if config.final_v.len() == 1 { + // Assign value directly. + set_val!(row, config.final_v[0], rec.value as u64); + } else { + // Assign value limbs. + config.final_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }); + } + set_val!(row, config.addr, rec.addr as u64); + set_val!(row, config.final_cycle, rec.cycle); + }); + + Ok([final_table, RowMajorMatrix::empty()]) + } +} + /// define public io /// init value set by instance #[derive(Clone, Debug)] @@ -418,7 +548,7 @@ pub struct DynVolatileRamTableConfig DynVolatileRamTableConfigTrait for DynVolatileRamTableConfig { - type Output = DynVolatileRamTableConfig; + type Config = DynVolatileRamTableConfig; fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, @@ -493,16 +623,16 @@ impl DynVolatileRamTableConfig /// TODO consider taking RowMajorMatrix as argument to save allocations. fn assign_instances( - &self, + config: &Self::Config, num_witin: usize, num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - assert!(final_mem.len() <= DVRAM::max_len(&self.params)); - assert!(DVRAM::max_len(&self.params).is_power_of_two()); + assert!(final_mem.len() <= DVRAM::max_len(&config.params)); + assert!(DVRAM::max_len(&config.params).is_power_of_two()); - let params = self.params.clone(); - let addr_id = self.addr.id as u64; + let params = config.params.clone(); + let addr_id = config.addr.id as u64; let addr_padding_fn = move |row: u64, col: u64| { assert_eq!(col, addr_id); DVRAM::addr(¶ms, row as usize) as u64 @@ -524,25 +654,25 @@ impl DynVolatileRamTableConfig .for_each(|(i, ((row, structural_row), rec))| { assert_eq!( rec.addr, - DVRAM::addr(&self.params, i), + DVRAM::addr(&config.params, i), "rec.addr {:x} != expected {:x}", rec.addr, - DVRAM::addr(&self.params, i), + DVRAM::addr(&config.params, i), ); - if self.final_v.len() == 1 { + if config.final_v.len() == 1 { // Assign value directly. - set_val!(row, self.final_v[0], rec.value as u64); + set_val!(row, config.final_v[0], rec.value as u64); } else { // Assign value limbs. - self.final_v.iter().enumerate().for_each(|(l, limb)| { + config.final_v.iter().enumerate().for_each(|(l, limb)| { let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; set_val!(row, limb, val as u64); }); } - set_val!(row, self.final_cycle, rec.cycle); + set_val!(row, config.final_cycle, rec.cycle); - set_val!(structural_row, self.addr, rec.addr as u64); + set_val!(structural_row, config.addr, rec.addr as u64); }); structural_witness.padding_by_strategy(); @@ -563,6 +693,8 @@ pub struct DynVolatileRamTableInitConfig DynVolatileRamTableConfigTrait for DynVolatileRamTableInitConfig { + type Config = DynVolatileRamTableInitConfig; + fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, @@ -609,16 +741,16 @@ impl DynVolatileRamTableConfig /// TODO consider taking RowMajorMatrix as argument to save allocations. fn assign_instances( - &self, + config: &Self::Config, num_witin: usize, num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - assert!(final_mem.len() <= DVRAM::max_len(&self.params)); - assert!(DVRAM::max_len(&self.params).is_power_of_two()); + assert!(final_mem.len() <= DVRAM::max_len(&config.params)); + assert!(DVRAM::max_len(&config.params).is_power_of_two()); - let params = self.params.clone(); - let addr_id = self.addr.id as u64; + let params = config.params.clone(); + let addr_id = config.addr.id as u64; let addr_padding_fn = move |row: u64, col: u64| { assert_eq!(col, addr_id); DVRAM::addr(¶ms, row as usize) as u64 @@ -640,12 +772,12 @@ impl DynVolatileRamTableConfig .for_each(|(i, ((row, structural_row), rec))| { assert_eq!( rec.addr, - DVRAM::addr(&self.params, i), + DVRAM::addr(&config.params, i), "rec.addr {:x} != expected {:x}", rec.addr, - DVRAM::addr(&self.params, i), + DVRAM::addr(&config.params, i), ); - set_val!(structural_row, self.addr, rec.addr as u64); + set_val!(structural_row, config.addr, rec.addr as u64); }); structural_witness.padding_by_strategy(); @@ -672,6 +804,7 @@ pub struct DynVolatileRamTableFinalConfig DynVolatileRamTableConfigTrait for DynVolatileRamTableFinalConfig { + type Config = DynVolatileRamTableFinalConfig; fn construct_circuit( cb: &mut CircuitBuilder, params: &ProgramParams, @@ -728,14 +861,14 @@ impl DynVolatileRamTableConfig /// TODO consider taking RowMajorMatrix as argument to save allocations. fn assign_instances( - &self, + config: &Self::Config, num_witin: usize, num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { assert_eq!(num_structural_witin, 1); - assert!(final_mem.len() <= DVRAM::max_len(&self.params)); - assert!(DVRAM::max_len(&self.params).is_power_of_two()); + assert!(final_mem.len() <= DVRAM::max_len(&config.params)); + assert!(DVRAM::max_len(&config.params).is_power_of_two()); let mut witness = RowMajorMatrix::::new(final_mem.len(), num_witin, InstancePaddingStrategy::Default); @@ -751,20 +884,20 @@ impl DynVolatileRamTableConfig .zip(final_mem) .enumerate() .for_each(|(i, ((row, structural_row), rec))| { - if self.final_v.len() == 1 { + if config.final_v.len() == 1 { // Assign value directly. - set_val!(row, self.final_v[0], rec.value as u64); + set_val!(row, config.final_v[0], rec.value as u64); } else { // Assign value limbs. - self.final_v.iter().enumerate().for_each(|(l, limb)| { + config.final_v.iter().enumerate().for_each(|(l, limb)| { let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; set_val!(row, limb, val as u64); }); } - set_val!(row, self.final_cycle, rec.cycle); + set_val!(row, config.final_cycle, rec.cycle); - set_val!(row, self.addr_subset, rec.addr as u64); - set_val!(row, self.sel, 1u64); + set_val!(row, config.addr_subset, rec.addr as u64); + set_val!(row, config.sel, 1u64); }); Ok([witness, structural_witness]) From 4d7423b5ef971fc81d0a4240946b465f209099a6 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 14 Oct 2025 22:22:05 +0800 Subject: [PATCH 45/91] finish most constraints for Global chip --- ceno_zkvm/src/gadgets/poseidon2.rs | 17 +++- ceno_zkvm/src/instructions/global.rs | 127 +++++++++++++++++++++------ 2 files changed, 118 insertions(+), 26 deletions(-) diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index f70049bcb..aafeb3c32 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -249,7 +249,22 @@ impl< } pub fn output(&self) -> Vec> { - todo!() + let col_exprs = self.cols.iter().map(|c| c.expr()).collect::>(); + + let poseidon2_cols: &Poseidon2Cols< + Expression, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = col_exprs.as_slice().borrow(); + + poseidon2_cols + .ending_full_rounds + .last() + .map(|r| r.post.to_vec()) + .unwrap() } // pub fn assign_instance(&self, input: &[E; STATE_WIDTH]) { diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 049802be7..8f49c7ed7 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -3,12 +3,14 @@ use std::iter::repeat; use crate::{ chip_handler::general::PublicIOQuery, gadgets::{Poseidon2Config, RoundConstants}, + structs::RAMType, witness::LkMultiplicity, }; use ceno_emul::StepRecord; use ff_ext::ExtensionField; use gkr_iop::{circuit_builder::CircuitBuilder, error::CircuitBuilderError}; -use multilinear_extensions::{ToExpr, WitIn}; +use itertools::Itertools; +use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::FieldAlgebra; use crate::{ @@ -17,18 +19,24 @@ use crate::{ }; // opcode circuit + mem init/final table + global chip: -// have read/write consistency for RAMType::Register -// and RAMType::Memory +// have read/write consistency for RAMType::Register and RAMType::Memory // // global chip: read from and write into a global set shared // among multiple shards pub struct GlobalConfig { addr: WitIn, - ram_type: WitIn, + is_ram_register: WitIn, value: UInt, shard: WitIn, - clk: WitIn, - is_write: WitIn, + global_clk: WitIn, + local_clk: WitIn, + nonce: WitIn, + // if it's a write to global set, then insert a local read record + // s.t. local offline memory checking can cancel out + // this serves as propagating local write to global. + is_global_write: WitIn, + r_record: WitIn, + w_record: WitIn, x: Vec, y: Vec, poseidon2: Poseidon2Config, @@ -47,23 +55,86 @@ impl GlobalConfig { .map(|i| cb.create_witin(|| format!("y{}", i))) .collect(); let addr = cb.create_witin(|| "addr"); - let ram_type = cb.create_witin(|| "ram_type"); + let is_ram_register = cb.create_witin(|| "is_ram_register"); let value = UInt::new(|| "value", cb)?; let shard = cb.create_witin(|| "shard"); - let clk = cb.create_witin(|| "clk"); - let is_write = cb.create_witin(|| "is_write"); - - // TODO: support other field + let global_clk = cb.create_witin(|| "global_clk"); + let local_clk = cb.create_witin(|| "local_clk"); + let nonce = cb.create_witin(|| "nonce"); + let is_global_write = cb.create_witin(|| "is_global_write"); + let r_record = cb.create_witin(|| "r_record"); + let w_record = cb.create_witin(|| "w_record"); + + let is_ram_reg: Expression = is_ram_register.expr(); + let reg: Expression = RAMType::Register.into(); + let mem: Expression = RAMType::Memory.into(); + let ram_type: Expression = is_ram_reg.clone() * reg + (1 - is_ram_reg) * mem; let hasher = Poseidon2Config::construct(cb, rc); let mut input = vec![]; input.push(addr.expr()); - input.push(ram_type.expr()); + input.push(ram_type.clone()); // memory expr has same number of limbs as register expr input.extend(value.memory_expr()); input.push(shard.expr()); - input.push(clk.expr()); - input.extend(repeat(E::BaseField::ZERO.expr()).take(16 - 6)); + input.push(global_clk.expr()); + // add nonce to ensure poseidon2(input) always map to a valid ec point + input.push(nonce.expr()); + input.extend(repeat(E::BaseField::ZERO.expr()).take(16 - input.len())); + + let mut record = vec![]; + record.push(addr.expr()); + record.push(ram_type); + record.extend(value.memory_expr()); + record.push(shard.expr()); + record.push(local_clk.expr()); + let rlc = cb.rlc_chip_record(record); + + // if is_global_write = 1, then it means we are propagating a local write to global + // so we need to insert a local read record to cancel out this local write + // otherwise, we insert a padding value 1 to avoid affecting local memory checking + + cb.assert_bit(|| "is_global_write must be boolean", is_global_write.expr())?; + // r_record = select(is_global_write, rlc, 1) + cb.condition_require_equal( + || "r_record = select(is_global_write, rlc, 1)", + is_global_write.expr(), + r_record.expr(), + rlc.clone(), + E::BaseField::ONE.expr(), + )?; + + // if we are reading from global set, then this record should be + // considered as a initial local write to that address. + // otherwise, we insert a padding value 1 as if we are not writing anything + + // w_record = select(is_global_write, 1, rlc) + cb.condition_require_equal( + || "w_record = select(is_global_write, 1, rlc)", + is_global_write.expr(), + w_record.expr(), + E::BaseField::ONE.expr(), + rlc, + )?; + + // local read/write consistency + cb.condition_require_zero( + || "is_global_read => local_clk = 0", + 1 - is_global_write.expr(), + local_clk.expr(), + )?; + // TODO: enforce shard = shard_id in the public values + + cb.read_record( + || "r_record", + gkr_iop::RAMType::Register, // TODO fixme + vec![r_record.expr()], + )?; + cb.write_record( + || "w_record", + gkr_iop::RAMType::Register, // TODO fixme + vec![w_record.expr()], + )?; // enforces final_sum = \sum_i (x_i, y_i) using ecc quark protocol let final_sum = cb.query_global_rw_sum()?; @@ -73,28 +144,34 @@ impl GlobalConfig { final_sum.into_iter().map(|x| x.expr()).collect::>(), ); - // enforces x = poseidon2([addr, ram_type, value[0], value[1], shard, clk, 0]) - for (input_expr, hasher_input) in input.into_iter().zip(hasher.inputs().into_iter()) { - // TODO: replace with cb.require_equal() - cb.require_zero(|| "poseidon2 input", input_expr - hasher_input)?; + // enforces x = poseidon2([addr, ram_type, value[0], value[1], shard, global_clk, nonce, 0, ..., 0]) + for (input_expr, hasher_input) in input.into_iter().zip_eq(hasher.inputs().into_iter()) { + cb.require_equal(|| "poseidon2 input", input_expr, hasher_input)?; } for (xi, hasher_output) in x.iter().zip(hasher.output().into_iter()) { - cb.require_zero(|| "poseidon2 output", xi.expr() - hasher_output)?; + cb.require_equal(|| "x = poseidon2's output", xi.expr(), hasher_output)?; } - // TODO: enforce is_write is boolean - // TODO: enforce y < p/2 if is_write = 1 - // enforce p/2 <= y < p if is_write = 0 + // both (x, y) and (x, -y) are valid ec points + // if is_global_write = 1, then y should be in [0, p/2) + // if is_global_write = 0, then y should be in [p/2, p) + + // TODO: enforce 0 <= y < p/2 if is_global_write = 1 + // enforce p/2 <= y < p if is_global_write = 0 Ok(GlobalConfig { x, y, addr, - ram_type, + is_ram_register, value, shard, - clk, - is_write, + global_clk, + local_clk, + nonce, + is_global_write, + r_record, + w_record, poseidon2: hasher, }) } From 03092e9f1707b1c7a3e08f64fd6d88a8ad34fe3d Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 14 Oct 2025 23:47:35 +0800 Subject: [PATCH 46/91] complete local finalized mem chip logic --- ceno_zkvm/src/chip_handler/general.rs | 14 +- ceno_zkvm/src/e2e.rs | 45 +- ceno_zkvm/src/instructions/riscv/constants.rs | 4 +- .../src/instructions/riscv/rv32im/mmu.rs | 121 ++-- ceno_zkvm/src/tables/ram.rs | 18 +- ceno_zkvm/src/tables/ram/ram_circuit.rs | 112 +++- ceno_zkvm/src/tables/ram/ram_impl.rs | 521 ++++++------------ gkr_iop/src/lib.rs | 1 + 8 files changed, 403 insertions(+), 433 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index 805b60baf..d04c50329 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -5,7 +5,7 @@ use crate::{ circuit_builder::CircuitBuilder, instructions::riscv::constants::{ END_CYCLE_IDX, END_PC_IDX, END_SHARD_ID_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, - MEM_BUS_WITH_READ_IDX, MEM_BUS_WITH_WRITE_IDX, PUBLIC_IO_IDX, UINT_LIMBS, + PUBLIC_IO_IDX, UINT_LIMBS, }, tables::InsnRecord, }; @@ -23,8 +23,6 @@ pub trait PublicIOQuery { fn query_end_cycle(&mut self) -> Result; fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; fn query_shard_id(&mut self) -> Result; - fn query_mem_bus_with_read(&mut self) -> Result; - fn query_mem_bus_with_write(&mut self) -> Result; } impl<'a, E: ExtensionField> InstFetch for CircuitBuilder<'a, E> { @@ -67,16 +65,6 @@ impl<'a, E: ExtensionField> PublicIOQuery for CircuitBuilder<'a, E> { self.cs.query_instance(|| "shard_id", END_SHARD_ID_IDX) } - fn query_mem_bus_with_read(&mut self) -> Result { - self.cs - .query_instance(|| "mem_bus_with_read", MEM_BUS_WITH_READ_IDX) - } - - fn query_mem_bus_with_write(&mut self) -> Result { - self.cs - .query_instance(|| "mem_bus_with_write", MEM_BUS_WITH_WRITE_IDX) - } - fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError> { Ok([ self.cs.query_instance(|| "public_io_low", PUBLIC_IO_IDX)?, diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index c0330558f..b6f51c362 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -213,6 +213,21 @@ impl<'a> ShardContext<'a> { } } + #[inline(always)] + pub fn is_first_shard(&self) -> bool { + self.shard_id == 0 + } + + #[inline(always)] + pub fn is_last_shard(&self) -> bool { + self.shard_id == self.num_shards - 1 + } + + #[inline(always)] + pub fn is_current_shard_cycle(&self, cycle: Cycle) -> bool { + self.cur_shard_cycle_range.contains(&(cycle as usize)) + } + #[inline(always)] pub fn send( &mut self, @@ -226,7 +241,7 @@ impl<'a> ShardContext<'a> { ) { // check read from external mem bus if prev_cycle < self.cur_shard_cycle_range.start as Cycle - && self.cur_shard_cycle_range.contains(&(cycle as usize)) + && self.is_current_shard_cycle(cycle) { let ram_record = self .read_thread_based_record_storage @@ -248,7 +263,7 @@ impl<'a> ShardContext<'a> { // check write to external mem bus if let Some(future_touch_cycle) = self.addr_future_accesses.get(&(addr, cycle)) { if *future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle - && self.cur_shard_cycle_range.contains(&(cycle as usize)) + && self.is_current_shard_cycle(cycle) { let ram_record = self .write_thread_based_record_storage @@ -348,6 +363,7 @@ pub fn emulate_program<'a>( if index < VMState::REG_COUNT { let vma: WordAddr = Platform::register_vma(index).into(); MemFinalRecord { + ram_type: RAMType::Memory, addr: rec.addr, value: vm.peek_register(index), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -355,6 +371,7 @@ pub fn emulate_program<'a>( } else { // The table is padded beyond the number of registers. MemFinalRecord { + ram_type: RAMType::Memory, addr: rec.addr, value: 0, cycle: 0, @@ -369,6 +386,7 @@ pub fn emulate_program<'a>( .map(|rec| { let vma: WordAddr = rec.addr.into(); MemFinalRecord { + ram_type: RAMType::Memory, addr: rec.addr, value: vm.peek_memory(vma), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -380,6 +398,7 @@ pub fn emulate_program<'a>( let io_final = io_init .iter() .map(|rec| MemFinalRecord { + ram_type: RAMType::Memory, addr: rec.addr, value: rec.value, cycle: *final_access.get(&rec.addr.into()).unwrap_or(&0), @@ -390,6 +409,7 @@ pub fn emulate_program<'a>( let hints_final = hints_init .iter() .map(|rec| MemFinalRecord { + ram_type: RAMType::Memory, addr: rec.addr, value: rec.value, cycle: *final_access.get(&rec.addr.into()).unwrap_or(&0), @@ -407,6 +427,7 @@ pub fn emulate_program<'a>( .map(|vma| { let byte_addr = vma.baddr(); MemFinalRecord { + ram_type: RAMType::Memory, addr: byte_addr.0, value: vm.peek_memory(vma), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -430,6 +451,7 @@ pub fn emulate_program<'a>( .map(|vma| { let byte_addr = vma.baddr(); MemFinalRecord { + ram_type: RAMType::Memory, addr: byte_addr.0, value: vm.peek_memory(vma), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -578,17 +600,17 @@ pub fn init_static_addrs(program: &Program) -> Vec { program_addrs } -pub struct ConstraintSystemConfig { +pub struct ConstraintSystemConfig<'a, E: ExtensionField> { pub zkvm_cs: ZKVMConstraintSystem, pub config: Rv32imConfig, - pub mmu_config: MmuConfig, + pub mmu_config: MmuConfig<'a, E>, pub dummy_config: DummyExtraConfig, pub prog_config: ProgramTableConfig, } -pub fn construct_configs( +pub fn construct_configs<'a, E: ExtensionField>( program_params: ProgramParams, -) -> ConstraintSystemConfig { +) -> ConstraintSystemConfig<'a, E> { let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params); let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); @@ -673,6 +695,7 @@ pub fn generate_witness( .mmu_config .assign_table_circuit( &system_config.zkvm_cs, + &emul_result.shard_ctx, &mut zkvm_witness, &emul_result.final_mem_state.reg, &emul_result.final_mem_state.mem, @@ -714,13 +737,13 @@ pub enum Checkpoint { pub type IntermediateState = (Option>, Option>); /// Context construct from a program and given platform -pub struct E2EProgramCtx { +pub struct E2EProgramCtx<'a, E: ExtensionField> { pub program: Arc, pub platform: Platform, pub shards: Shards, pub static_addrs: Vec, pub pubio_len: usize, - pub system_config: ConstraintSystemConfig, + pub system_config: ConstraintSystemConfig<'a, E>, pub reg_init: Vec, pub io_init: Vec, pub zkvm_fixed_traces: ZKVMFixedTraces, @@ -745,11 +768,11 @@ impl> E2ECheckpointResult< } /// Set up a program with the given platform -pub fn setup_program( +pub fn setup_program<'a, E: ExtensionField>( program: Program, platform: Platform, shards: Shards, -) -> E2EProgramCtx { +) -> E2EProgramCtx<'a, E> { let static_addrs = init_static_addrs(&program); let pubio_len = platform.public_io.iter_addresses().len(); let program_params = ProgramParams { @@ -784,7 +807,7 @@ pub fn setup_program( } } -impl E2EProgramCtx { +impl E2EProgramCtx<'_, E> { pub fn keygen + 'static>( &self, max_num_variables: usize, diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 17316f956..4e3786235 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -10,9 +10,7 @@ pub const INIT_CYCLE_IDX: usize = 3; pub const END_PC_IDX: usize = 4; pub const END_CYCLE_IDX: usize = 5; pub const END_SHARD_ID_IDX: usize = 6; -pub const MEM_BUS_WITH_READ_IDX: usize = 7; -pub const MEM_BUS_WITH_WRITE_IDX: usize = 8; -pub const PUBLIC_IO_IDX: usize = 9; +pub const PUBLIC_IO_IDX: usize = 7; pub const LIMB_BITS: usize = 16; pub const LIMB_MASK: u32 = 0xFFFF; diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 5cf537f3c..c37aa1615 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -1,60 +1,63 @@ -use std::{collections::HashSet, iter::zip, ops::Range}; - -use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; -use ff_ext::ExtensionField; -use itertools::{Itertools, chain}; - use crate::{ + e2e::ShardContext, error::ZKVMError, structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ - HeapCircuit, HintsCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, - PubIOTable, RegTable, RegTableCircuit, RegTableInitCircuit, StackCircuit, StaticMemCircuit, + DynVolatileRamTable, HeapInitCircuit, HeapTable, HintsCircuit, LocalFinalCircuit, + MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RBCircuit, + RegTable, RegTableInitCircuit, StackInitCircuit, StackTable, StaticMemInitCircuit, StaticMemTable, TableCircuit, }, }; +use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; +use ff_ext::ExtensionField; +use itertools::{Itertools, chain}; +use std::{collections::HashSet, iter::zip, ops::Range, sync::Arc}; +use witness::InstancePaddingStrategy; -pub struct RegConfigs { - pub reg_init_config: as TableCircuit>::TableConfig, - pub reg_final_config: as TableCircuit>::TableConfig, - pub reg_mem_bus: as TableCircuit>::TableConfig, -} - -pub struct MmuConfig { +pub struct MmuConfig<'a, E: ExtensionField> { /// Initialization of registers. - pub reg_config: as TableCircuit>::TableConfig, + pub reg_init_config: as TableCircuit>::TableConfig, /// Initialization of memory with static addresses. - pub static_mem_config: as TableCircuit>::TableConfig, + pub static_mem_init_config: as TableCircuit>::TableConfig, /// Initialization of public IO. pub public_io_config: as TableCircuit>::TableConfig, /// Initialization of hints. pub hints_config: as TableCircuit>::TableConfig, /// Initialization of heap. - pub heap_config: as TableCircuit>::TableConfig, + pub heap_init_config: as TableCircuit>::TableConfig, /// Initialization of stack. - pub stack_config: as TableCircuit>::TableConfig, + pub stack_init_config: as TableCircuit>::TableConfig, + /// finalized circuit for all MMIO + pub local_final_circuit: as TableCircuit>::TableConfig, + /// ram bus to deal with cross shard read/write + pub ram_bus_circuit: as TableCircuit>::TableConfig, pub params: ProgramParams, } -impl MmuConfig { +impl MmuConfig<'_, E> { pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { - let reg_config = cs.register_table_circuit::>(); + let reg_init_config = cs.register_table_circuit::>(); - let static_mem_config = cs.register_table_circuit::>(); + let static_mem_init_config = cs.register_table_circuit::>(); let public_io_config = cs.register_table_circuit::>(); let hints_config = cs.register_table_circuit::>(); - let stack_config = cs.register_table_circuit::>(); - let heap_config = cs.register_table_circuit::>(); + let stack_init_config = cs.register_table_circuit::>(); + let heap_init_config = cs.register_table_circuit::>(); + let local_final_circuit = cs.register_table_circuit::>(); + let ram_bus_circuit = cs.register_table_circuit::>(); Self { - reg_config, - static_mem_config, + reg_init_config, + static_mem_init_config, public_io_config, hints_config, - stack_config, - heap_config, + stack_init_config, + heap_init_config, + local_final_circuit, + ram_bus_circuit, params: cs.params.clone(), } } @@ -78,24 +81,27 @@ impl MmuConfig { "memory addresses must be unique" ); - fixed.register_table_circuit::>(cs, &self.reg_config, reg_init); + fixed.register_table_circuit::>(cs, &self.reg_init_config, reg_init); - fixed.register_table_circuit::>( + fixed.register_table_circuit::>( cs, - &self.static_mem_config, + &self.static_mem_init_config, static_mem_init, ); fixed.register_table_circuit::>(cs, &self.public_io_config, io_addrs); fixed.register_table_circuit::>(cs, &self.hints_config, &()); - fixed.register_table_circuit::>(cs, &self.stack_config, &()); - fixed.register_table_circuit::>(cs, &self.heap_config, &()); + fixed.register_table_circuit::>(cs, &self.stack_init_config, &()); + fixed.register_table_circuit::>(cs, &self.heap_init_config, &()); + fixed.register_table_circuit::>(cs, &self.local_final_circuit, &()); + fixed.register_table_circuit::>(cs, &self.ram_bus_circuit, &()); } #[allow(clippy::too_many_arguments)] pub fn assign_table_circuit( &self, cs: &ZKVMConstraintSystem, + shard_ctx: &ShardContext, witness: &mut ZKVMWitnesses, reg_final: &[MemFinalRecord], static_mem_final: &[MemFinalRecord], @@ -104,18 +110,57 @@ impl MmuConfig { stack_final: &[MemFinalRecord], heap_final: &[MemFinalRecord], ) -> Result<(), ZKVMError> { - witness.assign_table_circuit::>(cs, &self.reg_config, reg_final)?; + witness.assign_table_circuit::>( + cs, + &self.reg_init_config, + reg_final, + )?; - witness.assign_table_circuit::>( + witness.assign_table_circuit::>( cs, - &self.static_mem_config, + &self.static_mem_init_config, static_mem_final, )?; witness.assign_table_circuit::>(cs, &self.public_io_config, io_cycles)?; witness.assign_table_circuit::>(cs, &self.hints_config, hints_final)?; - witness.assign_table_circuit::>(cs, &self.stack_config, stack_final)?; - witness.assign_table_circuit::>(cs, &self.heap_config, heap_final)?; + witness.assign_table_circuit::>( + cs, + &self.stack_init_config, + stack_final, + )?; + witness.assign_table_circuit::>( + cs, + &self.heap_init_config, + heap_final, + )?; + + let all_records = vec![ + (InstancePaddingStrategy::Default, reg_final), + (InstancePaddingStrategy::Default, static_mem_final), + ( + InstancePaddingStrategy::Custom({ + let params = cs.params.clone(); + Arc::new(move |row: u64, _: u64| StackTable::addr(¶ms, row as usize) as u64) + }), + stack_final, + ), + ( + InstancePaddingStrategy::Custom({ + let params = cs.params.clone(); + Arc::new(move |row: u64, _: u64| HeapTable::addr(¶ms, row as usize) as u64) + }), + heap_final, + ), + ]; + // take all mem result and + witness.assign_table_circuit::>( + cs, + &self.local_final_circuit, + &(shard_ctx, all_records.as_slice()), + )?; + + witness.assign_table_circuit::>(cs, &self.ram_bus_circuit, todo!())?; Ok(()) } diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index 416ce83ff..b8ee97f16 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -8,9 +8,11 @@ use crate::{ mod ram_circuit; mod ram_impl; -use crate::tables::ram::ram_impl::{ - DynVolatileRamTableConfig, DynVolatileRamTableFinalConfig, DynVolatileRamTableInitConfig, - NonVolatileFinalTableConfig, NonVolatileInitTableConfig, +use crate::tables::ram::{ + ram_circuit::{LocalFinalRamCircuit, RamBusCircuit}, + ram_impl::{ + DynVolatileRamTableConfig, DynVolatileRamTableInitConfig, NonVolatileInitTableConfig, + }, }; pub use ram_circuit::{DynVolatileRamTable, MemFinalRecord, MemInitRecord, NonVolatileTable}; @@ -38,8 +40,6 @@ impl DynVolatileRamTable for HeapTable { pub type HeapInitCircuit = DynVolatileRamCircuit>; -pub type HeapFinalCircuit = - DynVolatileRamCircuit>; #[derive(Clone)] pub struct StackTable; @@ -75,8 +75,6 @@ impl DynVolatileRamTable for StackTable { pub type StackInitCircuit = DynVolatileRamCircuit>; -pub type StackFinalCircuit = - DynVolatileRamCircuit>; #[derive(Clone)] pub struct HintsTable; @@ -121,8 +119,6 @@ impl NonVolatileTable for RegTable { pub type RegTableInitCircuit = NonVolatileRamCircuit>; -pub type RegTableFinalCircuit = - NonVolatileRamCircuit>; #[derive(Clone)] pub struct StaticMemTable; @@ -143,8 +139,6 @@ impl NonVolatileTable for StaticMemTable { pub type StaticMemInitCircuit = NonVolatileRamCircuit>; -pub type StaticMemFinalCircuit = - NonVolatileRamCircuit>; #[derive(Clone)] pub struct PubIOTable; @@ -164,3 +158,5 @@ impl NonVolatileTable for PubIOTable { } pub type PubIOCircuit = PubIORamCircuit; +pub type LocalFinalCircuit<'a, E> = LocalFinalRamCircuit<'a, UINT_LIMBS, E>; +pub type RBCircuit = RamBusCircuit; diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 854548c95..160050988 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -1,7 +1,11 @@ use std::{collections::HashMap, marker::PhantomData}; +use super::ram_impl::{ + LocalRAMTableFinalConfig, NonVolatileTableConfigTrait, PubIOTableConfig, RAMBusConfig, +}; use crate::{ circuit_builder::CircuitBuilder, + e2e::{RAMRecord, ShardContext}, error::ZKVMError, structs::{ProgramParams, RAMType}, tables::{RMMCollections, TableCircuit}, @@ -11,11 +15,6 @@ use ff_ext::{ExtensionField, SmallField}; use gkr_iop::error::CircuitBuilderError; use witness::{InstancePaddingStrategy, RowMajorMatrix}; -use super::ram_impl::{ - DynVolatileRamTableConfig, NonVolatileInitTableConfig, NonVolatileTableConfig, - NonVolatileTableConfigTrait, PubIOTableConfig, -}; - #[derive(Clone, Debug)] pub struct MemInitRecord { pub addr: Addr, @@ -24,6 +23,7 @@ pub struct MemInitRecord { #[derive(Clone, Debug)] pub struct MemFinalRecord { + pub ram_type: RAMType, pub addr: Addr, pub cycle: Cycle, pub value: Word, @@ -268,3 +268,105 @@ impl< ) } } + +/// This circuit is generalized version to handle all mmio records +pub struct LocalFinalRamCircuit<'a, const V_LIMBS: usize, E>(PhantomData<(&'a (), E)>); + +impl<'a, E: ExtensionField, const V_LIMBS: usize> TableCircuit + for LocalFinalRamCircuit<'a, V_LIMBS, E> +{ + type TableConfig = LocalRAMTableFinalConfig; + type FixedInput = (); + type WitnessInput = ( + &'a ShardContext<'a>, + &'a [(InstancePaddingStrategy, &'a [MemFinalRecord])], + ); + + fn name() -> String { + "LocalRAMTableFinal".to_string() + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + Ok(cb.namespace( + || Self::name(), + |cb| Self::TableConfig::construct_circuit(cb, params), + )?) + } + + fn generate_fixed_traces( + _config: &Self::TableConfig, + _num_fixed: usize, + _init_v: &Self::FixedInput, + ) -> RowMajorMatrix { + RowMajorMatrix::::new(0, 0, InstancePaddingStrategy::Default) + } + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + num_structural_witin: usize, + _multiplicity: &[HashMap], + (shard_ctx, final_mem): &Self::WitnessInput, + ) -> Result, ZKVMError> { + // assume returned table is well-formed include padding + Ok(Self::TableConfig::assign_instances( + config, + shard_ctx, + num_witin, + num_structural_witin, + final_mem, + )?) + } +} + +/// This circuit is generalized version to handle all mmio records +pub struct RamBusCircuit(PhantomData); + +impl TableCircuit for RamBusCircuit { + type TableConfig = RAMBusConfig; + type FixedInput = (); + type WitnessInput = (&'static [RAMRecord], &'static [RAMRecord]); + + fn name() -> String { + "RamBusCircuit".to_string() + } + + fn construct_circuit( + cb: &mut CircuitBuilder, + params: &ProgramParams, + ) -> Result { + Ok(cb.namespace( + || Self::name(), + |cb| Self::TableConfig::construct_circuit(cb, params), + )?) + } + + fn generate_fixed_traces( + _config: &Self::TableConfig, + _num_fixed: usize, + _init_v: &Self::FixedInput, + ) -> RowMajorMatrix { + RowMajorMatrix::::new(0, 0, InstancePaddingStrategy::Default) + } + + fn assign_instances( + config: &Self::TableConfig, + num_witin: usize, + num_structural_witin: usize, + _multiplicity: &[HashMap], + final_v: &Self::WitnessInput, + ) -> Result, ZKVMError> { + let (global_read_mem, global_write_mem) = *final_v; + // assume returned table is well-formed include padding + Ok(Self::TableConfig::assign_instances( + config, + num_witin, + num_structural_witin, + global_read_mem, + global_write_mem, + )?) + } +} diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 8ef622539..ede32944c 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -3,8 +3,10 @@ use either::Either; use ff_ext::{ExtensionField, SmallField}; use gkr_iop::error::CircuitBuilderError; use itertools::Itertools; -use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; -use std::{marker::PhantomData, ops::Neg, sync::Arc}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; +use std::{marker::PhantomData, sync::Arc}; use witness::{ InstancePaddingStrategy, RowMajorMatrix, next_pow2_instance_padding, set_fixed_val, set_val, }; @@ -16,12 +18,13 @@ use super::{ use crate::{ chip_handler::general::PublicIOQuery, circuit_builder::{CircuitBuilder, SetTableSpec}, - e2e::RAMRecord, + e2e::{RAMRecord, ShardContext}, instructions::riscv::constants::{LIMB_BITS, LIMB_MASK}, structs::ProgramParams, tables::ram::ram_circuit::DynVolatileRamTableConfigTrait, }; use ff_ext::FieldInto; +use gkr_iop::RAMType; use multilinear_extensions::{ Expression, Fixed, StructuralWitIn, StructuralWitInType, ToExpr, WitIn, }; @@ -50,168 +53,6 @@ pub trait NonVolatileTableConfigTrait: Sized + Send + Sync { ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError>; } -/// define a non-volatile memory with init value -#[derive(Clone, Debug)] -pub struct NonVolatileTableConfig { - init_v: Vec, - addr: Fixed, - - final_v: Option>, - final_cycle: WitIn, - - phantom: PhantomData, - params: ProgramParams, -} - -impl NonVolatileTableConfigTrait - for NonVolatileTableConfig -{ - type Config = NonVolatileTableConfig; - fn construct_circuit( - cb: &mut CircuitBuilder, - params: &ProgramParams, - ) -> Result { - let init_v = (0..NVRAM::V_LIMBS) - .map(|i| cb.create_fixed(|| format!("init_v_limb_{i}"))) - .collect_vec(); - let addr = cb.create_fixed(|| "addr"); - - let final_cycle = cb.create_witin(|| "final_cycle"); - let final_v = if NVRAM::WRITABLE { - Some( - (0..NVRAM::V_LIMBS) - .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) - .collect::>(), - ) - } else { - None - }; - - let init_table = [ - vec![(NVRAM::RAM_TYPE as usize).into()], - vec![Expression::Fixed(addr)], - init_v.iter().map(|v| v.expr()).collect_vec(), - vec![Expression::ZERO], // Initial cycle. - ] - .concat(); - - let final_table = [ - // a v t - vec![(NVRAM::RAM_TYPE as usize).into()], - vec![Expression::Fixed(addr)], - final_v - .as_ref() - .map(|v_limb| v_limb.iter().map(|v| v.expr()).collect_vec()) - .unwrap_or_else(|| init_v.iter().map(|v| v.expr()).collect_vec()), - vec![final_cycle.expr()], - ] - .concat(); - - cb.w_table_record( - || "init_table", - NVRAM::RAM_TYPE, - SetTableSpec { - len: Some(NVRAM::len(params)), - structural_witins: vec![], - }, - init_table, - )?; - cb.r_table_record( - || "final_table", - NVRAM::RAM_TYPE, - SetTableSpec { - len: Some(NVRAM::len(params)), - structural_witins: vec![], - }, - final_table, - )?; - - Ok(Self { - init_v, - final_v, - addr, - final_cycle, - phantom: PhantomData, - params: params.clone(), - }) - } - - fn gen_init_state( - config: &Self::Config, - num_fixed: usize, - init_mem: &[MemInitRecord], - ) -> RowMajorMatrix { - assert!( - NVRAM::len(&config.params).is_power_of_two(), - "{} len {} must be a power of 2", - NVRAM::name(), - NVRAM::len(&config.params) - ); - - let mut init_table = RowMajorMatrix::::new( - NVRAM::len(&config.params), - num_fixed, - InstancePaddingStrategy::Default, - ); - assert_eq!(init_table.num_padding_instances(), 0); - - init_table - .par_rows_mut() - .zip_eq(init_mem) - .for_each(|(row, rec)| { - if config.init_v.len() == 1 { - // Assign value directly. - set_fixed_val!(row, config.init_v[0], (rec.value as u64).into_f()); - } else { - // Assign value limbs. - config.init_v.iter().enumerate().for_each(|(l, limb)| { - let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; - set_fixed_val!(row, limb, (val as u64).into_f()); - }); - } - set_fixed_val!(row, config.addr, (rec.addr as u64).into_f()); - }); - - init_table - } - - /// TODO consider taking RowMajorMatrix as argument to save allocations. - fn assign_instances( - config: &Self::Config, - num_witin: usize, - num_structural_witin: usize, - final_mem: &[MemFinalRecord], - ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - assert_eq!(num_structural_witin, 0); - let mut final_table = RowMajorMatrix::::new( - NVRAM::len(&config.params), - num_witin, - InstancePaddingStrategy::Default, - ); - - final_table - .par_rows_mut() - .zip_eq(final_mem) - .for_each(|(row, rec)| { - if let Some(final_v) = &config.final_v { - if final_v.len() == 1 { - // Assign value directly. - set_val!(row, final_v[0], rec.value as u64); - } else { - // Assign value limbs. - final_v.iter().enumerate().for_each(|(l, limb)| { - let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; - set_val!(row, limb, val as u64); - }); - } - } - set_val!(row, config.final_cycle, rec.cycle); - }); - - Ok([final_table, RowMajorMatrix::empty()]) - } -} - /// define a non-volatile memory with init value #[derive(Clone, Debug)] pub struct NonVolatileInitTableConfig { @@ -310,112 +151,10 @@ impl NonVolatileTableConfigTrait< _final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { assert_eq!(num_structural_witin, 0); - assert!(_final_mem.is_empty()); - Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]) } } -/// define a non-volatile memory with init value -#[derive(Clone, Debug)] -pub struct NonVolatileFinalTableConfig { - addr: WitIn, - - final_v: Vec, - final_cycle: WitIn, - - phantom: PhantomData, - params: ProgramParams, -} - -impl NonVolatileTableConfigTrait - for NonVolatileFinalTableConfig -{ - type Config = NonVolatileFinalTableConfig; - fn construct_circuit( - cb: &mut CircuitBuilder, - params: &ProgramParams, - ) -> Result { - assert!(NVRAM::WRITABLE); - let addr = cb.create_witin(|| "addr"); - - let final_cycle = cb.create_witin(|| "final_cycle"); - let final_v = (0..NVRAM::V_LIMBS) - .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) - .collect::>(); - - let final_table = [ - // a v t - vec![(NVRAM::RAM_TYPE as usize).into()], - vec![addr.expr()], - final_v.iter().map(|v| v.expr()).collect_vec(), - vec![final_cycle.expr()], - ] - .concat(); - - cb.r_table_record( - || "final_table", - NVRAM::RAM_TYPE, - SetTableSpec { - len: Some(NVRAM::len(params)), - structural_witins: vec![], - }, - final_table, - )?; - - Ok(Self { - final_v, - addr, - final_cycle, - phantom: PhantomData, - params: params.clone(), - }) - } - - fn gen_init_state( - config: &Self::Config, - num_fixed: usize, - init_mem: &[MemInitRecord], - ) -> RowMajorMatrix { - RowMajorMatrix::empty() - } - - /// TODO consider taking RowMajorMatrix as argument to save allocations. - fn assign_instances( - config: &Self::Config, - num_witin: usize, - num_structural_witin: usize, - final_mem: &[MemFinalRecord], - ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - assert_eq!(num_structural_witin, 0); - let mut final_table = RowMajorMatrix::::new( - NVRAM::len(&config.params), - num_witin, - InstancePaddingStrategy::Default, - ); - - final_table - .par_rows_mut() - .zip_eq(final_mem) - .for_each(|(row, rec)| { - if config.final_v.len() == 1 { - // Assign value directly. - set_val!(row, config.final_v[0], rec.value as u64); - } else { - // Assign value limbs. - config.final_v.iter().enumerate().for_each(|(l, limb)| { - let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; - set_val!(row, limb, val as u64); - }); - } - set_val!(row, config.addr, rec.addr as u64); - set_val!(row, config.final_cycle, rec.cycle); - }); - - Ok([final_table, RowMajorMatrix::empty()]) - } -} - /// define public io /// init value set by instance #[derive(Clone, Debug)] @@ -756,20 +495,17 @@ impl DynVolatileRamTableConfig DVRAM::addr(¶ms, row as usize) as u64 }; - let mut witness = - RowMajorMatrix::::new(final_mem.len(), num_witin, InstancePaddingStrategy::Default); let mut structural_witness = RowMajorMatrix::::new( final_mem.len(), num_structural_witin, InstancePaddingStrategy::Custom(Arc::new(addr_padding_fn)), ); - witness + structural_witness .par_rows_mut() - .zip(structural_witness.par_rows_mut()) .zip(final_mem) .enumerate() - .for_each(|(i, ((row, structural_row), rec))| { + .for_each(|(i, (structural_row, rec))| { assert_eq!( rec.addr, DVRAM::addr(&config.params, i), @@ -781,47 +517,39 @@ impl DynVolatileRamTableConfig }); structural_witness.padding_by_strategy(); - Ok([witness, structural_witness]) + Ok([RowMajorMatrix::empty(), structural_witness]) } } -/// volatile with all init value as 0 -/// dynamic address as witin, relied on augment of knowledge to prove address form +/// This table is generalized version to handle all mmio records #[derive(Clone, Debug)] -pub struct DynVolatileRamTableFinalConfig { - // addr is subset and could be any form - // TODO check soundness issue +pub struct LocalRAMTableFinalConfig { addr_subset: WitIn, sel: StructuralWitIn, final_v: Vec, final_cycle: WitIn, - - phantom: PhantomData, - params: ProgramParams, } -impl DynVolatileRamTableConfigTrait - for DynVolatileRamTableFinalConfig -{ - type Config = DynVolatileRamTableFinalConfig; - fn construct_circuit( +impl LocalRAMTableFinalConfig { + pub fn construct_circuit( cb: &mut CircuitBuilder, - params: &ProgramParams, + _params: &ProgramParams, ) -> Result { let addr_subset = cb.create_witin(|| "addr_subset"); + let ram_type = cb.create_witin(|| "ram_type"); let sel = cb.create_structural_witin( || "sel", StructuralWitInType::EqualDistanceSequence { max_len: 0, - offset: DVRAM::offset_addr(params), + offset: 0, multi_factor: WORD_SIZE, - descending: DVRAM::DESCENDING, + descending: false, }, ); - let final_v = (0..DVRAM::V_LIMBS) + let final_v = (0..V_LIMBS) .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) .collect::>(); let final_cycle = cb.create_witin(|| "final_cycle"); @@ -830,7 +558,7 @@ impl DynVolatileRamTableConfig let final_expr = final_v.iter().map(|v| v.expr()).collect_vec(); let raw_final_table = [ // a v t - vec![(DVRAM::RAM_TYPE as usize).into()], + vec![ram_type.expr()], vec![addr_subset.expr()], final_expr, vec![final_cycle.expr()], @@ -840,7 +568,8 @@ impl DynVolatileRamTableConfig + (Expression::Constant(Either::Left(E::BaseField::ONE)) - sel.expr()); cb.r_table_rlc_record( || "final_table", - DVRAM::RAM_TYPE, + // XXX we mixed all ram type here to save column allocation + RAMType::Undefined, SetTableSpec { len: None, structural_witins: vec![sel], @@ -854,60 +583,155 @@ impl DynVolatileRamTableConfig sel, final_v, final_cycle, - phantom: PhantomData, - params: params.clone(), }) } /// TODO consider taking RowMajorMatrix as argument to save allocations. - fn assign_instances( - config: &Self::Config, + pub fn assign_instances( + &self, + shard_ctx: &ShardContext, num_witin: usize, num_structural_witin: usize, - final_mem: &[MemFinalRecord], + final_mem: &[(InstancePaddingStrategy, &[MemFinalRecord])], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { assert_eq!(num_structural_witin, 1); - assert!(final_mem.len() <= DVRAM::max_len(&config.params)); - assert!(DVRAM::max_len(&config.params).is_power_of_two()); + + // collect each raw mem belong to this shard, BEFORE padding length + let current_shard_mems_len: Vec = final_mem + .par_iter() + .map(|(_, mem)| { + mem.par_iter() + .filter(|record| shard_ctx.is_current_shard_cycle(record.cycle)) + .count() + }) + .collect(); + + // deal with non-pow2 padding for first shard + // format Vec<(pad_len, pad_start_index)> + let padding_info = if shard_ctx.is_first_shard() { + final_mem + .iter() + .map(|(_, mem)| (next_pow2_instance_padding(mem.len()) - mem.len(), mem.len())) + .collect_vec() + } else { + vec![(0, 0); final_mem.len()] + }; + + // calculate mem length + let mem_lens = current_shard_mems_len + .iter() + .zip_eq(&padding_info) + .map(|(raw_len, (pad_len, _))| raw_len + pad_len) + .collect_vec(); + let total_records = mem_lens.iter().sum(); let mut witness = - RowMajorMatrix::::new(final_mem.len(), num_witin, InstancePaddingStrategy::Default); + RowMajorMatrix::::new(total_records, num_witin, InstancePaddingStrategy::Default); let mut structural_witness = RowMajorMatrix::::new( - final_mem.len(), + total_records, num_structural_witin, InstancePaddingStrategy::Default, ); - witness - .par_rows_mut() - .zip(structural_witness.par_rows_mut()) - .zip(final_mem) + let mut witness_mut_slices = Vec::with_capacity(final_mem.len()); + let mut structural_witness_mut_slices = Vec::with_capacity(final_mem.len()); + let mut witness_value_rest = witness.values.as_mut_slice(); + let mut structural_witness_value_rest = structural_witness.values.as_mut_slice(); + + for mem_len in mem_lens { + let witness_length = mem_len * num_witin; + let structural_witness_length = mem_len * num_structural_witin; + assert!( + witness_length <= witness_value_rest.len(), + "chunk size exceeds remaining data" + ); + assert!( + structural_witness_length <= structural_witness_value_rest.len(), + "chunk size exceeds remaining data" + ); + let (witness_left, witness_r) = witness_value_rest.split_at_mut(witness_length); + let (structural_witness_left, structural_witness_r) = + structural_witness_value_rest.split_at_mut(structural_witness_length); + witness_mut_slices.push(witness_left); + structural_witness_mut_slices.push(structural_witness_left); + witness_value_rest = witness_r; + structural_witness_value_rest = structural_witness_r; + } + + witness_mut_slices + .par_iter_mut() + .zip_eq(structural_witness_mut_slices.par_iter_mut()) + .zip_eq(final_mem.par_iter()) + .zip_eq(padding_info.par_iter()) .enumerate() - .for_each(|(i, ((row, structural_row), rec))| { - if config.final_v.len() == 1 { - // Assign value directly. - set_val!(row, config.final_v[0], rec.value as u64); - } else { - // Assign value limbs. - config.final_v.iter().enumerate().for_each(|(l, limb)| { - let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; - set_val!(row, limb, val as u64); - }); - } - set_val!(row, config.final_cycle, rec.cycle); + .for_each( + |( + i, + ( + ((witness, structural_witness), (padding_strategy, final_mem)), + (pad_size, pad_start_index), + ), + )| { + witness + .chunks_mut(num_witin) + .zip_eq(structural_witness.chunks_mut(num_structural_witin)) + .zip( + final_mem + .iter() + .filter(|record| shard_ctx.is_current_shard_cycle(record.cycle)), + ) + .enumerate() + .for_each(|(i, ((row, structural_row), rec))| { + if self.final_v.len() == 1 { + // Assign value directly. + set_val!(row, self.final_v[0], rec.value as u64); + } else { + // Assign value limbs. + self.final_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }); + } + set_val!(row, self.final_cycle, rec.cycle); + + set_val!(row, self.addr_subset, rec.addr as u64); + set_val!(row, self.sel, 1u64); + }); - set_val!(row, config.addr_subset, rec.addr as u64); - set_val!(row, config.sel, 1u64); - }); + if *pad_size > 0 && shard_ctx.is_first_shard() { + match padding_strategy { + InstancePaddingStrategy::Custom(pad_func) => { + witness[pad_size * num_witin..] + .chunks_mut(num_witin) + .zip_eq( + structural_witness[pad_size * num_structural_witin..] + .chunks_mut(num_structural_witin), + ) + .zip(std::iter::successors(Some(*pad_start_index), |n| { + Some(*n + 1) + })) + .for_each(|((row, structural_row), pad_index)| { + set_val!( + row, + self.addr_subset, + pad_func(pad_index as u64, self.addr_subset.id as u64) + ); + set_val!(row, self.sel, 1u64); + }); + } + _ => unimplemented!(), + } + } + }, + ); Ok([witness, structural_witness]) } } -/// volatile with all init value as 0 -/// dynamic address as witin, relied on augment of knowledge to prove address form +/// The general config to handle ram bus across all records #[derive(Clone, Debug)] -pub struct DynVolatileRAMBusConfig { +pub struct RAMBusConfig { addr_subset: WitIn, sel_read: StructuralWitIn, @@ -915,48 +739,44 @@ pub struct DynVolatileRAMBusConfig, local_read_v: Vec, local_read_cycle: WitIn, - - phantom: PhantomData, - params: ProgramParams, } -impl DynVolatileRAMBusConfig { +impl RAMBusConfig { pub fn construct_circuit( cb: &mut CircuitBuilder, - params: &ProgramParams, + _params: &ProgramParams, ) -> Result { + let ram_type = cb.create_witin(|| "ram_type"); let one = Expression::Constant(Either::Left(E::BaseField::ONE)); - let mem_bus_with_read = cb.query_mem_bus_with_read()?; - let mem_bus_with_write = cb.query_mem_bus_with_write()?; let addr_subset = cb.create_witin(|| "addr_subset"); // TODO add new selector to support sel_rw let sel_read = cb.create_structural_witin( || "sel_read", StructuralWitInType::EqualDistanceSequence { max_len: 0, - offset: DVRAM::offset_addr(params), + offset: 0, multi_factor: WORD_SIZE, - descending: DVRAM::DESCENDING, + descending: false, }, ); let sel_write = cb.create_structural_witin( || "sel_write", StructuralWitInType::EqualDistanceSequence { max_len: 0, - offset: DVRAM::offset_addr(params), + offset: 0, multi_factor: WORD_SIZE, - descending: DVRAM::DESCENDING, + descending: false, }, ); // local write - let local_write_v = (0..DVRAM::V_LIMBS) + let local_write_v = (0..V_LIMBS) .map(|i| cb.create_witin(|| format!("local_write_v_limb_{i}"))) .collect::>(); let local_write_v_expr = local_write_v.iter().map(|v| v.expr()).collect_vec(); // local read - let local_read_v = (0..DVRAM::V_LIMBS) + let local_read_v = (0..V_LIMBS) .map(|i| cb.create_witin(|| format!("local_read_v_limb_{i}"))) .collect::>(); let local_read_v_expr: Vec> = @@ -968,21 +788,20 @@ impl DynVolatileRAMBusConfig DynVolatileRAMBusConfig = mem_bus_with_write.expr() - * (sel_write.expr() * local_read_record + (one.clone() - sel_write.expr())) - + (one.clone() - mem_bus_with_write.expr()); + let local_read: Expression = + sel_write.expr() * local_read_record + (one.clone() - sel_write.expr()); + cb.r_table_rlc_record( || "local_read_record", - DVRAM::RAM_TYPE, + RAMType::Undefined, SetTableSpec { len: None, structural_witins: vec![sel_write], @@ -1024,8 +843,6 @@ impl DynVolatileRAMBusConfig DynVolatileRAMBusConfig Result<[RowMajorMatrix; 2], CircuitBuilderError> { - assert!(global_read_mem.len() <= DVRAM::max_len(&self.params)); - assert!(DVRAM::max_len(&self.params).is_power_of_two()); let witness_length = { let max_len = global_read_mem.len().max(global_write_mem.len()); // first half write, second half read @@ -1130,6 +945,7 @@ mod tests { use ceno_emul::WORD_SIZE; use ff_ext::GoldilocksExt2 as E; + use gkr_iop::RAMType; use itertools::Itertools; use multilinear_extensions::mle::MultilinearExtension; use p3::{field::FieldAlgebra, goldilocks::Goldilocks as F}; @@ -1148,6 +964,7 @@ mod tests { let some_non_2_pow = 26; let input = (0..some_non_2_pow) .map(|i| MemFinalRecord { + ram_type: RAMType::Memory, addr: HintsTable::addr(&def_params, i), cycle: 0, value: 0, diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index b1bffbb08..a5e20f704 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -84,6 +84,7 @@ pub enum RAMType { GlobalState = 0, Register, Memory, + Undefined, } impl_expr_from_unsigned!(RAMType); From d32c71f44723afa6109598c44ffe77cc2eebeca1 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 15 Oct 2025 11:08:51 +0800 Subject: [PATCH 47/91] aligned step cycle and prev_cycle to local version --- ceno_emul/src/shards.rs | 10 +-- ceno_emul/src/tracer.rs | 2 +- ceno_zkvm/src/bin/e2e.rs | 4 +- ceno_zkvm/src/e2e.rs | 41 +++++++++--- ceno_zkvm/src/instructions/riscv/b_insn.rs | 2 +- .../instructions/riscv/dummy/dummy_circuit.rs | 2 +- .../src/instructions/riscv/ecall/keccak.rs | 4 +- .../riscv/ecall/weierstrass_add.rs | 4 +- .../riscv/ecall/weierstrass_double.rs | 4 +- .../src/instructions/riscv/ecall_base.rs | 19 ++++-- ceno_zkvm/src/instructions/riscv/i_insn.rs | 2 +- ceno_zkvm/src/instructions/riscv/im_insn.rs | 2 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 50 ++++++++++----- ceno_zkvm/src/instructions/riscv/j_insn.rs | 2 +- ceno_zkvm/src/instructions/riscv/r_insn.rs | 2 +- ceno_zkvm/src/instructions/riscv/rv32im.rs | 63 +++++++++++-------- ceno_zkvm/src/instructions/riscv/s_insn.rs | 2 +- ceno_zkvm/src/precompiles/lookup_keccakf.rs | 1 + .../weierstrass/weierstrass_add.rs | 1 + .../weierstrass/weierstrass_double.rs | 1 + 20 files changed, 142 insertions(+), 76 deletions(-) diff --git a/ceno_emul/src/shards.rs b/ceno_emul/src/shards.rs index fd34baf85..935623fe3 100644 --- a/ceno_emul/src/shards.rs +++ b/ceno_emul/src/shards.rs @@ -1,14 +1,14 @@ pub struct Shards { pub shard_id: usize, - pub num_shards: usize, + pub max_num_shards: usize, } impl Shards { - pub fn new(shard_id: usize, num_shards: usize) -> Self { - assert!(shard_id < num_shards); + pub fn new(shard_id: usize, max_num_shards: usize) -> Self { + assert!(shard_id < max_num_shards); Self { shard_id, - num_shards, + max_num_shards, } } @@ -17,6 +17,6 @@ impl Shards { } pub fn is_last_shard(&self) -> bool { - self.shard_id == self.num_shards - 1 + self.shard_id == self.max_num_shards - 1 } } diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 9dc9a0b12..c667e07d2 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -25,7 +25,7 @@ use crate::{ /// - Any pair of `rs1 / rs2 / rd` **may be the same**. Then, one op will point to the other op in the same instruction but a different subcycle. The circuits may follow the operations **without special handling** of repeated registers. #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct StepRecord { - cycle: Cycle, + pub cycle: Cycle, pc: Change, pub insn: Instruction, diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index a2c2ffde2..3496477cb 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -115,7 +115,7 @@ struct Args { // number of total shards #[arg(long, default_value = "1")] - num_shards: u32, + max_num_shards: u32, } fn main() { @@ -248,7 +248,7 @@ fn main() { .unwrap_or_default(); let max_steps = args.max_steps.unwrap_or(usize::MAX); - let shards = Shards::new(args.shard_id as usize, args.num_shards as usize); + let shards = Shards::new(args.shard_id as usize, args.max_num_shards as usize); match (args.pcs, args.field) { (PcsKind::Basefold, FieldType::Goldilocks) => { diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index b6f51c362..6ebd9852a 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -112,7 +112,7 @@ pub struct RAMRecord { pub struct ShardContext<'a> { shard_id: usize, - num_shards: usize, + max_num_shards: usize, max_cycle: Cycle, addr_future_accesses: Cow<'a, HashMap<(WordAddr, Cycle), Cycle>>, read_thread_based_record_storage: Either< @@ -131,7 +131,7 @@ impl<'a> Default for ShardContext<'a> { let max_threads = max_usable_threads(); Self { shard_id: 0, - num_shards: 1, + max_num_shards: 1, max_cycle: Cycle::default(), addr_future_accesses: Cow::Owned(HashMap::new()), read_thread_based_record_storage: Either::Left( @@ -154,20 +154,27 @@ impl<'a> Default for ShardContext<'a> { impl<'a> ShardContext<'a> { pub fn new( shard_id: usize, - num_shards: usize, + max_num_shards: usize, executed_instructions: usize, addr_future_accesses: HashMap<(WordAddr, Cycle), Cycle>, ) -> Self { + // current strategy: at least each shard deal with one instruction + let max_num_shards = max_num_shards.min(executed_instructions); + assert!( + shard_id < max_num_shards, + "implement mechanism to skip current shard proof" + ); + let max_threads = max_usable_threads(); // let max_record_per_thread = max_insts.div_ceil(max_threads as u64); - let expected_inst_per_shard = executed_instructions.div_ceil(num_shards) as usize; + let expected_inst_per_shard = executed_instructions.div_ceil(max_num_shards) as usize; let max_cycle = (executed_instructions + 1) * 4; // cycle start from 4 - let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * 4).max(4) - ..((shard_id + 1) * expected_inst_per_shard * 4).min(max_cycle); + let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * 4 + 4) + ..((shard_id + 1) * expected_inst_per_shard * 4 + 4).min(max_cycle); ShardContext { shard_id, - num_shards, + max_num_shards, max_cycle: max_cycle as Cycle, addr_future_accesses: Cow::Owned(addr_future_accesses), // TODO with_capacity optimisation @@ -201,7 +208,7 @@ impl<'a> ShardContext<'a> { .zip(write_thread_based_record_storage.iter_mut()) .map(|(read, write)| ShardContext { shard_id: self.shard_id, - num_shards: self.num_shards, + max_num_shards: self.max_num_shards, max_cycle: self.max_cycle, addr_future_accesses: Cow::Borrowed(self.addr_future_accesses.as_ref()), read_thread_based_record_storage: Either::Right(read), @@ -220,7 +227,7 @@ impl<'a> ShardContext<'a> { #[inline(always)] pub fn is_last_shard(&self) -> bool { - self.shard_id == self.num_shards - 1 + self.shard_id == self.max_num_shards - 1 } #[inline(always)] @@ -228,6 +235,20 @@ impl<'a> ShardContext<'a> { self.cur_shard_cycle_range.contains(&(cycle as usize)) } + #[inline(always)] + pub fn aligned_prev_ts(&self, prev_cycle: Cycle) -> Cycle { + let mut ts = prev_cycle.saturating_sub(self.cur_shard_cycle_range.start as Cycle); + if ts < 4 { + ts = 0 + } + ts + } + + pub fn current_shard_offset_cycle(&self) -> Cycle { + // `-4` as cycle of each local shard start from 4 + (self.cur_shard_cycle_range.start as Cycle) - 4 + } + #[inline(always)] pub fn send( &mut self, @@ -475,7 +496,7 @@ pub fn emulate_program<'a>( let shard_ctx = ShardContext::new( shards.shard_id, - shards.num_shards, + shards.max_num_shards, insts, vm.take_tracer().next_accesses(), ); diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index e84d6a1a2..cdc1db56d 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -93,7 +93,7 @@ impl BInstructionConfig { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; self.rs1 .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; self.rs2 diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index e2396942d..1df279dd9 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -248,7 +248,7 @@ impl DummyConfig { step: &StepRecord, ) -> Result<(), ZKVMError> { // State in and out - self.vm_state.assign_instance(instance, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; // Fetch instruction lk_multiplicity.fetch(step.pc().before.0); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index 2d0e0c2fd..dccdf34a2 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -223,7 +223,9 @@ impl Instruction for KeccakInstruction { [round_index as usize * num_witin..][..num_witin]; // vm_state - config.vm_state.assign_instance(instance, step)?; + config + .vm_state + .assign_instance(instance, &shard_ctx, step)?; config.ecall_id.assign_op( instance, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 03a27a47a..f2fa49bc3 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -274,7 +274,9 @@ impl Instruction let ops = &step.syscall().expect("syscall step"); // vm_state - config.vm_state.assign_instance(instance, step)?; + config + .vm_state + .assign_instance(instance, &shard_ctx, step)?; config.ecall_id.assign_op( instance, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 1281eba7d..0b82b904f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -246,7 +246,9 @@ impl Instruction OpFixedRS Result<(), ZKVMError> { - set_val!(instance, self.prev_ts, op.previous_cycle); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = cycle - current_shard_offset_cycle; + set_val!(instance, self.prev_ts, shard_prev_cycle); // Register state if let Some(prev_value) = self.prev_value.as_ref() { @@ -79,14 +82,20 @@ impl OpFixedRS IInstructionConfig { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; self.rs1 .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; self.rd diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 567833b2f..c7f6cace0 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -72,7 +72,7 @@ impl IMInstructionConfig { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; self.rs1 .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; self.rd diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 03c654f98..4877df9d1 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -60,14 +60,17 @@ impl StateInOut { pub fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &ShardContext, // lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + set_val!(instance, self.pc, step.pc().before.0 as u64); if let Some(n_pc) = self.next_pc { set_val!(instance, n_pc, step.pc().after.0 as u64); } - set_val!(instance, self.ts, step.cycle()); + set_val!(instance, self.ts, step.cycle() - current_shard_offset_cycle); Ok(()) } @@ -113,15 +116,18 @@ impl ReadRS1 { step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.rs1().expect("rs1 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = step.cycle() - current_shard_offset_cycle; set_val!(instance, self.id, op.register_index() as u64); - set_val!(instance, self.prev_ts, op.previous_cycle); + set_val!(instance, self.prev_ts, shard_prev_cycle); // Register read self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - step.cycle() + Tracer::SUBCYCLE_RS1, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS1, )?; shard_ctx.send( RAMType::Register, @@ -177,15 +183,18 @@ impl ReadRS2 { step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.rs2().expect("rs2 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = step.cycle() - current_shard_offset_cycle; set_val!(instance, self.id, op.register_index() as u64); - set_val!(instance, self.prev_ts, op.previous_cycle); + set_val!(instance, self.prev_ts, shard_prev_cycle); // Register read self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - step.cycle() + Tracer::SUBCYCLE_RS2, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS2, )?; shard_ctx.send( @@ -255,8 +264,11 @@ impl WriteRD { cycle: Cycle, op: &WriteOp, ) -> Result<(), ZKVMError> { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = cycle - current_shard_offset_cycle; set_val!(instance, self.id, op.register_index() as u64); - set_val!(instance, self.prev_ts, op.previous_cycle); + set_val!(instance, self.prev_ts, shard_prev_cycle); // Register state self.prev_value.assign_limbs( @@ -268,8 +280,8 @@ impl WriteRD { self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - cycle + Tracer::SUBCYCLE_RD, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RD, )?; shard_ctx.send( RAMType::Register, @@ -323,15 +335,18 @@ impl ReadMEM { step: &StepRecord, ) -> Result<(), ZKVMError> { let op = step.memory_op().unwrap(); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = step.cycle() - current_shard_offset_cycle; // Memory state - set_val!(instance, self.prev_ts, op.previous_cycle); + set_val!(instance, self.prev_ts, shard_prev_cycle); // Memory read self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - step.cycle() + Tracer::SUBCYCLE_MEM, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, )?; shard_ctx.send( @@ -395,13 +410,16 @@ impl WriteMEM { cycle: Cycle, op: &WriteOp, ) -> Result<(), ZKVMError> { - set_val!(instance, self.prev_ts, op.previous_cycle); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = cycle - current_shard_offset_cycle; + set_val!(instance, self.prev_ts, shard_prev_cycle); self.lt_cfg.assign_instance( instance, lk_multiplicity, - op.previous_cycle, - cycle + Tracer::SUBCYCLE_MEM, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, )?; shard_ctx.send( diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index 81a954893..84cb84679 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -59,7 +59,7 @@ impl JInstructionConfig { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; self.rd .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index 1d559a941..a4b9bb128 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -68,7 +68,7 @@ impl RInstructionConfig { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; self.rs1 .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; self.rs2 diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index cc9810d45..a5913f495 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -414,35 +414,44 @@ impl Rv32imConfig { let mut bn254_double_records = Vec::new(); let mut secp256k1_add_records = Vec::new(); let mut secp256k1_double_records = Vec::new(); - steps.into_iter().for_each(|record| { - let insn_kind = record.insn.kind; - match insn_kind { - // ecall / halt - InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => { - halt_records.push(record); + steps + .into_iter() + .filter_map(|mut step| { + if shard_ctx.is_current_shard_cycle(step.cycle()) { + Some(step) + } else { + None } - InsnKind::ECALL if record.rs1().unwrap().value == KeccakSpec::CODE => { - keccak_records.push(record); + }) + .for_each(|record| { + let insn_kind = record.insn.kind; + match insn_kind { + // ecall / halt + InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => { + halt_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == KeccakSpec::CODE => { + keccak_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Bn254AddSpec::CODE => { + bn254_add_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Bn254DoubleSpec::CODE => { + bn254_double_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1AddSpec::CODE => { + secp256k1_add_records.push(record); + } + InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DoubleSpec::CODE => { + secp256k1_double_records.push(record); + } + // other type of ecalls are handled by dummy ecall instruction + _ => { + // it's safe to unwrap as all_records are initialized with Vec::new() + all_records.get_mut(&insn_kind).unwrap().push(record); + } } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254AddSpec::CODE => { - bn254_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254DoubleSpec::CODE => { - bn254_double_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1AddSpec::CODE => { - secp256k1_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DoubleSpec::CODE => { - secp256k1_double_records.push(record); - } - // other type of ecalls are handled by dummy ecall instruction - _ => { - // it's safe to unwrap as all_records are initialized with Vec::new() - all_records.get_mut(&insn_kind).unwrap().push(record); - } - } - }); + }); for (insn_kind, (_, records)) in izip!(InsnKind::iter(), &all_records).sorted_by_key(|(_, (_, a))| Reverse(a.len())) diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index 23a5ff810..f252a7c60 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -78,7 +78,7 @@ impl SInstructionConfig { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.vm_state.assign_instance(instance, step)?; + self.vm_state.assign_instance(instance, shard_ctx, step)?; self.rs1 .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; self.rs2 diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index e1823bc70..5b2c1867f 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -1091,6 +1091,7 @@ pub fn run_faster_keccakf .vm_state .assign_instance( instance, + &shard_ctx, &StepRecord::new_ecall_any(10, ByteAddr::from(0)), ) .expect("assign vm_state error"); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index f9c76fbf1..5eda6aed8 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -609,6 +609,7 @@ pub fn run_weierstrass_add< .vm_state .assign_instance( instance, + &shard_ctx, &StepRecord::new_ecall_any(10, ByteAddr::from(0)), ) .expect("assign vm_state error"); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index 3922b6c22..ee9bc0b65 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -613,6 +613,7 @@ pub fn run_weierstrass_double< .vm_state .assign_instance( instance, + &shard_ctx, &StepRecord::new_ecall_any(10, ByteAddr::from(0)), ) .expect("assign vm_state error"); From f347310c4efd29521b218a51b37f998e076a2901 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 15 Oct 2025 15:47:15 +0800 Subject: [PATCH 48/91] with mem bus chip build pass --- ceno_zkvm/src/e2e.rs | 42 ++-- ceno_zkvm/src/instructions/riscv/rv32im.rs | 2 +- .../src/instructions/riscv/rv32im/mmu.rs | 4 +- ceno_zkvm/src/tables/ram.rs | 2 +- ceno_zkvm/src/tables/ram/ram_circuit.rs | 16 +- ceno_zkvm/src/tables/ram/ram_impl.rs | 192 ++++++++++++------ 6 files changed, 172 insertions(+), 86 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 6ebd9852a..be842f155 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -34,7 +34,6 @@ use serde::Serialize; use std::{ borrow::Cow, collections::{BTreeMap, BTreeSet, HashMap, HashSet}, - mem, sync::Arc, }; use transcript::BasicTranscript as Transcript; @@ -102,6 +101,7 @@ pub struct EmulationResult<'a> { } pub struct RAMRecord { + pub ram_type: RAMType, pub id: u64, pub addr: WordAddr, pub prev_cycle: Cycle, @@ -115,14 +115,10 @@ pub struct ShardContext<'a> { max_num_shards: usize, max_cycle: Cycle, addr_future_accesses: Cow<'a, HashMap<(WordAddr, Cycle), Cycle>>, - read_thread_based_record_storage: Either< - Vec<[BTreeMap; mem::variant_count::()]>, - &'a mut [BTreeMap; mem::variant_count::()], - >, - write_thread_based_record_storage: Either< - Vec<[BTreeMap; mem::variant_count::()]>, - &'a mut [BTreeMap; mem::variant_count::()], - >, + read_thread_based_record_storage: + Either>, &'a mut BTreeMap>, + write_thread_based_record_storage: + Either>, &'a mut BTreeMap>, pub cur_shard_cycle_range: std::ops::Range, } @@ -137,13 +133,13 @@ impl<'a> Default for ShardContext<'a> { read_thread_based_record_storage: Either::Left( (0..max_threads) .into_par_iter() - .map(|_| std::array::from_fn(|_| BTreeMap::new())) + .map(|_| BTreeMap::new()) .collect::>(), ), write_thread_based_record_storage: Either::Left( (0..max_threads) .into_par_iter() - .map(|_| std::array::from_fn(|_| BTreeMap::new())) + .map(|_| BTreeMap::new()) .collect::>(), ), cur_shard_cycle_range: 0..usize::MAX, @@ -181,14 +177,14 @@ impl<'a> ShardContext<'a> { read_thread_based_record_storage: Either::Left( (0..max_threads) .into_par_iter() - .map(|_| std::array::from_fn(|_| BTreeMap::new())) + .map(|_| BTreeMap::new()) .collect::>(), ), // TODO with_capacity optimisation write_thread_based_record_storage: Either::Left( (0..max_threads) .into_par_iter() - .map(|_| std::array::from_fn(|_| BTreeMap::new())) + .map(|_| BTreeMap::new()) .collect::>(), ), cur_shard_cycle_range, @@ -220,6 +216,20 @@ impl<'a> ShardContext<'a> { } } + pub fn read_records(&self) -> &[BTreeMap] { + match &self.read_thread_based_record_storage { + Either::Left(m) => m, + Either::Right(_) => panic!("undefined behaviour"), + } + } + + pub fn write_records(&self) -> &[BTreeMap] { + match &self.write_thread_based_record_storage { + Either::Left(m) => m, + Either::Right(_) => panic!("undefined behaviour"), + } + } + #[inline(always)] pub fn is_first_shard(&self) -> bool { self.shard_id == 0 @@ -269,9 +279,10 @@ impl<'a> ShardContext<'a> { .as_mut() .right() .expect("illegal type"); - ram_record[ram_type as usize].insert( + ram_record.insert( addr, RAMRecord { + ram_type, id, addr, prev_cycle, @@ -291,9 +302,10 @@ impl<'a> ShardContext<'a> { .as_mut() .right() .expect("illegal type"); - ram_record[ram_type as usize].insert( + ram_record.insert( addr, RAMRecord { + ram_type, id, addr, prev_cycle, diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index a5913f495..7a8463569 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -416,7 +416,7 @@ impl Rv32imConfig { let mut secp256k1_double_records = Vec::new(); steps .into_iter() - .filter_map(|mut step| { + .filter_map(|step| { if shard_ctx.is_current_shard_cycle(step.cycle()) { Some(step) } else { diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index c37aa1615..e15335810 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -31,7 +31,7 @@ pub struct MmuConfig<'a, E: ExtensionField> { /// finalized circuit for all MMIO pub local_final_circuit: as TableCircuit>::TableConfig, /// ram bus to deal with cross shard read/write - pub ram_bus_circuit: as TableCircuit>::TableConfig, + pub ram_bus_circuit: as TableCircuit>::TableConfig, pub params: ProgramParams, } @@ -160,7 +160,7 @@ impl MmuConfig<'_, E> { &(shard_ctx, all_records.as_slice()), )?; - witness.assign_table_circuit::>(cs, &self.ram_bus_circuit, todo!())?; + witness.assign_table_circuit::>(cs, &self.ram_bus_circuit, shard_ctx)?; Ok(()) } diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index b8ee97f16..6075b0440 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -159,4 +159,4 @@ impl NonVolatileTable for PubIOTable { pub type PubIOCircuit = PubIORamCircuit; pub type LocalFinalCircuit<'a, E> = LocalFinalRamCircuit<'a, UINT_LIMBS, E>; -pub type RBCircuit = RamBusCircuit; +pub type RBCircuit<'a, E> = RamBusCircuit<'a, UINT_LIMBS, E>; diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 160050988..ff5e9a783 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -5,7 +5,7 @@ use super::ram_impl::{ }; use crate::{ circuit_builder::CircuitBuilder, - e2e::{RAMRecord, ShardContext}, + e2e::ShardContext, error::ZKVMError, structs::{ProgramParams, RAMType}, tables::{RMMCollections, TableCircuit}, @@ -323,12 +323,14 @@ impl<'a, E: ExtensionField, const V_LIMBS: usize> TableCircuit } /// This circuit is generalized version to handle all mmio records -pub struct RamBusCircuit(PhantomData); +pub struct RamBusCircuit<'a, const V_LIMBS: usize, E>(PhantomData<(&'a (), E)>); -impl TableCircuit for RamBusCircuit { +impl<'a, E: ExtensionField, const V_LIMBS: usize> TableCircuit + for RamBusCircuit<'a, V_LIMBS, E> +{ type TableConfig = RAMBusConfig; type FixedInput = (); - type WitnessInput = (&'static [RAMRecord], &'static [RAMRecord]); + type WitnessInput = ShardContext<'a>; fn name() -> String { "RamBusCircuit".to_string() @@ -357,16 +359,14 @@ impl TableCircuit for RamBusCircuit< num_witin: usize, num_structural_witin: usize, _multiplicity: &[HashMap], - final_v: &Self::WitnessInput, + shard_ctx: &Self::WitnessInput, ) -> Result, ZKVMError> { - let (global_read_mem, global_write_mem) = *final_v; // assume returned table is well-formed include padding Ok(Self::TableConfig::assign_instances( config, + shard_ctx, num_witin, num_structural_witin, - global_read_mem, - global_write_mem, )?) } } diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index ede32944c..2c66dc06d 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -18,7 +18,7 @@ use super::{ use crate::{ chip_handler::general::PublicIOQuery, circuit_builder::{CircuitBuilder, SetTableSpec}, - e2e::{RAMRecord, ShardContext}, + e2e::ShardContext, instructions::riscv::constants::{LIMB_BITS, LIMB_MASK}, structs::ProgramParams, tables::ram::ram_circuit::DynVolatileRamTableConfigTrait, @@ -29,7 +29,6 @@ use multilinear_extensions::{ Expression, Fixed, StructuralWitIn, StructuralWitInType, ToExpr, WitIn, }; use p3::field::FieldAlgebra; -use rayon::prelude::{ParallelSlice, ParallelSliceMut}; pub trait NonVolatileTableConfigTrait: Sized + Send + Sync { type Config: Sized + Send + Sync; @@ -481,7 +480,7 @@ impl DynVolatileRamTableConfig /// TODO consider taking RowMajorMatrix as argument to save allocations. fn assign_instances( config: &Self::Config, - num_witin: usize, + _num_witin: usize, num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { @@ -663,14 +662,10 @@ impl LocalRAMTableFinalConfig { .zip_eq(structural_witness_mut_slices.par_iter_mut()) .zip_eq(final_mem.par_iter()) .zip_eq(padding_info.par_iter()) - .enumerate() .for_each( |( - i, - ( - ((witness, structural_witness), (padding_strategy, final_mem)), - (pad_size, pad_start_index), - ), + ((witness, structural_witness), (padding_strategy, final_mem)), + (pad_size, pad_start_index), )| { witness .chunks_mut(num_witin) @@ -680,8 +675,7 @@ impl LocalRAMTableFinalConfig { .iter() .filter(|record| shard_ctx.is_current_shard_cycle(record.cycle)), ) - .enumerate() - .for_each(|(i, ((row, structural_row), rec))| { + .for_each(|((row, structural_row), rec)| { if self.final_v.len() == 1 { // Assign value directly. set_val!(row, self.final_v[0], rec.value as u64); @@ -695,7 +689,7 @@ impl LocalRAMTableFinalConfig { set_val!(row, self.final_cycle, rec.cycle); set_val!(row, self.addr_subset, rec.addr as u64); - set_val!(row, self.sel, 1u64); + set_val!(structural_row, self.sel, 1u64); }); if *pad_size > 0 && shard_ctx.is_first_shard() { @@ -716,7 +710,7 @@ impl LocalRAMTableFinalConfig { self.addr_subset, pad_func(pad_index as u64, self.addr_subset.id as u64) ); - set_val!(row, self.sel, 1u64); + set_val!(structural_row, self.sel, 1u64); }); } _ => unimplemented!(), @@ -799,6 +793,7 @@ impl RAMBusConfig { let local_write_record = cb.rlc_chip_record(local_raw_write_record.clone()); let local_write = sel_read.expr() * local_write_record + (one.clone() - sel_read.expr()).expr(); + // local write, global read cb.w_table_rlc_record( || "local_write_record", RAMType::Undefined, @@ -824,6 +819,7 @@ impl RAMBusConfig { let local_read: Expression = sel_write.expr() * local_read_record + (one.clone() - sel_write.expr()); + // local read, global write cb.r_table_rlc_record( || "local_read_record", RAMType::Undefined, @@ -849,81 +845,159 @@ impl RAMBusConfig { /// TODO consider taking RowMajorMatrix as argument to save allocations. pub fn assign_instances( &self, + shard_ctx: &ShardContext, num_witin: usize, num_structural_witin: usize, - global_read_mem: &[RAMRecord], - global_write_mem: &[RAMRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + let (global_read_records, global_write_records) = + (shard_ctx.read_records(), shard_ctx.write_records()); + assert_eq!(global_read_records.len(), global_write_records.len()); + let witness_length = { - let max_len = global_read_mem.len().max(global_write_mem.len()); + let raw_write_len: usize = global_write_records.iter().map(|m| m.len()).sum(); + let raw_read_len: usize = global_read_records.iter().map(|m| m.len()).sum(); + let max_len = raw_read_len.max(raw_write_len); // first half write, second half read next_pow2_instance_padding(max_len) * 2 }; - let mut witness = RowMajorMatrix::::new(witness_length, num_witin, InstancePaddingStrategy::Default); - let witness_mid = witness.values.len() / 2; let mut structural_witness = RowMajorMatrix::::new( witness_length, num_structural_witin, InstancePaddingStrategy::Default, ); + let witness_mid = witness.values.len() / 2; let (witness_write, witness_read) = witness.values.split_at_mut(witness_mid); let structural_witness_mid = structural_witness.values.len() / 2; let (structural_witness_write, structural_witness_read) = structural_witness .values .split_at_mut(structural_witness_mid); + let mut witness_write_mut_slices = Vec::with_capacity(global_write_records.len()); + let mut witness_read_mut_slices = Vec::with_capacity(global_read_records.len()); + let mut structural_witness_write_mut_slices = + Vec::with_capacity(global_write_records.len()); + let mut structural_witness_read_mut_slices = Vec::with_capacity(global_read_records.len()); + let mut witness_write_value_rest = witness_write; + let mut witness_read_value_rest = witness_read; + let mut structural_witness_write_value_rest = structural_witness_write; + let mut structural_witness_read_value_rest = structural_witness_read; + + for (global_read_record, global_write_record) in + global_read_records.iter().zip_eq(global_write_records) + { + let witness_write_length = global_write_record.len() * num_witin; + let witness_read_length = global_read_record.len() * num_witin; + let structural_witness_write_length = global_write_record.len() * num_structural_witin; + let structural_witness_read_length = global_read_record.len() * num_structural_witin; + assert!( + witness_write_length <= witness_write_value_rest.len(), + "chunk size exceeds remaining data" + ); + assert!( + witness_read_length <= witness_read_value_rest.len(), + "chunk size exceeds remaining data" + ); + assert!( + structural_witness_write_length <= structural_witness_write_value_rest.len(), + "chunk size exceeds remaining data" + ); + assert!( + structural_witness_read_length <= structural_witness_read_value_rest.len(), + "chunk size exceeds remaining data" + ); + let (witness_write, witness_write_r) = + witness_write_value_rest.split_at_mut(witness_write_length); + witness_write_mut_slices.push(witness_write); + witness_write_value_rest = witness_write_r; + + let (witness_read, witness_read_r) = + witness_read_value_rest.split_at_mut(witness_read_length); + witness_read_mut_slices.push(witness_read); + witness_read_value_rest = witness_read_r; + + let (structural_witness_write, structural_witness_write_r) = + structural_witness_write_value_rest.split_at_mut(structural_witness_write_length); + structural_witness_write_mut_slices.push(structural_witness_write); + structural_witness_write_value_rest = structural_witness_write_r; + + let (structural_witness_read, structural_witness_read_r) = + structural_witness_read_value_rest.split_at_mut(structural_witness_read_length); + structural_witness_read_mut_slices.push(structural_witness_read); + structural_witness_read_value_rest = structural_witness_read_r; + } + rayon::join( // global write, local read || { - witness_write - .par_chunks_mut(num_witin) - .zip(structural_witness_write.par_chunks_mut(num_structural_witin)) - .zip(global_write_mem) - .enumerate() - .for_each(|(i, ((row, structural_row), rec))| { - if self.local_read_v.len() == 1 { - // Assign value directly. - set_val!(row, self.local_read_v[0], rec.value as u64); - } else { - // Assign value limbs. - self.local_read_v.iter().enumerate().for_each(|(l, limb)| { - let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; - set_val!(row, limb, val as u64); - }); - } - set_val!(row, self.local_read_cycle, rec.cycle); + witness_write_mut_slices + .par_iter_mut() + .zip_eq(structural_witness_write_mut_slices.par_iter_mut()) + .zip_eq(global_write_records.par_iter()) + .for_each( + |((witness_write, structural_witness_write), global_write_mem)| { + witness_write + .chunks_mut(num_witin) + .zip_eq(structural_witness_write.chunks_mut(num_structural_witin)) + .zip_eq(global_write_mem.values()) + .for_each(|((row, structural_row), rec)| { + if self.local_read_v.len() == 1 { + // Assign value directly. + set_val!(row, self.local_read_v[0], rec.value as u64); + } else { + // Assign value limbs. + self.local_read_v.iter().enumerate().for_each( + |(l, limb)| { + let val = + (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }, + ); + } + set_val!(row, self.local_read_cycle, rec.cycle); - set_val!(row, self.addr_subset, rec.addr.baddr().0 as u64); - set_val!(structural_row, self.sel_write, 1u64); + set_val!(row, self.addr_subset, rec.addr.baddr().0 as u64); + set_val!(structural_row, self.sel_write, 1u64); - // TODO assign W_{global} - }); + // TODO assign W_{global} + }); + }, + ); }, // global read, local write || { - witness_read - .par_chunks_mut(num_witin) - .zip(structural_witness_read.par_chunks_mut(num_structural_witin)) - .zip(global_read_mem) - .enumerate() - .for_each(|(i, ((row, structural_row), rec))| { - if self.local_write_v.len() == 1 { - // Assign value directly. - set_val!(row, self.local_write_v[0], rec.value as u64); - } else { - // Assign value limbs. - self.local_write_v.iter().enumerate().for_each(|(l, limb)| { - let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; - set_val!(row, limb, val as u64); - }); - } - set_val!(row, self.addr_subset, rec.addr.baddr().0 as u64); - set_val!(structural_row, self.sel_read, 1u64); + witness_read_mut_slices + .par_iter_mut() + .zip_eq(structural_witness_read_mut_slices.par_iter_mut()) + .zip_eq(global_read_records.par_iter()) + .for_each( + |((witness_read, structural_witness_read), global_read_mem)| { + witness_read + .chunks_mut(num_witin) + .zip_eq(structural_witness_read.chunks_mut(num_structural_witin)) + .zip_eq(global_read_mem.values()) + .for_each(|((row, structural_row), rec)| { + if self.local_write_v.len() == 1 { + // Assign value directly. + set_val!(row, self.local_write_v[0], rec.value as u64); + } else { + // Assign value limbs. + self.local_write_v.iter().enumerate().for_each( + |(l, limb)| { + let val = + (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }, + ); + } + set_val!(row, self.addr_subset, rec.addr.baddr().0 as u64); + set_val!(structural_row, self.sel_read, 1u64); - // TODO assign R_{global} - }); + // TODO assign R_{global} + }); + }, + ); }, ); From 4d5a42169d9de58d94dbc7f0a41189e1eafee590 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 15 Oct 2025 16:21:56 +0800 Subject: [PATCH 49/91] cleanup --- ceno_cli/src/commands/common_args/ceno.rs | 11 +++++++++++ ceno_emul/src/tracer.rs | 2 +- ceno_zkvm/src/e2e.rs | 19 ++++++++++--------- ceno_zkvm/src/scheme.rs | 8 -------- ceno_zkvm/src/scheme/tests.rs | 2 +- 5 files changed, 23 insertions(+), 19 deletions(-) diff --git a/ceno_cli/src/commands/common_args/ceno.rs b/ceno_cli/src/commands/common_args/ceno.rs index 9632986a8..d3df66e77 100644 --- a/ceno_cli/src/commands/common_args/ceno.rs +++ b/ceno_cli/src/commands/common_args/ceno.rs @@ -13,6 +13,7 @@ use ceno_zkvm::{ use clap::Args; use ff_ext::{BabyBearExt4, ExtensionField, GoldilocksExt2}; +use ceno_emul::shards::Shards; use mpcs::{ Basefold, BasefoldRSParams, PolynomialCommitmentScheme, SecurityLevel, Whir, WhirDefaultSpec, }; @@ -78,6 +79,14 @@ pub struct CenoOptions { #[arg(long)] pub out_vk: Option, + /// shard id + #[arg(long, default_value = "0")] + shard_id: u32, + + /// number of total shards. + #[arg(long, default_value = "1")] + max_num_shards: u32, + /// Profiling granularity. /// Setting any value restricts logs to profiling information #[arg(long)] @@ -337,6 +346,7 @@ fn run_elf_inner< std::fs::read(elf_path).context(format!("failed to read {}", elf_path.display()))?; let program = Program::load_elf(&elf_bytes, u32::MAX).context("failed to load elf")?; print_cargo_message("Loaded", format_args!("{}", elf_path.display())); + let shards = Shards::new(options.shard_id as usize, options.max_num_shards as usize); let public_io = options .read_public_io() @@ -385,6 +395,7 @@ fn run_elf_inner< create_prover(backend.clone()), program, platform, + shards, &hints, &public_io, options.max_steps, diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index c667e07d2..9dc9a0b12 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -25,7 +25,7 @@ use crate::{ /// - Any pair of `rs1 / rs2 / rd` **may be the same**. Then, one op will point to the other op in the same instruction but a different subcycle. The circuits may follow the operations **without special handling** of repeated registers. #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct StepRecord { - pub cycle: Cycle, + cycle: Cycle, pc: Change, pub insn: Instruction, diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index be842f155..1b08cdf1d 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -142,7 +142,7 @@ impl<'a> Default for ShardContext<'a> { .map(|_| BTreeMap::new()) .collect::>(), ), - cur_shard_cycle_range: 0..usize::MAX, + cur_shard_cycle_range: Tracer::SUBCYCLES_PER_INSN as usize..usize::MAX, } } } @@ -161,12 +161,15 @@ impl<'a> ShardContext<'a> { "implement mechanism to skip current shard proof" ); + let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN as usize; let max_threads = max_usable_threads(); // let max_record_per_thread = max_insts.div_ceil(max_threads as u64); let expected_inst_per_shard = executed_instructions.div_ceil(max_num_shards) as usize; - let max_cycle = (executed_instructions + 1) * 4; // cycle start from 4 - let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * 4 + 4) - ..((shard_id + 1) * expected_inst_per_shard * 4 + 4).min(max_cycle); + let max_cycle = (executed_instructions + 1) * subcycle_per_insn; // cycle start from subcycle_per_insn + let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * subcycle_per_insn + + subcycle_per_insn) + ..((shard_id + 1) * expected_inst_per_shard * subcycle_per_insn + subcycle_per_insn) + .min(max_cycle); ShardContext { shard_id, @@ -248,15 +251,15 @@ impl<'a> ShardContext<'a> { #[inline(always)] pub fn aligned_prev_ts(&self, prev_cycle: Cycle) -> Cycle { let mut ts = prev_cycle.saturating_sub(self.cur_shard_cycle_range.start as Cycle); - if ts < 4 { + if ts < Tracer::SUBCYCLES_PER_INSN { ts = 0 } ts } pub fn current_shard_offset_cycle(&self) -> Cycle { - // `-4` as cycle of each local shard start from 4 - (self.cur_shard_cycle_range.start as Cycle) - 4 + // cycle of each local shard start from Tracer::SUBCYCLES_PER_INSN + (self.cur_shard_cycle_range.start as Cycle) - Tracer::SUBCYCLES_PER_INSN } #[inline(always)] @@ -383,8 +386,6 @@ pub fn emulate_program<'a>( vm.get_pc().into(), end_cycle, shards.shard_id as u32, - !shards.is_first_shard(), // first shard disable global read - !shards.is_last_shard(), // last shard disable global write io_init.iter().map(|rec| rec.value).collect_vec(), ); diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index f2f81c096..b36759d10 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -73,8 +73,6 @@ pub struct PublicValues { end_pc: u32, end_cycle: u64, shard_id: u32, - mem_bus_with_read: bool, - mem_bus_with_write: bool, public_io: Vec, } @@ -86,8 +84,6 @@ impl PublicValues { end_pc: u32, end_cycle: u64, shard_id: u32, - mem_bus_with_read: bool, - mem_bus_with_write: bool, public_io: Vec, ) -> Self { Self { @@ -97,8 +93,6 @@ impl PublicValues { end_pc, end_cycle, shard_id, - mem_bus_with_read, - mem_bus_with_write, public_io, } } @@ -113,8 +107,6 @@ impl PublicValues { vec![E::BaseField::from_canonical_u32(self.end_pc)], vec![E::BaseField::from_canonical_u64(self.end_cycle)], vec![E::BaseField::from_canonical_u32(self.shard_id)], - vec![E::BaseField::from_bool(self.mem_bus_with_read)], - vec![E::BaseField::from_bool(self.mem_bus_with_write)], ] .into_iter() .chain( diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index f7970b413..5cce8f4db 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -370,7 +370,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); - let pi = PublicValues::new(0, 0, 0, 0, 0, vec![0]); + let pi = PublicValues::new(0, 0, 0, 0, 0, 0, vec![0]); let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover .create_proof(zkvm_witness, pi, transcript) From 3c2115825721a3b8af3e792088345da25c3379cb Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 15 Oct 2025 21:04:13 +0800 Subject: [PATCH 50/91] add table circuit cpu sumcheck --- ceno_zkvm/src/scheme/cpu/mod.rs | 115 +++++++++++++++++++++++---- ceno_zkvm/src/scheme/prover.rs | 1 + ceno_zkvm/src/tables/ram/ram_impl.rs | 6 +- 3 files changed, 104 insertions(+), 18 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 414cf1068..a141f8818 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -20,16 +20,19 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, gkr::{self, Evaluation, GKRProof, GKRProverOutput, layer::LayerWitness}, hal::ProverBackend, + selector::SelectorType, }; use itertools::{Itertools, chain}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - Expression, Instance, WitnessId, + ChallengeId, Expression, Instance, ToExpr, WitnessId, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, + monomialize_expr_to_wit_terms, util::ceil_log2, virtual_poly::build_eq_x_r_vec, virtual_polys::VirtualPolynomialsBuilder, }; +use p3::field::FieldAlgebra; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; use std::{collections::BTreeMap, sync::Arc}; use sumcheck::{ @@ -567,7 +570,6 @@ impl> MainSumcheckProver> MainSumcheckProver>(); - let fixed_in_evals = evals.split_off(input.witness.len()); - let wits_in_evals = evals; - exit_span!(span); + let (wits_in_evals, fixed_in_evals, main_sumcheck_proof, rt) = if num_instances + .is_power_of_two() + { + let span = entered_span!("fixed::evals + witin::evals"); + let mut evals = input + .witness + .par_iter() + .chain(input.fixed.par_iter()) + .map(|poly| poly.evaluate(&rt_tower[..poly.num_vars()])) + .collect::>(); + let fixed_in_evals = evals.split_off(input.witness.len()); + let wits_in_evals = evals; + exit_span!(span); + (wits_in_evals, fixed_in_evals, None, rt_tower) + } else { + assert!(cs.w_table_expressions.len() <= 1); + assert!(cs.r_table_expressions.len() <= 1); + + let sel_type = SelectorType::Prefix(E::BaseField::ZERO, 0.into()); + let mut sel_mle = sel_type.compute(&rt_tower, num_instances).unwrap(); + + // `wit` := witin ++ fixed + // we concat eq in between `wit` := witin ++ eqs ++ fixed + let all_witins = input + .witness + .iter() + .map(|mle| Either::Left(mle.as_ref())) + .chain(vec![Either::Right(&mut sel_mle)]) + .chain(input.fixed.iter().map(|mle| Either::Left(mle.as_ref()))) + .collect_vec(); + assert_eq!( + all_witins.len() as WitnessId, + cs.num_witin + cs.num_structural_witin + cs.num_fixed as WitnessId, + "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {}", + all_witins.len(), + cs.num_witin, + cs.num_structural_witin, + cs.num_fixed, + ); + let builder = VirtualPolynomialsBuilder::new_with_mles( + num_threads, + rt_tower.len(), + all_witins, + ); + + let alpha_pows_expr = (2..) + .take(cs.w_table_expressions.len() + cs.r_table_expressions.len()) + .map(|id| Expression::Challenge(id as ChallengeId, 1, E::ONE, E::ZERO)) + .collect_vec(); + let zero_check_expr: Expression = cs + .w_table_expressions + .iter() + .take(1) + .chain(cs.r_table_expressions.iter().take(1)) + .zip_eq(&alpha_pows_expr) + .map(|(expr, alpha)| alpha * expr.expr.expr()) + .sum(); + let zero_check_monomial = monomialize_expr_to_wit_terms( + &zero_check_expr, + cs.num_witin as WitnessId, + cs.num_structural_witin as WitnessId, + cs.num_fixed as WitnessId, + ); + let main_sumcheck_challenges = chain!( + challenges.iter().copied(), + get_challenge_pows( + cs.w_table_expressions.len() + cs.r_table_expressions.len(), + transcript, + ) + ) + .collect_vec(); + + let span = entered_span!("IOPProverState::prove", profiling_4 = true); + let (proof, prover_state) = IOPProverState::prove( + builder.to_virtual_polys_with_monomial_terms( + &zero_check_monomial, + &[], + &main_sumcheck_challenges, + ), + transcript, + ); + exit_span!(span); + let rt = prover_state + .challenges + .iter() + .map(|c| c.elements) + .collect_vec(); + let mut evals = prover_state.get_mle_flatten_final_evaluations(); + let fixed_in_evals = evals.split_off(cs.num_fixed); + let _ = evals.split_off(cs.num_structural_witin as usize); + let wits_in_evals = evals; + (wits_in_evals, fixed_in_evals, Some(proof.proofs), rt) + }; Ok(( - rt_tower, + rt, MainSumcheckEvals { wits_in_evals, fixed_in_evals, }, - None, + main_sumcheck_proof, None, )) } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 1a1c4f17e..1d449fe52 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -323,6 +323,7 @@ impl< transcript: &mut impl Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { + println!("create_proof {name}"); let cs = circuit_pk.get_cs(); let log2_num_instances = input.log2_num_instances(); let num_var_with_rotation = log2_num_instances + cs.rotation_vars().unwrap_or(0); diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 2c66dc06d..251f47676 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -554,6 +554,7 @@ impl LocalRAMTableFinalConfig { let final_cycle = cb.create_witin(|| "final_cycle"); // R_{local} = sel * rlc_final_table + (1 - sel) * ONE + // => R_{local} - ONE = sel * (rlc_final_table - ONE) let final_expr = final_v.iter().map(|v| v.expr()).collect_vec(); let raw_final_table = [ // a v t @@ -563,8 +564,9 @@ impl LocalRAMTableFinalConfig { vec![final_cycle.expr()], ] .concat(); - let final_table_expr = sel.expr() * cb.rlc_chip_record(raw_final_table.clone()) - + (Expression::Constant(Either::Left(E::BaseField::ONE)) - sel.expr()); + let final_table_expr = sel.expr() + * (cb.rlc_chip_record(raw_final_table.clone()) + - Expression::Constant(Either::Left(E::BaseField::ONE))); cb.r_table_rlc_record( || "final_table", // XXX we mixed all ram type here to save column allocation From 4900dc840a5b3beb51d7d3092498873867ee9013 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 16 Oct 2025 09:57:39 +0800 Subject: [PATCH 51/91] global chip unit test wip2 --- ceno_zkvm/src/instructions/global.rs | 210 +++++++++++++++++++++++---- ceno_zkvm/src/scheme/septic_curve.rs | 30 ++-- gkr_iop/src/lib.rs | 3 +- 3 files changed, 208 insertions(+), 35 deletions(-) diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 8f49c7ed7..f2a1078de 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -1,17 +1,24 @@ use std::iter::repeat; use crate::{ + Value, chip_handler::general::PublicIOQuery, gadgets::{Poseidon2Config, RoundConstants}, + scheme::septic_curve::{SepticExtension, SepticPoint}, structs::RAMType, witness::LkMultiplicity, }; use ceno_emul::StepRecord; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, FieldInto, POSEIDON2_BABYBEAR_WIDTH, SmallField}; use gkr_iop::{circuit_builder::CircuitBuilder, error::CircuitBuilderError}; use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::FieldAlgebra; +use p3::{ + field::{Field, FieldAlgebra}, + symmetric::Permutation, +}; +use std::ops::Deref; +use witness::set_val; use crate::{ instructions::{Instruction, riscv::constants::UInt}, @@ -23,7 +30,7 @@ use crate::{ // // global chip: read from and write into a global set shared // among multiple shards -pub struct GlobalConfig { +pub struct GlobalConfig { addr: WitIn, is_ram_register: WitIn, value: UInt, @@ -39,14 +46,16 @@ pub struct GlobalConfig { w_record: WitIn, x: Vec, y: Vec, - poseidon2: Poseidon2Config, + perm_config: Poseidon2Config, + perm: P, } -impl GlobalConfig { +impl GlobalConfig { // TODO: make `WIDTH`, `HALF_FULL_ROUNDS`, `PARTIAL_ROUNDS` generic parameters pub fn configure( cb: &mut CircuitBuilder, rc: RoundConstants, + perm: P, ) -> Result { let x: Vec = (0..SEPTIC_EXTENSION_DEGREE) .map(|i| cb.create_witin(|| format!("x{}", i))) @@ -69,7 +78,7 @@ impl GlobalConfig { let reg: Expression = RAMType::Register.into(); let mem: Expression = RAMType::Memory.into(); let ram_type: Expression = is_ram_reg.clone() * reg + (1 - is_ram_reg) * mem; - let hasher = Poseidon2Config::construct(cb, rc); + let perm_config = Poseidon2Config::construct(cb, rc); let mut input = vec![]; input.push(addr.expr()); @@ -145,10 +154,11 @@ impl GlobalConfig { ); // enforces x = poseidon2([addr, ram_type, value[0], value[1], shard, global_clk, nonce, 0, ..., 0]) - for (input_expr, hasher_input) in input.into_iter().zip_eq(hasher.inputs().into_iter()) { + for (input_expr, hasher_input) in input.into_iter().zip_eq(perm_config.inputs().into_iter()) + { cb.require_equal(|| "poseidon2 input", input_expr, hasher_input)?; } - for (xi, hasher_output) in x.iter().zip(hasher.output().into_iter()) { + for (xi, hasher_output) in x.iter().zip(perm_config.output().into_iter()) { cb.require_equal(|| "x = poseidon2's output", xi.expr(), hasher_output)?; } @@ -172,19 +182,125 @@ impl GlobalConfig { is_global_write, r_record, w_record, - poseidon2: hasher, + perm_config, + perm, }) } } +#[derive(Default)] +pub struct GlobalRecord { + pub addr: u32, + pub ram_type: RAMType, + pub value: u32, + pub shard: u64, + pub local_clk: u64, + pub global_clk: u64, + pub is_write: bool, +} + +impl GlobalRecord { + pub fn to_ec_point< + E: ExtensionField, + P: Permutation<[E::BaseField; POSEIDON2_BABYBEAR_WIDTH]>, + >( + &self, + hasher: &P, + ) -> (u32, SepticPoint) { + let mut nonce = 0; + let mut input = [ + E::BaseField::from_canonical_u32(self.addr), + E::BaseField::from_canonical_u32(self.ram_type as u32), + E::BaseField::from_canonical_u32(self.value & 0xFFFF), // lower 16 bits + E::BaseField::from_canonical_u32((self.value >> 16) & 0xFFFF), // higher 16 bits + E::BaseField::from_canonical_u64(self.shard), + E::BaseField::from_canonical_u64(self.global_clk), + E::BaseField::from_canonical_u32(nonce), + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + ]; + + let prime = E::BaseField::order().to_u64_digits()[0]; + loop { + let x: SepticExtension = + hasher.permute(input)[0..SEPTIC_EXTENSION_DEGREE].into(); + if let Some(p) = SepticPoint::from_x(x) { + let y6 = (p.y.0)[SEPTIC_EXTENSION_DEGREE - 1].to_canonical_u64(); + let is_y_in_2nd_half = y6 >= (prime / 2); + + // we negate y if needed + let negate = match (self.is_write, is_y_in_2nd_half) { + (true, false) => true, // write, y in [0, p/2) + (false, true) => true, // read, y in [p/2, p) + _ => false, + }; + + if negate { + return (nonce, -p); + } else { + return (nonce, p); + } + } else { + // try again with different nonce + nonce += 1; + input[6] = E::BaseField::from_canonical_u32(nonce); + } + } + } +} + +impl From for GlobalRecord { + fn from(step: StepRecord) -> Self { + let mut record = GlobalRecord::default(); + match step.memory_op() { + None => { + record.ram_type = RAMType::Register; + } + Some(_) => { + record.ram_type = RAMType::Memory; + } + }; + if let Some(op) = step.rs1() { + // read from previous shard + record.addr = op.addr.into(); + record.value = op.value; + record.global_clk = 0; // FIXME + record.shard = 0; // FIXME + record.local_clk = 0; + record.is_write = false; + } else { + // propagate local write to global for future shards + let op = step.rd().unwrap(); + record.addr = op.addr.into(); + record.value = op.value.after; + record.shard = 0; // FIXME + record.global_clk = step.cycle(); + record.local_clk = step.cycle(); + record.is_write = true; + } + + record + } +} + // This chip is used to manage read/write into a global set // shared among multiple shards -pub struct GlobalChip { - rc: RoundConstants, +pub struct GlobalChip { + rc: RoundConstants, + perm: P, } -impl Instruction for GlobalChip { - type InstructionConfig = GlobalConfig; +impl + Send> + Instruction for GlobalChip +{ + type InstructionConfig = GlobalConfig; fn name() -> String { "Global".to_string() @@ -195,39 +311,83 @@ impl Instruction for GlobalChip { cb: &mut CircuitBuilder, _param: &crate::structs::ProgramParams, ) -> Result { - let config = GlobalConfig::configure(cb, self.rc.clone())?; + let config = GlobalConfig::configure(cb, self.rc.clone(), self.perm.clone())?; Ok(config) } fn assign_instance( - _config: &Self::InstructionConfig, - _instance: &mut [E::BaseField], + config: &Self::InstructionConfig, + instance: &mut [E::BaseField], _lk_multiplicity: &mut LkMultiplicity, _step: &StepRecord, ) -> Result<(), crate::error::ZKVMError> { - // assign (x, y) - - // assign [addr, ram_type, value, shard, clk, is_write] - - // assign poseidon2 hasher - - todo!() + let record: GlobalRecord = _step.clone().into(); + + // assign basic fields + let is_ram_register = match record.ram_type { + RAMType::Register => 1, + RAMType::Memory => 0, + RAMType::GlobalState => unreachable!(), + }; + set_val!(instance, config.addr, record.addr as u64); + set_val!(instance, config.is_ram_register, is_ram_register as u64); + config + .value + .assign_limbs(instance, Value::new_unchecked(record.value).as_u16_limbs()); + set_val!(instance, config.shard, record.shard); + set_val!(instance, config.global_clk, record.global_clk); + set_val!(instance, config.local_clk, record.local_clk); + set_val!(instance, config.is_global_write, record.is_write as u64); + + // assign (x, y) and nonce + let (nonce, point) = record.to_ec_point::(&config.perm); + set_val!(instance, config.nonce, nonce as u64); + config + .x + .iter() + .chain(config.y.iter()) + .zip_eq((point.x.deref()).iter().chain((point.y.deref()).iter())) + .for_each(|(witin, fe)| { + set_val!(instance, *witin, fe.to_canonical_u64()); + }); + + // TODO: assign poseidon2 hasher + + Ok(()) } } #[cfg(test)] mod tests { + use ff_ext::{BabyBearExt4, PoseidonField}; + use mpcs::{BasefoldDefault, SecurityLevel}; + use p3::babybear::BabyBear; + + use crate::{ + gadgets::horizen_round_consts, + instructions::global::GlobalChip, + scheme::{create_backend, create_prover}, + }; + + type E = BabyBearExt4; + type F = BabyBear; + type PERM = ::P; + type PCS = BasefoldDefault; + #[test] fn test_global_chip() { - // Test the GlobalChip functionality here - // init global chip with horizen_rc_consts + let rc = horizen_round_consts(); + let perm = ::get_default_perm(); + let global_chip = GlobalChip:: { rc, perm }; // create a bunch of random memory read/write records // assign witness // create chip proof for global chip + let backend = create_backend::(20, SecurityLevel::Conjecture100bits); + let prover = create_prover(backend); } } diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index ed3030d23..12b07fcaf 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -741,6 +741,25 @@ pub struct SepticPoint { } impl SepticPoint { + // if there exists y such that (x, y) is on the curve, return one of them + pub fn from_x(x: SepticExtension) -> Option { + let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); + let a: F = F::from_canonical_u32(2); + + let y2 = x.square() * &x + (&x * a) + &b; + if y2.is_square() { + let y = y2.sqrt().unwrap(); + + Some(Self { + x, + y, + is_infinity: false, + }) + } else { + None + } + } + pub fn from_affine(x: SepticExtension, y: SepticExtension) -> Self { let is_infinity = if x.is_zero() && y.is_zero() { true @@ -867,15 +886,8 @@ impl SepticPoint { loop { let x = SepticExtension::random(&mut rng); - let y2 = x.square() * &x + (&x * a) + &b; - if y2.is_square() { - let y = y2.sqrt().unwrap(); - - return Self { - x, - y, - is_infinity: false, - }; + if let Some(point) = Self::from_x(x) { + return point; } } } diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index fc69037ff..f2d504fce 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -77,9 +77,10 @@ pub struct ProtocolVerifier, PCS>( PhantomData<(E, Trans, PCS)>, ); -#[derive(Clone, Debug, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Debug, Default, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[repr(usize)] pub enum RAMType { + #[default] GlobalState, Register, Memory, From 64ab29eb6fb3d617cf0940c6b40af593841f0b81 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 16 Oct 2025 13:44:57 +0800 Subject: [PATCH 52/91] remove r_record / w_record --- ceno_zkvm/src/instructions/global.rs | 43 +++++++--------------------- 1 file changed, 10 insertions(+), 33 deletions(-) diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index f2a1078de..1db0c2246 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -42,8 +42,6 @@ pub struct GlobalConfig { // s.t. local offline memory checking can cancel out // this serves as propagating local write to global. is_global_write: WitIn, - r_record: WitIn, - w_record: WitIn, x: Vec, y: Vec, perm_config: Poseidon2Config, @@ -71,8 +69,6 @@ impl GlobalConfig { let local_clk = cb.create_witin(|| "local_clk"); let nonce = cb.create_witin(|| "nonce"); let is_global_write = cb.create_witin(|| "is_global_write"); - let r_record = cb.create_witin(|| "r_record"); - let w_record = cb.create_witin(|| "w_record"); let is_ram_reg: Expression = is_ram_register.expr(); let reg: Expression = RAMType::Register.into(); @@ -104,28 +100,11 @@ impl GlobalConfig { // otherwise, we insert a padding value 1 to avoid affecting local memory checking cb.assert_bit(|| "is_global_write must be boolean", is_global_write.expr())?; - // r_record = select(is_global_write, rlc, 1) - cb.condition_require_equal( - || "r_record = select(is_global_write, rlc, 1)", - is_global_write.expr(), - r_record.expr(), - rlc.clone(), - E::BaseField::ONE.expr(), - )?; // if we are reading from global set, then this record should be // considered as a initial local write to that address. // otherwise, we insert a padding value 1 as if we are not writing anything - // w_record = select(is_global_write, 1, rlc) - cb.condition_require_equal( - || "w_record = select(is_global_write, 1, rlc)", - is_global_write.expr(), - w_record.expr(), - E::BaseField::ONE.expr(), - rlc, - )?; - // local read/write consistency cb.condition_require_zero( || "is_global_read => local_clk = 0", @@ -134,16 +113,16 @@ impl GlobalConfig { )?; // TODO: enforce shard = shard_id in the public values - cb.read_record( - || "r_record", - gkr_iop::RAMType::Register, // TODO fixme - vec![r_record.expr()], - )?; - cb.write_record( - || "w_record", - gkr_iop::RAMType::Register, // TODO fixme - vec![w_record.expr()], - )?; + // cb.read_record( + // || "r_record", + // gkr_iop::RAMType::Register, // TODO fixme + // vec![r_record.expr()], + // )?; + // cb.write_record( + // || "w_record", + // gkr_iop::RAMType::Register, // TODO fixme + // vec![w_record.expr()], + // )?; // enforces final_sum = \sum_i (x_i, y_i) using ecc quark protocol let final_sum = cb.query_global_rw_sum()?; @@ -180,8 +159,6 @@ impl GlobalConfig { local_clk, nonce, is_global_write, - r_record, - w_record, perm_config, perm, }) From ea6f8ed0d456641fcf9bad883a7fe83a5c41cc76 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 16 Oct 2025 20:57:43 +0800 Subject: [PATCH 53/91] one shard prover pass --- ceno_zkvm/src/e2e.rs | 4 ++ ceno_zkvm/src/keygen.rs | 5 +- ceno_zkvm/src/scheme/cpu/mod.rs | 52 ++++++++-------- ceno_zkvm/src/scheme/gpu/mod.rs | 50 ++++++++-------- ceno_zkvm/src/scheme/hal.rs | 4 +- ceno_zkvm/src/scheme/prover.rs | 90 ++++++++++++++-------------- ceno_zkvm/src/structs.rs | 11 ++++ ceno_zkvm/src/tables/ram/ram_impl.rs | 8 ++- 8 files changed, 122 insertions(+), 102 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 1b08cdf1d..642ad5c34 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -114,6 +114,7 @@ pub struct ShardContext<'a> { shard_id: usize, max_num_shards: usize, max_cycle: Cycle, + // TODO this map is super huge addr_future_accesses: Cow<'a, HashMap<(WordAddr, Cycle), Cycle>>, read_thread_based_record_storage: Either>, &'a mut BTreeMap>, @@ -274,8 +275,10 @@ impl<'a> ShardContext<'a> { prev_value: Option, ) { // check read from external mem bus + // exclude first shard if prev_cycle < self.cur_shard_cycle_range.start as Cycle && self.is_current_shard_cycle(cycle) + && !self.is_first_shard() { let ram_record = self .read_thread_based_record_storage @@ -295,6 +298,7 @@ impl<'a> ShardContext<'a> { }, ); } + // check write to external mem bus if let Some(future_touch_cycle) = self.addr_future_accesses.get(&(addr, cycle)) { if *future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle diff --git a/ceno_zkvm/src/keygen.rs b/ceno_zkvm/src/keygen.rs index 17ab9e72c..0ced182b8 100644 --- a/ceno_zkvm/src/keygen.rs +++ b/ceno_zkvm/src/keygen.rs @@ -26,8 +26,11 @@ impl ZKVMConstraintSystem { .remove(&c_name) .flatten() .ok_or(ZKVMError::FixedTraceNotFound(c_name.clone().into()))?; + vm_pk + .circuit_index_fixed_num_instances + .insert(circuit_index, fixed_trace_rmm.num_instances()); fixed_traces.insert(circuit_index, fixed_trace_rmm); - }; + } let circuit_pk = cs.key_gen(); assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none()); diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index a141f8818..53d78d0e3 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -796,38 +796,38 @@ impl> OpeningProver>, points: Vec>, - mut evals: Vec>, // where each inner Vec = wit_evals + fixed_evals - circuit_num_polys: &[(usize, usize)], - num_instances: &[(usize, usize)], + mut evals: Vec>>, // where each inner vec![wit_evals, fixed_evals] transcript: &mut impl Transcript, ) -> PCS::Proof { let mut rounds = vec![]; - rounds.push(( - &witness_data, - points - .iter() - .zip_eq(evals.iter_mut()) - .zip_eq(num_instances.iter()) - .map(|((point, evals), (chip_idx, _))| { - let (num_witin, _) = circuit_num_polys[*chip_idx]; - (point.clone(), evals.drain(..num_witin).collect_vec()) + rounds.push((&witness_data, { + evals + .iter_mut() + .zip(&points) + .filter_map(|(evals, point)| { + let witin_evals = evals.remove(0); + if !witin_evals.is_empty() { + Some((point.clone(), witin_evals)) + } else { + None + } }) - .collect_vec(), - )); + .collect_vec() + })); if let Some(fixed_data) = fixed_data.as_ref().map(|f| f.as_ref()) { - rounds.push(( - fixed_data, - points - .iter() - .zip_eq(evals.iter_mut()) - .zip_eq(num_instances.iter()) - .filter(|(_, (chip_idx, _))| { - let (_, num_fixed) = circuit_num_polys[*chip_idx]; - num_fixed > 0 + rounds.push((fixed_data, { + evals + .iter_mut() + .zip(points) + .filter_map(|(evals, point)| { + if !evals.is_empty() && !evals[0].is_empty() { + Some((point.clone(), evals.remove(0))) + } else { + None + } }) - .map(|((point, evals), _)| (point.clone(), evals.to_vec())) - .collect_vec(), - )); + .collect_vec() + })); } PCS::batch_open(&self.backend.pp, rounds, transcript).unwrap() } diff --git a/ceno_zkvm/src/scheme/gpu/mod.rs b/ceno_zkvm/src/scheme/gpu/mod.rs index 89dff2160..d1b21f65d 100644 --- a/ceno_zkvm/src/scheme/gpu/mod.rs +++ b/ceno_zkvm/src/scheme/gpu/mod.rs @@ -739,8 +739,6 @@ impl> OpeningProver as ProverBackend>::PcsData>>, points: Vec>, mut evals: Vec>, // where each inner Vec = wit_evals + fixed_evals - circuit_num_polys: &[(usize, usize)], - num_instances: &[(usize, usize)], transcript: &mut (impl Transcript + 'static), ) -> PCS::Proof { if std::any::TypeId::of::() @@ -750,32 +748,34 @@ impl> OpeningProver 0 + rounds.push((fixed_data, { + evals + .iter_mut() + .zip(points) + .filter_map(|(evals, point)| { + if !evals.is_empty() && !evals[0].is_empty() { + Some((point.clone(), evals.remove(0))) + } else { + None + } }) - .map(|((point, evals), _)| (point.clone(), evals.to_vec())) - .collect_vec(), - )); + .collect_vec() + })); } // use ceno_gpu::{ diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index 85cb5ce45..ef118f912 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -147,9 +147,7 @@ pub trait OpeningProver { witness_data: PB::PcsData, fixed_data: Option>, points: Vec>, - evals: Vec>, - circuit_num_polys: &[(usize, usize)], - num_instances: &[(usize, usize)], + evals: Vec>>, transcript: &mut (impl Transcript + 'static), ) -> >::Proof; } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 1d449fe52..91c4fe006 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -113,14 +113,20 @@ impl< // only keep track of circuits that have non-zero instances let mut num_instances = Vec::with_capacity(self.pk.circuit_pks.len()); let mut num_instances_with_rotation = Vec::with_capacity(self.pk.circuit_pks.len()); + let mut circuit_name_num_instances_mapping = BTreeMap::new(); for (index, (circuit_name, ProvingKey { vk, .. })) in self.pk.circuit_pks.iter().enumerate() { // num_instance from witness might include rotation if let Some(num_instance) = witnesses .get_opcode_witness(circuit_name) .or_else(|| witnesses.get_table_witness(circuit_name)) - .map(|rmms| &rmms[0]) - .map(|rmm| rmm.num_instances()) + .map(|rmms| { + if rmms[0].num_instances() == 0 { + rmms[1].num_instances() + } else { + rmms[0].num_instances() + } + }) .and_then(|num_instance| { if num_instance > 0 { Some(num_instance) @@ -128,12 +134,25 @@ impl< None } }) + .or_else(|| { + vk.get_cs().is_static_circuit().then(|| { + self.pk + .circuit_index_fixed_num_instances + .get(&index) + .copied() + .unwrap_or(0) + }) + }) { num_instances.push(( index, num_instance >> vk.get_cs().rotation_vars().unwrap_or(0), )); - num_instances_with_rotation.push((index, num_instance)) + num_instances_with_rotation.push((index, num_instance)); + circuit_name_num_instances_mapping.insert( + circuit_name, + num_instance >> vk.get_cs().rotation_vars().unwrap_or(0), + ); } } @@ -144,7 +163,6 @@ impl< } let commit_to_traces_span = entered_span!("batch commit to traces", profiling_1 = true); - let mut wits_instances = BTreeMap::new(); let mut wits_rmms = BTreeMap::new(); let mut structural_wits = BTreeMap::new(); @@ -157,31 +175,19 @@ impl< } else { RowMajorMatrix::empty() }; - let rotation_vars = self - .pk - .circuit_pks - .get(&circuit_name) - .unwrap() - .vk - .get_cs() - .rotation_vars(); - let num_instances = witness_rmm.num_instances() >> (rotation_vars.unwrap_or(0)); - assert!( - wits_instances - .insert(circuit_name.clone(), num_instances) - .is_none() - ); - if num_instances == 0 { - continue; - } - let structural_witness = structural_witness_rmm.to_mles(); - wits_rmms.insert(circuit_name_index_mapping[&circuit_name], witness_rmm); - structural_wits.insert(circuit_name, (structural_witness, num_instances)); + if witness_rmm.num_instances() > 0 { + wits_rmms.insert(circuit_name_index_mapping[&circuit_name], witness_rmm); + } + if structural_witness_rmm.num_instances() > 0 { + let num_instances = circuit_name_num_instances_mapping + .get(&circuit_name) + .unwrap(); + let structural_witness = structural_witness_rmm.to_mles(); + structural_wits.insert(circuit_name, (structural_witness, num_instances)); + } } - debug_assert_eq!(num_instances.len(), wits_rmms.len()); - // commit to witness traces in batch let (mut witness_mles, witness_data, witin_commit) = self.device.commit_traces(wits_rmms); PCS::write_commitment(&witin_commit, &mut transcript).map_err(ZKVMError::PCSError)?; @@ -204,9 +210,10 @@ impl< let (points, evaluations) = self.pk.circuit_pks.iter().enumerate().try_fold( (vec![], vec![]), |(mut points, mut evaluations), (index, (circuit_name, pk))| { - let num_instances = *wits_instances - .get(circuit_name) - .ok_or(ZKVMError::WitnessNotFound(circuit_name.to_string().into()))?; + let num_instances = circuit_name_num_instances_mapping + .get(&circuit_name) + .copied() + .unwrap_or(0); let cs = pk.get_cs(); if num_instances == 0 { // we need to drain respective fixed when num_instances is 0 @@ -251,21 +258,23 @@ impl< num_instances ); points.push(input_opening_point); - evaluations.push(opcode_proof.wits_in_evals.clone()); + evaluations.push(vec![opcode_proof.wits_in_evals.clone()]); chip_proofs.insert(index, opcode_proof); } else { // FIXME: PROGRAM table circuit is not guaranteed to have 2^n instances input.num_instances = 1 << input.log2_num_instances(); let (mut table_proof, pi_in_evals, input_opening_point) = self .create_chip_proof(circuit_name, pk, input, &mut transcript, &challenges)?; - points.push(input_opening_point); - evaluations.push( - [ + if cs.num_witin() > 0 || cs.num_fixed() > 0 { + points.push(input_opening_point); + evaluations.push(vec![ table_proof.wits_in_evals.clone(), table_proof.fixed_in_evals.clone(), - ] - .concat(), - ); + ]); + } else { + assert!(table_proof.wits_in_evals.is_empty()); + assert!(table_proof.fixed_in_evals.is_empty()); + } // FIXME: PROGRAM table circuit is not guaranteed to have 2^n instances table_proof.num_instances = num_instances; chip_proofs.insert(index, table_proof); @@ -280,20 +289,12 @@ impl< // batch opening pcs // generate static info from prover key for expected num variable - let circuit_num_polys = self - .pk - .circuit_pks - .values() - .map(|pk| (pk.get_cs().num_witin(), pk.get_cs().num_fixed())) - .collect_vec(); let pcs_opening = entered_span!("pcs_opening", profiling_1 = true); let mpcs_opening_proof = self.device.open( witness_data, Some(device_pk.pcs_data), points, evaluations, - &circuit_num_polys, - &num_instances_with_rotation, &mut transcript, ); exit_span!(pcs_opening); @@ -323,7 +324,6 @@ impl< transcript: &mut impl Transcript, challenges: &[E; 2], ) -> Result, ZKVMError> { - println!("create_proof {name}"); let cs = circuit_pk.get_cs(); let log2_num_instances = input.log2_num_instances(); let num_var_with_rotation = log2_num_instances + cs.rotation_vars().unwrap_or(0); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 04842f349..ff5bd15a7 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -109,10 +109,19 @@ impl ComposedConstrainSystem { self.zkvm_v1_css.num_witin.into() } + pub fn num_structural_witin(&self) -> usize { + self.zkvm_v1_css.num_structural_witin.into() + } + pub fn num_fixed(&self) -> usize { self.zkvm_v1_css.num_fixed } + /// static circuit means there is only fixed column + pub fn is_static_circuit(&self) -> bool { + (self.num_witin() + self.num_structural_witin()) == 0 && self.num_fixed() > 0 + } + pub fn num_reads(&self) -> usize { self.zkvm_v1_css.r_expressions.len() + self.zkvm_v1_css.r_table_expressions.len() } @@ -407,6 +416,7 @@ pub struct ZKVMProvingKey> pub circuit_pks: BTreeMap>, pub fixed_commit_wd: Option>::CommitmentWithWitness>>, pub fixed_commit: Option<>::Commitment>, + pub circuit_index_fixed_num_instances: BTreeMap, // expression for global state in/out pub initial_global_state_expr: Expression, @@ -421,6 +431,7 @@ impl> ZKVMProvingKey RAMBusConfig { let (global_read_records, global_write_records) = (shard_ctx.read_records(), shard_ctx.write_records()); assert_eq!(global_read_records.len(), global_write_records.len()); + let raw_write_len: usize = global_write_records.iter().map(|m| m.len()).sum(); + let raw_read_len: usize = global_read_records.iter().map(|m| m.len()).sum(); + if raw_read_len + raw_write_len == 0 { + return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); + } + // TODO refactor to deal with only read/write let witness_length = { - let raw_write_len: usize = global_write_records.iter().map(|m| m.len()).sum(); - let raw_read_len: usize = global_read_records.iter().map(|m| m.len()).sum(); let max_len = raw_read_len.max(raw_write_len); // first half write, second half read next_pow2_instance_padding(max_len) * 2 From 82403f20e3450f989c0aba963a64e6ae9c052554 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 17 Oct 2025 15:02:58 +0800 Subject: [PATCH 54/91] fix most of local final table issue in e2e --- ceno_zkvm/benches/fibonacci.rs | 2 +- ceno_zkvm/src/e2e.rs | 3 +- ceno_zkvm/src/scheme/cpu/mod.rs | 181 ++++++++++--------- ceno_zkvm/src/scheme/mock_prover.rs | 7 +- ceno_zkvm/src/scheme/prover.rs | 6 +- ceno_zkvm/src/scheme/verifier.rs | 249 +++++++++++++-------------- ceno_zkvm/src/tables/program.rs | 3 +- ceno_zkvm/src/tables/ram/ram_impl.rs | 46 +++-- 8 files changed, 250 insertions(+), 247 deletions(-) diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index 878502f8e..eb7133344 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -13,7 +13,7 @@ use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; -use ceno_zkvm::scheme::verifier::ZKVMVerifier; +use ceno_zkvm::{e2e::ShardContext, scheme::verifier::ZKVMVerifier}; use mpcs::BasefoldDefault; use transcript::BasicTranscript; diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 642ad5c34..d1c14bb9d 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -37,6 +37,7 @@ use std::{ sync::Arc, }; use transcript::BasicTranscript as Transcript; +use witness::next_pow2_instance_padding; /// The polynomial commitment scheme kind #[derive( @@ -815,7 +816,7 @@ pub fn setup_program<'a, E: ExtensionField>( let pubio_len = platform.public_io.iter_addresses().len(); let program_params = ProgramParams { platform: platform.clone(), - program_size: program.instructions.len(), + program_size: next_pow2_instance_padding(program.instructions.len()), static_memory_len: static_addrs.len(), pubio_len, }; diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 53d78d0e3..f6f0683c9 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -681,99 +681,98 @@ impl> MainSumcheckProver>(); - let fixed_in_evals = evals.split_off(input.witness.len()); - let wits_in_evals = evals; - exit_span!(span); - (wits_in_evals, fixed_in_evals, None, rt_tower) - } else { - assert!(cs.w_table_expressions.len() <= 1); - assert!(cs.r_table_expressions.len() <= 1); - - let sel_type = SelectorType::Prefix(E::BaseField::ZERO, 0.into()); - let mut sel_mle = sel_type.compute(&rt_tower, num_instances).unwrap(); - - // `wit` := witin ++ fixed - // we concat eq in between `wit` := witin ++ eqs ++ fixed - let all_witins = input - .witness - .iter() - .map(|mle| Either::Left(mle.as_ref())) - .chain(vec![Either::Right(&mut sel_mle)]) - .chain(input.fixed.iter().map(|mle| Either::Left(mle.as_ref()))) - .collect_vec(); - assert_eq!( - all_witins.len() as WitnessId, - cs.num_witin + cs.num_structural_witin + cs.num_fixed as WitnessId, - "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {}", - all_witins.len(), - cs.num_witin, - cs.num_structural_witin, - cs.num_fixed, - ); - let builder = VirtualPolynomialsBuilder::new_with_mles( - num_threads, - rt_tower.len(), - all_witins, - ); - - let alpha_pows_expr = (2..) - .take(cs.w_table_expressions.len() + cs.r_table_expressions.len()) - .map(|id| Expression::Challenge(id as ChallengeId, 1, E::ONE, E::ZERO)) - .collect_vec(); - let zero_check_expr: Expression = cs - .w_table_expressions - .iter() - .take(1) - .chain(cs.r_table_expressions.iter().take(1)) - .zip_eq(&alpha_pows_expr) - .map(|(expr, alpha)| alpha * expr.expr.expr()) - .sum(); - let zero_check_monomial = monomialize_expr_to_wit_terms( - &zero_check_expr, - cs.num_witin as WitnessId, - cs.num_structural_witin as WitnessId, - cs.num_fixed as WitnessId, - ); - let main_sumcheck_challenges = chain!( - challenges.iter().copied(), - get_challenge_pows( - cs.w_table_expressions.len() + cs.r_table_expressions.len(), - transcript, + let (wits_in_evals, fixed_in_evals, main_sumcheck_proof, rt) = + if next_pow2_instance_padding(num_instances) == num_instances { + let span = entered_span!("fixed::evals + witin::evals"); + let mut evals = input + .witness + .par_iter() + .chain(input.fixed.par_iter()) + .map(|poly| poly.evaluate(&rt_tower[..poly.num_vars()])) + .collect::>(); + let fixed_in_evals = evals.split_off(input.witness.len()); + let wits_in_evals = evals; + exit_span!(span); + (wits_in_evals, fixed_in_evals, None, rt_tower) + } else { + assert!(cs.r_table_expressions.len() <= 1); + assert!(cs.w_table_expressions.len() <= 1); + + let sel_type = SelectorType::Prefix(E::BaseField::ZERO, 0.into()); + let mut sel_mle = sel_type.compute(&rt_tower, num_instances).unwrap(); + + // `wit` := witin ++ fixed + // we concat eq in between `wit` := witin ++ eqs ++ fixed + let all_witins = input + .witness + .iter() + .map(|mle| Either::Left(mle.as_ref())) + .chain(vec![Either::Right(&mut sel_mle)]) + .chain(input.fixed.iter().map(|mle| Either::Left(mle.as_ref()))) + .collect_vec(); + assert_eq!( + all_witins.len() as WitnessId, + cs.num_witin + cs.num_structural_witin + cs.num_fixed as WitnessId, + "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {}", + all_witins.len(), + cs.num_witin, + cs.num_structural_witin, + cs.num_fixed, + ); + let builder = VirtualPolynomialsBuilder::new_with_mles( + num_threads, + rt_tower.len(), + all_witins, + ); + + let alpha_pows_expr = (2..) + .take(cs.r_table_expressions.len() + cs.w_table_expressions.len()) + .map(|id| Expression::Challenge(id as ChallengeId, 1, E::ONE, E::ZERO)) + .collect_vec(); + let zero_check_expr: Expression = cs + .r_table_expressions + .iter() + .take(1) + .chain(cs.w_table_expressions.iter().take(1)) + .zip_eq(&alpha_pows_expr) + .map(|(expr, alpha)| alpha * expr.expr.expr()) + .sum(); + let zero_check_monomial = monomialize_expr_to_wit_terms( + &zero_check_expr, + cs.num_witin as WitnessId, + cs.num_structural_witin as WitnessId, + cs.num_fixed as WitnessId, + ); + let main_sumcheck_challenges = chain!( + challenges.iter().copied(), + get_challenge_pows( + cs.w_table_expressions.len() + cs.r_table_expressions.len(), + transcript, + ) ) - ) - .collect_vec(); - - let span = entered_span!("IOPProverState::prove", profiling_4 = true); - let (proof, prover_state) = IOPProverState::prove( - builder.to_virtual_polys_with_monomial_terms( - &zero_check_monomial, - &[], - &main_sumcheck_challenges, - ), - transcript, - ); - exit_span!(span); - let rt = prover_state - .challenges - .iter() - .map(|c| c.elements) .collect_vec(); - let mut evals = prover_state.get_mle_flatten_final_evaluations(); - let fixed_in_evals = evals.split_off(cs.num_fixed); - let _ = evals.split_off(cs.num_structural_witin as usize); - let wits_in_evals = evals; - (wits_in_evals, fixed_in_evals, Some(proof.proofs), rt) - }; + + let span = entered_span!("IOPProverState::prove", profiling_4 = true); + let (proof, prover_state) = IOPProverState::prove( + builder.to_virtual_polys_with_monomial_terms( + &zero_check_monomial, + &[], + &main_sumcheck_challenges, + ), + transcript, + ); + exit_span!(span); + let rt = prover_state + .challenges + .iter() + .map(|c| c.elements) + .collect_vec(); + let mut wits_in_evals = prover_state.get_mle_flatten_final_evaluations(); + let mut rest = wits_in_evals.split_off(cs.num_witin as usize); + let rest = rest.split_off(cs.num_structural_witin as usize); + let fixed_in_evals = rest; + (wits_in_evals, fixed_in_evals, Some(proof.proofs), rt) + }; Ok(( rt, diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 028f844a6..da5e8fe00 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -42,6 +42,7 @@ use std::{ }; use strum::IntoEnumIterator; use tiny_keccak::{Hasher, Keccak}; +use witness::next_pow2_instance_padding; const MAX_CONSTRAINT_DEGREE: usize = 3; const MOCK_PROGRAM_SIZE: usize = 32; @@ -828,7 +829,10 @@ impl<'a, E: ExtensionField + Hash> MockProver { let mut cs = ConstraintSystem::::new(|| "mock_program"); let params = ProgramParams { platform: CENO_PLATFORM, - program_size: max(program.instructions.len(), MOCK_PROGRAM_SIZE), + program_size: max( + next_pow2_instance_padding(program.instructions.len()), + MOCK_PROGRAM_SIZE, + ), ..ProgramParams::default() }; let mut cb = CircuitBuilder::new(&mut cs); @@ -1487,7 +1491,6 @@ fn filter_mle_by_selector_mle( #[cfg(test)] mod tests { - use super::*; use crate::{ ROMType, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 91c4fe006..8df281186 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -236,7 +236,7 @@ impl< ); let fixed = fixed_mles.drain(..cs.num_fixed()).collect_vec(); let public_input = self.device.transport_mles(pi.clone()); - let mut input = ProofInput { + let input = ProofInput { witness: witness_mle, fixed, structural_witness, @@ -262,7 +262,7 @@ impl< chip_proofs.insert(index, opcode_proof); } else { // FIXME: PROGRAM table circuit is not guaranteed to have 2^n instances - input.num_instances = 1 << input.log2_num_instances(); + // input.num_instances = 1 << input.log2_num_instances(); let (mut table_proof, pi_in_evals, input_opening_point) = self .create_chip_proof(circuit_name, pk, input, &mut transcript, &challenges)?; if cs.num_witin() > 0 || cs.num_fixed() > 0 { @@ -276,7 +276,7 @@ impl< assert!(table_proof.fixed_in_evals.is_empty()); } // FIXME: PROGRAM table circuit is not guaranteed to have 2^n instances - table_proof.num_instances = num_instances; + // table_proof.num_instances = num_instances; chip_proofs.insert(index, table_proof); for (idx, eval) in pi_in_evals { pi_evals[idx] = eval; diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index f8c1c8a2a..246e7e73e 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,14 +1,26 @@ -use std::marker::PhantomData; - +use either::Either; use ff_ext::ExtensionField; +use std::marker::PhantomData; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use gkr_iop::gkr::GKRClaims; +use super::{ZKVMChipProof, ZKVMProof}; +use crate::{ + error::ZKVMError, + scheme::constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, + structs::{ComposedConstrainSystem, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, + utils::{ + eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, + eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, + }, +}; +use gkr_iop::{gkr::GKRClaims, selector::SelectorType, utils::eq_eval_less_or_equal_than}; use itertools::{Itertools, chain, interleave, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ + Expression, + Expression::WitIn, Instance, StructuralWitIn, StructuralWitInType, mle::IntoMLE, util::ceil_log2, @@ -21,19 +33,7 @@ use sumcheck::{ util::get_challenge_pows, }; use transcript::{ForkableTranscript, Transcript}; -use witness::next_pow2_instance_padding; - -use crate::{ - error::ZKVMError, - scheme::constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, - structs::{ComposedConstrainSystem, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, - utils::{ - eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, - eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, - }, -}; - -use super::{ZKVMChipProof, ZKVMProof}; +use witness::{InstancePaddingStrategy::Default, next_pow2_instance_padding}; pub struct ZKVMVerifier> { pub vk: ZKVMVerifyingKey, @@ -162,6 +162,7 @@ impl> ZKVMVerifier let mut witin_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); let mut fixed_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); for (index, proof) in &vm_proof.chip_proofs { + assert!(proof.num_instances > 0); let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; @@ -437,43 +438,43 @@ impl> ZKVMVerifier let ComposedConstrainSystem { zkvm_v1_css: cs, .. } = circuit_vk.get_cs(); - debug_assert!( - cs.r_table_expressions - .iter() - .zip_eq(cs.w_table_expressions.iter()) - .all(|(r, w)| r.table_spec.len == w.table_spec.len) - ); - + let with_rw = !cs.r_table_expressions.is_empty() && !cs.w_table_expressions.is_empty(); + if with_rw { + debug_assert!( + cs.r_table_expressions + .iter() + .zip_eq(cs.w_table_expressions.iter()) + .all(|(r, w)| r.table_spec.len == w.table_spec.len) + ); + } let log2_num_instances = next_pow2_instance_padding(proof.num_instances).ilog2() as usize; - // in table proof, we always skip same point sumcheck for now - // as tower sumcheck batch product argument/logup in same length - let is_skip_same_point_sumcheck = true; - // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; // NOTE: for all structural witness within same constrain system should got same hints num variable via `log2_num_instances` - let expected_rounds = cs - // only iterate r set, as read/write set round should match - .r_table_expressions - .iter() - .flat_map(|r| { + let expected_rounds = interleave(&cs.r_table_expressions, &cs.w_table_expressions) + .map(|set_table_expr| { // iterate through structural witins and collect max round. - let num_vars = r.table_spec.len.map(ceil_log2).unwrap_or_else(|| { - r.table_spec - .structural_witins - .iter() - .map(|StructuralWitIn { witin_type, .. }| { - let hint_num_vars = log2_num_instances; - assert!((1 << hint_num_vars) <= witin_type.max_len()); - hint_num_vars - }) - .max() - .unwrap() - }); + let num_vars = set_table_expr + .table_spec + .len + .map(ceil_log2) + .unwrap_or_else(|| { + set_table_expr + .table_spec + .structural_witins + .iter() + .map(|StructuralWitIn { witin_type, .. }| { + let hint_num_vars = log2_num_instances; + assert!((1 << hint_num_vars) <= witin_type.max_len()); + hint_num_vars + }) + .max() + .unwrap() + }); assert_eq!(num_vars, log2_num_instances); - [num_vars, num_vars] // format: [read_round, write_round] + num_vars }) .chain(cs.lk_table_expressions.iter().map(|l| { // iterate through structural witins and collect max round. @@ -497,11 +498,8 @@ impl> ZKVMVerifier let expected_max_rounds = expected_rounds.iter().cloned().max().unwrap(); let (rt_tower, prod_point_and_eval, logup_p_point_and_eval, logup_q_point_and_eval) = TowerVerify::verify( - proof - .r_out_evals - .iter() - .zip(proof.w_out_evals.iter()) - .flat_map(|(r_evals, w_evals)| [r_evals.to_vec(), w_evals.to_vec()]) + interleave(&proof.r_out_evals, &proof.w_out_evals) + .map(|eval| eval.to_vec()) .collect_vec(), proof .lk_out_evals @@ -530,13 +528,18 @@ impl> ZKVMVerifier cs.r_table_expressions.len() + cs.w_table_expressions.len(), "[prod_record] mismatch length" ); - let num_rw_records = cs.r_table_expressions.len() + cs.w_table_expressions.len(); - // evaluate the evaluation of structural mles at input_opening_point by verifier - let structural_evals = cs - .r_table_expressions - .iter() - .map(|r| &r.table_spec) + let input_opening_point = if next_pow2_instance_padding(proof.num_instances) + == proof.num_instances + { + // evaluate the evaluation of structural mles at input_opening_point by verifier + let structural_evals = if with_rw { + // only iterate r set, as read/write set round should match + Either::Left(cs.r_table_expressions.iter()) + } else { + Either::Right(cs.r_table_expressions.iter().chain(&cs.w_table_expressions)) + } + .map(|set_table_expr| &set_table_expr.table_spec) .chain(cs.lk_table_expressions.iter().map(|r| &r.table_spec)) .flat_map(|table_spec| { table_spec @@ -571,32 +574,30 @@ impl> ZKVMVerifier }) .collect_vec(); - // verify records (degree = 1) statement, thus no sumcheck - let expected_evals = interleave( - &cs.r_table_expressions, // r - &cs.w_table_expressions, // w - ) - .map(|rw| &rw.expr) - .chain( - cs.lk_table_expressions - .iter() - .flat_map(|lk| vec![&lk.multiplicity, &lk.values]), // p, q - ) - .map(|expr| { - eval_by_expr_with_instance( - &proof.fixed_in_evals, - &proof.wits_in_evals, - &structural_evals, - pi, - challenges, - expr, + // verify records (degree = 1) statement, thus no sumcheck + let expected_evals = interleave( + &cs.r_table_expressions, // r + &cs.w_table_expressions, // w ) - .right() - .unwrap() - }) - .collect_vec(); - - let input_opening_point = if is_skip_same_point_sumcheck { + .map(|rw| &rw.expr) + .chain( + cs.lk_table_expressions + .iter() + .flat_map(|lk| vec![&lk.multiplicity, &lk.values]), // p, q + ) + .map(|expr| { + eval_by_expr_with_instance( + &proof.fixed_in_evals, + &proof.wits_in_evals, + &structural_evals, + pi, + challenges, + expr, + ) + .right() + .unwrap() + }) + .collect_vec(); for (expected_eval, eval) in expected_evals.iter().zip( prod_point_and_eval .into_iter() @@ -619,29 +620,24 @@ impl> ZKVMVerifier } rt_tower } else { + assert_eq!(cs.lk_table_expressions.len(), 0); assert!(proof.main_sumcheck_proofs.is_some()); + assert_eq!(cs.num_structural_witin, 1); + assert_eq!(prod_point_and_eval.len(), 1); // verify opening same point layer sumcheck let alpha_pow = get_challenge_pows( - cs.r_table_expressions.len() - + cs.w_table_expressions.len() - + cs.lk_table_expressions.len() * 2, // 2 for lk numerator and denominator + cs.r_table_expressions.len() + cs.w_table_expressions.len(), transcript, ); - // \sum_i alpha_{i} * (out_r_eval{i}) - // + \sum_i alpha_{i} * (out_w_eval{i}) - // + \sum_i alpha_{i} * (out_lk_n{i}) - // + \sum_i alpha_{i} * (out_lk_d{i}) + // \sum_i alpha_{i} * (out_r_eval{i} - ONE) + // + \sum_i alpha_{i} * (out_w_eval{i} - ONE) let claim_sum = prod_point_and_eval .iter() .zip(alpha_pow.iter()) - .map(|(point_and_eval, alpha)| *alpha * point_and_eval.eval) - .sum::() - + interleave(&logup_p_point_and_eval, &logup_q_point_and_eval) - .zip_eq(alpha_pow.iter().skip(num_rw_records)) - .map(|(point_n_eval, alpha)| *alpha * point_n_eval.eval) - .sum::(); + .map(|(point_and_eval, alpha)| *alpha * (point_and_eval.eval - E::ONE)) + .sum::(); let sel_subclaim = IOPVerifierState::verify( claim_sum, &IOPProof { @@ -654,44 +650,37 @@ impl> ZKVMVerifier }, transcript, ); - let (input_opening_point, expected_evaluation) = ( + let (input_opening_point, sumcheck_eval) = ( sel_subclaim.point.iter().map(|c| c.elements).collect_vec(), sel_subclaim.expected_evaluation, ); - - let computed_evals = [ - // r, w - prod_point_and_eval - .into_iter() - .zip_eq(&expected_evals[0..num_rw_records]) - .zip(alpha_pow.iter()) - .map(|((point_and_eval, in_eval), alpha)| { - let eq = eq_eval( - &point_and_eval.point, - &input_opening_point[0..point_and_eval.point.len()], - ); - // TODO times multiplication factor - *alpha * eq * *in_eval - }) - .sum::(), - interleave(logup_p_point_and_eval, logup_q_point_and_eval) - .zip_eq(&expected_evals[num_rw_records..]) - .zip_eq(alpha_pow.iter().skip(num_rw_records)) - .map(|((point_and_eval, in_eval), alpha)| { - let eq = eq_eval( - &point_and_eval.point, - &input_opening_point[0..point_and_eval.point.len()], - ); - // TODO times multiplication factor - *alpha * eq * *in_eval - }) - .sum::(), - ] - .iter() - .copied() + let structural_evals = vec![eq_eval_less_or_equal_than( + proof.num_instances - 1, + &prod_point_and_eval[0].point, + &input_opening_point, + )]; + + let expected_evals = interleave( + &cs.r_table_expressions, // r + &cs.w_table_expressions, // w + ) + .map(|rw| &rw.expr) + .zip(alpha_pow.iter()) + .map(|(expr, alpha)| { + *alpha + * eval_by_expr_with_instance( + &proof.fixed_in_evals, + &proof.wits_in_evals, + &structural_evals, + pi, + challenges, + expr, + ) + .right() + .unwrap() + }) .sum::(); - - if computed_evals != expected_evaluation { + if expected_evals != sumcheck_eval { return Err(ZKVMError::VerifyError( "sel evaluation verify failed".into(), )); @@ -749,9 +738,9 @@ impl TowerVerify { let log2_num_fanin = ceil_log2(num_fanin); // sanity check - assert!(num_prod_spec == tower_proofs.prod_spec_size()); + assert_eq!(num_prod_spec, tower_proofs.prod_spec_size()); assert!(prod_out_evals.iter().all(|evals| evals.len() == num_fanin)); - assert!(num_logup_spec == tower_proofs.logup_spec_size()); + assert_eq!(num_logup_spec, tower_proofs.logup_spec_size()); assert!(logup_out_evals.iter().all(|evals| { evals.len() == 4 // [p1, p2, q1, q2] })); diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 41890200e..833663e74 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -182,6 +182,7 @@ impl TableCircuit for ProgramTableCircuit { cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { + assert!(params.program_size.is_power_of_two()); #[cfg(not(feature = "u16limb_circuit"))] let record = InsnRecord([ cb.create_fixed(|| "pc"), @@ -214,7 +215,7 @@ impl TableCircuit for ProgramTableCircuit { cb.lk_table_record( || "prog table", SetTableSpec { - len: Some(params.program_size.next_power_of_two()), + len: Some(params.program_size), structural_witins: vec![], }, ROMType::Instruction, diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index d737e86bb..b055283c2 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -484,38 +484,48 @@ impl DynVolatileRamTableConfig num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { + if final_mem.is_empty() { + return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); + } assert!(final_mem.len() <= DVRAM::max_len(&config.params)); assert!(DVRAM::max_len(&config.params).is_power_of_two()); let params = config.params.clone(); - let addr_id = config.addr.id as u64; - let addr_padding_fn = move |row: u64, col: u64| { - assert_eq!(col, addr_id); - DVRAM::addr(¶ms, row as usize) as u64 - }; + let num_instances_padded = next_pow2_instance_padding(final_mem.len()); + // let addr_id = config.addr.id as u64; + // let addr_padding_fn = move |row: u64, col: u64| { + // assert_eq!(col, addr_id); + // DVRAM::addr(¶ms, row as usize) as u64 + // }; let mut structural_witness = RowMajorMatrix::::new( - final_mem.len(), + num_instances_padded, num_structural_witin, - InstancePaddingStrategy::Custom(Arc::new(addr_padding_fn)), + InstancePaddingStrategy::Default, ); structural_witness .par_rows_mut() - .zip(final_mem) .enumerate() - .for_each(|(i, (structural_row, rec))| { - assert_eq!( - rec.addr, - DVRAM::addr(&config.params, i), - "rec.addr {:x} != expected {:x}", - rec.addr, - DVRAM::addr(&config.params, i), + .for_each(|(i, structural_row)| { + if cfg!(debug_assertions) { + if let Some(addr) = final_mem.get(i).map(|rec| rec.addr) { + debug_assert_eq!( + addr, + DVRAM::addr(&config.params, i), + "rec.addr {:x} != expected {:x}", + addr, + DVRAM::addr(&config.params, i), + ); + } + } + set_val!( + structural_row, + config.addr, + DVRAM::addr(&config.params, i) as u64 ); - set_val!(structural_row, config.addr, rec.addr as u64); }); - structural_witness.padding_by_strategy(); Ok([RowMajorMatrix::empty(), structural_witness]) } } @@ -541,7 +551,7 @@ impl LocalRAMTableFinalConfig { let sel = cb.create_structural_witin( || "sel", StructuralWitInType::EqualDistanceSequence { - max_len: 0, + max_len: u32::MAX as usize, offset: 0, multi_factor: WORD_SIZE, descending: false, From aeea15d2d11378af1e622d02eb2596cfbf30e9f3 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 17 Oct 2025 22:59:35 +0800 Subject: [PATCH 55/91] chores: cosmetics --- ceno_zkvm/src/scheme/cpu/mod.rs | 3 +-- ceno_zkvm/src/scheme/utils.rs | 2 +- ceno_zkvm/src/scheme/verifier.rs | 3 ++- ceno_zkvm/src/tables/ram/ram_impl.rs | 4 +++- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index f6f0683c9..174bddeb1 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -743,8 +743,7 @@ impl> MainSumcheckProver>::table_witness(device, input, cs, challenges), - false, + input.num_instances > 1 && input.num_instances.is_power_of_two(), ) } }; diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 246e7e73e..0153e5191 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -620,6 +620,7 @@ impl> ZKVMVerifier } rt_tower } else { + // TODO LocalFinalTable goes here, merge flow into gkr_iop assert_eq!(cs.lk_table_expressions.len(), 0); assert!(proof.main_sumcheck_proofs.is_some()); assert_eq!(cs.num_structural_witin, 1); @@ -635,7 +636,7 @@ impl> ZKVMVerifier // + \sum_i alpha_{i} * (out_w_eval{i} - ONE) let claim_sum = prod_point_and_eval .iter() - .zip(alpha_pow.iter()) + .zip_eq(alpha_pow.iter()) .map(|(point_and_eval, alpha)| *alpha * (point_and_eval.eval - E::ONE)) .sum::(); let sel_subclaim = IOPVerifierState::verify( diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index b055283c2..cf38184da 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -563,8 +563,10 @@ impl LocalRAMTableFinalConfig { .collect::>(); let final_cycle = cb.create_witin(|| "final_cycle"); - // R_{local} = sel * rlc_final_table + (1 - sel) * ONE + // R_{local} = sel * rlc_final_table + (ONE - sel) * ONE // => R_{local} - ONE = sel * (rlc_final_table - ONE) + // so we put `sel * (rlc_final_table - ONE)` in expression + // and `R_{local} - ONE` can be derived from verifier let final_expr = final_v.iter().map(|v| v.expr()).collect_vec(); let raw_final_table = [ // a v t From 40887fc1397dde0c74f92c8e2a8f6646241c348c Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Sun, 19 Oct 2025 09:43:27 +0800 Subject: [PATCH 56/91] gkr iop support table circuit --- ceno_zkvm/src/scheme/cpu/mod.rs | 3 +- gkr_iop/src/gkr/layer.rs | 64 +++++++++++++++++++++++---------- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 174bddeb1..f6f0683c9 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -743,7 +743,8 @@ impl> MainSumcheckProver Layer { n_challenges: usize, out_evals: OutEvalGroups, ) -> Layer { - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); + let w_len = cb.cs.w_expressions.len() + cb.cs.w_table_expressions.len(); + let r_len = cb.cs.r_expressions.len() + cb.cs.r_table_expressions.len(); + let lk_len = cb.cs.lk_expressions.len() + cb.cs.lk_table_expressions.len() * 2; // logup lk table include p, q let zero_len = cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); @@ -331,9 +332,12 @@ impl Layer { assert_eq!(lookup_evals.len(), lk_len); assert_eq!(zero_evals.len(), zero_len); - let non_zero_expr_len = cb.cs.w_expressions_namespace_map.len() - + cb.cs.r_expressions_namespace_map.len() - + cb.cs.lk_expressions.len(); + let non_zero_expr_len = cb.cs.w_expressions.len() + + cb.cs.w_table_expressions.len() + + cb.cs.r_expressions.len() + + cb.cs.r_table_expressions.len() + + cb.cs.lk_expressions.len() + + cb.cs.lk_table_expressions.len() * 2; let zero_expr_len = cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); @@ -344,13 +348,19 @@ impl Layer { // process r_record let evals = Self::dedup_last_selector_evals(cb.cs.r_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, ((ram_expr, name), ram_eval)) in cb + for (idx, ((ram_expr, name), ram_eval)) in (cb .cs .r_expressions .iter() - .zip_eq(&cb.cs.r_expressions_namespace_map) - .zip_eq(&r_record_evals) - .enumerate() + .chain(cb.cs.r_table_expressions.iter().map(|t| &t.expr))) + .zip_eq( + cb.cs + .r_expressions_namespace_map + .iter() + .chain(&cb.cs.r_table_expressions_namespace_map), + ) + .zip_eq(&r_record_evals) + .enumerate() { expressions.push(ram_expr - E::BaseField::ONE.expr()); evals.push(EvalExpression::::Linear( @@ -365,13 +375,19 @@ impl Layer { // process w_record let evals = Self::dedup_last_selector_evals(cb.cs.w_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, ((ram_expr, name), ram_eval)) in cb + for (idx, ((ram_expr, name), ram_eval)) in (cb .cs .w_expressions .iter() - .zip_eq(&cb.cs.w_expressions_namespace_map) - .zip_eq(&w_record_evals) - .enumerate() + .chain(cb.cs.w_table_expressions.iter().map(|t| &t.expr))) + .zip_eq( + cb.cs + .w_expressions_namespace_map + .iter() + .chain(&cb.cs.w_table_expressions_namespace_map), + ) + .zip_eq(&w_record_evals) + .enumerate() { expressions.push(ram_expr - E::BaseField::ONE.expr()); evals.push(EvalExpression::::Linear( @@ -386,13 +402,25 @@ impl Layer { // process lookup records let evals = Self::dedup_last_selector_evals(cb.cs.lk_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, ((lookup, name), lookup_eval)) in cb + for (idx, ((lookup, name), lookup_eval)) in (cb .cs .lk_expressions .iter() - .zip_eq(&cb.cs.lk_expressions_namespace_map) - .zip_eq(&lookup_evals) - .enumerate() + .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.multiplicity)) + .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.values))) + .zip_eq(if cb.cs.lk_table_expressions.is_empty() { + Either::Left(cb.cs.lk_expressions_namespace_map.iter()) + } else { + // repeat expressions_namespace_map twice to deal with lk p, q + Either::Right( + cb.cs + .lk_expressions_namespace_map + .iter() + .chain(&cb.cs.lk_expressions_namespace_map), + ) + }) + .zip_eq(&lookup_evals) + .enumerate() { expressions.push(lookup - cb.cs.chip_record_alpha.clone()); evals.push(EvalExpression::::Linear( From 9649633ef1aa21f4f970e20ea9353ed7443496b2 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sun, 19 Oct 2025 12:20:50 +0800 Subject: [PATCH 57/91] wip3 --- ceno_zkvm/src/gadgets/poseidon2.rs | 8 +- ceno_zkvm/src/instructions.rs | 7 +- ceno_zkvm/src/instructions/global.rs | 224 ++++++++++++++++-- .../weierstrass/weierstrass_add.rs | 12 +- .../weierstrass/weierstrass_decompress.rs | 12 +- .../weierstrass/weierstrass_double.rs | 12 +- ceno_zkvm/src/scheme.rs | 3 + ceno_zkvm/src/scheme/cpu/mod.rs | 5 +- ceno_zkvm/src/scheme/prover.rs | 2 + gkr_iop/src/gkr/layer/zerocheck_layer.rs | 5 +- gkr_iop/src/selector.rs | 130 ++++++++-- 11 files changed, 360 insertions(+), 60 deletions(-) diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index aafeb3c32..7ecaeabc7 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -24,15 +24,15 @@ use crate::circuit_builder::CircuitBuilder; // copied from poseidon2-air/src/constants.rs // as the original one cannot be accessed here #[derive(Debug, Clone)] -pub(crate) struct RoundConstants< +pub struct RoundConstants< F: Field, const WIDTH: usize, const HALF_FULL_ROUNDS: usize, const PARTIAL_ROUNDS: usize, > { - pub(crate) beginning_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], - pub(crate) partial_round_constants: [F; PARTIAL_ROUNDS], - pub(crate) ending_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], + pub beginning_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], + pub partial_round_constants: [F; PARTIAL_ROUNDS], + pub ending_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], } pub type Poseidon2BabyBearConfig = Poseidon2Config; diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index af83d0695..976eba333 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -59,7 +59,10 @@ pub trait Instruction { descending: false, }, ); - let selector_type = SelectorType::Prefix(E::BaseField::ZERO, selector.expr()); + let selector_type = SelectorType::Prefix { + offset: 0, + expression: selector.expr(), + }; // all shared the same selector let (out_evals, mut chip) = ( @@ -82,7 +85,7 @@ pub trait Instruction { cb.cs.lk_selector = Some(selector_type.clone()); cb.cs.zero_selector = Some(selector_type.clone()); - let layer = Layer::from_circuit_builder(cb, "Rounds".to_string(), 0, out_evals); + let layer = Layer::from_circuit_builder(cb, format!("{}_main", Self::name()), 0, out_evals); chip.add_layer(layer); Ok((config, chip.gkr_circuit())) diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 1db0c2246..5cdfc83e0 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -5,14 +5,19 @@ use crate::{ chip_handler::general::PublicIOQuery, gadgets::{Poseidon2Config, RoundConstants}, scheme::septic_curve::{SepticExtension, SepticPoint}, - structs::RAMType, + structs::{ProgramParams, RAMType}, witness::LkMultiplicity, }; use ceno_emul::StepRecord; use ff_ext::{ExtensionField, FieldInto, POSEIDON2_BABYBEAR_WIDTH, SmallField}; -use gkr_iop::{circuit_builder::CircuitBuilder, error::CircuitBuilderError}; +use gkr_iop::{ + chip::Chip, circuit_builder::CircuitBuilder, error::CircuitBuilderError, gkr::layer::Layer, + selector::SelectorType, +}; use itertools::Itertools; -use multilinear_extensions::{Expression, ToExpr, WitIn}; +use multilinear_extensions::{ + Expression, StructuralWitInType::EqualDistanceSequence, ToExpr, WitIn, +}; use p3::{ field::{Field, FieldAlgebra}, symmetric::Permutation, @@ -44,7 +49,7 @@ pub struct GlobalConfig { is_global_write: WitIn, x: Vec, y: Vec, - perm_config: Poseidon2Config, + // perm_config: Poseidon2Config, perm: P, } @@ -74,7 +79,7 @@ impl GlobalConfig { let reg: Expression = RAMType::Register.into(); let mem: Expression = RAMType::Memory.into(); let ram_type: Expression = is_ram_reg.clone() * reg + (1 - is_ram_reg) * mem; - let perm_config = Poseidon2Config::construct(cb, rc); + // let perm_config = Poseidon2Config::construct(cb, rc); let mut input = vec![]; input.push(addr.expr()); @@ -93,7 +98,6 @@ impl GlobalConfig { record.extend(value.memory_expr()); record.push(shard.expr()); record.push(local_clk.expr()); - let rlc = cb.rlc_chip_record(record); // if is_global_write = 1, then it means we are propagating a local write to global // so we need to insert a local read record to cancel out this local write @@ -113,16 +117,16 @@ impl GlobalConfig { )?; // TODO: enforce shard = shard_id in the public values - // cb.read_record( - // || "r_record", - // gkr_iop::RAMType::Register, // TODO fixme - // vec![r_record.expr()], - // )?; - // cb.write_record( - // || "w_record", - // gkr_iop::RAMType::Register, // TODO fixme - // vec![w_record.expr()], - // )?; + cb.read_record( + || "r_record", + gkr_iop::RAMType::Register, // TODO fixme + record.clone(), + )?; + cb.write_record( + || "w_record", + gkr_iop::RAMType::Register, // TODO fixme + record.clone(), + )?; // enforces final_sum = \sum_i (x_i, y_i) using ecc quark protocol let final_sum = cb.query_global_rw_sum()?; @@ -159,7 +163,7 @@ impl GlobalConfig { local_clk, nonce, is_global_write, - perm_config, + // perm_config, perm, }) } @@ -286,13 +290,99 @@ impl fn construct_circuit( &self, cb: &mut CircuitBuilder, - _param: &crate::structs::ProgramParams, + _param: &ProgramParams, ) -> Result { let config = GlobalConfig::configure(cb, self.rc.clone(), self.perm.clone())?; Ok(config) } + fn build_gkr_iop_circuit( + &self, + cb: &mut CircuitBuilder, + param: &ProgramParams, + ) -> Result<(Self::InstructionConfig, gkr_iop::gkr::GKRCircuit), crate::error::ZKVMError> + { + let config = self.construct_circuit(cb, param)?; + + let w_len = cb.cs.w_expressions.len(); + let r_len = cb.cs.r_expressions.len(); + let lk_len = cb.cs.lk_expressions.len(); + let zero_len = + cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); + + // create three selectors: selector_r, selector_w, selector_zero + let selector_r = cb.create_structural_witin( + || "selector_r", + // this is just a placeholder, the actural type is SelectorType::Prefix() + EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: 0, + descending: false, + }, + ); + let selector_w = cb.create_structural_witin( + || "selector_w", + EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: 0, + descending: false, + }, + ); + let selector_zero = cb.create_structural_witin( + || "selector_zero", + EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: 0, + descending: false, + }, + ); + let selector_r = SelectorType::Prefix { + offset: 0, + expression: selector_r.expr(), + }; + // note that the actual offset should be set by prover + // depending on the number of local read instances + let selector_w = SelectorType::Prefix { + offset: 0, + expression: selector_w.expr(), + }; + // TODO: when selector_r = 1 => selector_zero = 1 + // when selector_w = 1 => selector_zero = 1 + let selector_zero = SelectorType::Prefix { + offset: 0, + expression: selector_zero.expr(), + }; + + cb.cs.r_selector = Some(selector_r); + cb.cs.w_selector = Some(selector_w); + cb.cs.zero_selector = Some(selector_zero.clone()); + cb.cs.lk_selector = Some(selector_zero); + + // all shared the same selector + let (out_evals, mut chip) = ( + [ + // r_record + (0..r_len).collect_vec(), + // w_record + (r_len..r_len + w_len).collect_vec(), + // lk_record + (r_len + w_len..r_len + w_len + lk_len).collect_vec(), + // zero_record + (0..zero_len).collect_vec(), + ], + Chip::new_from_cb(cb, 0), + ); + + let layer = Layer::from_circuit_builder(cb, format!("{}_main", Self::name()), 0, out_evals); + chip.add_layer(layer); + + Ok((config, chip.gkr_circuit())) + } + fn assign_instance( config: &Self::InstructionConfig, instance: &mut [E::BaseField], @@ -338,13 +428,22 @@ impl #[cfg(test)] mod tests { use ff_ext::{BabyBearExt4, PoseidonField}; - use mpcs::{BasefoldDefault, SecurityLevel}; - use p3::babybear::BabyBear; + use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; + use p3::{babybear::BabyBear, field::FieldAlgebra}; + use transcript::BasicTranscript; use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, gadgets::horizen_round_consts, - instructions::global::GlobalChip, - scheme::{create_backend, create_prover}, + instructions::{ + Instruction, + global::{GlobalChip, GlobalRecord}, + }, + scheme::{ + create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, + septic_curve::SepticPoint, + }, + structs::{ComposedConstrainSystem, ProgramParams, RAMType, ZKVMProvingKey}, }; type E = BabyBearExt4; @@ -359,12 +458,91 @@ mod tests { let perm = ::get_default_perm(); let global_chip = GlobalChip:: { rc, perm }; + let mut cs = ConstraintSystem::new(|| "global chip test"); + let mut cb = CircuitBuilder::new(&mut cs); + + let (config, gkr_circuit) = global_chip + .build_gkr_iop_circuit(&mut cb, &ProgramParams::default()) + .unwrap(); + let composed_cs = ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit: Some(gkr_circuit), + }; + let pk = composed_cs.key_gen(); + // create a bunch of random memory read/write records + let n_reads = 10; + let n_writes = 10; + let global_reads = (0..n_reads) + .map(|i| { + let addr = i * 8; + let value = (i + 1) * 8; + + GlobalRecord { + addr: addr as u32, + ram_type: RAMType::Memory, + value: value as u32, + shard: 1, + local_clk: 0, + global_clk: i, + is_write: false, + } + }) + .collect::>(); + + let global_writes = (0..n_writes) + .map(|i| { + let addr = i * 8; + let value = (i + 1) * 8; + + GlobalRecord { + addr: addr as u32, + ram_type: RAMType::Memory, + value: value as u32, + shard: 1, + local_clk: i, + global_clk: i, + is_write: true, + } + }) + .collect::>(); + let global_ec_sum: SepticPoint = global_reads + .iter() + .chain(global_writes.iter()) + .map(|record| record.to_ec_point::(&global_chip.perm).1) + .sum(); + + assert!(global_ec_sum.is_infinity == true); // assign witness // create chip proof for global chip + let pcs_param = PCS::setup(1 << 20, SecurityLevel::Conjecture100bits).unwrap(); + let (pp, vp) = PCS::trim(pcs_param, 1 << 20).unwrap(); let backend = create_backend::(20, SecurityLevel::Conjecture100bits); - let prover = create_prover(backend); + let pd = create_prover(backend); + + // let pk = prover.create_chip_proof(); + let mut zkvm_pk = ZKVMProvingKey::new(pp, vp); + let zkvm_prover = ZKVMProver::new(zkvm_pk, pd); + let mut transcript = BasicTranscript::new(b"global chip test"); + + let proof_input = ProofInput { + witness: todo!(), + structural_witness: todo!(), + fixed: todo!(), + public_input: todo!(), + num_instances: todo!(), + }; + let challenges = [E::ONE, E::ONE]; + let proof = zkvm_prover + .create_chip_proof( + "global chip", + &pk, + proof_input, + &mut transcript, + &challenges, + ) + .unwrap(); } } diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index 3be11dbd0..ccbda01f9 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -140,11 +140,15 @@ impl WeierstrassAddAssignLayout { descending: false, }, ); + let sel = SelectorType::Prefix { + offset: 0, + expression: eq.expr(), + }; let selector_type_layout = SelectorTypeLayout { - sel_mem_read: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_mem_write: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_lookup: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_zero: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), + sel_mem_read: sel.clone(), + sel_mem_write: sel.clone(), + sel_lookup: sel.clone(), + sel_zero: sel.clone(), }; // Default expression, will be updated in build_layer_logic diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index 52496e869..0d6406431 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -158,11 +158,15 @@ impl descending: false, }, ); + let sel = SelectorType::Prefix { + offset: 0, + expression: eq.expr(), + }; let selector_type_layout = SelectorTypeLayout { - sel_mem_read: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_mem_write: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_lookup: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_zero: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), + sel_mem_read: sel.clone(), + sel_mem_write: sel.clone(), + sel_lookup: sel.clone(), + sel_zero: sel.clone(), }; let input32_exprs: GenericArray< diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index e5f16ba2f..decaa317f 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -142,11 +142,15 @@ impl descending: false, }, ); + let sel = SelectorType::Prefix { + offset: 0, + expression: eq.expr(), + }; let selector_type_layout = SelectorTypeLayout { - sel_mem_read: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_mem_write: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_lookup: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), - sel_zero: SelectorType::Prefix(E::BaseField::ZERO, eq.expr()), + sel_mem_read: sel.clone(), + sel_mem_write: sel.clone(), + sel_lookup: sel.clone(), + sel_zero: sel.clone(), }; let input32_exprs: GenericArray< diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index a33b890e9..98393d2e1 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -60,7 +60,10 @@ pub struct ZKVMChipProof { pub tower_proof: TowerProofs, + pub num_read_instances: usize, + pub num_write_instances: usize, pub num_instances: usize, + pub fixed_in_evals: Vec, pub wits_in_evals: Vec, } diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 029051f34..06dd70ef6 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -83,7 +83,10 @@ impl CpuEccProver { let mut expr_builder = VirtualPolynomialsBuilder::new(num_threads, out_rt.len()); - let sel = SelectorType::Prefix(E::BaseField::ZERO, 0.into()); + let sel = SelectorType::Prefix { + offset: 0, + expression: 0.into(), + }; let num_instances = (1 << n) - 1; let mut sel_mle: MultilinearExtension<'_, E> = sel.compute(&out_rt, num_instances).unwrap(); let sel_expr = expr_builder.lift(sel_mle.to_either()); diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 1a1c4f17e..cd2f64fde 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -378,6 +378,8 @@ impl< tower_proof, fixed_in_evals, wits_in_evals, + num_read_instances: input.num_instances, + num_write_instances: input.num_instances, num_instances: input.num_instances, }, pi_in_evals, diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index 1d4e6c56a..8a00132cb 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -450,7 +450,10 @@ pub fn extend_exprs_with_rotation( let expr = match sel_type { SelectorType::None => zero_check_expr, SelectorType::Whole(sel) - | SelectorType::Prefix(_, sel) + | SelectorType::Prefix { + offset: _, + expression: sel, + } | SelectorType::OrderedSparse32 { expression: sel, .. } => match_expr(sel) * zero_check_expr, diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index bc57295f1..05dac293d 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -1,12 +1,19 @@ +use std::iter::repeat_n; + use rayon::iter::IndexedParallelIterator; use ff_ext::ExtensionField; use multilinear_extensions::{ Expression, mle::{IntoMLE, MultilinearExtension, Point}, + util::ceil_log2, virtual_poly::{build_eq_x_r_vec, eq_eval}, }; -use rayon::{iter::ParallelIterator, slice::ParallelSliceMut}; +use p3::field::FieldAlgebra; +use rayon::{ + iter::{IntoParallelIterator, ParallelIterator}, + slice::ParallelSliceMut, +}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use crate::{gkr::booleanhypercube::CYCLIC_POW2_5, utils::eq_eval_less_or_equal_than}; @@ -21,7 +28,14 @@ pub enum SelectorType { None, Whole(Expression), /// Select a prefix as the instances, padded with a field element. - Prefix(E::BaseField, Expression), + /// 1. [0, offset) are zeros; + /// 2. [offset, offset + num_instances) are ones, + /// 3. [offset + num_instances, 2^n) are zeros. + Prefix { + // offset is not fixed at setup time. + offset: usize, + expression: Expression, + }, /// selector activates on the specified `indices`, which are assumed to be in ascending order. /// each index corresponds to a position within a fixed-size chunk (e.g., size 32), OrderedSparse32 { @@ -31,6 +45,77 @@ pub enum SelectorType { } impl SelectorType { + pub fn as_mle( + &self, + num_instances: usize, + num_vars: usize, + ) -> Option> { + match self { + SelectorType::None => None, + SelectorType::Whole(_) => { + assert_eq!(ceil_log2(num_instances), num_vars); + Some( + (0..(1 << num_vars)) + .into_par_iter() + .map(|_| E::BaseField::ONE) + .collect::>() + .into_mle(), + ) + } + SelectorType::Prefix { + offset, + expression: _, + } => { + assert!(*offset + num_instances <= (1 << num_vars)); + let end = *offset + num_instances; + Some( + (0..*offset) + .into_par_iter() + .map(|_| E::BaseField::ZERO) + .chain((*offset..end).into_par_iter().map(|_| E::BaseField::ONE)) + .chain( + (end..(1 << num_vars)) + .into_par_iter() + .map(|_| E::BaseField::ZERO), + ) + .collect::>() + .into_mle(), + ) + } + SelectorType::OrderedSparse32 { + indices, + expression: _, + } => { + assert_eq!(ceil_log2(num_instances), num_vars); + Some( + (0..(1 << num_vars)) + .into_par_iter() + .flat_map(|chunk_index| { + if chunk_index >= num_instances { + vec![E::ZERO; 32] + } else { + let mut chunk = vec![E::ZERO; 32]; + let mut indices_iter = indices.iter().copied(); + let mut next_keep = indices_iter.next(); + + for (i, e) in chunk.iter_mut().enumerate() { + if let Some(idx) = next_keep + && i == idx + { + *e = E::ONE; + next_keep = indices_iter.next(); // Keep this one + } + } + chunk + } + }) + .collect::>() + .into_mle(), + ) + } + } + } + /// Compute true and false mle eq(1; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (eq() - sel(y; b[5..])) pub fn compute( &self, @@ -39,18 +124,23 @@ impl SelectorType { ) -> Option> { match self { SelectorType::None => None, - SelectorType::Whole(_expr) => Some(build_eq_x_r_vec(out_point).into_mle()), - SelectorType::Prefix(_, _expr) => { + SelectorType::Whole(_) => Some(build_eq_x_r_vec(out_point).into_mle()), + SelectorType::Prefix { + offset, + expression: _expr, + } => { + let num_vars = out_point.len(); + let end = *offset + num_instances; + assert!(end <= (1 << num_vars)); + let mut sel = build_eq_x_r_vec(out_point); - if num_instances < sel.len() { - sel.splice( - num_instances..sel.len(), - std::iter::repeat_n(E::ZERO, sel.len() - num_instances), - ); - } + sel.splice(0..*offset, repeat_n(E::ZERO, *offset)); + sel.splice(end..sel.len(), repeat_n(E::ZERO, sel.len() - end)); Some(sel.into_mle()) } SelectorType::OrderedSparse32 { indices, .. } => { + assert_eq!(out_point.len(), ceil_log2(num_instances) + 5); + let mut sel = build_eq_x_r_vec(out_point); sel.par_chunks_exact_mut(CYCLIC_POW2_5.len()) .enumerate() @@ -93,12 +183,15 @@ impl SelectorType { debug_assert_eq!(out_point.len(), in_point.len()); (expr, eq_eval(out_point, in_point)) } - SelectorType::Prefix(_, expr) => { - debug_assert!(num_instances <= (1 << out_point.len())); - ( - expr, - eq_eval_less_or_equal_than(num_instances - 1, out_point, in_point), - ) + SelectorType::Prefix { offset, expression } => { + let end = *offset + num_instances; + + assert_eq!(in_point.len(), out_point.len()); + assert!(end <= (1 << out_point.len())); + + let eq_start = eq_eval_less_or_equal_than(*offset - 1, out_point, in_point); + let eq_end = eq_eval_less_or_equal_than(end - 1, out_point, in_point); + (expression, eq_end - eq_start) } SelectorType::OrderedSparse32 { indices, @@ -137,7 +230,10 @@ impl SelectorType { match self { Self::OrderedSparse32 { expression, .. } | Self::Whole(expression) - | Self::Prefix(_, expression) => expression, + | Self::Prefix { + offset: _, + expression, + } => expression, e => unimplemented!("no selector expression in {:?}", e), } } From 08783ca606bf09fba7565028b91e8d2caeeeced7 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Sun, 19 Oct 2025 16:36:17 +0800 Subject: [PATCH 58/91] wip convert local final ram circuit to gkr-iop circuit --- .../src/instructions/riscv/rv32im/mmu.rs | 2 + ceno_zkvm/src/scheme/cpu/mod.rs | 105 ++-------- ceno_zkvm/src/scheme/prover.rs | 1 + ceno_zkvm/src/scheme/utils.rs | 14 +- ceno_zkvm/src/scheme/verifier.rs | 80 +------ ceno_zkvm/src/structs.rs | 23 +- ceno_zkvm/src/tables/mod.rs | 15 ++ ceno_zkvm/src/tables/ram/ram_circuit.rs | 57 ++++- ceno_zkvm/src/tables/ram/ram_impl.rs | 39 ++-- gkr_iop/src/chip.rs | 8 +- gkr_iop/src/gkr/layer.rs | 198 +++++++++--------- 11 files changed, 229 insertions(+), 313 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index e15335810..a95b8e03f 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -46,7 +46,9 @@ impl MmuConfig<'_, E> { let hints_config = cs.register_table_circuit::>(); let stack_init_config = cs.register_table_circuit::>(); let heap_init_config = cs.register_table_circuit::>(); + println!("register LocalFinalCircuit"); let local_final_circuit = cs.register_table_circuit::>(); + println!("end register LocalFinalCircuit"); let ram_bus_circuit = cs.register_table_circuit::>(); Self { diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index f6f0683c9..54fbac475 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -681,98 +681,19 @@ impl> MainSumcheckProver>(); - let fixed_in_evals = evals.split_off(input.witness.len()); - let wits_in_evals = evals; - exit_span!(span); - (wits_in_evals, fixed_in_evals, None, rt_tower) - } else { - assert!(cs.r_table_expressions.len() <= 1); - assert!(cs.w_table_expressions.len() <= 1); - - let sel_type = SelectorType::Prefix(E::BaseField::ZERO, 0.into()); - let mut sel_mle = sel_type.compute(&rt_tower, num_instances).unwrap(); - - // `wit` := witin ++ fixed - // we concat eq in between `wit` := witin ++ eqs ++ fixed - let all_witins = input - .witness - .iter() - .map(|mle| Either::Left(mle.as_ref())) - .chain(vec![Either::Right(&mut sel_mle)]) - .chain(input.fixed.iter().map(|mle| Either::Left(mle.as_ref()))) - .collect_vec(); - assert_eq!( - all_witins.len() as WitnessId, - cs.num_witin + cs.num_structural_witin + cs.num_fixed as WitnessId, - "all_witins.len() {} != layer.n_witin {} + layer.n_structural_witin {} + layer.n_fixed {}", - all_witins.len(), - cs.num_witin, - cs.num_structural_witin, - cs.num_fixed, - ); - let builder = VirtualPolynomialsBuilder::new_with_mles( - num_threads, - rt_tower.len(), - all_witins, - ); - - let alpha_pows_expr = (2..) - .take(cs.r_table_expressions.len() + cs.w_table_expressions.len()) - .map(|id| Expression::Challenge(id as ChallengeId, 1, E::ONE, E::ZERO)) - .collect_vec(); - let zero_check_expr: Expression = cs - .r_table_expressions - .iter() - .take(1) - .chain(cs.w_table_expressions.iter().take(1)) - .zip_eq(&alpha_pows_expr) - .map(|(expr, alpha)| alpha * expr.expr.expr()) - .sum(); - let zero_check_monomial = monomialize_expr_to_wit_terms( - &zero_check_expr, - cs.num_witin as WitnessId, - cs.num_structural_witin as WitnessId, - cs.num_fixed as WitnessId, - ); - let main_sumcheck_challenges = chain!( - challenges.iter().copied(), - get_challenge_pows( - cs.w_table_expressions.len() + cs.r_table_expressions.len(), - transcript, - ) - ) - .collect_vec(); - - let span = entered_span!("IOPProverState::prove", profiling_4 = true); - let (proof, prover_state) = IOPProverState::prove( - builder.to_virtual_polys_with_monomial_terms( - &zero_check_monomial, - &[], - &main_sumcheck_challenges, - ), - transcript, - ); - exit_span!(span); - let rt = prover_state - .challenges - .iter() - .map(|c| c.elements) - .collect_vec(); - let mut wits_in_evals = prover_state.get_mle_flatten_final_evaluations(); - let mut rest = wits_in_evals.split_off(cs.num_witin as usize); - let rest = rest.split_off(cs.num_structural_witin as usize); - let fixed_in_evals = rest; - (wits_in_evals, fixed_in_evals, Some(proof.proofs), rt) - }; + let (wits_in_evals, fixed_in_evals, main_sumcheck_proof, rt) = { + let span = entered_span!("fixed::evals + witin::evals"); + let mut evals = input + .witness + .par_iter() + .chain(input.fixed.par_iter()) + .map(|poly| poly.evaluate(&rt_tower[..poly.num_vars()])) + .collect::>(); + let fixed_in_evals = evals.split_off(input.witness.len()); + let wits_in_evals = evals; + exit_span!(span); + (wits_in_evals, fixed_in_evals, None, rt_tower) + }; Ok(( rt, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 8df281186..2a583eae1 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -210,6 +210,7 @@ impl< let (points, evaluations) = self.pk.circuit_pks.iter().enumerate().try_fold( (vec![], vec![]), |(mut points, mut evaluations), (index, (circuit_name, pk))| { + println!("prove circuit_name {circuit_name}"); let num_instances = circuit_name_num_instances_mapping .get(&circuit_name) .copied() diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 24137eb3a..0bc1e8f09 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -345,12 +345,16 @@ pub fn build_main_witness< } if let Some(gkr_circuit) = gkr_circuit { - // opcode must have at least one read/write/lookup + // circuit must have at least one read/write/lookup assert!( - cs.lk_expressions.is_empty() - || !cs.r_expressions.is_empty() - || !cs.w_expressions.is_empty(), - "assert opcode circuit" + cs.r_expressions.len() + + cs.w_expressions.len() + + cs.lk_expressions.len() + + cs.r_table_expressions.len() + + cs.w_table_expressions.len() + + cs.lk_table_expressions.len() + > 0, + "assert circuit" ); let (_, gkr_circuit_out) = gkr_witness::( diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 0153e5191..f29440cbc 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -164,6 +164,7 @@ impl> ZKVMVerifier for (index, proof) in &vm_proof.chip_proofs { assert!(proof.num_instances > 0); let circuit_name = &self.vk.circuit_index_to_name[index]; + println!("verify circuit_name {circuit_name}"); let circuit_vk = &self.vk.circuit_vks[circuit_name]; // check chip proof is well-formed @@ -356,9 +357,9 @@ impl> ZKVMVerifier } = &composed_cs; let num_instances = proof.num_instances; let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = ( - cs.r_expressions.len(), - cs.w_expressions.len(), - cs.lk_expressions.len(), + cs.r_expressions.len() + cs.r_table_expressions.len(), + cs.w_expressions.len() + cs.w_table_expressions.len(), + cs.lk_expressions.len() + cs.lk_table_expressions.len() * 2, ); let num_batched = r_counts_per_instance + w_counts_per_instance + lk_counts_per_instance; @@ -529,9 +530,8 @@ impl> ZKVMVerifier "[prod_record] mismatch length" ); - let input_opening_point = if next_pow2_instance_padding(proof.num_instances) - == proof.num_instances - { + let ram_bus_circuit = false; + let input_opening_point = if !ram_bus_circuit { // evaluate the evaluation of structural mles at input_opening_point by verifier let structural_evals = if with_rw { // only iterate r set, as read/write set round should match @@ -620,73 +620,7 @@ impl> ZKVMVerifier } rt_tower } else { - // TODO LocalFinalTable goes here, merge flow into gkr_iop - assert_eq!(cs.lk_table_expressions.len(), 0); - assert!(proof.main_sumcheck_proofs.is_some()); - assert_eq!(cs.num_structural_witin, 1); - assert_eq!(prod_point_and_eval.len(), 1); - - // verify opening same point layer sumcheck - let alpha_pow = get_challenge_pows( - cs.r_table_expressions.len() + cs.w_table_expressions.len(), - transcript, - ); - - // \sum_i alpha_{i} * (out_r_eval{i} - ONE) - // + \sum_i alpha_{i} * (out_w_eval{i} - ONE) - let claim_sum = prod_point_and_eval - .iter() - .zip_eq(alpha_pow.iter()) - .map(|(point_and_eval, alpha)| *alpha * (point_and_eval.eval - E::ONE)) - .sum::(); - let sel_subclaim = IOPVerifierState::verify( - claim_sum, - &IOPProof { - proofs: proof.main_sumcheck_proofs.clone().unwrap(), - }, - &VPAuxInfo { - max_degree: SEL_DEGREE, - max_num_variables: expected_max_rounds, - phantom: PhantomData, - }, - transcript, - ); - let (input_opening_point, sumcheck_eval) = ( - sel_subclaim.point.iter().map(|c| c.elements).collect_vec(), - sel_subclaim.expected_evaluation, - ); - let structural_evals = vec![eq_eval_less_or_equal_than( - proof.num_instances - 1, - &prod_point_and_eval[0].point, - &input_opening_point, - )]; - - let expected_evals = interleave( - &cs.r_table_expressions, // r - &cs.w_table_expressions, // w - ) - .map(|rw| &rw.expr) - .zip(alpha_pow.iter()) - .map(|(expr, alpha)| { - *alpha - * eval_by_expr_with_instance( - &proof.fixed_in_evals, - &proof.wits_in_evals, - &structural_evals, - pi, - challenges, - expr, - ) - .right() - .unwrap() - }) - .sum::(); - if expected_evals != sumcheck_eval { - return Err(ZKVMError::VerifyError( - "sel evaluation verify failed".into(), - )); - } - input_opening_point + unimplemented!("shard ram bus circuit go here"); }; // assume public io is tiny vector, so we evaluate it directly without PCS diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index ff5bd15a7..8c92036ae 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -135,9 +135,7 @@ impl ComposedConstrainSystem { } pub fn is_opcode_circuit(&self) -> bool { - self.zkvm_v1_css.lk_table_expressions.is_empty() - && self.zkvm_v1_css.r_table_expressions.is_empty() - && self.zkvm_v1_css.w_table_expressions.is_empty() + self.gkr_circuit.is_some() } /// return number of lookup operation @@ -219,18 +217,13 @@ impl ZKVMConstraintSystem { pub fn register_table_circuit>(&mut self) -> TC::TableConfig { let mut cs = ConstraintSystem::new(|| format!("riscv_table/{}", TC::name())); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let config = TC::construct_circuit(&mut circuit_builder, &self.params).unwrap(); - assert!( - self.circuit_css - .insert( - TC::name(), - ComposedConstrainSystem { - zkvm_v1_css: cs, - gkr_circuit: None - } - ) - .is_none() - ); + let (config, gkr_iop_circuit) = + TC::build_gkr_iop_circuit(&mut circuit_builder, &self.params).unwrap(); + let cs = ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit: gkr_iop_circuit, + }; + assert!(self.circuit_css.insert(TC::name(), cs).is_none()); config } diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 33ce3bf4a..05be820f2 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -1,7 +1,14 @@ use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, structs::ProgramParams}; use ff_ext::ExtensionField; use std::collections::HashMap; +use itertools::Itertools; +use multilinear_extensions::{StructuralWitInType, ToExpr}; use witness::RowMajorMatrix; +use gkr_iop::chip::Chip; +use gkr_iop::gkr::GKRCircuit; +use gkr_iop::gkr::layer::Layer; +use gkr_iop::selector::SelectorType; + mod range; pub use range::*; @@ -29,6 +36,14 @@ pub trait TableCircuit { params: &ProgramParams, ) -> Result; + fn build_gkr_iop_circuit( + cb: &mut CircuitBuilder, + param: &ProgramParams, + ) -> Result<(Self::TableConfig, Option>), ZKVMError> { + let config = Self::construct_circuit(cb, param)?; + Ok((config, None)) + } + fn generate_fixed_traces( config: &Self::TableConfig, num_fixed: usize, diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index ff5e9a783..8fc43e348 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, marker::PhantomData}; use super::ram_impl::{ - LocalRAMTableFinalConfig, NonVolatileTableConfigTrait, PubIOTableConfig, RAMBusConfig, + LocalFinalRAMTableConfig, NonVolatileTableConfigTrait, PubIOTableConfig, RAMBusConfig, }; use crate::{ circuit_builder::CircuitBuilder, @@ -12,7 +12,15 @@ use crate::{ }; use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word}; use ff_ext::{ExtensionField, SmallField}; -use gkr_iop::error::CircuitBuilderError; +use gkr_iop::{ + chip::Chip, + error::CircuitBuilderError, + gkr::{GKRCircuit, layer::Layer}, + selector::SelectorType, +}; +use itertools::Itertools; +use multilinear_extensions::{StructuralWitInType, ToExpr}; +use p3::field::FieldAlgebra; use witness::{InstancePaddingStrategy, RowMajorMatrix}; #[derive(Clone, Debug)] @@ -275,7 +283,7 @@ pub struct LocalFinalRamCircuit<'a, const V_LIMBS: usize, E>(PhantomData<(&'a () impl<'a, E: ExtensionField, const V_LIMBS: usize> TableCircuit for LocalFinalRamCircuit<'a, V_LIMBS, E> { - type TableConfig = LocalRAMTableFinalConfig; + type TableConfig = LocalFinalRAMTableConfig; type FixedInput = (); type WitnessInput = ( &'a ShardContext<'a>, @@ -296,6 +304,49 @@ impl<'a, E: ExtensionField, const V_LIMBS: usize> TableCircuit )?) } + fn build_gkr_iop_circuit( + cb: &mut CircuitBuilder, + param: &ProgramParams, + ) -> Result<(Self::TableConfig, Option>), ZKVMError> { + let config = Self::construct_circuit(cb, param)?; + let r_table_len = cb.cs.r_table_expressions.len(); + + let selector = cb.create_structural_witin( + || "selector", + StructuralWitInType::EqualDistanceSequence { + // TODO determin proper size of max length + max_len: u32::MAX as usize, + offset: 0, + multi_factor: 0, + descending: false, + }, + ); + let selector_type = SelectorType::Prefix(E::BaseField::ZERO, selector.expr()); + + // all shared the same selector + let (out_evals, mut chip) = ( + [ + // r_record + (0..r_table_len).collect_vec(), + // w_record + vec![], + // lk_record + vec![], + // zero_record + vec![], + ], + Chip::new_from_cb(cb, 0), + ); + + // register selector to legacy constrain system + cb.cs.r_selector = Some(selector_type.clone()); + + let layer = Layer::from_circuit_builder(cb, "Rounds".to_string(), 0, out_evals); + chip.add_layer(layer); + + Ok((config, Some(chip.gkr_circuit()))) + } + fn generate_fixed_traces( _config: &Self::TableConfig, _num_fixed: usize, diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index cf38184da..7a24ef66f 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -532,15 +532,15 @@ impl DynVolatileRamTableConfig /// This table is generalized version to handle all mmio records #[derive(Clone, Debug)] -pub struct LocalRAMTableFinalConfig { +pub struct LocalFinalRAMTableConfig { addr_subset: WitIn, - sel: StructuralWitIn, + ram_type: WitIn, final_v: Vec, final_cycle: WitIn, } -impl LocalRAMTableFinalConfig { +impl LocalFinalRAMTableConfig { pub fn construct_circuit( cb: &mut CircuitBuilder, _params: &ProgramParams, @@ -548,25 +548,11 @@ impl LocalRAMTableFinalConfig { let addr_subset = cb.create_witin(|| "addr_subset"); let ram_type = cb.create_witin(|| "ram_type"); - let sel = cb.create_structural_witin( - || "sel", - StructuralWitInType::EqualDistanceSequence { - max_len: u32::MAX as usize, - offset: 0, - multi_factor: WORD_SIZE, - descending: false, - }, - ); - let final_v = (0..V_LIMBS) .map(|i| cb.create_witin(|| format!("final_v_limb_{i}"))) .collect::>(); let final_cycle = cb.create_witin(|| "final_cycle"); - // R_{local} = sel * rlc_final_table + (ONE - sel) * ONE - // => R_{local} - ONE = sel * (rlc_final_table - ONE) - // so we put `sel * (rlc_final_table - ONE)` in expression - // and `R_{local} - ONE` can be derived from verifier let final_expr = final_v.iter().map(|v| v.expr()).collect_vec(); let raw_final_table = [ // a v t @@ -576,24 +562,20 @@ impl LocalRAMTableFinalConfig { vec![final_cycle.expr()], ] .concat(); - let final_table_expr = sel.expr() - * (cb.rlc_chip_record(raw_final_table.clone()) - - Expression::Constant(Either::Left(E::BaseField::ONE))); - cb.r_table_rlc_record( + cb.r_table_record( || "final_table", // XXX we mixed all ram type here to save column allocation RAMType::Undefined, SetTableSpec { len: None, - structural_witins: vec![sel], + structural_witins: vec![], }, raw_final_table, - final_table_expr, )?; Ok(Self { addr_subset, - sel, + ram_type, final_v, final_cycle, }) @@ -607,7 +589,9 @@ impl LocalRAMTableFinalConfig { num_structural_witin: usize, final_mem: &[(InstancePaddingStrategy, &[MemFinalRecord])], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - assert_eq!(num_structural_witin, 1); + assert!(num_structural_witin == 0 || num_structural_witin == 1); + let num_structural_witin = num_structural_witin.max(1); + let selector_witin = WitIn { id: 0 }; // collect each raw mem belong to this shard, BEFORE padding length let current_shard_mems_len: Vec = final_mem @@ -702,8 +686,9 @@ impl LocalRAMTableFinalConfig { } set_val!(row, self.final_cycle, rec.cycle); + set_val!(row, self.ram_type, rec.ram_type as u64); set_val!(row, self.addr_subset, rec.addr as u64); - set_val!(structural_row, self.sel, 1u64); + set_val!(structural_row, selector_witin, 1u64); }); if *pad_size > 0 && shard_ctx.is_first_shard() { @@ -724,7 +709,7 @@ impl LocalRAMTableFinalConfig { self.addr_subset, pad_func(pad_index as u64, self.addr_subset.id as u64) ); - set_val!(structural_row, self.sel, 1u64); + set_val!(structural_row, selector_witin, 1u64); }); } _ => unimplemented!(), diff --git a/gkr_iop/src/chip.rs b/gkr_iop/src/chip.rs index 1b33bb1de..10048418e 100644 --- a/gkr_iop/src/chip.rs +++ b/gkr_iop/src/chip.rs @@ -40,11 +40,17 @@ impl Chip { n_evaluations: cb.cs.w_expressions.len() + cb.cs.r_expressions.len() + cb.cs.lk_expressions.len() + + cb.cs.w_table_expressions.len() + + cb.cs.r_table_expressions.len() + + cb.cs.lk_table_expressions.len() * 2 + cb.cs.num_fixed + cb.cs.num_witin as usize, final_out_evals: (0..cb.cs.w_expressions.len() + cb.cs.r_expressions.len() - + cb.cs.lk_expressions.len()) + + cb.cs.lk_expressions.len() + + cb.cs.w_table_expressions.len() + + cb.cs.r_table_expressions.len() + + cb.cs.lk_table_expressions.len() * 2) .collect_vec(), layers: vec![], } diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 4f88147c4..6bd76af68 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -345,112 +345,116 @@ impl Layer { let mut expr_names = Vec::with_capacity(non_zero_expr_len + zero_expr_len); let mut expressions = Vec::with_capacity(non_zero_expr_len + zero_expr_len); - // process r_record - let evals = - Self::dedup_last_selector_evals(cb.cs.r_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, ((ram_expr, name), ram_eval)) in (cb - .cs - .r_expressions - .iter() - .chain(cb.cs.r_table_expressions.iter().map(|t| &t.expr))) - .zip_eq( - cb.cs - .r_expressions_namespace_map - .iter() - .chain(&cb.cs.r_table_expressions_namespace_map), - ) - .zip_eq(&r_record_evals) - .enumerate() - { - expressions.push(ram_expr - E::BaseField::ONE.expr()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - one (padding) - *ram_eval, - E::BaseField::ONE.expr().into(), - E::BaseField::ONE.neg().expr().into(), - )); - expr_names.push(format!("{}/{idx}", name)); - } - - // process w_record - let evals = - Self::dedup_last_selector_evals(cb.cs.w_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, ((ram_expr, name), ram_eval)) in (cb - .cs - .w_expressions - .iter() - .chain(cb.cs.w_table_expressions.iter().map(|t| &t.expr))) - .zip_eq( - cb.cs - .w_expressions_namespace_map + if let Some(r_selector) = cb.cs.r_selector.as_ref() { + // process r_record + let evals = Self::dedup_last_selector_evals(r_selector, &mut expr_evals); + for (idx, ((ram_expr, name), ram_eval)) in (cb + .cs + .r_expressions .iter() - .chain(&cb.cs.w_table_expressions_namespace_map), - ) - .zip_eq(&w_record_evals) - .enumerate() - { - expressions.push(ram_expr - E::BaseField::ONE.expr()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - one (padding) - *ram_eval, - E::BaseField::ONE.expr().into(), - E::BaseField::ONE.neg().expr().into(), - )); - expr_names.push(format!("{}/{idx}", name)); - } - - // process lookup records - let evals = - Self::dedup_last_selector_evals(cb.cs.lk_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, ((lookup, name), lookup_eval)) in (cb - .cs - .lk_expressions - .iter() - .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.multiplicity)) - .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.values))) - .zip_eq(if cb.cs.lk_table_expressions.is_empty() { - Either::Left(cb.cs.lk_expressions_namespace_map.iter()) - } else { - // repeat expressions_namespace_map twice to deal with lk p, q - Either::Right( + .chain(cb.cs.r_table_expressions.iter().map(|t| &t.expr))) + .zip_eq( cb.cs - .lk_expressions_namespace_map + .r_expressions_namespace_map .iter() - .chain(&cb.cs.lk_expressions_namespace_map), + .chain(&cb.cs.r_table_expressions_namespace_map), ) - }) - .zip_eq(&lookup_evals) - .enumerate() - { - expressions.push(lookup - cb.cs.chip_record_alpha.clone()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - alpha (padding) - *lookup_eval, - E::BaseField::ONE.expr().into(), - cb.cs.chip_record_alpha.clone().neg().into(), - )); - expr_names.push(format!("{}/{idx}", name)); + .zip_eq(&r_record_evals) + .enumerate() + { + expressions.push(ram_expr - E::BaseField::ONE.expr()); + evals.push(EvalExpression::::Linear( + // evaluation = claim * one - one (padding) + *ram_eval, + E::BaseField::ONE.expr().into(), + E::BaseField::ONE.neg().expr().into(), + )); + expr_names.push(format!("{}/{idx}", name)); + } } - // process zero_record - let evals = - Self::dedup_last_selector_evals(cb.cs.zero_selector.as_ref().unwrap(), &mut expr_evals); - for (idx, (zero_expr, name)) in izip!( - 0.., - chain!( - cb.cs - .assert_zero_expressions - .iter() - .zip_eq(&cb.cs.assert_zero_expressions_namespace_map), + if let Some(w_selector) = cb.cs.w_selector.as_ref() { + // process w_record + let evals = Self::dedup_last_selector_evals(w_selector, &mut expr_evals); + for (idx, ((ram_expr, name), ram_eval)) in (cb + .cs + .w_expressions + .iter() + .chain(cb.cs.w_table_expressions.iter().map(|t| &t.expr))) + .zip_eq( cb.cs - .assert_zero_sumcheck_expressions + .w_expressions_namespace_map .iter() - .zip_eq(&cb.cs.assert_zero_sumcheck_expressions_namespace_map) + .chain(&cb.cs.w_table_expressions_namespace_map), ) - ) { - expressions.push(zero_expr.clone()); - evals.push(EvalExpression::Zero); - expr_names.push(format!("{}/{idx}", name)); + .zip_eq(&w_record_evals) + .enumerate() + { + expressions.push(ram_expr - E::BaseField::ONE.expr()); + evals.push(EvalExpression::::Linear( + // evaluation = claim * one - one (padding) + *ram_eval, + E::BaseField::ONE.expr().into(), + E::BaseField::ONE.neg().expr().into(), + )); + expr_names.push(format!("{}/{idx}", name)); + } + } + + if let Some(lk_selector) = cb.cs.lk_selector.as_ref() { + // process lookup records + let evals = Self::dedup_last_selector_evals(lk_selector, &mut expr_evals); + for (idx, ((lookup, name), lookup_eval)) in (cb + .cs + .lk_expressions + .iter() + .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.multiplicity)) + .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.values))) + .zip_eq(if cb.cs.lk_table_expressions.is_empty() { + Either::Left(cb.cs.lk_expressions_namespace_map.iter()) + } else { + // repeat expressions_namespace_map twice to deal with lk p, q + Either::Right( + cb.cs + .lk_expressions_namespace_map + .iter() + .chain(&cb.cs.lk_expressions_namespace_map), + ) + }) + .zip_eq(&lookup_evals) + .enumerate() + { + expressions.push(lookup - cb.cs.chip_record_alpha.clone()); + evals.push(EvalExpression::::Linear( + // evaluation = claim * one - alpha (padding) + *lookup_eval, + E::BaseField::ONE.expr().into(), + cb.cs.chip_record_alpha.clone().neg().into(), + )); + expr_names.push(format!("{}/{idx}", name)); + } + } + + if let Some(zero_selector) = cb.cs.zero_selector.as_ref() { + // process zero_record + let evals = Self::dedup_last_selector_evals(zero_selector, &mut expr_evals); + for (idx, (zero_expr, name)) in izip!( + 0.., + chain!( + cb.cs + .assert_zero_expressions + .iter() + .zip_eq(&cb.cs.assert_zero_expressions_namespace_map), + cb.cs + .assert_zero_sumcheck_expressions + .iter() + .zip_eq(&cb.cs.assert_zero_sumcheck_expressions_namespace_map) + ) + ) { + expressions.push(zero_expr.clone()); + evals.push(EvalExpression::Zero); + expr_names.push(format!("{}/{idx}", name)); + } } // Sort expressions, expr_names, and evals according to eval.0 and classify evals. From 6ac69fc37fc33fc6f1a9d282f5a42dbc9fef6a61 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Sun, 19 Oct 2025 16:45:30 +0800 Subject: [PATCH 59/91] chores: rename config --- ceno_zkvm/src/scheme/constants.rs | 1 - ceno_zkvm/src/scheme/verifier.rs | 8 +++++--- ceno_zkvm/src/tables/mod.rs | 11 ++++++----- gkr_iop/src/gkr/layer/cpu/mod.rs | 1 + 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/ceno_zkvm/src/scheme/constants.rs b/ceno_zkvm/src/scheme/constants.rs index 3cc212e9f..191fdf103 100644 --- a/ceno_zkvm/src/scheme/constants.rs +++ b/ceno_zkvm/src/scheme/constants.rs @@ -1,5 +1,4 @@ pub(crate) const MIN_PAR_SIZE: usize = 64; -pub(crate) const SEL_DEGREE: usize = 2; pub const NUM_FANIN: usize = 2; pub const NUM_FANIN_LOGUP: usize = 2; diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index f29440cbc..222e44c9a 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -8,7 +8,7 @@ use ff_ext::{Instrumented, PoseidonField}; use super::{ZKVMChipProof, ZKVMProof}; use crate::{ error::ZKVMError, - scheme::constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEL_DEGREE}, + scheme::constants::{NUM_FANIN, NUM_FANIN_LOGUP}, structs::{ComposedConstrainSystem, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey}, utils::{ eval_inner_repeated_incremental_vec, eval_outer_repeated_incremental_vec, @@ -530,8 +530,10 @@ impl> ZKVMVerifier "[prod_record] mismatch length" ); - let ram_bus_circuit = false; - let input_opening_point = if !ram_bus_circuit { + // TODO differentiate `ram_bus` via cs + let is_shard_ram_bus_circuit = false; + + let input_opening_point = if !is_shard_ram_bus_circuit { // evaluate the evaluation of structural mles at input_opening_point by verifier let structural_evals = if with_rw { // only iterate r set, as read/write set round should match diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 05be820f2..6e9b4d9d2 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -1,13 +1,14 @@ use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, structs::ProgramParams}; use ff_ext::ExtensionField; -use std::collections::HashMap; +use gkr_iop::{ + chip::Chip, + gkr::{GKRCircuit, layer::Layer}, + selector::SelectorType, +}; use itertools::Itertools; use multilinear_extensions::{StructuralWitInType, ToExpr}; +use std::collections::HashMap; use witness::RowMajorMatrix; -use gkr_iop::chip::Chip; -use gkr_iop::gkr::GKRCircuit; -use gkr_iop::gkr::layer::Layer; -use gkr_iop::selector::SelectorType; mod range; pub use range::*; diff --git a/gkr_iop/src/gkr/layer/cpu/mod.rs b/gkr_iop/src/gkr/layer/cpu/mod.rs index fa4c33c5e..41807e6b7 100644 --- a/gkr_iop/src/gkr/layer/cpu/mod.rs +++ b/gkr_iop/src/gkr/layer/cpu/mod.rs @@ -168,6 +168,7 @@ impl> ZerocheckLayerProver ) ) .collect_vec(); + // zero check eq || rotation eq let mut eqs = layer .out_sel_and_eval_exprs From f20e9702a1e1e7e488938b301b89b99f4902100c Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sun, 19 Oct 2025 18:27:20 +0800 Subject: [PATCH 60/91] make record to be trait Instruction's associated type --- ceno_zkvm/src/instructions.rs | 5 +- ceno_zkvm/src/instructions/global.rs | 168 ++++++++++++------ ceno_zkvm/src/instructions/riscv/arith.rs | 1 + .../riscv/arith_imm/arith_imm_circuit_v2.rs | 1 + ceno_zkvm/src/instructions/riscv/auipc.rs | 3 +- .../riscv/branch/branch_circuit.rs | 5 +- .../riscv/branch/branch_circuit_v2.rs | 1 + ceno_zkvm/src/instructions/riscv/div.rs | 10 +- .../instructions/riscv/div/div_circuit_v2.rs | 1 + .../instructions/riscv/dummy/dummy_circuit.rs | 1 + .../instructions/riscv/dummy/dummy_ecall.rs | 1 + .../src/instructions/riscv/ecall/halt.rs | 1 + .../src/instructions/riscv/ecall/keccak.rs | 1 + .../riscv/ecall/weierstrass_add.rs | 1 + .../riscv/ecall/weierstrass_decompress.rs | 1 + .../riscv/ecall/weierstrass_double.rs | 1 + .../src/instructions/riscv/jump/jal_v2.rs | 3 +- .../src/instructions/riscv/jump/jalr_v2.rs | 3 +- .../instructions/riscv/logic/logic_circuit.rs | 1 + .../riscv/logic_imm/logic_imm_circuit_v2.rs | 1 + ceno_zkvm/src/instructions/riscv/lui.rs | 3 +- .../src/instructions/riscv/memory/load_v2.rs | 1 + .../src/instructions/riscv/memory/store_v2.rs | 1 + .../src/instructions/riscv/memory/test.rs | 9 +- .../riscv/mulh/mulh_circuit_v2.rs | 1 + .../riscv/shift/shift_circuit_v2.rs | 4 +- .../instructions/riscv/slt/slt_circuit_v2.rs | 1 + .../riscv/slti/slti_circuit_v2.rs | 1 + ceno_zkvm/src/scheme/tests.rs | 1 + ceno_zkvm/src/structs.rs | 2 +- 30 files changed, 163 insertions(+), 71 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 976eba333..e85643c6d 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -24,6 +24,7 @@ pub mod riscv; pub trait Instruction { type InstructionConfig: Send + Sync; + type Record: Sync; fn padding_strategy() -> InstancePaddingStrategy { InstancePaddingStrategy::Default @@ -103,14 +104,14 @@ pub trait Instruction { config: &Self::InstructionConfig, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, + step: &Self::Record, ) -> Result<(), ZKVMError>; fn assign_instances( config: &Self::InstructionConfig, num_witin: usize, num_structural_witin: usize, - steps: Vec, + steps: Vec, ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { // FIXME selector is the only structural witness // this is workaround, as call `construct_circuit` will not initialized selector diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 5cdfc83e0..2a14eccb2 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -3,27 +3,33 @@ use std::iter::repeat; use crate::{ Value, chip_handler::general::PublicIOQuery, + error::ZKVMError, gadgets::{Poseidon2Config, RoundConstants}, scheme::septic_curve::{SepticExtension, SepticPoint}, structs::{ProgramParams, RAMType}, + tables::RMMCollections, witness::LkMultiplicity, }; use ceno_emul::StepRecord; use ff_ext::{ExtensionField, FieldInto, POSEIDON2_BABYBEAR_WIDTH, SmallField}; use gkr_iop::{ chip::Chip, circuit_builder::CircuitBuilder, error::CircuitBuilderError, gkr::layer::Layer, - selector::SelectorType, + selector::SelectorType, utils::lk_multiplicity::Multiplicity, }; use itertools::Itertools; use multilinear_extensions::{ - Expression, StructuralWitInType::EqualDistanceSequence, ToExpr, WitIn, + Expression, StructuralWitInType::EqualDistanceSequence, ToExpr, WitIn, util::max_usable_threads, }; use p3::{ field::{Field, FieldAlgebra}, symmetric::Permutation, }; +use rayon::{ + iter::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSlice, +}; use std::ops::Deref; -use witness::set_val; +use witness::{RowMajorMatrix, set_val}; use crate::{ instructions::{Instruction, riscv::constants::UInt}, @@ -137,13 +143,13 @@ impl GlobalConfig { ); // enforces x = poseidon2([addr, ram_type, value[0], value[1], shard, global_clk, nonce, 0, ..., 0]) - for (input_expr, hasher_input) in input.into_iter().zip_eq(perm_config.inputs().into_iter()) - { - cb.require_equal(|| "poseidon2 input", input_expr, hasher_input)?; - } - for (xi, hasher_output) in x.iter().zip(perm_config.output().into_iter()) { - cb.require_equal(|| "x = poseidon2's output", xi.expr(), hasher_output)?; - } + // for (input_expr, hasher_input) in input.into_iter().zip_eq(perm_config.inputs().into_iter()) + // { + // cb.require_equal(|| "poseidon2 input", input_expr, hasher_input)?; + // } + // for (xi, hasher_output) in x.iter().zip(perm_config.output().into_iter()) { + // cb.require_equal(|| "x = poseidon2's output", xi.expr(), hasher_output)?; + // } // both (x, y) and (x, -y) are valid ec points // if is_global_write = 1, then y should be in [0, p/2) @@ -237,40 +243,6 @@ impl GlobalRecord { } } -impl From for GlobalRecord { - fn from(step: StepRecord) -> Self { - let mut record = GlobalRecord::default(); - match step.memory_op() { - None => { - record.ram_type = RAMType::Register; - } - Some(_) => { - record.ram_type = RAMType::Memory; - } - }; - if let Some(op) = step.rs1() { - // read from previous shard - record.addr = op.addr.into(); - record.value = op.value; - record.global_clk = 0; // FIXME - record.shard = 0; // FIXME - record.local_clk = 0; - record.is_write = false; - } else { - // propagate local write to global for future shards - let op = step.rd().unwrap(); - record.addr = op.addr.into(); - record.value = op.value.after; - record.shard = 0; // FIXME - record.global_clk = step.cycle(); - record.local_clk = step.cycle(); - record.is_write = true; - } - - record - } -} - // This chip is used to manage read/write into a global set // shared among multiple shards pub struct GlobalChip { @@ -282,6 +254,7 @@ impl Instruction for GlobalChip { type InstructionConfig = GlobalConfig; + type Record = GlobalRecord; fn name() -> String { "Global".to_string() @@ -387,10 +360,8 @@ impl config: &Self::InstructionConfig, instance: &mut [E::BaseField], _lk_multiplicity: &mut LkMultiplicity, - _step: &StepRecord, + record: &GlobalRecord, ) -> Result<(), crate::error::ZKVMError> { - let record: GlobalRecord = _step.clone().into(); - // assign basic fields let is_ram_register = match record.ram_type { RAMType::Register => 1, @@ -423,10 +394,82 @@ impl Ok(()) } + + fn assign_instances( + config: &Self::InstructionConfig, + num_witin: usize, + num_structural_witin: usize, + steps: Vec, + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + // FIXME selector is the only structural witness + // this is workaround, as call `construct_circuit` will not initialized selector + // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` + assert!(num_structural_witin == 3); + let selector_r_witin = WitIn { id: 0 }; + let selector_w_witin = WitIn { id: 1 }; + let selector_zero_witin = WitIn { id: 2 }; + + let nthreads = max_usable_threads(); + + let num_local_reads = steps.iter().filter(|s| s.is_write).count(); + let num_local_writes = steps.len() - num_local_reads; + + let num_instance_per_batch = if steps.len() > 256 { + steps.len().div_ceil(nthreads) + } else { + steps.len() + } + .max(1); + let lk_multiplicity = LkMultiplicity::default(); + let mut raw_witin = + RowMajorMatrix::::new(steps.len(), num_witin, Self::padding_strategy()); + let mut raw_structual_witin = RowMajorMatrix::::new( + steps.len(), + num_structural_witin, + Self::padding_strategy(), + ); + let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); + let raw_structual_witin_iter = + raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); + + raw_witin_iter + .zip_eq(raw_structual_witin_iter) + .zip_eq(steps.par_chunks(num_instance_per_batch)) + .flat_map(|((instances, structural_instance), steps)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip_eq(structural_instance.chunks_mut(num_structural_witin)) + .zip_eq(steps) + .enumerate() + .map(|(i, ((instance, structural_instance), step))| { + let (sel_r, sel_w) = if i < num_local_reads { + (E::BaseField::ONE, E::BaseField::ZERO) + } else { + (E::BaseField::ZERO, E::BaseField::ONE) + }; + set_val!(structural_instance, selector_r_witin, sel_r); + set_val!(structural_instance, selector_w_witin, sel_w); + set_val!(structural_instance, selector_zero_witin, E::BaseField::ONE); + Self::assign_instance(config, instance, &mut lk_multiplicity, step) + }) + .collect::>() + }) + .collect::>()?; + + raw_witin.padding_by_strategy(); + raw_structual_witin.padding_by_strategy(); + Ok(( + [raw_witin, raw_structual_witin], + lk_multiplicity.into_finalize_result(), + )) + } } #[cfg(test)] mod tests { + use std::sync::Arc; + use ff_ext::{BabyBearExt4, PoseidonField}; use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; use p3::{babybear::BabyBear, field::FieldAlgebra}; @@ -464,11 +507,6 @@ mod tests { let (config, gkr_circuit) = global_chip .build_gkr_iop_circuit(&mut cb, &ProgramParams::default()) .unwrap(); - let composed_cs = ComposedConstrainSystem { - zkvm_v1_css: cs, - gkr_circuit: Some(gkr_circuit), - }; - let pk = composed_cs.key_gen(); // create a bunch of random memory read/write records let n_reads = 10; @@ -515,6 +553,22 @@ mod tests { assert!(global_ec_sum.is_infinity == true); // assign witness + let (witness, lk) = GlobalChip::assign_instances( + &config, + cs.num_witin as usize, + cs.num_structural_witin as usize, + global_reads + .into_iter() + .chain(global_writes.into_iter()) + .collect::>(), + ) + .unwrap(); + + let composed_cs = ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit: Some(gkr_circuit), + }; + let pk = composed_cs.key_gen(); // create chip proof for global chip let pcs_param = PCS::setup(1 << 20, SecurityLevel::Conjecture100bits).unwrap(); @@ -528,11 +582,11 @@ mod tests { let mut transcript = BasicTranscript::new(b"global chip test"); let proof_input = ProofInput { - witness: todo!(), - structural_witness: todo!(), - fixed: todo!(), - public_input: todo!(), - num_instances: todo!(), + witness: witness[0].to_mles().into_iter().map(Arc::new).collect(), + structural_witness: witness[1].to_mles().into_iter().map(Arc::new).collect(), + fixed: vec![], + public_input: vec![], + num_instances: (n_reads + n_writes) as usize, }; let challenges = [E::ONE, E::ONE]; let proof = zkvm_prover diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index a71147f18..ba43becfb 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -39,6 +39,7 @@ pub type SubInstruction = ArithInstruction; impl Instruction for ArithInstruction { type InstructionConfig = ArithConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index cbadc807d..d5834b93e 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -32,6 +32,7 @@ pub struct InstructionConfig { impl Instruction for AddiInstruction { type InstructionConfig = InstructionConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", Self::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 5304016e0..fc3f74a72 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -17,7 +17,7 @@ use crate::{ utils::split_to_u8, witness::LkMultiplicity, }; -use ceno_emul::InsnKind; +use ceno_emul::{InsnKind, StepRecord}; use gkr_iop::tables::{LookupTable, ops::XorTable}; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; @@ -37,6 +37,7 @@ pub struct AuipcInstruction(PhantomData); impl Instruction for AuipcInstruction { type InstructionConfig = AuipcConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", InsnKind::AUIPC) diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs index efbe64d5a..974f3aa84 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs @@ -35,12 +35,13 @@ pub struct BranchConfig { } impl Instruction for BranchCircuit { + type InstructionConfig = BranchConfig; + type Record = StepRecord; + fn name() -> String { format!("{:?}", I::INST_KIND) } - type InstructionConfig = BranchConfig; - fn construct_circuit( &self, circuit_builder: &mut CircuitBuilder, diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 3d4ed92de..ee475feb3 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -29,6 +29,7 @@ pub struct BranchConfig { impl Instruction for BranchCircuit { type InstructionConfig = BranchConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 9fb0695bc..94e0b0fe7 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -166,7 +166,10 @@ mod test { const INSN_KIND: InsnKind = InsnKind::REMU; } - fn verify + TestInstance + Default>( + fn verify< + E: ExtensionField, + Insn: Instruction + TestInstance + Default, + >( name: &str, dividend: >::NumType, divisor: >::NumType, @@ -229,7 +232,10 @@ mod test { } // shortcut to verify given pair produces correct output - fn verify_positive + TestInstance + Default>( + fn verify_positive< + E: ExtensionField, + Insn: Instruction + TestInstance + Default, + >( name: &str, dividend: >::NumType, divisor: >::NumType, diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index d11330f27..e94d6fad6 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -48,6 +48,7 @@ pub struct ArithInstruction(PhantomData<(E, I)>); impl Instruction for ArithInstruction { type InstructionConfig = DivRemConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index 04e59cc96..28b8d72f7 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -25,6 +25,7 @@ pub struct DummyInstruction(PhantomData<(E, I)>); impl Instruction for DummyInstruction { type InstructionConfig = DummyConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}_DUMMY", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 662f5a0e1..43f3b0f50 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -29,6 +29,7 @@ pub struct LargeEcallDummy(PhantomData<(E, S)>); impl Instruction for LargeEcallDummy { type InstructionConfig = LargeEcallConfig; + type Record = StepRecord; fn name() -> String { S::NAME.to_owned() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index e5709adc0..9f3311383 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -31,6 +31,7 @@ pub struct HaltInstruction(PhantomData); impl Instruction for HaltInstruction { type InstructionConfig = HaltConfig; + type Record = StepRecord; fn name() -> String { "ECALL_HALT".into() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index 57bd13897..38028c14f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -54,6 +54,7 @@ pub struct KeccakInstruction(PhantomData); impl Instruction for KeccakInstruction { type InstructionConfig = EcallKeccakConfig; + type Record = StepRecord; fn name() -> String { "Ecall_Keccak".to_string() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index e2fa19e7a..502b5b4d0 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -59,6 +59,7 @@ impl Instruction for WeierstrassAddAssignInstruction { type InstructionConfig = EcallWeierstrassAddAssignConfig; + type Record = StepRecord; fn name() -> String { "Ecall_WeierstrassAddAssign_".to_string() + format!("{:?}", EC::CURVE_TYPE).as_str() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index 7d094e612..17c59544c 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -66,6 +66,7 @@ impl Instruction { type InstructionConfig = EcallWeierstrassDecompressConfig; + type Record = StepRecord; fn name() -> String { "Ecall_WeierstrassDecompress_".to_string() + format!("{:?}", EC::CURVE_TYPE).as_str() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 210fe81c9..8915c71e6 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -61,6 +61,7 @@ impl Instruction { type InstructionConfig = EcallWeierstrassDoubleAssignConfig; + type Record = StepRecord; fn name() -> String { "Ecall_WeierstrassDoubleAssign_".to_string() + format!("{:?}", EC::CURVE_TYPE).as_str() diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 9a6830e6f..cc575a1ae 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -16,7 +16,7 @@ use crate::{ utils::split_to_u8, witness::LkMultiplicity, }; -use ceno_emul::{InsnKind, PC_STEP_SIZE}; +use ceno_emul::{InsnKind, PC_STEP_SIZE, StepRecord}; use gkr_iop::tables::{LookupTable, ops::XorTable}; use multilinear_extensions::{Expression, ToExpr}; use p3::field::FieldAlgebra; @@ -42,6 +42,7 @@ pub struct JalInstruction(PhantomData); /// of native WitIn values for address space arithmetic. impl Instruction for JalInstruction { type InstructionConfig = JalConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", InsnKind::JAL) diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index f3ff2990a..524d4ab22 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -19,7 +19,7 @@ use crate::{ utils::imm_sign_extend, witness::{LkMultiplicity, set_val}, }; -use ceno_emul::{InsnKind, PC_STEP_SIZE}; +use ceno_emul::{InsnKind, PC_STEP_SIZE, StepRecord}; use ff_ext::FieldInto; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; @@ -42,6 +42,7 @@ pub struct JalrInstruction(PhantomData); /// the program table impl Instruction for JalrInstruction { type InstructionConfig = JalrConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", InsnKind::JALR) diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index c57a20b8e..0bf185a6e 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -29,6 +29,7 @@ pub struct LogicInstruction(PhantomData<(E, I)>); impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 6f98e6c74..b676469f1 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -31,6 +31,7 @@ pub struct LogicInstruction(PhantomData<(E, I)>); impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index c495cfb04..f76650f7b 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -17,7 +17,7 @@ use crate::{ utils::split_to_u8, witness::LkMultiplicity, }; -use ceno_emul::InsnKind; +use ceno_emul::{InsnKind, StepRecord}; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::FieldAlgebra; use witness::set_val; @@ -34,6 +34,7 @@ pub struct LuiInstruction(PhantomData); impl Instruction for LuiInstruction { type InstructionConfig = LuiConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", InsnKind::LUI) diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 4a008f009..29c6b7da8 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -42,6 +42,7 @@ pub struct LoadInstruction(PhantomData<(E, I)>); impl Instruction for LoadInstruction { type InstructionConfig = LoadConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index ce85ced04..ff74a7833 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -42,6 +42,7 @@ impl Instruction for StoreInstruction { type InstructionConfig = StoreConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index 031d27398..da519027f 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -21,6 +21,7 @@ use ff_ext::BabyBearExt4; use ff_ext::{ExtensionField, GoldilocksExt2}; use gkr_iop::circuit_builder::DebugIndex; use std::hash::Hash; +use tracing::span::Record; fn sb(prev: Word, rs2: Word, shift: u32) -> Word { let shift = (shift * 8) as usize; @@ -78,7 +79,7 @@ fn load(mem_value: Word, insn: InsnKind, shift: u32) -> Word { fn impl_opcode_store< E: ExtensionField + Hash, I: RIVInstruction, - Inst: Instruction + Default, + Inst: Instruction + Default, >( imm: i32, ) { @@ -144,7 +145,11 @@ fn impl_opcode_store< MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } -fn impl_opcode_load + Default>( +fn impl_opcode_load< + E: ExtensionField + Hash, + I: RIVInstruction, + Inst: Instruction + Default, +>( imm: i32, ) { let mut cs = ConstraintSystem::::new(|| "riscv"); diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index 6b3d5c13c..7b2588b5b 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -38,6 +38,7 @@ pub struct MulhConfig { impl Instruction for MulhInstructionBase { type InstructionConfig = MulhConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index c7915ca74..fcf68da53 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -12,7 +12,7 @@ use crate::{ structs::ProgramParams, utils::{split_to_limb, split_to_u8}, }; -use ceno_emul::InsnKind; +use ceno_emul::{InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto}; use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; @@ -276,6 +276,7 @@ pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); impl Instruction for ShiftLogicalInstruction { type InstructionConfig = ShiftRTypeConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) @@ -373,6 +374,7 @@ pub struct ShiftImmInstruction(PhantomData<(E, I)>); impl Instruction for ShiftImmInstruction { type InstructionConfig = ShiftImmConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index 16050e733..30999b679 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -30,6 +30,7 @@ pub struct SetLessThanConfig { } impl Instruction for SetLessThanInstruction { type InstructionConfig = SetLessThanConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 1e1b1c9b7..f2dc080db 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -41,6 +41,7 @@ pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); impl Instruction for SetLessThanImmInstruction { type InstructionConfig = SetLessThanImmConfig; + type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 4f3404fd9..09246cb02 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -63,6 +63,7 @@ struct TestCircuit { impl Instruction for TestCircuit { type InstructionConfig = TestConfig; + type Record = StepRecord; fn name() -> String { "TEST".into() diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 0e6c085a9..47820680d 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -328,7 +328,7 @@ impl ZKVMWitnesses { self.lk_mlts.get(name) } - pub fn assign_opcode_circuit>( + pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, config: &OC::InstructionConfig, From cfd9870cadce4db79d2d831b443cb6357b502d0f Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sun, 19 Oct 2025 21:49:54 +0800 Subject: [PATCH 61/91] wip4 --- ceno_zkvm/src/e2e.rs | 1 + ceno_zkvm/src/instructions/global.rs | 25 +++++++++++++++++++++++-- ceno_zkvm/src/scheme.rs | 9 +++++++++ ceno_zkvm/src/scheme/tests.rs | 2 +- 4 files changed, 34 insertions(+), 3 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 226231c2b..892de16d9 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -157,6 +157,7 @@ pub fn emulate_program( vm.get_pc().into(), end_cycle, io_init.iter().map(|rec| rec.value).collect_vec(), + vec![0; 14], // point_at_infinity ); // Find the final register values and cycles. diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 2a14eccb2..0030cfe62 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -471,6 +471,7 @@ mod tests { use std::sync::Arc; use ff_ext::{BabyBearExt4, PoseidonField}; + use itertools::Itertools; use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; use p3::{babybear::BabyBear, field::FieldAlgebra}; use transcript::BasicTranscript; @@ -483,11 +484,13 @@ mod tests { global::{GlobalChip, GlobalRecord}, }, scheme::{ - create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, + PublicValues, create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, septic_curve::SepticPoint, }, structs::{ComposedConstrainSystem, ProgramParams, RAMType, ZKVMProvingKey}, }; + use multilinear_extensions::mle::IntoMLE; + use p3::field::PrimeField32; type E = BabyBearExt4; type F = BabyBear; @@ -551,6 +554,20 @@ mod tests { .map(|record| record.to_ec_point::(&global_chip.perm).1) .sum(); + let public_value = PublicValues::new( + 0, + 0, + 0, + 0, + 0, + vec![0], // dummy + global_ec_sum + .x + .iter() + .chain(global_ec_sum.y.iter()) + .map(|fe| fe.as_canonical_u32()) + .collect_vec(), + ); assert!(global_ec_sum.is_infinity == true); // assign witness let (witness, lk) = GlobalChip::assign_instances( @@ -585,7 +602,11 @@ mod tests { witness: witness[0].to_mles().into_iter().map(Arc::new).collect(), structural_witness: witness[1].to_mles().into_iter().map(Arc::new).collect(), fixed: vec![], - public_input: vec![], + public_input: public_value + .to_vec::() + .into_iter() + .map(|v| Arc::new(v.into_mle())) + .collect_vec(), num_instances: (n_reads + n_writes) as usize, }; let challenges = [E::ONE, E::ONE]; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 98393d2e1..1a27160bc 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -77,6 +77,7 @@ pub struct PublicValues { end_pc: u32, end_cycle: u64, public_io: Vec, + global_sum: Vec, } impl PublicValues { @@ -87,6 +88,7 @@ impl PublicValues { end_pc: u32, end_cycle: u64, public_io: Vec, + global_sum: Vec, ) -> Self { Self { exit_code, @@ -95,6 +97,7 @@ impl PublicValues { end_pc, end_cycle, public_io, + global_sum, } } pub fn to_vec(&self) -> Vec> { @@ -124,6 +127,12 @@ impl PublicValues { }) .collect_vec(), ) + .chain( + self.global_sum + .iter() + .map(|value| vec![E::BaseField::from_canonical_u32(*value)]) + .collect_vec(), + ) .collect::>() } } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 09246cb02..44e88836b 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -360,7 +360,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); - let pi = PublicValues::new(0, 0, 0, 0, 0, vec![0]); + let pi = PublicValues::new(0, 0, 0, 0, 0, vec![0], vec![0; 14]); let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover .create_proof(zkvm_witness, pi, transcript) From 48d5f93900203182caef6723c221509f0b456d95 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Sun, 19 Oct 2025 22:10:33 +0800 Subject: [PATCH 62/91] fix few bugs in e2e --- ceno_zkvm/src/e2e.rs | 4 +- .../src/instructions/riscv/rv32im/mmu.rs | 8 ++- ceno_zkvm/src/scheme/verifier.rs | 20 +++---- ceno_zkvm/src/tables/ram/ram_impl.rs | 55 +++++++++++-------- gkr_iop/src/gkr/layer/cpu/mod.rs | 7 ++- 5 files changed, 52 insertions(+), 42 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index d1c14bb9d..7e38b158e 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -402,7 +402,7 @@ pub fn emulate_program<'a>( if index < VMState::REG_COUNT { let vma: WordAddr = Platform::register_vma(index).into(); MemFinalRecord { - ram_type: RAMType::Memory, + ram_type: RAMType::Register, addr: rec.addr, value: vm.peek_register(index), cycle: *final_access.get(&vma).unwrap_or(&0), @@ -410,7 +410,7 @@ pub fn emulate_program<'a>( } else { // The table is padded beyond the number of registers. MemFinalRecord { - ram_type: RAMType::Memory, + ram_type: RAMType::Register, addr: rec.addr, value: 0, cycle: 0, diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index a95b8e03f..217080293 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -12,6 +12,7 @@ use crate::{ use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; use ff_ext::ExtensionField; use itertools::{Itertools, chain}; +use multilinear_extensions::mle::IntoInstanceIterMut; use std::{collections::HashSet, iter::zip, ops::Range, sync::Arc}; use witness::InstancePaddingStrategy; @@ -46,9 +47,7 @@ impl MmuConfig<'_, E> { let hints_config = cs.register_table_circuit::>(); let stack_init_config = cs.register_table_circuit::>(); let heap_init_config = cs.register_table_circuit::>(); - println!("register LocalFinalCircuit"); let local_final_circuit = cs.register_table_circuit::>(); - println!("end register LocalFinalCircuit"); let ram_bus_circuit = cs.register_table_circuit::>(); Self { @@ -154,7 +153,10 @@ impl MmuConfig<'_, E> { }), heap_final, ), - ]; + ] + .into_iter() + .filter(|(_, record)| !record.is_empty()) + .collect_vec(); // take all mem result and witness.assign_table_circuit::>( cs, diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 222e44c9a..0c897298e 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -157,8 +157,6 @@ impl> ZKVMVerifier let dummy_table_item = challenges[0]; let mut dummy_table_item_multiplicity = 0; let point_eval = PointAndEval::default(); - let mut rt_points = Vec::with_capacity(vm_proof.chip_proofs.len()); - let mut evaluations = Vec::with_capacity(vm_proof.chip_proofs.len()); let mut witin_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); let mut fixed_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); for (index, proof) in &vm_proof.chip_proofs { @@ -256,22 +254,18 @@ impl> ZKVMVerifier &challenges, )? }; - rt_points.push((*index, input_opening_point.clone())); - evaluations.push(( - *index, - [proof.wits_in_evals.clone(), proof.fixed_in_evals.clone()].concat(), - )); - witin_openings.push(( - input_opening_point.len(), - (input_opening_point.clone(), proof.wits_in_evals.clone()), - )); - if !proof.fixed_in_evals.is_empty() { + if circuit_vk.get_cs().num_witin() > 0 { + witin_openings.push(( + input_opening_point.len(), + (input_opening_point.clone(), proof.wits_in_evals.clone()), + )); + } + if circuit_vk.get_cs().num_fixed() > 0 { fixed_openings.push(( input_opening_point.len(), (input_opening_point.clone(), proof.fixed_in_evals.clone()), )); } - prod_w *= proof.w_out_evals.iter().flatten().copied().product::(); prod_r *= proof.r_out_evals.iter().flatten().copied().product::(); tracing::debug!("verified proof for circuit {}", circuit_name); diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 7a24ef66f..fc6f50c7b 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -593,14 +593,15 @@ impl LocalFinalRAMTableConfig { let num_structural_witin = num_structural_witin.max(1); let selector_witin = WitIn { id: 0 }; + let is_current_shard_mem_record = |record: &&MemFinalRecord| -> bool { + (shard_ctx.is_first_shard() && record.cycle == 0) + || shard_ctx.is_current_shard_cycle(record.cycle) + }; + // collect each raw mem belong to this shard, BEFORE padding length let current_shard_mems_len: Vec = final_mem .par_iter() - .map(|(_, mem)| { - mem.par_iter() - .filter(|record| shard_ctx.is_current_shard_cycle(record.cycle)) - .count() - }) + .map(|(_, mem)| mem.par_iter().filter(is_current_shard_mem_record).count()) .collect(); // deal with non-pow2 padding for first shard @@ -608,17 +609,24 @@ impl LocalFinalRAMTableConfig { let padding_info = if shard_ctx.is_first_shard() { final_mem .iter() - .map(|(_, mem)| (next_pow2_instance_padding(mem.len()) - mem.len(), mem.len())) + .map(|(_, mem)| { + assert!(!mem.is_empty()); + ( + next_pow2_instance_padding(mem.len()) - mem.len(), + mem.len(), + mem[0].ram_type, + ) + }) .collect_vec() } else { - vec![(0, 0); final_mem.len()] + vec![(0, 0, RAMType::Undefined); final_mem.len()] }; // calculate mem length let mem_lens = current_shard_mems_len .iter() .zip_eq(&padding_info) - .map(|(raw_len, (pad_len, _))| raw_len + pad_len) + .map(|(raw_len, (pad_len, _, _))| raw_len + pad_len) .collect_vec(); let total_records = mem_lens.iter().sum(); @@ -663,17 +671,13 @@ impl LocalFinalRAMTableConfig { .for_each( |( ((witness, structural_witness), (padding_strategy, final_mem)), - (pad_size, pad_start_index), + (pad_size, pad_start_index, ram_type), )| { - witness + let mem_record_count = witness .chunks_mut(num_witin) .zip_eq(structural_witness.chunks_mut(num_structural_witin)) - .zip( - final_mem - .iter() - .filter(|record| shard_ctx.is_current_shard_cycle(record.cycle)), - ) - .for_each(|((row, structural_row), rec)| { + .zip(final_mem.iter().filter(is_current_shard_mem_record)) + .map(|((row, structural_row), rec)| { if self.final_v.len() == 1 { // Assign value directly. set_val!(row, self.final_v[0], rec.value as u64); @@ -689,26 +693,33 @@ impl LocalFinalRAMTableConfig { set_val!(row, self.ram_type, rec.ram_type as u64); set_val!(row, self.addr_subset, rec.addr as u64); set_val!(structural_row, selector_witin, 1u64); - }); + () + }) + .count(); if *pad_size > 0 && shard_ctx.is_first_shard() { match padding_strategy { InstancePaddingStrategy::Custom(pad_func) => { - witness[pad_size * num_witin..] + witness[mem_record_count * num_witin..] .chunks_mut(num_witin) .zip_eq( - structural_witness[pad_size * num_structural_witin..] + structural_witness + [mem_record_count * num_structural_witin..] .chunks_mut(num_structural_witin), ) - .zip(std::iter::successors(Some(*pad_start_index), |n| { - Some(*n + 1) - })) + .zip_eq( + std::iter::successors(Some(*pad_start_index), |n| { + Some(*n + 1) + }) + .take(*pad_size), + ) .for_each(|((row, structural_row), pad_index)| { set_val!( row, self.addr_subset, pad_func(pad_index as u64, self.addr_subset.id as u64) ); + set_val!(row, self.ram_type, *ram_type as u64); set_val!(structural_row, selector_witin, 1u64); }); } diff --git a/gkr_iop/src/gkr/layer/cpu/mod.rs b/gkr_iop/src/gkr/layer/cpu/mod.rs index 41807e6b7..3719e3029 100644 --- a/gkr_iop/src/gkr/layer/cpu/mod.rs +++ b/gkr_iop/src/gkr/layer/cpu/mod.rs @@ -20,6 +20,7 @@ use multilinear_extensions::{ monomial::Term, virtual_poly::build_eq_x_r_vec, virtual_polys::VirtualPolynomialsBuilder, + wit_infer_by_monomial_expr, }; use rayon::{ iter::{ @@ -27,6 +28,7 @@ use rayon::{ }, slice::ParallelSlice, }; +use std::sync::Arc; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProof, IOPProverState}, @@ -222,15 +224,16 @@ impl> ZerocheckLayerProver layer.n_structural_witin, layer.n_fixed, ); + let builder = VirtualPolynomialsBuilder::new_with_mles(num_threads, max_num_variables, all_witins); let span = entered_span!("IOPProverState::prove", profiling_4 = true); let (proof, prover_state) = IOPProverState::prove( builder.to_virtual_polys_with_monomial_terms( - &layer + layer .main_sumcheck_expression_monomial_terms - .clone() + .as_ref() .unwrap(), pub_io_evals, &main_sumcheck_challenges, From 253043f8689f9c99ab8cd51a7a0ffbcf33ecbc96 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 20 Oct 2025 00:39:16 +0800 Subject: [PATCH 63/91] wip5 --- ceno_zkvm/src/instructions/global.rs | 54 +++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 0030cfe62..adb4ba35a 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -474,6 +474,10 @@ mod tests { use itertools::Itertools; use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; use p3::{babybear::BabyBear, field::FieldAlgebra}; + use tracing_forest::{ForestLayer, util::LevelFilter}; + use tracing_subscriber::{ + EnvFilter, Registry, fmt, layer::SubscriberExt, util::SubscriberInitExt, + }; use transcript::BasicTranscript; use crate::{ @@ -485,9 +489,9 @@ mod tests { }, scheme::{ PublicValues, create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, - septic_curve::SepticPoint, + septic_curve::SepticPoint, verifier::ZKVMVerifier, }, - structs::{ComposedConstrainSystem, ProgramParams, RAMType, ZKVMProvingKey}, + structs::{ComposedConstrainSystem, PointAndEval, ProgramParams, RAMType, ZKVMProvingKey}, }; use multilinear_extensions::mle::IntoMLE; use p3::field::PrimeField32; @@ -499,6 +503,22 @@ mod tests { #[test] fn test_global_chip() { + // default filter + let default_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::DEBUG.into()) + .from_env_lossy(); + let fmt_layer = fmt::layer() + .compact() + .with_thread_ids(false) + .with_thread_names(false) + .without_time(); + + Registry::default() + .with(ForestLayer::default()) + .with(fmt_layer) + .with(default_filter) + .init(); + // init global chip with horizen_rc_consts let rc = horizen_round_consts(); let perm = ::get_default_perm(); @@ -595,22 +615,24 @@ mod tests { // let pk = prover.create_chip_proof(); let mut zkvm_pk = ZKVMProvingKey::new(pp, vp); + let zkvm_vk = zkvm_pk.get_vk_slow(); let zkvm_prover = ZKVMProver::new(zkvm_pk, pd); let mut transcript = BasicTranscript::new(b"global chip test"); + let public_input_mles = public_value + .to_vec::() + .into_iter() + .map(|v| Arc::new(v.into_mle())) + .collect_vec(); let proof_input = ProofInput { witness: witness[0].to_mles().into_iter().map(Arc::new).collect(), structural_witness: witness[1].to_mles().into_iter().map(Arc::new).collect(), fixed: vec![], - public_input: public_value - .to_vec::() - .into_iter() - .map(|v| Arc::new(v.into_mle())) - .collect_vec(), + public_input: public_input_mles.clone(), num_instances: (n_reads + n_writes) as usize, }; let challenges = [E::ONE, E::ONE]; - let proof = zkvm_prover + let (proof, pi_evals, point) = zkvm_prover .create_chip_proof( "global chip", &pk, @@ -619,5 +641,21 @@ mod tests { &challenges, ) .unwrap(); + + let mut transcript = BasicTranscript::new(b"global chip test"); + let verifier = ZKVMVerifier::new(zkvm_vk); + let pi_evals = pi_evals.into_iter().map(|(k, v)| v).collect_vec(); + let opening_point = verifier + .verify_opcode_proof( + "global", + &pk.vk, + &proof, + &pi_evals, + &mut transcript, + 2, + &PointAndEval::default(), + &challenges, + ) + .expect("verify global chip proof"); } } From ae46b8e7f20f4ff1c773fe20155cda673243b88e Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Mon, 20 Oct 2025 01:05:25 +0800 Subject: [PATCH 64/91] wip6 --- ceno_zkvm/src/instructions/global.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index adb4ba35a..8f0abd164 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -632,7 +632,7 @@ mod tests { num_instances: (n_reads + n_writes) as usize, }; let challenges = [E::ONE, E::ONE]; - let (proof, pi_evals, point) = zkvm_prover + let (proof, _, point) = zkvm_prover .create_chip_proof( "global chip", &pk, @@ -644,7 +644,10 @@ mod tests { let mut transcript = BasicTranscript::new(b"global chip test"); let verifier = ZKVMVerifier::new(zkvm_vk); - let pi_evals = pi_evals.into_iter().map(|(k, v)| v).collect_vec(); + let pi_evals = public_input_mles + .iter() + .map(|mle| mle.evaluate(&point[..mle.num_vars()])) + .collect_vec(); let opening_point = verifier .verify_opcode_proof( "global", From 306e0dfb0d8f43eca9aab24ba85cdfd6f81150cf Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 20 Oct 2025 13:58:39 +0800 Subject: [PATCH 65/91] debug log --- ceno_zkvm/src/tables/ram/ram_impl.rs | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index fc6f50c7b..1fbbd9bbd 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -115,6 +115,12 @@ impl NonVolatileTableConfigTrait< NVRAM::len(&config.params) ); + println!( + "Init: NVRAM::RAM_TYPE {:?}, raw len {}", + NVRAM::RAM_TYPE, + init_mem.len(), + ); + let mut init_table = RowMajorMatrix::::new( NVRAM::len(&config.params), num_fixed, @@ -503,6 +509,12 @@ impl DynVolatileRamTableConfig num_structural_witin, InstancePaddingStrategy::Default, ); + println!( + "Init: DVRAM::RAM_TYPE {:?}, raw len {}, padded {}", + DVRAM::RAM_TYPE, + final_mem.len(), + num_instances_padded - final_mem.len() + ); structural_witness .par_rows_mut() @@ -604,6 +616,16 @@ impl LocalFinalRAMTableConfig { .map(|(_, mem)| mem.par_iter().filter(is_current_shard_mem_record).count()) .collect(); + current_shard_mems_len + .iter() + .zip(final_mem.iter()) + .for_each(|(raw_len, (_, mem))| { + println!( + "Final: DVRAM::RAM_TYPE {:?}, raw len {}", + mem[0].ram_type, raw_len + ) + }); + // deal with non-pow2 padding for first shard // format Vec<(pad_len, pad_start_index)> let padding_info = if shard_ctx.is_first_shard() { @@ -622,6 +644,16 @@ impl LocalFinalRAMTableConfig { vec![(0, 0, RAMType::Undefined); final_mem.len()] }; + padding_info + .iter() + .zip(final_mem.iter()) + .for_each(|((pad_size, ..), (_, mem))| { + println!( + "Final: DVRAM::RAM_TYPE {:?}, pad_size {}", + mem[0].ram_type, pad_size + ) + }); + // calculate mem length let mem_lens = current_shard_mems_len .iter() From 8e9c2d2bd37c147014a61dadd2fed5f385a55492 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 20 Oct 2025 16:16:48 +0800 Subject: [PATCH 66/91] chores: mock_proving non-static ram type --- .github/workflows/lints.yml | 2 + Cargo.lock | 16 ++--- ceno_zkvm/src/scheme/mock_prover.rs | 91 +++++++++++++++++++++++++--- gkr_iop/src/circuit_builder.rs | 94 ++++++++++++++++++++++++----- 4 files changed, 169 insertions(+), 34 deletions(-) diff --git a/.github/workflows/lints.yml b/.github/workflows/lints.yml index 29ff53880..8c08a3662 100644 --- a/.github/workflows/lints.yml +++ b/.github/workflows/lints.yml @@ -66,3 +66,5 @@ jobs: run: taplo --version || cargo install taplo-cli - name: Run taplo run: taplo fmt --check --diff + - name: Ensure Cargo.lock not modified by build + run: git diff --exit-code Cargo.lock diff --git a/Cargo.lock b/Cargo.lock index d3cb607be..c837a88ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -939,7 +939,7 @@ dependencies = [ [[package]] name = "ceno_crypto_primitives" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno#4a9dff21fd408e93c21edb6e874a09b0171b0c8b" +source = "git+https://github.com/scroll-tech/ceno#050108047aad24101fcb010da4e7d29e9d72678a" dependencies = [ "ceno_syscall 0.1.0 (git+https://github.com/scroll-tech/ceno)", "elliptic-curve", @@ -1013,7 +1013,7 @@ version = "0.1.0" [[package]] name = "ceno_syscall" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno#4a9dff21fd408e93c21edb6e874a09b0171b0c8b" +source = "git+https://github.com/scroll-tech/ceno#050108047aad24101fcb010da4e7d29e9d72678a" [[package]] name = "ceno_zkvm" @@ -1851,7 +1851,7 @@ dependencies = [ "ceno_syscall 0.1.0", "getrandom 0.3.2", "rand 0.8.5", - "revm-precompile 28.1.0", + "revm-precompile 28.1.1", "rkyv", "substrate-bn 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", "substrate-bn 0.6.0 (git+https://github.com/scroll-tech/bn?branch=ceno)", @@ -3874,9 +3874,9 @@ dependencies = [ [[package]] name = "revm-precompile" -version = "28.1.0" +version = "28.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "176169b39beb1f57b11f2ea3900c404b8498a56dfd8394e66f4d24f66cea368e" +checksum = "e57aadd7a2087705f653b5aaacc8ad4f8e851f5d330661e3f4c43b5475bbceae" dependencies = [ "ark-bls12-381", "ark-bn254", @@ -3888,7 +3888,7 @@ dependencies = [ "cfg-if", "k256 0.13.4 (registry+https://github.com/rust-lang/crates.io-index)", "p256", - "revm-primitives 21.0.0", + "revm-primitives 21.0.1", "ripemd", "sha2 0.10.9", ] @@ -3907,9 +3907,9 @@ dependencies = [ [[package]] name = "revm-primitives" -version = "21.0.0" +version = "21.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38271b8b85f00154bdcf9f2ab0a3ec7a8100377d2c7a0d8eb23e19389b42c795" +checksum = "536f30e24c3c2bf0d3d7d20fa9cf99b93040ed0f021fd9301c78cddb0dacda13" dependencies = [ "alloy-primitives", "num_enum 0.7.4", diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 028f844a6..911649e12 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -1147,7 +1147,7 @@ Hints: .into() }; - for ((w_rlc_expr, annotation), _) in (cs + for ((w_rlc_expr, annotation), (ram_type_expr, _)) in (cs .w_expressions .iter() .chain(cs.w_table_expressions.iter().map(|expr| &expr.expr))) @@ -1157,8 +1157,19 @@ Hints: .chain(cs.w_table_expressions_namespace_map.iter()), ) .zip_eq(cs.w_ram_types.iter()) - .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { + let ram_type_mle = wit_infer_by_expr( + ram_type_expr, + cs.num_witin, + cs.num_structural_witin, + cs.num_fixed as WitnessId, + fixed, + witness, + structural_witness, + &pi_mles, + &challenges, + ); + let ram_type_vec = ram_type_mle.get_ext_field_vec(); let write_rlc_records = wit_infer_by_expr( w_rlc_expr, cs.num_witin, @@ -1170,13 +1181,32 @@ Hints: &pi_mles, &challenges, ); + let w_selector_vec = w_selector.get_base_field_vec(); let write_rlc_records = - filter_mle_by_selector_mle(write_rlc_records, w_selector.clone()); + filter_mle_by_predicate(write_rlc_records, |i, _v| { + ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) + && w_selector_vec[i] == E::BaseField::ONE + }); + if write_rlc_records.is_empty() { + continue; + } let mut records = vec![]; + let mut writes_within_expr_dedup = HashSet::new(); for (row, record_rlc) in enumerate(write_rlc_records) { // TODO: report error - assert_eq!(writes.insert(record_rlc), true); + assert_eq!( + writes_within_expr_dedup.insert(record_rlc), + true, + "within expression write duplicated on RAMType {:?}", + $ram_type + ); + assert_eq!( + writes.insert(record_rlc), + true, + "crossing-chip write duplicated on RAMType {:?}", + $ram_type + ); records.push((record_rlc, row)); } writes_grp_by_annotations @@ -1212,7 +1242,7 @@ Hints: ) .into() }; - for ((r_rlc_expr, annotation), (_, r_exprs)) in (cs + for ((r_rlc_expr, annotation), (ram_type_expr, r_exprs)) in (cs .r_expressions .iter() .chain(cs.r_table_expressions.iter().map(|expr| &expr.expr))) @@ -1222,8 +1252,19 @@ Hints: .chain(cs.r_table_expressions_namespace_map.iter()), ) .zip_eq(cs.r_ram_types.iter()) - .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { + let ram_type_mle = wit_infer_by_expr( + ram_type_expr, + cs.num_witin, + cs.num_structural_witin, + cs.num_fixed as WitnessId, + fixed, + witness, + structural_witness, + &pi_mles, + &challenges, + ); + let ram_type_vec = ram_type_mle.get_ext_field_vec(); let read_records = wit_infer_by_expr( r_rlc_expr, cs.num_witin, @@ -1235,8 +1276,14 @@ Hints: &pi_mles, &challenges, ); - let read_records = - filter_mle_by_selector_mle(read_records, r_selector.clone()); + let r_selector_vec = r_selector.get_base_field_vec(); + let read_records = filter_mle_by_predicate(read_records, |i, _v| { + ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) + && r_selector_vec[i] == E::BaseField::ONE + }); + if read_records.is_empty() { + continue; + } if $ram_type == RAMType::GlobalState { // r_exprs = [GlobalState, pc, timestamp] @@ -1269,9 +1316,21 @@ Hints: }; let mut records = vec![]; + let mut reads_within_expr_dedup = HashSet::new(); for (row, record) in enumerate(read_records) { // TODO: return error - assert_eq!(reads.insert(record), true); + assert_eq!( + reads_within_expr_dedup.insert(record), + true, + "within expression read duplicated on RAMType {:?}", + $ram_type + ); + assert_eq!( + reads.insert(record), + true, + "crossing-chip read duplicated on RAMType {:?}", + $ram_type + ); records.push((record, row)); } reads_grp_by_annotations @@ -1467,6 +1526,19 @@ fn print_errors( } } +fn filter_mle_by_predicate(target_mle: ArcMultilinearExtension, mut predicate: F) -> Vec +where + E: ExtensionField, + F: FnMut(usize, &E) -> bool, +{ + target_mle + .get_ext_field_vec() + .iter() + .enumerate() + .filter_map(|(i, v)| if predicate(i, v) { Some(*v) } else { None }) + .collect_vec() +} + fn filter_mle_by_selector_mle( target_mle: ArcMultilinearExtension, selector: ArcMultilinearExtension, @@ -1487,7 +1559,6 @@ fn filter_mle_by_selector_mle( #[cfg(test)] mod tests { - use super::*; use crate::{ ROMType, diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index e4129bfe8..395b9e6c9 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -108,14 +108,14 @@ pub struct ConstraintSystem { pub r_expressions_namespace_map: Vec, // for each read expression we store its ram type and original value before doing RLC // the original value will be used for debugging - pub r_ram_types: Vec<(RAMType, Vec>)>, + pub r_ram_types: Vec<(Expression, Vec>)>, pub w_selector: Option>, pub w_expressions: Vec>, pub w_expressions_namespace_map: Vec, // for each write expression we store its ram type and original value before doing RLC // the original value will be used for debugging - pub w_ram_types: Vec<(RAMType, Vec>)>, + pub w_ram_types: Vec<(Expression, Vec>)>, /// init/final ram expression pub r_table_expressions: Vec>, @@ -329,12 +329,27 @@ impl ConstraintSystem { N: FnOnce() -> NR, { let rlc_record = self.rlc_chip_record(record.clone()); - assert_eq!( - rlc_record.degree(), - 1, - "rlc record degree {} != 1", - rlc_record.degree() - ); + self.r_table_rlc_record( + name_fn, + (ram_type as u64).into(), + table_spec, + record, + rlc_record, + ) + } + + pub fn r_table_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { self.r_table_expressions.push(SetTableExpression { expr: rlc_record, table_spec, @@ -358,12 +373,27 @@ impl ConstraintSystem { N: FnOnce() -> NR, { let rlc_record = self.rlc_chip_record(record.clone()); - assert_eq!( - rlc_record.degree(), - 1, - "rlc record degree {} != 1", - rlc_record.degree() - ); + self.w_table_rlc_record( + name_fn, + (ram_type as u64).into(), + table_spec, + record, + rlc_record, + ) + } + + pub fn w_table_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { self.w_table_expressions.push(SetTableExpression { expr: rlc_record, table_spec, @@ -387,7 +417,7 @@ impl ConstraintSystem { self.r_expressions_namespace_map.push(path); // Since r_expression is RLC(record) and when we're debugging // it's helpful to recover the value of record itself. - self.r_ram_types.push((ram_type, record)); + self.r_ram_types.push(((ram_type as u64).into(), record)); Ok(()) } @@ -401,7 +431,7 @@ impl ConstraintSystem { self.w_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); self.w_expressions_namespace_map.push(path); - self.w_ram_types.push((ram_type, record)); + self.w_ram_types.push(((ram_type as u64).into(), record)); Ok(()) } @@ -579,6 +609,22 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { .r_table_record(name_fn, ram_type, table_spec, record) } + pub fn r_table_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .r_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record) + } + pub fn w_table_record( &mut self, name_fn: N, @@ -594,6 +640,22 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { .w_table_record(name_fn, ram_type, table_spec, record) } + pub fn w_table_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + table_spec: SetTableSpec, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .w_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record) + } + pub fn read_record( &mut self, name_fn: N, From dc15ff463fc7bbdcc8bc3ee2328ce3c6aa62a16a Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 20 Oct 2025 16:16:48 +0800 Subject: [PATCH 67/91] chores: mock_proving non-static ram type --- .github/workflows/lints.yml | 2 + ceno_zkvm/src/scheme/mock_prover.rs | 90 ++++++++++++++++++++++++++--- gkr_iop/src/circuit_builder.rs | 32 ++++++---- 3 files changed, 105 insertions(+), 19 deletions(-) diff --git a/.github/workflows/lints.yml b/.github/workflows/lints.yml index 29ff53880..8c08a3662 100644 --- a/.github/workflows/lints.yml +++ b/.github/workflows/lints.yml @@ -66,3 +66,5 @@ jobs: run: taplo --version || cargo install taplo-cli - name: Run taplo run: taplo fmt --check --diff + - name: Ensure Cargo.lock not modified by build + run: git diff --exit-code Cargo.lock diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index da5e8fe00..296ce3267 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -1151,7 +1151,7 @@ Hints: .into() }; - for ((w_rlc_expr, annotation), _) in (cs + for ((w_rlc_expr, annotation), (ram_type_expr, _)) in (cs .w_expressions .iter() .chain(cs.w_table_expressions.iter().map(|expr| &expr.expr))) @@ -1161,8 +1161,19 @@ Hints: .chain(cs.w_table_expressions_namespace_map.iter()), ) .zip_eq(cs.w_ram_types.iter()) - .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { + let ram_type_mle = wit_infer_by_expr( + ram_type_expr, + cs.num_witin, + cs.num_structural_witin, + cs.num_fixed as WitnessId, + fixed, + witness, + structural_witness, + &pi_mles, + &challenges, + ); + let ram_type_vec = ram_type_mle.get_ext_field_vec(); let write_rlc_records = wit_infer_by_expr( w_rlc_expr, cs.num_witin, @@ -1174,13 +1185,32 @@ Hints: &pi_mles, &challenges, ); + let w_selector_vec = w_selector.get_base_field_vec(); let write_rlc_records = - filter_mle_by_selector_mle(write_rlc_records, w_selector.clone()); + filter_mle_by_predicate(write_rlc_records, |i, _v| { + ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) + && w_selector_vec[i] == E::BaseField::ONE + }); + if write_rlc_records.is_empty() { + continue; + } let mut records = vec![]; + let mut writes_within_expr_dedup = HashSet::new(); for (row, record_rlc) in enumerate(write_rlc_records) { // TODO: report error - assert_eq!(writes.insert(record_rlc), true); + assert_eq!( + writes_within_expr_dedup.insert(record_rlc), + true, + "within expression write duplicated on RAMType {:?}", + $ram_type + ); + assert_eq!( + writes.insert(record_rlc), + true, + "crossing-chip write duplicated on RAMType {:?}", + $ram_type + ); records.push((record_rlc, row)); } writes_grp_by_annotations @@ -1216,7 +1246,7 @@ Hints: ) .into() }; - for ((r_rlc_expr, annotation), (_, r_exprs)) in (cs + for ((r_rlc_expr, annotation), (ram_type_expr, r_exprs)) in (cs .r_expressions .iter() .chain(cs.r_table_expressions.iter().map(|expr| &expr.expr))) @@ -1226,8 +1256,19 @@ Hints: .chain(cs.r_table_expressions_namespace_map.iter()), ) .zip_eq(cs.r_ram_types.iter()) - .filter(|((_, _), (ram_type, _))| *ram_type == $ram_type) { + let ram_type_mle = wit_infer_by_expr( + ram_type_expr, + cs.num_witin, + cs.num_structural_witin, + cs.num_fixed as WitnessId, + fixed, + witness, + structural_witness, + &pi_mles, + &challenges, + ); + let ram_type_vec = ram_type_mle.get_ext_field_vec(); let read_records = wit_infer_by_expr( r_rlc_expr, cs.num_witin, @@ -1239,8 +1280,14 @@ Hints: &pi_mles, &challenges, ); - let read_records = - filter_mle_by_selector_mle(read_records, r_selector.clone()); + let r_selector_vec = r_selector.get_base_field_vec(); + let read_records = filter_mle_by_predicate(read_records, |i, _v| { + ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) + && r_selector_vec[i] == E::BaseField::ONE + }); + if read_records.is_empty() { + continue; + } if $ram_type == RAMType::GlobalState { // r_exprs = [GlobalState, pc, timestamp] @@ -1273,9 +1320,21 @@ Hints: }; let mut records = vec![]; + let mut reads_within_expr_dedup = HashSet::new(); for (row, record) in enumerate(read_records) { // TODO: return error - assert_eq!(reads.insert(record), true); + assert_eq!( + reads_within_expr_dedup.insert(record), + true, + "within expression read duplicated on RAMType {:?}", + $ram_type + ); + assert_eq!( + reads.insert(record), + true, + "crossing-chip read duplicated on RAMType {:?}", + $ram_type + ); records.push((record, row)); } reads_grp_by_annotations @@ -1471,6 +1530,19 @@ fn print_errors( } } +fn filter_mle_by_predicate(target_mle: ArcMultilinearExtension, mut predicate: F) -> Vec +where + E: ExtensionField, + F: FnMut(usize, &E) -> bool, +{ + target_mle + .get_ext_field_vec() + .iter() + .enumerate() + .filter_map(|(i, v)| if predicate(i, v) { Some(*v) } else { None }) + .collect_vec() +} + fn filter_mle_by_selector_mle( target_mle: ArcMultilinearExtension, selector: ArcMultilinearExtension, diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 2da0271ff..395b9e6c9 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -108,14 +108,14 @@ pub struct ConstraintSystem { pub r_expressions_namespace_map: Vec, // for each read expression we store its ram type and original value before doing RLC // the original value will be used for debugging - pub r_ram_types: Vec<(RAMType, Vec>)>, + pub r_ram_types: Vec<(Expression, Vec>)>, pub w_selector: Option>, pub w_expressions: Vec>, pub w_expressions_namespace_map: Vec, // for each write expression we store its ram type and original value before doing RLC // the original value will be used for debugging - pub w_ram_types: Vec<(RAMType, Vec>)>, + pub w_ram_types: Vec<(Expression, Vec>)>, /// init/final ram expression pub r_table_expressions: Vec>, @@ -329,13 +329,19 @@ impl ConstraintSystem { N: FnOnce() -> NR, { let rlc_record = self.rlc_chip_record(record.clone()); - self.r_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record) + self.r_table_rlc_record( + name_fn, + (ram_type as u64).into(), + table_spec, + record, + rlc_record, + ) } pub fn r_table_rlc_record( &mut self, name_fn: N, - ram_type: RAMType, + ram_type: Expression, table_spec: SetTableSpec, record: Vec>, rlc_record: Expression, @@ -367,13 +373,19 @@ impl ConstraintSystem { N: FnOnce() -> NR, { let rlc_record = self.rlc_chip_record(record.clone()); - self.w_table_rlc_record(name_fn, ram_type, table_spec, record, rlc_record) + self.w_table_rlc_record( + name_fn, + (ram_type as u64).into(), + table_spec, + record, + rlc_record, + ) } pub fn w_table_rlc_record( &mut self, name_fn: N, - ram_type: RAMType, + ram_type: Expression, table_spec: SetTableSpec, record: Vec>, rlc_record: Expression, @@ -405,7 +417,7 @@ impl ConstraintSystem { self.r_expressions_namespace_map.push(path); // Since r_expression is RLC(record) and when we're debugging // it's helpful to recover the value of record itself. - self.r_ram_types.push((ram_type, record)); + self.r_ram_types.push(((ram_type as u64).into(), record)); Ok(()) } @@ -419,7 +431,7 @@ impl ConstraintSystem { self.w_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); self.w_expressions_namespace_map.push(path); - self.w_ram_types.push((ram_type, record)); + self.w_ram_types.push(((ram_type as u64).into(), record)); Ok(()) } @@ -600,7 +612,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn r_table_rlc_record( &mut self, name_fn: N, - ram_type: RAMType, + ram_type: Expression, table_spec: SetTableSpec, record: Vec>, rlc_record: Expression, @@ -631,7 +643,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn w_table_rlc_record( &mut self, name_fn: N, - ram_type: RAMType, + ram_type: Expression, table_spec: SetTableSpec, record: Vec>, rlc_record: Expression, From c2fee659be3ed747f4f1e1f4cf0bf08f0b4b4023 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 20 Oct 2025 17:35:54 +0800 Subject: [PATCH 68/91] e2e test pass --- ceno_zkvm/src/e2e.rs | 2 +- ceno_zkvm/src/scheme/mock_prover.rs | 36 +++++++++++++++------------- ceno_zkvm/src/scheme/prover.rs | 1 - ceno_zkvm/src/scheme/verifier.rs | 1 - ceno_zkvm/src/tables/ram/ram_impl.rs | 10 ++++---- 5 files changed, 27 insertions(+), 23 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 7e38b158e..45dac669f 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -252,7 +252,7 @@ impl<'a> ShardContext<'a> { #[inline(always)] pub fn aligned_prev_ts(&self, prev_cycle: Cycle) -> Cycle { - let mut ts = prev_cycle.saturating_sub(self.cur_shard_cycle_range.start as Cycle); + let mut ts = prev_cycle - self.current_shard_offset_cycle(); if ts < Tracer::SUBCYCLES_PER_INSN { ts = 0 } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 296ce3267..21adee608 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -982,14 +982,11 @@ Hints: let mut lkm_opcodes = LkMultiplicityRaw::::default(); // Process all circuits. - for ( - circuit_name, - ComposedConstrainSystem { + for (circuit_name, composed_cs) in &cs.circuit_css { + let ComposedConstrainSystem { zkvm_v1_css: cs, gkr_circuit, - }, - ) in &cs.circuit_css - { + } = &composed_cs; let is_opcode = gkr_circuit.is_some(); let [witness, structural_witness] = witnesses .get_opcode_witness(circuit_name) @@ -997,11 +994,13 @@ Hints: .unwrap_or_else(|| panic!("witness for {} should not be None", circuit_name)); let num_rows = witness.num_instances(); - if witness.num_instances() == 0 { + if witness.num_instances() + structural_witness.num_instances() == 0 + && (!composed_cs.is_static_circuit()) + { wit_mles.insert(circuit_name.clone(), vec![]); structural_wit_mles.insert(circuit_name.clone(), vec![]); fixed_mles.insert(circuit_name.clone(), vec![]); - num_instances.insert(circuit_name.clone(), num_rows); + num_instances.insert(circuit_name.clone(), 0); continue; } let mut witness = witness @@ -1202,14 +1201,16 @@ Hints: assert_eq!( writes_within_expr_dedup.insert(record_rlc), true, - "within expression write duplicated on RAMType {:?}", - $ram_type + "within expression write duplicated on RAMType {:?} annotation {:?}", + $ram_type, + annotation ); assert_eq!( writes.insert(record_rlc), true, - "crossing-chip write duplicated on RAMType {:?}", - $ram_type + "crossing-chip write duplicated on RAMType {:?} annotation {:?}", + $ram_type, + annotation ); records.push((record_rlc, row)); } @@ -1227,6 +1228,7 @@ Hints: }, ) in &cs.circuit_css { + println!("process read {circuit_name}"); let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); @@ -1326,14 +1328,16 @@ Hints: assert_eq!( reads_within_expr_dedup.insert(record), true, - "within expression read duplicated on RAMType {:?}", - $ram_type + "within expression read duplicated on RAMType {:?} annotation {:?}", + $ram_type, + annotation, ); assert_eq!( reads.insert(record), true, - "crossing-chip read duplicated on RAMType {:?}", - $ram_type + "crossing-chip read duplicated on RAMType {:?} annotation {:?}", + $ram_type, + annotation, ); records.push((record, row)); } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 2a583eae1..8df281186 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -210,7 +210,6 @@ impl< let (points, evaluations) = self.pk.circuit_pks.iter().enumerate().try_fold( (vec![], vec![]), |(mut points, mut evaluations), (index, (circuit_name, pk))| { - println!("prove circuit_name {circuit_name}"); let num_instances = circuit_name_num_instances_mapping .get(&circuit_name) .copied() diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 0c897298e..1a9d29b69 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -162,7 +162,6 @@ impl> ZKVMVerifier for (index, proof) in &vm_proof.chip_proofs { assert!(proof.num_instances > 0); let circuit_name = &self.vk.circuit_index_to_name[index]; - println!("verify circuit_name {circuit_name}"); let circuit_vk = &self.vk.circuit_vks[circuit_name]; // check chip proof is well-formed diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 1fbbd9bbd..085a6c127 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -574,15 +574,17 @@ impl LocalFinalRAMTableConfig { vec![final_cycle.expr()], ] .concat(); - cb.r_table_record( + let rlc_record = cb.rlc_chip_record(raw_final_table.clone()); + cb.r_table_rlc_record( || "final_table", // XXX we mixed all ram type here to save column allocation - RAMType::Undefined, + ram_type.expr(), SetTableSpec { len: None, structural_witins: vec![], }, raw_final_table, + rlc_record, )?; Ok(Self { @@ -838,7 +840,7 @@ impl RAMBusConfig { // local write, global read cb.w_table_rlc_record( || "local_write_record", - RAMType::Undefined, + ram_type.expr(), SetTableSpec { len: None, structural_witins: vec![sel_read], @@ -864,7 +866,7 @@ impl RAMBusConfig { // local read, global write cb.r_table_rlc_record( || "local_read_record", - RAMType::Undefined, + ram_type.expr(), SetTableSpec { len: None, structural_witins: vec![sel_write], From 74ca4f1890df1594aecedca0870da3302e60d32e Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 20 Oct 2025 17:47:41 +0800 Subject: [PATCH 69/91] cosmetics and fix lint --- ceno_emul/src/shards.rs | 9 ++++ ceno_zkvm/benches/fibonacci.rs | 5 ++- ceno_zkvm/benches/fibonacci_witness.rs | 2 + ceno_zkvm/benches/is_prime.rs | 2 + ceno_zkvm/benches/keccak.rs | 3 ++ ceno_zkvm/benches/quadratic_sorting.rs | 2 + ceno_zkvm/src/chip_handler/general.rs | 1 + ceno_zkvm/src/e2e.rs | 45 +++++++++---------- .../src/instructions/riscv/rv32im/mmu.rs | 1 - ceno_zkvm/src/scheme/cpu/mod.rs | 5 +-- ceno_zkvm/src/scheme/prover.rs | 9 +++- ceno_zkvm/src/scheme/verifier.rs | 7 +-- ceno_zkvm/src/tables/mod.rs | 8 +--- ceno_zkvm/src/tables/ram/ram_impl.rs | 27 +++++------ gkr_iop/src/gkr/layer/cpu/mod.rs | 2 - 15 files changed, 66 insertions(+), 62 deletions(-) diff --git a/ceno_emul/src/shards.rs b/ceno_emul/src/shards.rs index 935623fe3..eba152504 100644 --- a/ceno_emul/src/shards.rs +++ b/ceno_emul/src/shards.rs @@ -20,3 +20,12 @@ impl Shards { self.shard_id == self.max_num_shards - 1 } } + +impl Default for Shards { + fn default() -> Self { + Self { + shard_id: 0, + max_num_shards: 1, + } + } +} diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index eb7133344..ea359f7ed 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -13,7 +13,8 @@ use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; -use ceno_zkvm::{e2e::ShardContext, scheme::verifier::ZKVMVerifier}; +use ceno_emul::shards::Shards; +use ceno_zkvm::scheme::verifier::ZKVMVerifier; use mpcs::BasefoldDefault; use transcript::BasicTranscript; @@ -54,6 +55,7 @@ fn fibonacci_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, @@ -91,6 +93,7 @@ fn fibonacci_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, diff --git a/ceno_zkvm/benches/fibonacci_witness.rs b/ceno_zkvm/benches/fibonacci_witness.rs index 483b690d5..cc224edd2 100644 --- a/ceno_zkvm/benches/fibonacci_witness.rs +++ b/ceno_zkvm/benches/fibonacci_witness.rs @@ -9,6 +9,7 @@ use std::{fs, path::PathBuf, time::Duration}; mod alloc; use criterion::*; +use ceno_emul::shards::Shards; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; use mpcs::BasefoldDefault; @@ -65,6 +66,7 @@ fn fibonacci_witness(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, diff --git a/ceno_zkvm/benches/is_prime.rs b/ceno_zkvm/benches/is_prime.rs index b55805fb7..9d4765af1 100644 --- a/ceno_zkvm/benches/is_prime.rs +++ b/ceno_zkvm/benches/is_prime.rs @@ -8,6 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; +use ceno_emul::shards::Shards; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; @@ -62,6 +63,7 @@ fn is_prime_1(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &hints, &[], max_steps, diff --git a/ceno_zkvm/benches/keccak.rs b/ceno_zkvm/benches/keccak.rs index c1a889594..5194cba05 100644 --- a/ceno_zkvm/benches/keccak.rs +++ b/ceno_zkvm/benches/keccak.rs @@ -8,6 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; +use ceno_emul::shards::Shards; use ceno_zkvm::scheme::verifier::ZKVMVerifier; use criterion::*; use ff_ext::BabyBearExt4; @@ -51,6 +52,7 @@ fn keccak_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, @@ -85,6 +87,7 @@ fn keccak_prove(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &Vec::from(&hints), &[], max_steps, diff --git a/ceno_zkvm/benches/quadratic_sorting.rs b/ceno_zkvm/benches/quadratic_sorting.rs index dc234a03a..2f652e36e 100644 --- a/ceno_zkvm/benches/quadratic_sorting.rs +++ b/ceno_zkvm/benches/quadratic_sorting.rs @@ -13,6 +13,7 @@ use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; use mpcs::BasefoldDefault; use rand::{RngCore, SeedableRng}; +use ceno_emul::shards::Shards; criterion_group! { name = quadratic_sorting; @@ -63,6 +64,7 @@ fn quadratic_sorting_1(c: &mut Criterion) { create_prover(backend.clone()), program.clone(), platform.clone(), + Shards::default(), &hints, &[], max_steps, diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index d04c50329..dbd9961a9 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -22,6 +22,7 @@ pub trait PublicIOQuery { fn query_end_pc(&mut self) -> Result; fn query_end_cycle(&mut self) -> Result; fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; + #[allow(dead_code)] fn query_shard_id(&mut self) -> Result; } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 45dac669f..8e213118c 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -165,8 +165,7 @@ impl<'a> ShardContext<'a> { let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN as usize; let max_threads = max_usable_threads(); - // let max_record_per_thread = max_insts.div_ceil(max_threads as u64); - let expected_inst_per_shard = executed_instructions.div_ceil(max_num_shards) as usize; + let expected_inst_per_shard = executed_instructions.div_ceil(max_num_shards); let max_cycle = (executed_instructions + 1) * subcycle_per_insn; // cycle start from subcycle_per_insn let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * subcycle_per_insn + subcycle_per_insn) @@ -265,6 +264,7 @@ impl<'a> ShardContext<'a> { } #[inline(always)] + #[allow(clippy::too_many_arguments)] pub fn send( &mut self, ram_type: crate::structs::RAMType, @@ -301,28 +301,27 @@ impl<'a> ShardContext<'a> { } // check write to external mem bus - if let Some(future_touch_cycle) = self.addr_future_accesses.get(&(addr, cycle)) { - if *future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle - && self.is_current_shard_cycle(cycle) - { - let ram_record = self - .write_thread_based_record_storage - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + if let Some(future_touch_cycle) = self.addr_future_accesses.get(&(addr, cycle)) + && *future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle + && self.is_current_shard_cycle(cycle) + { + let ram_record = self + .write_thread_based_record_storage + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert( + addr, + RAMRecord { + ram_type, + id, addr, - RAMRecord { - ram_type, - id, - addr, - prev_cycle, - cycle, - prev_value, - value, - }, - ); - } + prev_cycle, + cycle, + prev_value, + value, + }, + ); } } } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 217080293..82a8d0c91 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -12,7 +12,6 @@ use crate::{ use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; use ff_ext::ExtensionField; use itertools::{Itertools, chain}; -use multilinear_extensions::mle::IntoInstanceIterMut; use std::{collections::HashSet, iter::zip, ops::Range, sync::Arc}; use witness::InstancePaddingStrategy; diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 54fbac475..10afb51ec 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -20,19 +20,16 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, gkr::{self, Evaluation, GKRProof, GKRProverOutput, layer::LayerWitness}, hal::ProverBackend, - selector::SelectorType, }; use itertools::{Itertools, chain}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - ChallengeId, Expression, Instance, ToExpr, WitnessId, + Expression, Instance, WitnessId, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, - monomialize_expr_to_wit_terms, util::ceil_log2, virtual_poly::build_eq_x_r_vec, virtual_polys::VirtualPolynomialsBuilder, }; -use p3::field::FieldAlgebra; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; use std::{collections::BTreeMap, sync::Arc}; use sumcheck::{ diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 8df281186..2c327da63 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -263,8 +263,13 @@ impl< } else { // FIXME: PROGRAM table circuit is not guaranteed to have 2^n instances // input.num_instances = 1 << input.log2_num_instances(); - let (mut table_proof, pi_in_evals, input_opening_point) = self - .create_chip_proof(circuit_name, pk, input, &mut transcript, &challenges)?; + let (table_proof, pi_in_evals, input_opening_point) = self.create_chip_proof( + circuit_name, + pk, + input, + &mut transcript, + &challenges, + )?; if cs.num_witin() > 0 || cs.num_fixed() > 0 { points.push(input_opening_point); evaluations.push(vec![ diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 1a9d29b69..b38c6e589 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -15,12 +15,10 @@ use crate::{ eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, }, }; -use gkr_iop::{gkr::GKRClaims, selector::SelectorType, utils::eq_eval_less_or_equal_than}; +use gkr_iop::gkr::GKRClaims; use itertools::{Itertools, chain, interleave, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - Expression, - Expression::WitIn, Instance, StructuralWitIn, StructuralWitInType, mle::IntoMLE, util::ceil_log2, @@ -33,7 +31,7 @@ use sumcheck::{ util::get_challenge_pows, }; use transcript::{ForkableTranscript, Transcript}; -use witness::{InstancePaddingStrategy::Default, next_pow2_instance_padding}; +use witness::next_pow2_instance_padding; pub struct ZKVMVerifier> { pub vk: ZKVMVerifyingKey, @@ -489,7 +487,6 @@ impl> ZKVMVerifier })) .collect_vec(); - let expected_max_rounds = expected_rounds.iter().cloned().max().unwrap(); let (rt_tower, prod_point_and_eval, logup_p_point_and_eval, logup_q_point_and_eval) = TowerVerify::verify( interleave(&proof.r_out_evals, &proof.w_out_evals) diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 6e9b4d9d2..d55a6a907 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -1,12 +1,6 @@ use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, structs::ProgramParams}; use ff_ext::ExtensionField; -use gkr_iop::{ - chip::Chip, - gkr::{GKRCircuit, layer::Layer}, - selector::SelectorType, -}; -use itertools::Itertools; -use multilinear_extensions::{StructuralWitInType, ToExpr}; +use gkr_iop::gkr::GKRCircuit; use std::collections::HashMap; use witness::RowMajorMatrix; diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 085a6c127..76b4e012a 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -496,13 +496,7 @@ impl DynVolatileRamTableConfig assert!(final_mem.len() <= DVRAM::max_len(&config.params)); assert!(DVRAM::max_len(&config.params).is_power_of_two()); - let params = config.params.clone(); let num_instances_padded = next_pow2_instance_padding(final_mem.len()); - // let addr_id = config.addr.id as u64; - // let addr_padding_fn = move |row: u64, col: u64| { - // assert_eq!(col, addr_id); - // DVRAM::addr(¶ms, row as usize) as u64 - // }; let mut structural_witness = RowMajorMatrix::::new( num_instances_padded, @@ -520,16 +514,16 @@ impl DynVolatileRamTableConfig .par_rows_mut() .enumerate() .for_each(|(i, structural_row)| { - if cfg!(debug_assertions) { - if let Some(addr) = final_mem.get(i).map(|rec| rec.addr) { - debug_assert_eq!( - addr, - DVRAM::addr(&config.params, i), - "rec.addr {:x} != expected {:x}", - addr, - DVRAM::addr(&config.params, i), - ); - } + if cfg!(debug_assertions) + && let Some(addr) = final_mem.get(i).map(|rec| rec.addr) + { + debug_assert_eq!( + addr, + DVRAM::addr(&config.params, i), + "rec.addr {:x} != expected {:x}", + addr, + DVRAM::addr(&config.params, i), + ); } set_val!( structural_row, @@ -727,7 +721,6 @@ impl LocalFinalRAMTableConfig { set_val!(row, self.ram_type, rec.ram_type as u64); set_val!(row, self.addr_subset, rec.addr as u64); set_val!(structural_row, selector_witin, 1u64); - () }) .count(); diff --git a/gkr_iop/src/gkr/layer/cpu/mod.rs b/gkr_iop/src/gkr/layer/cpu/mod.rs index 3719e3029..95d315f25 100644 --- a/gkr_iop/src/gkr/layer/cpu/mod.rs +++ b/gkr_iop/src/gkr/layer/cpu/mod.rs @@ -20,7 +20,6 @@ use multilinear_extensions::{ monomial::Term, virtual_poly::build_eq_x_r_vec, virtual_polys::VirtualPolynomialsBuilder, - wit_infer_by_monomial_expr, }; use rayon::{ iter::{ @@ -28,7 +27,6 @@ use rayon::{ }, slice::ParallelSlice, }; -use std::sync::Arc; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProof, IOPProverState}, From e819dc777952956dba87fe70ef8e8f8184ac8b28 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 20 Oct 2025 21:08:08 +0800 Subject: [PATCH 70/91] chores: fix test and ci --- ceno_zkvm/benches/quadratic_sorting.rs | 2 +- .../weierstrass/weierstrass_decompress.rs | 2 +- ceno_zkvm/src/scheme/mock_prover.rs | 39 ++++++++++++++----- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/ceno_zkvm/benches/quadratic_sorting.rs b/ceno_zkvm/benches/quadratic_sorting.rs index 2f652e36e..35eb494f2 100644 --- a/ceno_zkvm/benches/quadratic_sorting.rs +++ b/ceno_zkvm/benches/quadratic_sorting.rs @@ -8,12 +8,12 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; +use ceno_emul::shards::Shards; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; use mpcs::BasefoldDefault; use rand::{RngCore, SeedableRng}; -use ceno_emul::shards::Shards; criterion_group! { name = quadratic_sorting; diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index 04268d038..de03a829e 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -582,7 +582,7 @@ pub fn run_weierstrass_decompress< let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(instances.par_chunks(num_instance_per_batch)) - .zip_eq(shard_ctx_vec) + .zip(shard_ctx_vec) .for_each(|((instances, steps), mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); instances diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 21adee608..278bee488 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -26,13 +26,14 @@ use itertools::{Itertools, chain, enumerate, izip}; use multilinear_extensions::{ Expression, WitnessId, fmt, mle::{ArcMultilinearExtension, IntoMLEs, MultilinearExtension}, + util::ceil_log2, utils::{eval_by_expr, eval_by_expr_with_fixed, eval_by_expr_with_instance}, }; use p3::field::{Field, FieldAlgebra}; use rand::thread_rng; use std::{ cmp::max, - collections::{BTreeSet, HashMap, HashSet}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, fmt::Debug, fs::File, hash::Hash, @@ -978,6 +979,16 @@ Hints: let mut fixed_mles = HashMap::new(); let mut num_instances = HashMap::new(); + let circuit_index_fixed_num_instances: BTreeMap = fixed_trace + .circuit_fixed_traces + .iter() + .map(|(circuit_name, rmm)| { + ( + circuit_name.clone(), + rmm.as_ref().map(|rmm| rmm.num_instances()).unwrap_or(0), + ) + }) + .collect(); let mut lkm_tables = LkMultiplicityRaw::::default(); let mut lkm_opcodes = LkMultiplicityRaw::::default(); @@ -992,11 +1003,20 @@ Hints: .get_opcode_witness(circuit_name) .or_else(|| witnesses.get_table_witness(circuit_name)) .unwrap_or_else(|| panic!("witness for {} should not be None", circuit_name)); - let num_rows = witness.num_instances(); + let num_rows = if witness.num_instances() > 0 { + witness.num_instances() + } else if structural_witness.num_instances() > 0 { + structural_witness.num_instances() + } else if composed_cs.is_static_circuit() { + circuit_index_fixed_num_instances + .get(circuit_name) + .copied() + .unwrap_or(0) + } else { + 0 + }; - if witness.num_instances() + structural_witness.num_instances() == 0 - && (!composed_cs.is_static_circuit()) - { + if num_rows == 0 { wit_mles.insert(circuit_name.clone(), vec![]); structural_wit_mles.insert(circuit_name.clone(), vec![]); fixed_mles.insert(circuit_name.clone(), vec![]); @@ -1136,15 +1156,14 @@ Hints: if *num_rows == 0 { continue; } - let w_selector: ArcMultilinearExtension<_> = if let Some(w_selector) = &cs.w_selector { structural_witness[w_selector.selector_expr().id()].clone() } else { let mut selector = vec![E::BaseField::ONE; *num_rows]; - selector.resize(witness[0].evaluations().len(), E::BaseField::ZERO); + selector.resize(next_pow2_instance_padding(*num_rows), E::BaseField::ZERO); MultilinearExtension::from_evaluation_vec_smart( - witness[0].num_vars(), + ceil_log2(next_pow2_instance_padding(*num_rows)), selector, ) .into() @@ -1241,9 +1260,9 @@ Hints: structural_witness[r_selector.selector_expr().id()].clone() } else { let mut selector = vec![E::BaseField::ONE; *num_rows]; - selector.resize(witness[0].evaluations().len(), E::BaseField::ZERO); + selector.resize(next_pow2_instance_padding(*num_rows), E::BaseField::ZERO); MultilinearExtension::from_evaluation_vec_smart( - witness[0].num_vars(), + ceil_log2(next_pow2_instance_padding(*num_rows)), selector, ) .into() From 2632e5b58987e9d1d5a5dbbc645ee4f8094fc9a4 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 20 Oct 2025 21:12:59 +0800 Subject: [PATCH 71/91] log cleanup --- ceno_zkvm/src/scheme/mock_prover.rs | 1 - ceno_zkvm/src/tables/ram/ram_impl.rs | 32 ---------------------------- 2 files changed, 33 deletions(-) diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 278bee488..edf7a63f1 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -1247,7 +1247,6 @@ Hints: }, ) in &cs.circuit_css { - println!("process read {circuit_name}"); let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 76b4e012a..6a4b5c04a 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -115,12 +115,6 @@ impl NonVolatileTableConfigTrait< NVRAM::len(&config.params) ); - println!( - "Init: NVRAM::RAM_TYPE {:?}, raw len {}", - NVRAM::RAM_TYPE, - init_mem.len(), - ); - let mut init_table = RowMajorMatrix::::new( NVRAM::len(&config.params), num_fixed, @@ -503,12 +497,6 @@ impl DynVolatileRamTableConfig num_structural_witin, InstancePaddingStrategy::Default, ); - println!( - "Init: DVRAM::RAM_TYPE {:?}, raw len {}, padded {}", - DVRAM::RAM_TYPE, - final_mem.len(), - num_instances_padded - final_mem.len() - ); structural_witness .par_rows_mut() @@ -612,16 +600,6 @@ impl LocalFinalRAMTableConfig { .map(|(_, mem)| mem.par_iter().filter(is_current_shard_mem_record).count()) .collect(); - current_shard_mems_len - .iter() - .zip(final_mem.iter()) - .for_each(|(raw_len, (_, mem))| { - println!( - "Final: DVRAM::RAM_TYPE {:?}, raw len {}", - mem[0].ram_type, raw_len - ) - }); - // deal with non-pow2 padding for first shard // format Vec<(pad_len, pad_start_index)> let padding_info = if shard_ctx.is_first_shard() { @@ -640,16 +618,6 @@ impl LocalFinalRAMTableConfig { vec![(0, 0, RAMType::Undefined); final_mem.len()] }; - padding_info - .iter() - .zip(final_mem.iter()) - .for_each(|((pad_size, ..), (_, mem))| { - println!( - "Final: DVRAM::RAM_TYPE {:?}, pad_size {}", - mem[0].ram_type, pad_size - ) - }); - // calculate mem length let mem_lens = current_shard_mems_len .iter() From ee5196431ab279916b0da28ef25efc44f856bb49 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 20 Oct 2025 21:31:11 +0800 Subject: [PATCH 72/91] fix goldilocks circuit --- ceno_zkvm/src/e2e.rs | 2 +- .../src/instructions/riscv/arith_imm/arith_imm_circuit.rs | 4 +++- ceno_zkvm/src/instructions/riscv/div/div_circuit.rs | 6 +++++- ceno_zkvm/src/instructions/riscv/jump/jal.rs | 4 +++- ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 4 +++- .../src/instructions/riscv/logic_imm/logic_imm_circuit.rs | 8 ++++++-- ceno_zkvm/src/instructions/riscv/memory/load.rs | 4 +++- ceno_zkvm/src/instructions/riscv/memory/store.rs | 4 +++- ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs | 4 +++- ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs | 4 +++- .../src/instructions/riscv/shift_imm/shift_imm_circuit.rs | 4 +++- ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs | 6 +++++- ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs | 6 +++++- 13 files changed, 46 insertions(+), 14 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 8e213118c..8ba7b4fd1 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -115,7 +115,7 @@ pub struct ShardContext<'a> { shard_id: usize, max_num_shards: usize, max_cycle: Cycle, - // TODO this map is super huge + // TODO optimize this map as it's super huge addr_future_accesses: Cow<'a, HashMap<(WordAddr, Cycle), Cycle>>, read_thread_based_record_storage: Either>, &'a mut BTreeMap>, diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs index 8a4722a08..11d93242c 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -58,6 +59,7 @@ impl Instruction for AddiInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -77,7 +79,7 @@ impl Instruction for AddiInstruction { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs index ef5b9d936..99a73a8a4 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs @@ -75,6 +75,7 @@ use super::{ }; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{AssertLtConfig, IsEqualConfig, IsLtConfig, IsZeroConfig, Signed}, instructions::{Instruction, riscv::constants::LIMB_BITS}, @@ -310,6 +311,7 @@ impl Instruction for ArithInstruction Instruction for ArithInstruction Instruction for JalInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, ) -> Result<(), ZKVMError> { config .j_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; let rd_written = Value::new(step.rd().unwrap().value.after, lk_multiplicity); config.rd_written.assign_value(instance, rd_written); diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index f1ba94aa7..77f6ad1f8 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -5,6 +5,7 @@ use ff_ext::ExtensionField; use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -111,6 +112,7 @@ impl Instruction for JalrInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &ceno_emul::StepRecord, @@ -150,7 +152,7 @@ impl Instruction for JalrInstruction { config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index aad60b43b..596792ad8 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -6,6 +6,7 @@ use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, @@ -48,6 +49,7 @@ impl Instruction for LogicInstruction { fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, @@ -58,7 +60,7 @@ impl Instruction for LogicInstruction { InsnRecord::::imm_internal(&step.insn()).0 as u64, ); - config.assign_instance(instance, lkm, step) + config.assign_instance(instance, shard_ctx, lkm, step) } } @@ -102,10 +104,12 @@ impl LogicConfig { fn assign_instance( &self, instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - self.i_insn.assign_instance(instance, lkm, step)?; + self.i_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1_read = split_to_u8(step.rs1().unwrap().value); self.rs1_read.assign_limbs(instance, &rs1_read); diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 5945f26bd..41fbf0059 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -165,6 +166,7 @@ impl Instruction for LoadInstruction Instruction for LoadInstruction Instruction fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -124,7 +126,7 @@ impl Instruction let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config .s_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; config.rs1_read.assign_value(instance, rs1); config.rs2_read.assign_value(instance, rs2); set_val!(instance, config.imm, imm.1); diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs index bc5bc9ed4..dd919dd3e 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit.rs @@ -86,6 +86,7 @@ use p3::{field::FieldAlgebra, goldilocks::Goldilocks}; use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::{IsEqualConfig, Signed}, instructions::{ @@ -286,6 +287,7 @@ impl Instruction for MulhInstructionBas fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -312,7 +314,7 @@ impl Instruction for MulhInstructionBas // R-type instruction config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; // Assign signed values, if any, and compute low 32-bit limb of product let prod_lo_hi = match &config.sign_deps { diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs index 87374b20e..c1d83ce87 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs @@ -1,5 +1,6 @@ use crate::{ Value, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -151,6 +152,7 @@ impl Instruction for ShiftLogicalInstru fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut crate::witness::LkMultiplicity, step: &ceno_emul::StepRecord, @@ -211,7 +213,7 @@ impl Instruction for ShiftLogicalInstru config .r_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs index 0bba35411..a2fa8d032 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -132,6 +133,7 @@ impl Instruction for ShiftImmInstructio fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, @@ -168,7 +170,7 @@ impl Instruction for ShiftImmInstructio config .i_insn - .assign_instance(instance, lk_multiplicity, step)?; + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs index 3ffd9de69..b9b63acaf 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs @@ -1,5 +1,6 @@ use crate::{ Value, + e2e::ShardContext, error::ZKVMError, gadgets::SignedLtConfig, instructions::{ @@ -92,11 +93,14 @@ impl Instruction for SetLessThanInstruc fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.r_insn.assign_instance(instance, lkm, step)?; + config + .r_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1 = step.rs1().unwrap().value; let rs2 = step.rs2().unwrap().value; diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs index 266faeed3..8b93f593c 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -1,6 +1,7 @@ use crate::{ Value, circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, instructions::{ @@ -94,11 +95,14 @@ impl Instruction for SetLessThanImmInst fn assign_instance( config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, instance: &mut [E::BaseField], lkm: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - config.i_insn.assign_instance(instance, lkm, step)?; + config + .i_insn + .assign_instance(instance, shard_ctx, lkm, step)?; let rs1 = step.rs1().unwrap().value; let rs1_value = Value::new_unchecked(rs1 as Word); From 0ffa915bcf522b792b569e5f07b8f809a89ef162 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 20 Oct 2025 21:40:53 +0800 Subject: [PATCH 73/91] refactor and clippy --- ceno_cli/src/commands/common_args/ceno.rs | 1 - ceno_emul/src/lib.rs | 1 - ceno_emul/src/shards.rs | 31 ---------- ceno_zkvm/benches/fibonacci.rs | 3 +- ceno_zkvm/benches/fibonacci_witness.rs | 2 +- ceno_zkvm/benches/is_prime.rs | 2 +- ceno_zkvm/benches/keccak.rs | 3 +- ceno_zkvm/benches/quadratic_sorting.rs | 2 +- ceno_zkvm/src/bin/e2e.rs | 4 +- ceno_zkvm/src/e2e.rs | 69 +++++++++++++++-------- 10 files changed, 53 insertions(+), 65 deletions(-) delete mode 100644 ceno_emul/src/shards.rs diff --git a/ceno_cli/src/commands/common_args/ceno.rs b/ceno_cli/src/commands/common_args/ceno.rs index d3df66e77..d73841080 100644 --- a/ceno_cli/src/commands/common_args/ceno.rs +++ b/ceno_cli/src/commands/common_args/ceno.rs @@ -13,7 +13,6 @@ use ceno_zkvm::{ use clap::Args; use ff_ext::{BabyBearExt4, ExtensionField, GoldilocksExt2}; -use ceno_emul::shards::Shards; use mpcs::{ Basefold, BasefoldRSParams, PolynomialCommitmentScheme, SecurityLevel, Whir, WhirDefaultSpec, }; diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 38bd6fcb2..8f439d036 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -45,4 +45,3 @@ pub mod utils; pub mod test_utils; pub mod host_utils; -pub mod shards; diff --git a/ceno_emul/src/shards.rs b/ceno_emul/src/shards.rs deleted file mode 100644 index eba152504..000000000 --- a/ceno_emul/src/shards.rs +++ /dev/null @@ -1,31 +0,0 @@ -pub struct Shards { - pub shard_id: usize, - pub max_num_shards: usize, -} - -impl Shards { - pub fn new(shard_id: usize, max_num_shards: usize) -> Self { - assert!(shard_id < max_num_shards); - Self { - shard_id, - max_num_shards, - } - } - - pub fn is_first_shard(&self) -> bool { - self.shard_id == 0 - } - - pub fn is_last_shard(&self) -> bool { - self.shard_id == self.max_num_shards - 1 - } -} - -impl Default for Shards { - fn default() -> Self { - Self { - shard_id: 0, - max_num_shards: 1, - } - } -} diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index ea359f7ed..325c59f46 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -13,8 +13,7 @@ use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; -use ceno_emul::shards::Shards; -use ceno_zkvm::scheme::verifier::ZKVMVerifier; +use ceno_zkvm::{e2e::Shards, scheme::verifier::ZKVMVerifier}; use mpcs::BasefoldDefault; use transcript::BasicTranscript; diff --git a/ceno_zkvm/benches/fibonacci_witness.rs b/ceno_zkvm/benches/fibonacci_witness.rs index cc224edd2..d942743db 100644 --- a/ceno_zkvm/benches/fibonacci_witness.rs +++ b/ceno_zkvm/benches/fibonacci_witness.rs @@ -9,7 +9,7 @@ use std::{fs, path::PathBuf, time::Duration}; mod alloc; use criterion::*; -use ceno_emul::shards::Shards; +use ceno_zkvm::e2e::Shards; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; use mpcs::BasefoldDefault; diff --git a/ceno_zkvm/benches/is_prime.rs b/ceno_zkvm/benches/is_prime.rs index 9d4765af1..6d66ff859 100644 --- a/ceno_zkvm/benches/is_prime.rs +++ b/ceno_zkvm/benches/is_prime.rs @@ -8,7 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; -use ceno_emul::shards::Shards; +use ceno_zkvm::e2e::Shards; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; diff --git a/ceno_zkvm/benches/keccak.rs b/ceno_zkvm/benches/keccak.rs index 5194cba05..19011d460 100644 --- a/ceno_zkvm/benches/keccak.rs +++ b/ceno_zkvm/benches/keccak.rs @@ -8,8 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; -use ceno_emul::shards::Shards; -use ceno_zkvm::scheme::verifier::ZKVMVerifier; +use ceno_zkvm::{e2e::Shards, scheme::verifier::ZKVMVerifier}; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; diff --git a/ceno_zkvm/benches/quadratic_sorting.rs b/ceno_zkvm/benches/quadratic_sorting.rs index 35eb494f2..93389c388 100644 --- a/ceno_zkvm/benches/quadratic_sorting.rs +++ b/ceno_zkvm/benches/quadratic_sorting.rs @@ -8,7 +8,7 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; -use ceno_emul::shards::Shards; +use ceno_zkvm::e2e::Shards; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 3496477cb..52df7e6da 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -1,10 +1,10 @@ -use ceno_emul::{IterAddresses, Platform, Program, WORD_SIZE, Word, shards::Shards}; +use ceno_emul::{IterAddresses, Platform, Program, WORD_SIZE, Word}; use ceno_host::{CenoStdin, memory_from_file}; #[cfg(all(feature = "jemalloc", unix, not(test)))] use ceno_zkvm::print_allocated_bytes; use ceno_zkvm::{ e2e::{ - Checkpoint, FieldType, PcsKind, Preset, run_e2e_with_checkpoint, setup_platform, + Checkpoint, FieldType, PcsKind, Preset, Shards, run_e2e_with_checkpoint, setup_platform, setup_platform_debug, verify, }, scheme::{ diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 8ba7b4fd1..912bcf658 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -18,7 +18,6 @@ use crate::{ use ceno_emul::{ Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, IterAddresses, Platform, Program, StepRecord, Tracer, VMState, WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, - shards::Shards, }; use clap::ValueEnum; use either::Either; @@ -111,9 +110,41 @@ pub struct RAMRecord { pub value: Word, } +#[derive(Clone, Debug)] +pub struct Shards { + pub shard_id: usize, + pub max_num_shards: usize, +} + +impl Shards { + pub fn new(shard_id: usize, max_num_shards: usize) -> Self { + assert!(shard_id < max_num_shards); + Self { + shard_id, + max_num_shards, + } + } + + pub fn is_first_shard(&self) -> bool { + self.shard_id == 0 + } + + pub fn is_last_shard(&self) -> bool { + self.shard_id == self.max_num_shards - 1 + } +} + +impl Default for Shards { + fn default() -> Self { + Self { + shard_id: 0, + max_num_shards: 1, + } + } +} + pub struct ShardContext<'a> { - shard_id: usize, - max_num_shards: usize, + shards: Shards, max_cycle: Cycle, // TODO optimize this map as it's super huge addr_future_accesses: Cow<'a, HashMap<(WordAddr, Cycle), Cycle>>, @@ -128,8 +159,7 @@ impl<'a> Default for ShardContext<'a> { fn default() -> Self { let max_threads = max_usable_threads(); Self { - shard_id: 0, - max_num_shards: 1, + shards: Shards::default(), max_cycle: Cycle::default(), addr_future_accesses: Cow::Owned(HashMap::new()), read_thread_based_record_storage: Either::Left( @@ -151,15 +181,14 @@ impl<'a> Default for ShardContext<'a> { impl<'a> ShardContext<'a> { pub fn new( - shard_id: usize, - max_num_shards: usize, + shards: Shards, executed_instructions: usize, addr_future_accesses: HashMap<(WordAddr, Cycle), Cycle>, ) -> Self { // current strategy: at least each shard deal with one instruction - let max_num_shards = max_num_shards.min(executed_instructions); + let max_num_shards = shards.max_num_shards.min(executed_instructions); assert!( - shard_id < max_num_shards, + shards.shard_id < max_num_shards, "implement mechanism to skip current shard proof" ); @@ -167,14 +196,14 @@ impl<'a> ShardContext<'a> { let max_threads = max_usable_threads(); let expected_inst_per_shard = executed_instructions.div_ceil(max_num_shards); let max_cycle = (executed_instructions + 1) * subcycle_per_insn; // cycle start from subcycle_per_insn - let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * subcycle_per_insn + let cur_shard_cycle_range = (shards.shard_id * expected_inst_per_shard * subcycle_per_insn + subcycle_per_insn) - ..((shard_id + 1) * expected_inst_per_shard * subcycle_per_insn + subcycle_per_insn) + ..((shards.shard_id + 1) * expected_inst_per_shard * subcycle_per_insn + + subcycle_per_insn) .min(max_cycle); ShardContext { - shard_id, - max_num_shards, + shards, max_cycle: max_cycle as Cycle, addr_future_accesses: Cow::Owned(addr_future_accesses), // TODO with_capacity optimisation @@ -207,8 +236,7 @@ impl<'a> ShardContext<'a> { .iter_mut() .zip(write_thread_based_record_storage.iter_mut()) .map(|(read, write)| ShardContext { - shard_id: self.shard_id, - max_num_shards: self.max_num_shards, + shards: self.shards.clone(), max_cycle: self.max_cycle, addr_future_accesses: Cow::Borrowed(self.addr_future_accesses.as_ref()), read_thread_based_record_storage: Either::Right(read), @@ -236,12 +264,12 @@ impl<'a> ShardContext<'a> { #[inline(always)] pub fn is_first_shard(&self) -> bool { - self.shard_id == 0 + self.shards.shard_id == 0 } #[inline(always)] pub fn is_last_shard(&self) -> bool { - self.shard_id == self.max_num_shards - 1 + self.shards.shard_id == self.shards.max_num_shards - 1 } #[inline(always)] @@ -511,12 +539,7 @@ pub fn emulate_program<'a>( ), ); - let shard_ctx = ShardContext::new( - shards.shard_id, - shards.max_num_shards, - insts, - vm.take_tracer().next_accesses(), - ); + let shard_ctx = ShardContext::new(shards.clone(), insts, vm.take_tracer().next_accesses()); EmulationResult { pi, From 1739d4ab6cb87a2acb41df65dc811272c80585fc Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 20 Oct 2025 22:05:37 +0800 Subject: [PATCH 74/91] chores fix missing padding of DynVolatileRamTableConfig --- ceno_zkvm/src/tables/ram/ram_impl.rs | 84 +++++++++++++++------------- 1 file changed, 46 insertions(+), 38 deletions(-) diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 6a4b5c04a..554c71235 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -6,7 +6,7 @@ use itertools::Itertools; use rayon::iter::{ IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, }; -use std::{marker::PhantomData, sync::Arc}; +use std::marker::PhantomData; use witness::{ InstancePaddingStrategy, RowMajorMatrix, next_pow2_instance_padding, set_fixed_val, set_val, }; @@ -366,54 +366,62 @@ impl DynVolatileRamTableConfig num_structural_witin: usize, final_mem: &[MemFinalRecord], ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { - assert!(final_mem.len() <= DVRAM::max_len(&config.params)); - assert!(DVRAM::max_len(&config.params).is_power_of_two()); + if final_mem.is_empty() { + return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); + } - let params = config.params.clone(); - let addr_id = config.addr.id as u64; - let addr_padding_fn = move |row: u64, col: u64| { - assert_eq!(col, addr_id); - DVRAM::addr(¶ms, row as usize) as u64 - }; + let num_instances_padded = next_pow2_instance_padding(final_mem.len()); + assert!(num_instances_padded <= DVRAM::max_len(&config.params)); + assert!(DVRAM::max_len(&config.params).is_power_of_two()); - let mut witness = - RowMajorMatrix::::new(final_mem.len(), num_witin, InstancePaddingStrategy::Default); + let mut witness = RowMajorMatrix::::new( + num_instances_padded, + num_witin, + InstancePaddingStrategy::Default, + ); let mut structural_witness = RowMajorMatrix::::new( - final_mem.len(), + num_instances_padded, num_structural_witin, - InstancePaddingStrategy::Custom(Arc::new(addr_padding_fn)), + InstancePaddingStrategy::Default, ); witness .par_rows_mut() - .zip(structural_witness.par_rows_mut()) - .zip(final_mem) + .zip_eq(structural_witness.par_rows_mut()) .enumerate() - .for_each(|(i, ((row, structural_row), rec))| { - assert_eq!( - rec.addr, - DVRAM::addr(&config.params, i), - "rec.addr {:x} != expected {:x}", - rec.addr, - DVRAM::addr(&config.params, i), - ); - - if config.final_v.len() == 1 { - // Assign value directly. - set_val!(row, config.final_v[0], rec.value as u64); - } else { - // Assign value limbs. - config.final_v.iter().enumerate().for_each(|(l, limb)| { - let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; - set_val!(row, limb, val as u64); - }); + .for_each(|(i, (row, structural_row))| { + if cfg!(debug_assertions) + && let Some(addr) = final_mem.get(i).map(|rec| rec.addr) + { + debug_assert_eq!( + addr, + DVRAM::addr(&config.params, i), + "rec.addr {:x} != expected {:x}", + addr, + DVRAM::addr(&config.params, i), + ); } - set_val!(row, config.final_cycle, rec.cycle); - set_val!(structural_row, config.addr, rec.addr as u64); + if let Some(rec) = final_mem.get(i) { + if config.final_v.len() == 1 { + // Assign value directly. + set_val!(row, config.final_v[0], rec.value as u64); + } else { + // Assign value limbs. + config.final_v.iter().enumerate().for_each(|(l, limb)| { + let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK; + set_val!(row, limb, val as u64); + }); + } + set_val!(row, config.final_cycle, rec.cycle); + } + set_val!( + structural_row, + config.addr, + DVRAM::addr(&config.params, i) as u64 + ); }); - structural_witness.padding_by_strategy(); Ok([witness, structural_witness]) } } @@ -487,10 +495,10 @@ impl DynVolatileRamTableConfig if final_mem.is_empty() { return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); } - assert!(final_mem.len() <= DVRAM::max_len(&config.params)); - assert!(DVRAM::max_len(&config.params).is_power_of_two()); let num_instances_padded = next_pow2_instance_padding(final_mem.len()); + assert!(num_instances_padded <= DVRAM::max_len(&config.params)); + assert!(DVRAM::max_len(&config.params).is_power_of_two()); let mut structural_witness = RowMajorMatrix::::new( num_instances_padded, From 6601cf193ae6c1acb187a042f8b376aa5f3fddfc Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 21 Oct 2025 00:40:56 +0800 Subject: [PATCH 75/91] evaluate selector with context --- ceno_zkvm/src/instructions.rs | 5 +- ceno_zkvm/src/instructions/global.rs | 60 ++++----- ceno_zkvm/src/precompiles/bitwise_keccakf.rs | 14 ++- ceno_zkvm/src/precompiles/lookup_keccakf.rs | 7 +- .../weierstrass/weierstrass_add.rs | 12 +- .../weierstrass/weierstrass_decompress.rs | 12 +- .../weierstrass/weierstrass_double.rs | 12 +- ceno_zkvm/src/scheme/cpu/mod.rs | 48 +++++++- ceno_zkvm/src/scheme/hal.rs | 3 + ceno_zkvm/src/scheme/prover.rs | 4 + ceno_zkvm/src/scheme/tests.rs | 2 + ceno_zkvm/src/scheme/verifier.rs | 34 ++++- gkr_iop/src/cpu/mod.rs | 2 +- gkr_iop/src/gkr.rs | 9 +- gkr_iop/src/gkr/layer.rs | 10 +- gkr_iop/src/gkr/layer/cpu/mod.rs | 14 ++- gkr_iop/src/gkr/layer/hal.rs | 3 +- gkr_iop/src/gkr/layer/zerocheck_layer.rs | 42 +++---- gkr_iop/src/selector.rs | 116 ++++++++++-------- 19 files changed, 256 insertions(+), 153 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index e85643c6d..2ad63d163 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -60,10 +60,7 @@ pub trait Instruction { descending: false, }, ); - let selector_type = SelectorType::Prefix { - offset: 0, - expression: selector.expr(), - }; + let selector_type = SelectorType::Prefix(selector.expr()); // all shared the same selector let (out_evals, mut chip) = ( diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 8f0abd164..7cf5c86e4 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -125,12 +125,12 @@ impl GlobalConfig { cb.read_record( || "r_record", - gkr_iop::RAMType::Register, // TODO fixme + gkr_iop::RAMType::Memory, // TODO fixme record.clone(), )?; cb.write_record( || "w_record", - gkr_iop::RAMType::Register, // TODO fixme + gkr_iop::RAMType::Memory, // TODO fixme record.clone(), )?; @@ -313,22 +313,13 @@ impl descending: false, }, ); - let selector_r = SelectorType::Prefix { - offset: 0, - expression: selector_r.expr(), - }; + let selector_r = SelectorType::Prefix(selector_r.expr()); // note that the actual offset should be set by prover // depending on the number of local read instances - let selector_w = SelectorType::Prefix { - offset: 0, - expression: selector_w.expr(), - }; + let selector_w = SelectorType::Prefix(selector_w.expr()); // TODO: when selector_r = 1 => selector_zero = 1 // when selector_w = 1 => selector_zero = 1 - let selector_zero = SelectorType::Prefix { - offset: 0, - expression: selector_zero.expr(), - }; + let selector_zero = SelectorType::Prefix(selector_zero.expr()); cb.cs.r_selector = Some(selector_r); cb.cs.w_selector = Some(selector_w); @@ -411,8 +402,8 @@ impl let nthreads = max_usable_threads(); + // local read => global write let num_local_reads = steps.iter().filter(|s| s.is_write).count(); - let num_local_writes = steps.len() - num_local_reads; let num_instance_per_batch = if steps.len() > 256 { steps.len().div_ceil(nthreads) @@ -470,10 +461,11 @@ impl mod tests { use std::sync::Arc; - use ff_ext::{BabyBearExt4, PoseidonField}; + use ff_ext::{BabyBearExt4, FromUniformBytes, PoseidonField}; use itertools::Itertools; use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; use p3::{babybear::BabyBear, field::FieldAlgebra}; + use rand::thread_rng; use tracing_forest::{ForestLayer, util::LevelFilter}; use tracing_subscriber::{ EnvFilter, Registry, fmt, layer::SubscriberExt, util::SubscriberInitExt, @@ -507,15 +499,15 @@ mod tests { let default_filter = EnvFilter::builder() .with_default_directive(LevelFilter::DEBUG.into()) .from_env_lossy(); - let fmt_layer = fmt::layer() - .compact() - .with_thread_ids(false) - .with_thread_names(false) - .without_time(); + // let fmt_layer = fmt::layer() + // .compact() + // .with_thread_ids(false) + // .with_thread_names(false) + // .without_time(); Registry::default() .with(ForestLayer::default()) - .with(fmt_layer) + // .with(fmt_layer) .with(default_filter) .init(); @@ -532,8 +524,8 @@ mod tests { .unwrap(); // create a bunch of random memory read/write records - let n_reads = 10; - let n_writes = 10; + let n_reads = 16; + let n_writes = 16; let global_reads = (0..n_reads) .map(|i| { let addr = i * 8; @@ -544,7 +536,7 @@ mod tests { ram_type: RAMType::Memory, value: value as u32, shard: 1, - local_clk: 0, + local_clk: i, global_clk: i, is_write: false, } @@ -594,9 +586,9 @@ mod tests { &config, cs.num_witin as usize, cs.num_structural_witin as usize, - global_reads + global_writes // local reads .into_iter() - .chain(global_writes.into_iter()) + .chain(global_reads.into_iter()) // local writes .collect::>(), ) .unwrap(); @@ -629,10 +621,13 @@ mod tests { structural_witness: witness[1].to_mles().into_iter().map(Arc::new).collect(), fixed: vec![], public_input: public_input_mles.clone(), + num_read_instances: n_writes as usize, + num_write_instances: n_reads as usize, num_instances: (n_reads + n_writes) as usize, }; - let challenges = [E::ONE, E::ONE]; - let (proof, _, point) = zkvm_prover + let mut rng = thread_rng(); + let challenges = [E::random(&mut rng), E::random(&mut rng)]; + let (proof, _pi_evals, point) = zkvm_prover .create_chip_proof( "global chip", &pk, @@ -648,6 +643,13 @@ mod tests { .iter() .map(|mle| mle.evaluate(&point[..mle.num_vars()])) .collect_vec(); + pi_evals + .iter() + .skip(8) + .zip(_pi_evals.values()) + .for_each(|(a, b)| { + assert_eq!(*a, *b); + }); let opening_point = verifier .verify_opcode_proof( "global", diff --git a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs index e25ee972d..51bf0092a 100644 --- a/ceno_zkvm/src/precompiles/bitwise_keccakf.rs +++ b/ceno_zkvm/src/precompiles/bitwise_keccakf.rs @@ -30,7 +30,7 @@ use gkr_iop::{ layer::Layer, layer_constraint_system::{LayerConstraintSystem, expansion_expr}, }, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, utils::{indices_arr_with_offset, lk_multiplicity::LkMultiplicity, wits_fixed_and_eqs}, }; @@ -963,6 +963,14 @@ pub fn run_keccakf + 'stat }; let span = entered_span!("prove", profiling_1 = true); + let selector_ctxs = vec![ + SelectorContext::new(0, num_instances, log2_num_instances); + gkr_circuit + .layers + .first() + .map(|layer| layer.out_sel_and_eval_exprs.len()) + .unwrap() + ]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -972,7 +980,7 @@ pub fn run_keccakf + 'stat &[], &[], &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -993,7 +1001,7 @@ pub fn run_keccakf + 'stat &[], &[], &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 2fcd8de79..8ee53e6c7 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -14,7 +14,7 @@ use gkr_iop::{ layer::Layer, mock::MockProver, }, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, utils::lk_multiplicity::LkMultiplicity, }; use itertools::{Itertools, iproduct, izip, zip_eq}; @@ -1222,6 +1222,7 @@ pub fn run_faster_keccakf } let span = entered_span!("create_proof", profiling_2 = true); + let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance_rounds); 3]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -1231,7 +1232,7 @@ pub fn run_faster_keccakf &[], &challenges, &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -1260,7 +1261,7 @@ pub fn run_faster_keccakf &[], &challenges, &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index ccbda01f9..8eb68deb5 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -36,7 +36,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; @@ -140,10 +140,7 @@ impl WeierstrassAddAssignLayout { descending: false, }, ); - let sel = SelectorType::Prefix { - offset: 0, - expression: eq.expr(), - }; + let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { sel_mem_read: sel.clone(), sel_mem_write: sel.clone(), @@ -750,6 +747,7 @@ pub fn run_weierstrass_add< } let span = entered_span!("create_proof", profiling_2 = true); + let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -759,7 +757,7 @@ pub fn run_weierstrass_add< &[], &challenges, &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -784,7 +782,7 @@ pub fn run_weierstrass_add< &[], &challenges, &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index 0d6406431..5e970e7a7 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -36,7 +36,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; @@ -158,10 +158,7 @@ impl descending: false, }, ); - let sel = SelectorType::Prefix { - offset: 0, - expression: eq.expr(), - }; + let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { sel_mem_read: sel.clone(), sel_mem_write: sel.clone(), @@ -730,6 +727,7 @@ pub fn run_weierstrass_decompress< } let span = entered_span!("create_proof", profiling_2 = true); + let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -739,7 +737,7 @@ pub fn run_weierstrass_decompress< &[], &challenges, &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -764,7 +762,7 @@ pub fn run_weierstrass_decompress< &[], &challenges, &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index decaa317f..4bf0147eb 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -36,7 +36,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; @@ -142,10 +142,7 @@ impl descending: false, }, ); - let sel = SelectorType::Prefix { - offset: 0, - expression: eq.expr(), - }; + let sel = SelectorType::Prefix(eq.expr()); let selector_type_layout = SelectorTypeLayout { sel_mem_read: sel.clone(), sel_mem_write: sel.clone(), @@ -752,6 +749,7 @@ pub fn run_weierstrass_double< } let span = entered_span!("create_proof", profiling_2 = true); + let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -761,7 +759,7 @@ pub fn run_weierstrass_double< &[], &challenges, &mut prover_transcript, - num_instances, + &selector_ctxs, ) .expect("Failed to prove phase"); exit_span!(span); @@ -786,7 +784,7 @@ pub fn run_weierstrass_double< &[], &challenges, &mut verifier_transcript, - num_instances, + &selector_ctxs, ) .expect("GKR verify failed"); diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 06dd70ef6..312c9fa32 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -21,7 +21,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, gkr::{self, Evaluation, GKRProof, GKRProverOutput, layer::LayerWitness}, hal::ProverBackend, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; use itertools::{Itertools, chain}; use mpcs::{Point, PolynomialCommitmentScheme}; @@ -83,12 +83,14 @@ impl CpuEccProver { let mut expr_builder = VirtualPolynomialsBuilder::new(num_threads, out_rt.len()); - let sel = SelectorType::Prefix { + let sel = SelectorType::Prefix(0.into()); + let num_instances = (1 << n) - 1; + let sel_ctx = SelectorContext { offset: 0, - expression: 0.into(), + num_instances, + num_vars: n, }; - let num_instances = (1 << n) - 1; - let mut sel_mle: MultilinearExtension<'_, E> = sel.compute(&out_rt, num_instances).unwrap(); + let mut sel_mle: MultilinearExtension<'_, E> = sel.compute(&out_rt, &sel_ctx).unwrap(); let sel_expr = expr_builder.lift(sel_mle.to_either()); let mut exprs = vec![]; @@ -859,6 +861,40 @@ impl> MainSumcheckProver> MainSumcheckProver { pub structural_witness: Vec>>, pub fixed: Vec>>, pub public_input: Vec>>, + pub num_read_instances: usize, + pub num_write_instances: usize, pub num_instances: usize, } @@ -47,6 +49,7 @@ impl<'a, PB: ProverBackend> ProofInput<'a, PB> { } } +#[derive(Clone)] pub struct TowerProverSpec<'a, PB: ProverBackend> { pub witness: Vec>>, } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index cd2f64fde..8f24acb7f 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -234,6 +234,8 @@ impl< fixed, structural_witness, public_input, + num_read_instances: num_instances, // TODO: fixme + num_write_instances: num_instances, // TODO: fixme num_instances, }; @@ -344,6 +346,8 @@ impl< num_var_with_rotation, ); + // override cs.gkr_circuit.layers + // 1. prove the main constraints among witness polynomials // 2. prove the relation between last layer in the tower and read/write/logup records let (input_opening_point, evals, main_sumcheck_proofs, gkr_iop_proof) = self diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 44e88836b..99aeb9648 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -198,6 +198,8 @@ fn test_rw_lk_expression_combination() { witness: wits_in, structural_witness: structural_in, public_input: vec![], + num_read_instances: num_instances, + num_write_instances: num_instances, num_instances, }; let (proof, _, _) = prover diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index abe0f6ec2..e3229971d 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -5,7 +5,7 @@ use ff_ext::ExtensionField; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use gkr_iop::{gkr::GKRClaims, utils::eq_eval_less_or_equal_than}; +use gkr_iop::{gkr::GKRClaims, selector::SelectorContext, utils::eq_eval_less_or_equal_than}; use itertools::{Itertools, chain, interleave, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ @@ -415,6 +415,36 @@ impl> ZKVMVerifier debug_assert_eq!(logup_q_evals.len(), lk_counts_per_instance); let gkr_circuit = gkr_circuit.as_ref().unwrap(); + let selector_ctxs = if cs.ec_final_sum.is_empty() { + // it's not global chip + vec![ + SelectorContext::new(0, num_instances, num_var_with_rotation); + gkr_circuit + .layers + .first() + .map(|layer| layer.out_sel_and_eval_exprs.len()) + .unwrap_or(0) + ] + } else { + // it's global chip + vec![ + SelectorContext { + offset: 0, + num_instances: proof.num_read_instances, + num_vars: num_var_with_rotation, + }, + SelectorContext { + offset: proof.num_read_instances, + num_instances: proof.num_write_instances, + num_vars: num_var_with_rotation, + }, + SelectorContext { + offset: 0, + num_instances: proof.num_instances, + num_vars: num_var_with_rotation, + }, + ] + }; let GKRClaims(opening_evaluations) = gkr_circuit.verify( num_var_with_rotation, proof.gkr_iop_proof.clone().unwrap(), @@ -422,7 +452,7 @@ impl> ZKVMVerifier pi, challenges, transcript, - num_instances, + &selector_ctxs, )?; Ok(opening_evaluations[0].point.clone()) } diff --git a/gkr_iop/src/cpu/mod.rs b/gkr_iop/src/cpu/mod.rs index 32ccb77a0..2e22fc5fc 100644 --- a/gkr_iop/src/cpu/mod.rs +++ b/gkr_iop/src/cpu/mod.rs @@ -4,7 +4,7 @@ use crate::{ hal::{MultilinearPolynomial, ProtocolWitnessGeneratorProver, ProverBackend, ProverDevice}, }; use ff_ext::ExtensionField; -use itertools::izip; +use itertools::{Itertools, izip}; use mpcs::{PolynomialCommitmentScheme, SecurityLevel, SecurityLevel::Conjecture100bits}; use multilinear_extensions::{ mle::{ArcMultilinearExtension, MultilinearExtension, Point}, diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 7d80229fd..b06e8fe71 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -11,6 +11,7 @@ use transcript::Transcript; use crate::{ error::BackendError, hal::{ProverBackend, ProverDevice}, + selector::SelectorContext, }; pub mod booleanhypercube; @@ -77,7 +78,7 @@ impl GKRCircuit { pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result>, BackendError> { let mut running_evals = out_evals.to_vec(); // running evals is a global referable within chip @@ -97,7 +98,7 @@ impl GKRCircuit { pub_io_evals, &mut challenges, transcript, - num_instances, + selector_ctxs, ); exit_span!(span); res @@ -122,7 +123,7 @@ impl GKRCircuit { pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result>, BackendError> where E: ExtensionField, @@ -141,7 +142,7 @@ impl GKRCircuit { pub_io_evals, &mut challenges, transcript, - num_instances, + selector_ctxs, )?; } diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index a337dde30..b981e2b25 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -20,7 +20,7 @@ use crate::{ error::BackendError, evaluation::EvalExpression, hal::{MultilinearPolynomial, ProverBackend, ProverDevice}, - selector::SelectorType, + selector::{SelectorContext, SelectorType}, }; pub mod cpu; @@ -183,7 +183,7 @@ impl Layer { pub_io_evals: &[E], challenges: &mut Vec, transcript: &mut T, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> LayerProof { self.update_challenges(challenges, transcript); let mut eval_and_dedup_points = self.extract_claim_and_point(claims, challenges); @@ -203,7 +203,7 @@ impl Layer { pub_io_evals, challenges, transcript, - num_instances, + selector_ctxs, ) } LayerType::Linear => { @@ -231,7 +231,7 @@ impl Layer { pub_io_evals: &[E], challenges: &mut Vec, transcript: &mut Trans, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result<(), BackendError> { self.update_challenges(challenges, transcript); let mut eval_and_dedup_points = self.extract_claim_and_point(claims, challenges); @@ -245,7 +245,7 @@ impl Layer { pub_io_evals, challenges, transcript, - num_instances, + selector_ctxs, )?, LayerType::Linear => { assert_eq!(eval_and_dedup_points.len(), 1); diff --git a/gkr_iop/src/gkr/layer/cpu/mod.rs b/gkr_iop/src/gkr/layer/cpu/mod.rs index fa4c33c5e..7b6e9adec 100644 --- a/gkr_iop/src/gkr/layer/cpu/mod.rs +++ b/gkr_iop/src/gkr/layer/cpu/mod.rs @@ -8,6 +8,7 @@ use crate::{ zerocheck_layer::RotationPoints, }, }, + selector::SelectorContext, utils::{rotation_next_base_mle, rotation_selector}, }; use either::Either; @@ -113,7 +114,7 @@ impl> ZerocheckLayerProver pub_io_evals: &[ as ProverBackend>::E], challenges: &[ as ProverBackend>::E], transcript: &mut impl Transcript< as ProverBackend>::E>, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> ( LayerProof< as ProverBackend>::E>, Point< as ProverBackend>::E>, @@ -126,6 +127,12 @@ impl> ZerocheckLayerProver layer.out_sel_and_eval_exprs.len(), out_points.len(), ); + assert_eq!( + layer.out_sel_and_eval_exprs.len(), + selector_ctxs.len(), + "selector_ctxs length {}", + selector_ctxs.len() + ); let (_, raw_rotation_exprs) = &layer.rotation_exprs; let (rotation_proof, rotation_left, rotation_right, rotation_point) = @@ -173,7 +180,10 @@ impl> ZerocheckLayerProver .out_sel_and_eval_exprs .par_iter() .zip(out_points.par_iter()) - .filter_map(|((sel_type, _), point)| sel_type.compute(point, num_instances)) + .zip(selector_ctxs.par_iter()) + .filter_map(|(((sel_type, _), point), selector_ctx)| { + sel_type.compute(point, selector_ctx) + }) // for rotation left point .chain(rotation_left.par_iter().map(|rotation_left| { MultilinearExtension::from_evaluations_ext_vec( diff --git a/gkr_iop/src/gkr/layer/hal.rs b/gkr_iop/src/gkr/layer/hal.rs index 06508e298..c6cce26a0 100644 --- a/gkr_iop/src/gkr/layer/hal.rs +++ b/gkr_iop/src/gkr/layer/hal.rs @@ -4,6 +4,7 @@ use transcript::Transcript; use crate::{ gkr::layer::{Layer, LayerWitness, sumcheck_layer::LayerProof}, hal::ProverBackend, + selector::SelectorContext, }; pub trait LinearLayerProver { @@ -37,6 +38,6 @@ pub trait ZerocheckLayerProver { pub_io_evals: &[PB::E], challenges: &[PB::E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> (LayerProof, Point); } diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index 8a00132cb..cd4a8036a 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -27,7 +27,7 @@ use crate::{ }, }, hal::{ProverBackend, ProverDevice}, - selector::SelectorType, + selector::{self, SelectorContext, SelectorType}, utils::rotation_selector_eval, }; @@ -58,7 +58,7 @@ pub trait ZerocheckLayer { pub_io_evals: &[PB::E], challenges: &[PB::E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> (LayerProof, Point); #[allow(clippy::too_many_arguments)] @@ -70,7 +70,7 @@ pub trait ZerocheckLayer { pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result, BackendError>; } @@ -177,7 +177,7 @@ impl ZerocheckLayer for Layer { pub_io_evals: &[PB::E], challenges: &[PB::E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> (LayerProof, Point) { >::prove( self, @@ -188,7 +188,7 @@ impl ZerocheckLayer for Layer { pub_io_evals, challenges, transcript, - num_instances, + selector_ctxs, ) } @@ -200,7 +200,7 @@ impl ZerocheckLayer for Layer { pub_io_evals: &[E], challenges: &[E], transcript: &mut impl Transcript, - num_instances: usize, + selector_ctxs: &[SelectorContext], ) -> Result, BackendError> { assert_eq!( self.out_sel_and_eval_exprs.len(), @@ -284,17 +284,20 @@ impl ZerocheckLayer for Layer { let in_point = in_point.into_iter().map(|c| c.elements).collect_vec(); // eval eq and set to respective witin - izip!(&self.out_sel_and_eval_exprs, &eval_and_dedup_points).for_each( - |((sel_type, _), (_, out_point))| { - sel_type.evaluate( - &mut main_evals, - out_point.as_ref().unwrap(), - &in_point, - num_instances, - self.n_witin, - ); - }, - ); + izip!( + &self.out_sel_and_eval_exprs, + &eval_and_dedup_points, + selector_ctxs.iter() + ) + .for_each(|((sel_type, _), (_, out_point), selector_ctx)| { + sel_type.evaluate( + &mut main_evals, + out_point.as_ref().unwrap(), + &in_point, + selector_ctx, + self.n_witin, + ); + }); let got_claim = eval_by_expr_with_instance( &[], @@ -450,10 +453,7 @@ pub fn extend_exprs_with_rotation( let expr = match sel_type { SelectorType::None => zero_check_expr, SelectorType::Whole(sel) - | SelectorType::Prefix { - offset: _, - expression: sel, - } + | SelectorType::Prefix(sel) | SelectorType::OrderedSparse32 { expression: sel, .. } => match_expr(sel) * zero_check_expr, diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index 05dac293d..f9436bd1d 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -18,6 +18,24 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; use crate::{gkr::booleanhypercube::CYCLIC_POW2_5, utils::eq_eval_less_or_equal_than}; +/// Provide context for selector's instantiation at runtime +#[derive(Clone, Debug)] +pub struct SelectorContext { + pub offset: usize, + pub num_instances: usize, + pub num_vars: usize, +} + +impl SelectorContext { + pub fn new(offset: usize, num_instances: usize, num_vars: usize) -> Self { + Self { + offset, + num_instances, + num_vars, + } + } +} + /// Selector selects part of the witnesses in the sumcheck protocol. #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] #[serde(bound( @@ -27,15 +45,8 @@ use crate::{gkr::booleanhypercube::CYCLIC_POW2_5, utils::eq_eval_less_or_equal_t pub enum SelectorType { None, Whole(Expression), - /// Select a prefix as the instances, padded with a field element. - /// 1. [0, offset) are zeros; - /// 2. [offset, offset + num_instances) are ones, - /// 3. [offset + num_instances, 2^n) are zeros. - Prefix { - // offset is not fixed at setup time. - offset: usize, - expression: Expression, - }, + /// Select part of the instances, other parts padded with a field element. + Prefix(Expression), /// selector activates on the specified `indices`, which are assumed to be in ascending order. /// each index corresponds to a position within a fixed-size chunk (e.g., size 32), OrderedSparse32 { @@ -45,36 +56,31 @@ pub enum SelectorType { } impl SelectorType { - pub fn as_mle( - &self, - num_instances: usize, - num_vars: usize, - ) -> Option> { + /// Returns an MultilinearExtension with `ctx.num_vars` variables whenever applicable + pub fn to_mle(&self, ctx: &SelectorContext) -> Option> { match self { SelectorType::None => None, SelectorType::Whole(_) => { - assert_eq!(ceil_log2(num_instances), num_vars); + assert_eq!(ceil_log2(ctx.num_instances), ctx.num_vars); Some( - (0..(1 << num_vars)) + (0..(1 << ctx.num_vars)) .into_par_iter() .map(|_| E::BaseField::ONE) .collect::>() .into_mle(), ) } - SelectorType::Prefix { - offset, - expression: _, - } => { - assert!(*offset + num_instances <= (1 << num_vars)); - let end = *offset + num_instances; + SelectorType::Prefix(_) => { + assert!(ctx.offset + ctx.num_instances <= (1 << ctx.num_vars)); + let start = ctx.offset; + let end = start + ctx.num_instances; Some( - (0..*offset) + (0..start) .into_par_iter() .map(|_| E::BaseField::ZERO) - .chain((*offset..end).into_par_iter().map(|_| E::BaseField::ONE)) + .chain((start..end).into_par_iter().map(|_| E::BaseField::ONE)) .chain( - (end..(1 << num_vars)) + (end..(1 << ctx.num_vars)) .into_par_iter() .map(|_| E::BaseField::ZERO), ) @@ -86,12 +92,12 @@ impl SelectorType { indices, expression: _, } => { - assert_eq!(ceil_log2(num_instances), num_vars); + assert_eq!(ceil_log2(ctx.num_instances) + 5, ctx.num_vars); Some( - (0..(1 << num_vars)) + (0..(1 << (ctx.num_vars - 5))) .into_par_iter() .flat_map(|chunk_index| { - if chunk_index >= num_instances { + if chunk_index >= ctx.num_instances { vec![E::ZERO; 32] } else { let mut chunk = vec![E::ZERO; 32]; @@ -109,7 +115,7 @@ impl SelectorType { chunk } }) - .collect::>() + .collect::>() .into_mle(), ) } @@ -120,32 +126,31 @@ impl SelectorType { pub fn compute( &self, out_point: &Point, - num_instances: usize, + ctx: &SelectorContext, ) -> Option> { + assert_eq!(out_point.len(), ctx.num_vars); + match self { SelectorType::None => None, SelectorType::Whole(_) => Some(build_eq_x_r_vec(out_point).into_mle()), - SelectorType::Prefix { - offset, - expression: _expr, - } => { - let num_vars = out_point.len(); - let end = *offset + num_instances; - assert!(end <= (1 << num_vars)); + SelectorType::Prefix(_) => { + let start = ctx.offset; + let end = start + ctx.num_instances; + assert!(end <= (1 << ctx.num_vars), "start: {}, num_instances: {}, num_vars: {}", start, ctx.num_instances, ctx.num_vars); let mut sel = build_eq_x_r_vec(out_point); - sel.splice(0..*offset, repeat_n(E::ZERO, *offset)); + sel.splice(0..start, repeat_n(E::ZERO, start)); sel.splice(end..sel.len(), repeat_n(E::ZERO, sel.len() - end)); Some(sel.into_mle()) } SelectorType::OrderedSparse32 { indices, .. } => { - assert_eq!(out_point.len(), ceil_log2(num_instances) + 5); + assert_eq!(out_point.len(), ceil_log2(ctx.num_instances) + 5); let mut sel = build_eq_x_r_vec(out_point); sel.par_chunks_exact_mut(CYCLIC_POW2_5.len()) .enumerate() .for_each(|(chunk_index, chunk)| { - if chunk_index >= num_instances { + if chunk_index >= ctx.num_instances { // Zero out the entire chunk if out of instance range chunk.iter_mut().for_each(|e| *e = E::ZERO); return; @@ -174,24 +179,33 @@ impl SelectorType { evals: &mut Vec, out_point: &Point, in_point: &Point, - num_instances: usize, + ctx: &SelectorContext, offset_eq_id: usize, ) { + assert_eq!(in_point.len(), ctx.num_vars); + assert_eq!(out_point.len(), ctx.num_vars); + let (expr, eval) = match self { SelectorType::None => return, SelectorType::Whole(expr) => { debug_assert_eq!(out_point.len(), in_point.len()); (expr, eq_eval(out_point, in_point)) } - SelectorType::Prefix { offset, expression } => { - let end = *offset + num_instances; + SelectorType::Prefix(expression) => { + let start = ctx.offset; + let end = start + ctx.num_instances; assert_eq!(in_point.len(), out_point.len()); assert!(end <= (1 << out_point.len())); - let eq_start = eq_eval_less_or_equal_than(*offset - 1, out_point, in_point); let eq_end = eq_eval_less_or_equal_than(end - 1, out_point, in_point); - (expression, eq_end - eq_start) + let sel = if start > 0 { + let eq_start = eq_eval_less_or_equal_than(start - 1, out_point, in_point); + eq_end - eq_start + } else { + eq_end + }; + (expression, sel) } SelectorType::OrderedSparse32 { indices, @@ -203,8 +217,11 @@ impl SelectorType { for index in indices { eval += out_subgroup_eq[*index] * in_subgroup_eq[*index]; } - let sel = - eq_eval_less_or_equal_than(num_instances - 1, &out_point[5..], &in_point[5..]); + let sel = eq_eval_less_or_equal_than( + ctx.num_instances - 1, + &out_point[5..], + &in_point[5..], + ); (expression, eval * sel) } }; @@ -230,10 +247,7 @@ impl SelectorType { match self { Self::OrderedSparse32 { expression, .. } | Self::Whole(expression) - | Self::Prefix { - offset: _, - expression, - } => expression, + | Self::Prefix(expression) => expression, e => unimplemented!("no selector expression in {:?}", e), } } From 7efb21a256b6db652e670feafcb8e008233b857f Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 21 Oct 2025 10:46:27 +0800 Subject: [PATCH 76/91] switch gkr-backend --- Cargo.lock | 22 +++++++++++----------- Cargo.toml | 20 ++++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5691f8624..537b5c262 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1879,7 +1879,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" dependencies = [ "once_cell", "p3", @@ -2723,7 +2723,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" dependencies = [ "bincode", "clap", @@ -2747,7 +2747,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" dependencies = [ "either", "ff_ext", @@ -3068,7 +3068,7 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" dependencies = [ "p3-air", "p3-baby-bear", @@ -3544,7 +3544,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" dependencies = [ "ff_ext", "p3", @@ -4484,7 +4484,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" dependencies = [ "cfg-if", "dashu", @@ -4590,7 +4590,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" dependencies = [ "either", "ff_ext", @@ -4608,7 +4608,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" dependencies = [ "itertools 0.13.0", "p3", @@ -5003,7 +5003,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -5275,7 +5275,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" dependencies = [ "bincode", "clap", @@ -5562,7 +5562,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?rev=v1.0.0-alpha.10#a1050f9249e1756c07219201d04883adbb674cdf" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index cd39987af..1571e3af6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,16 +22,16 @@ repository = "https://github.com/scroll-tech/ceno" version = "0.1.0" [workspace.dependencies] -ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", rev = "v1.0.0-alpha.10" } -mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", rev = "v1.0.0-alpha.10" } -multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", rev = "v1.0.0-alpha.10" } -p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", rev = "v1.0.0-alpha.10" } -poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", rev = "v1.0.0-alpha.10" } -sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", rev = "v1.0.0-alpha.10" } -sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", rev = "v1.0.0-alpha.10" } -transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", rev = "v1.0.0-alpha.10" } -whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", rev = "v1.0.0-alpha.10" } -witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", rev = "v1.0.0-alpha.10" } +ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", branch = "chore/sw_curve_default" } +mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", branch = "chore/sw_curve_default" } +multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", branch = "chore/sw_curve_default" } +p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", branch = "chore/sw_curve_default" } +poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", branch = "chore/sw_curve_default" } +sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", branch = "chore/sw_curve_default" } +sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", branch = "chore/sw_curve_default" } +transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", branch = "chore/sw_curve_default" } +whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", branch = "chore/sw_curve_default" } +witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", branch = "chore/sw_curve_default" } alloy-primitives = "1.3" anyhow = { version = "1.0", default-features = false } From 636ab8e9b7c88a36f42007f4079b9d8ed8f4ce02 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 21 Oct 2025 16:23:50 +0800 Subject: [PATCH 77/91] fix --- ceno_zkvm/src/instructions/global.rs | 33 +++++++++------------------- ceno_zkvm/src/scheme/prover.rs | 4 ++-- gkr_iop/src/selector.rs | 2 +- 3 files changed, 13 insertions(+), 26 deletions(-) diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 7cf5c86e4..c4409805e 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -107,14 +107,9 @@ impl GlobalConfig { // if is_global_write = 1, then it means we are propagating a local write to global // so we need to insert a local read record to cancel out this local write - // otherwise, we insert a padding value 1 to avoid affecting local memory checking cb.assert_bit(|| "is_global_write must be boolean", is_global_write.expr())?; - // if we are reading from global set, then this record should be - // considered as a initial local write to that address. - // otherwise, we insert a padding value 1 as if we are not writing anything - // local read/write consistency cb.condition_require_zero( || "is_global_read => local_clk = 0", @@ -524,9 +519,9 @@ mod tests { .unwrap(); // create a bunch of random memory read/write records - let n_reads = 16; - let n_writes = 16; - let global_reads = (0..n_reads) + let n_global_reads = 16; + let n_global_writes = 16; + let global_reads = (0..n_global_reads) .map(|i| { let addr = i * 8; let value = (i + 1) * 8; @@ -535,15 +530,15 @@ mod tests { addr: addr as u32, ram_type: RAMType::Memory, value: value as u32, - shard: 1, - local_clk: i, + shard: 0, + local_clk: 0, global_clk: i, is_write: false, } }) .collect::>(); - let global_writes = (0..n_writes) + let global_writes = (0..n_global_writes) .map(|i| { let addr = i * 8; let value = (i + 1) * 8; @@ -580,7 +575,6 @@ mod tests { .map(|fe| fe.as_canonical_u32()) .collect_vec(), ); - assert!(global_ec_sum.is_infinity == true); // assign witness let (witness, lk) = GlobalChip::assign_instances( &config, @@ -621,13 +615,13 @@ mod tests { structural_witness: witness[1].to_mles().into_iter().map(Arc::new).collect(), fixed: vec![], public_input: public_input_mles.clone(), - num_read_instances: n_writes as usize, - num_write_instances: n_reads as usize, - num_instances: (n_reads + n_writes) as usize, + num_read_instances: n_global_writes as usize, + num_write_instances: n_global_reads as usize, + num_instances: (n_global_reads + n_global_writes) as usize, }; let mut rng = thread_rng(); let challenges = [E::random(&mut rng), E::random(&mut rng)]; - let (proof, _pi_evals, point) = zkvm_prover + let (proof, _, point) = zkvm_prover .create_chip_proof( "global chip", &pk, @@ -643,13 +637,6 @@ mod tests { .iter() .map(|mle| mle.evaluate(&point[..mle.num_vars()])) .collect_vec(); - pi_evals - .iter() - .skip(8) - .zip(_pi_evals.values()) - .for_each(|(a, b)| { - assert_eq!(*a, *b); - }); let opening_point = verifier .verify_opcode_proof( "global", diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 8f24acb7f..6ba250ec3 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -382,8 +382,8 @@ impl< tower_proof, fixed_in_evals, wits_in_evals, - num_read_instances: input.num_instances, - num_write_instances: input.num_instances, + num_read_instances: input.num_read_instances, + num_write_instances: input.num_write_instances, num_instances: input.num_instances, }, pi_in_evals, diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index f9436bd1d..c6ffea80b 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -196,7 +196,7 @@ impl SelectorType { let end = start + ctx.num_instances; assert_eq!(in_point.len(), out_point.len()); - assert!(end <= (1 << out_point.len())); + assert!(end <= (1 << out_point.len()), "start: {}, num_instances: {}, num_vars: {}", start, ctx.num_instances, ctx.num_vars); let eq_end = eq_eval_less_or_equal_than(end - 1, out_point, in_point); let sel = if start > 0 { From 052b5672fbdbfd4934884e40096d44f2ca8d671f Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 21 Oct 2025 18:48:58 +0800 Subject: [PATCH 78/91] enable poseidon2 --- ceno_zkvm/src/gadgets/poseidon2.rs | 220 +++++++++++++++++++++++++-- ceno_zkvm/src/instructions/global.rs | 45 ++++-- 2 files changed, 235 insertions(+), 30 deletions(-) diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index 7ecaeabc7..80e6e6728 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -12,11 +12,11 @@ use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; use num_bigint::BigUint; use p3::{ - babybear::BabyBearInternalLayerParameters, - field::{Field, FieldAlgebra}, + babybear::{BabyBear, BabyBearInternalLayerParameters}, + field::{Field, FieldAlgebra, PrimeField}, monty_31::InternalLayerBaseParameters, - poseidon2::{MDSMat4, mds_light_permutation}, - poseidon2_air::{FullRound, PartialRound, Poseidon2Cols, SBox, generate_trace_rows, num_cols}, + poseidon2::{GenericPoseidon2LinearLayers, MDSMat4, mds_light_permutation}, + poseidon2_air::{FullRound, PartialRound, Poseidon2Cols, SBox, num_cols}, }; use crate::circuit_builder::CircuitBuilder; @@ -48,6 +48,35 @@ pub struct Poseidon2Config< constants: RoundConstants, } +#[derive(Debug, Clone)] +pub struct Poseidon2LinearLayers; + +impl GenericPoseidon2LinearLayers + for Poseidon2LinearLayers +{ + fn internal_linear_layer(state: &mut [F; WIDTH]) { + // this only works when F is BabyBear field for now + let babybear_prime = BigUint::from(0x7800_0001u32); + if F::order() == babybear_prime { + let diag_m1_matrix = &>::INTERNAL_DIAG_MONTY; + let diag_m1_matrix: &[F; WIDTH] = unsafe { transmute(diag_m1_matrix) }; + let sum = state.iter().cloned().sum::(); + for (input, diag_m1) in state.iter_mut().zip(diag_m1_matrix) { + *input = sum.clone() + F::from_f(*diag_m1) * input.clone(); + } + } else { + panic!("Unsupported field"); + } + } + + fn external_linear_layer(state: &mut [F; WIDTH]) { + mds_light_permutation(state, &MDSMat4); + } +} + impl< E: ExtensionField, const STATE_WIDTH: usize, @@ -267,17 +296,178 @@ impl< .unwrap() } - // pub fn assign_instance(&self, input: &[E; STATE_WIDTH]) { - // generate_trace_rows(inputs, constants) - // let poseidon2_cols: &Poseidon2Cols< - // WitIn, - // STATE_WIDTH, - // SBOX_DEGREE, - // SBOX_REGISTERS, - // HALF_FULL_ROUNDS, - // PARTIAL_ROUNDS, - // > = self.cols.as_slice().borrow(); - // } + pub fn assign_instance( + &self, + instance: &mut [E::BaseField], + state: [E::BaseField; STATE_WIDTH], + ) { + let poseidon2_cols: &mut Poseidon2Cols< + E::BaseField, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = instance.borrow_mut(); + + generate_trace_rows_for_perm::< + E::BaseField, + Poseidon2LinearLayers, + STATE_WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >(poseidon2_cols, state, &self.constants); + } +} + +////////////////////////////////////////////////////////////////////////// +/// The following routines are taken from poseidon2-air/src/generation.rs +////////////////////////////////////////////////////////////////////////// + +fn generate_trace_rows_for_perm< + F: PrimeField, + LinearLayers: GenericPoseidon2LinearLayers, + const WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +>( + perm: &mut Poseidon2Cols< + F, + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + mut state: [F; WIDTH], + constants: &RoundConstants, +) { + perm.export = F::ONE; + perm.inputs + .iter_mut() + .zip(state.iter()) + .for_each(|(input, &x)| { + *input = x; + }); + + LinearLayers::external_linear_layer(&mut state); + + for (full_round, constants) in perm + .beginning_full_rounds + .iter_mut() + .zip(&constants.beginning_full_round_constants) + { + generate_full_round::( + &mut state, full_round, constants, + ); + } + + for (partial_round, constant) in perm + .partial_rounds + .iter_mut() + .zip(&constants.partial_round_constants) + { + generate_partial_round::( + &mut state, + partial_round, + *constant, + ); + } + + for (full_round, constants) in perm + .ending_full_rounds + .iter_mut() + .zip(&constants.ending_full_round_constants) + { + generate_full_round::( + &mut state, full_round, constants, + ); + } +} + +#[inline] +fn generate_full_round< + F: PrimeField, + LinearLayers: GenericPoseidon2LinearLayers, + const WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, +>( + state: &mut [F; WIDTH], + full_round: &mut FullRound, + round_constants: &[F; WIDTH], +) { + for (state_i, const_i) in state.iter_mut().zip(round_constants) { + *state_i += *const_i; + } + for (state_i, sbox_i) in state.iter_mut().zip(full_round.sbox.iter_mut()) { + generate_sbox(sbox_i, state_i); + } + LinearLayers::external_linear_layer(state); + full_round + .post + .iter_mut() + .zip(*state) + .for_each(|(post, x)| { + *post = x; + }); +} + +#[inline] +fn generate_partial_round< + F: PrimeField, + LinearLayers: GenericPoseidon2LinearLayers, + const WIDTH: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, +>( + state: &mut [F; WIDTH], + partial_round: &mut PartialRound, + round_constant: F, +) { + state[0] += round_constant; + generate_sbox(&mut partial_round.sbox, &mut state[0]); + partial_round.post_sbox = state[0]; + LinearLayers::internal_linear_layer(state); +} + +#[inline] +fn generate_sbox( + sbox: &mut SBox, + x: &mut F, +) { + *x = match (DEGREE, REGISTERS) { + (3, 0) => x.cube(), + (5, 0) => x.exp_const_u64::<5>(), + (7, 0) => x.exp_const_u64::<7>(), + (5, 1) => { + let x2 = x.square(); + let x3 = x2 * *x; + sbox.0[0] = x3; + x3 * x2 + } + (7, 1) => { + let x3 = x.cube(); + sbox.0[0] = x3; + x3 * x3 * *x + } + (11, 2) => { + let x2 = x.square(); + let x3 = x2 * *x; + let x9 = x3.cube(); + sbox.0[0] = x3; + sbox.0[1] = x9; + x9 * x2 + } + _ => panic!( + "Unexpected (DEGREE, REGISTERS) of ({}, {})", + DEGREE, REGISTERS + ), + } } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index c4409805e..59f624216 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -5,12 +5,12 @@ use crate::{ chip_handler::general::PublicIOQuery, error::ZKVMError, gadgets::{Poseidon2Config, RoundConstants}, + instructions::riscv::constants::UINT_LIMBS, scheme::septic_curve::{SepticExtension, SepticPoint}, structs::{ProgramParams, RAMType}, tables::RMMCollections, witness::LkMultiplicity, }; -use ceno_emul::StepRecord; use ff_ext::{ExtensionField, FieldInto, POSEIDON2_BABYBEAR_WIDTH, SmallField}; use gkr_iop::{ chip::Chip, circuit_builder::CircuitBuilder, error::CircuitBuilderError, gkr::layer::Layer, @@ -55,7 +55,7 @@ pub struct GlobalConfig { is_global_write: WitIn, x: Vec, y: Vec, - // perm_config: Poseidon2Config, + perm_config: Poseidon2Config, perm: P, } @@ -85,7 +85,7 @@ impl GlobalConfig { let reg: Expression = RAMType::Register.into(); let mem: Expression = RAMType::Memory.into(); let ram_type: Expression = is_ram_reg.clone() * reg + (1 - is_ram_reg) * mem; - // let perm_config = Poseidon2Config::construct(cb, rc); + let perm_config = Poseidon2Config::construct(cb, rc); let mut input = vec![]; input.push(addr.expr()); @@ -138,13 +138,13 @@ impl GlobalConfig { ); // enforces x = poseidon2([addr, ram_type, value[0], value[1], shard, global_clk, nonce, 0, ..., 0]) - // for (input_expr, hasher_input) in input.into_iter().zip_eq(perm_config.inputs().into_iter()) - // { - // cb.require_equal(|| "poseidon2 input", input_expr, hasher_input)?; - // } - // for (xi, hasher_output) in x.iter().zip(perm_config.output().into_iter()) { - // cb.require_equal(|| "x = poseidon2's output", xi.expr(), hasher_output)?; - // } + for (input_expr, hasher_input) in input.into_iter().zip_eq(perm_config.inputs().into_iter()) + { + cb.require_equal(|| "poseidon2 input", input_expr, hasher_input)?; + } + for (xi, hasher_output) in x.iter().zip(perm_config.output().into_iter()) { + cb.require_equal(|| "x = poseidon2's output", xi.expr(), hasher_output)?; + } // both (x, y) and (x, -y) are valid ec points // if is_global_write = 1, then y should be in [0, p/2) @@ -164,7 +164,7 @@ impl GlobalConfig { local_clk, nonce, is_global_write, - // perm_config, + perm_config, perm, }) } @@ -356,9 +356,8 @@ impl }; set_val!(instance, config.addr, record.addr as u64); set_val!(instance, config.is_ram_register, is_ram_register as u64); - config - .value - .assign_limbs(instance, Value::new_unchecked(record.value).as_u16_limbs()); + let value = Value::new_unchecked(record.value); + config.value.assign_limbs(instance, value.as_u16_limbs()); set_val!(instance, config.shard, record.shard); set_val!(instance, config.global_clk, record.global_clk); set_val!(instance, config.local_clk, record.local_clk); @@ -376,7 +375,23 @@ impl set_val!(instance, *witin, fe.to_canonical_u64()); }); - // TODO: assign poseidon2 hasher + let ram_type = E::BaseField::from_canonical_u32(record.ram_type as u32); + let mut input = [E::BaseField::ZERO; 16]; + + let k = UINT_LIMBS; + input[0] = E::BaseField::from_canonical_u32(record.addr); + input[1] = ram_type; + input[2..(k + 2)] + .iter_mut() + .zip(value.as_u16_limbs().iter()) + .for_each(|(i, v)| *i = E::BaseField::from_canonical_u16(*v)); + input[2 + k] = E::BaseField::from_canonical_u64(record.shard); + input[2 + k + 1] = E::BaseField::from_canonical_u64(record.global_clk); + input[2 + k + 2] = E::BaseField::from_canonical_u32(nonce); + + config + .perm_config + .assign_instance(&mut instance[21 + UINT_LIMBS..], input); Ok(()) } From 0ea92e0b5e9372cf6e5c2e6ef3643815a390fe90 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 21 Oct 2025 20:50:39 +0800 Subject: [PATCH 79/91] optimise tracer performance --- Cargo.lock | 7 +++ Cargo.toml | 8 +++ ceno_emul/Cargo.toml | 3 ++ ceno_emul/src/chunked_vec.rs | 89 ++++++++++++++++++++++++++++++++ ceno_emul/src/lib.rs | 5 +- ceno_emul/src/tracer.rs | 28 ++++++---- ceno_emul/tests/test_vm_trace.rs | 10 ++-- ceno_zkvm/Cargo.toml | 1 + ceno_zkvm/src/e2e.rs | 28 +++++++--- 9 files changed, 154 insertions(+), 25 deletions(-) create mode 100644 ceno_emul/src/chunked_vec.rs diff --git a/Cargo.lock b/Cargo.lock index c837a88ee..22db24a3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -958,9 +958,12 @@ dependencies = [ "multilinear_extensions", "num-derive", "num-traits", + "rayon", "rrs-succinct", + "rustc-hash", "secp", "serde", + "smallvec", "strum", "strum_macros", "substrate-bn 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1049,6 +1052,7 @@ dependencies = [ "proptest", "rand 0.8.5", "rayon", + "rustc-hash", "serde", "serde_json", "sp1-curves", @@ -4432,6 +4436,9 @@ name = "smallvec" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" +dependencies = [ + "serde", +] [[package]] name = "snowbridge-amcl" diff --git a/Cargo.toml b/Cargo.toml index 16e66caaf..ae19534fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,9 +57,17 @@ rand_chacha = { version = "0.3", features = ["serde1"] } rand_core = "0.6" rayon = "1.10" rkyv = { version = "0.8", features = ["pointer_width_32"] } +rustc-hash = "2.0.0" secp = "0.4.1" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" +smallvec = { version = "1.13.2", features = [ + "const_generics", + "const_new", + "serde", + "union", + "write", +] } strum = "0.26" strum_macros = "0.26" substrate-bn = { version = "0.6.0" } diff --git a/ceno_emul/Cargo.toml b/ceno_emul/Cargo.toml index b0af43fe3..6cc12cd17 100644 --- a/ceno_emul/Cargo.toml +++ b/ceno_emul/Cargo.toml @@ -19,9 +19,12 @@ itertools.workspace = true multilinear_extensions.workspace = true num-derive.workspace = true num-traits.workspace = true +rayon.workspace = true rrs_lib = { package = "rrs-succinct", version = "0.1.0" } +rustc-hash.workspace = true secp.workspace = true serde.workspace = true +smallvec.workspace = true strum.workspace = true strum_macros.workspace = true substrate-bn.workspace = true diff --git a/ceno_emul/src/chunked_vec.rs b/ceno_emul/src/chunked_vec.rs new file mode 100644 index 000000000..e53d51a73 --- /dev/null +++ b/ceno_emul/src/chunked_vec.rs @@ -0,0 +1,89 @@ +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use std::ops::{Index, IndexMut}; + +/// a chunked vector that grows in fixed-size chunks. +#[derive(Default, Debug, Clone)] +pub struct ChunkedVec { + chunks: Vec>, + chunk_size: usize, + len: usize, +} + +impl ChunkedVec { + /// create a new ChunkedVec with a given chunk size. + pub fn new(chunk_size: usize) -> Self { + assert!(chunk_size > 0, "chunk_size must be > 0"); + Self { + chunks: Vec::new(), + chunk_size, + len: 0, + } + } + + /// get the current number of elements. + pub fn len(&self) -> usize { + self.len + } + + /// returns true if the vector is empty. + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// access element by index (immutable). + pub fn get(&self, index: usize) -> Option<&T> { + if index >= self.len { + return None; + } + let chunk_idx = index / self.chunk_size; + let within_idx = index % self.chunk_size; + self.chunks.get(chunk_idx)?.get(within_idx) + } + + /// access element by index (mutable). + /// get mutable reference to element at index, auto-creating chunks as needed + pub fn get_or_create(&mut self, index: usize) -> &mut T { + let chunk_idx = index / self.chunk_size; + let within_idx = index % self.chunk_size; + + // Ensure enough chunks exist + if chunk_idx >= self.chunks.len() { + let to_create = chunk_idx + 1 - self.chunks.len(); + + // Use rayon to create all missing chunks in parallel + let mut new_chunks: Vec> = (0..to_create) + .map(|_| { + (0..self.chunk_size) + .into_par_iter() + .map(|_| Default::default()) + .collect::>() + }) + .collect(); + + self.chunks.append(&mut new_chunks); + } + + let chunk = &mut self.chunks[chunk_idx]; + + // Update the overall length + if index >= self.len { + self.len = index + 1; + } + + &mut chunk[within_idx] + } +} + +impl Index for ChunkedVec { + type Output = T; + + fn index(&self, index: usize) -> &Self::Output { + self.get(index).expect("index out of bounds") + } +} + +impl IndexMut for ChunkedVec { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + self.get_or_create(index) + } +} diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 8f439d036..3d88484fa 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -7,7 +7,9 @@ mod platform; pub use platform::{CENO_PLATFORM, Platform}; mod tracer; -pub use tracer::{Change, MemOp, ReadOp, StepRecord, Tracer, WriteOp}; +pub use tracer::{ + Change, MemOp, NextAccessPair, NextCycleAccess, ReadOp, StepRecord, Tracer, WriteOp, +}; mod vm_state; pub use vm_state::VMState; @@ -44,4 +46,5 @@ pub mod utils; pub mod test_utils; +mod chunked_vec; pub mod host_utils; diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 9dc9a0b12..c36bd5bef 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -1,13 +1,13 @@ -use std::{ - collections::{BTreeMap, HashMap}, - fmt, mem, -}; +use rustc_hash::FxHashMap; +use smallvec::SmallVec; +use std::{collections::BTreeMap, fmt, mem}; use ceno_rt::WORD_SIZE; use crate::{ CENO_PLATFORM, InsnKind, Instruction, PC_STEP_SIZE, Platform, addr::{ByteAddr, Cycle, RegIdx, Word, WordAddr}, + chunked_vec::ChunkedVec, encode_rv32, syscalls::{SyscallEffects, SyscallWitness}, }; @@ -39,6 +39,10 @@ pub struct StepRecord { syscall: Option, } +pub type NextAccessPair = SmallVec<[(WordAddr, Cycle); 1]>; +pub type NextCycleAccess = ChunkedVec; +const ACCESSED_CHUNK_SIZE: usize = 1 << 20; + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct MemOp { /// Virtual Memory Address. @@ -305,8 +309,8 @@ pub struct Tracer { // record each section max access address // (start_addr -> (start_addr, end_addr, min_access_addr, max_access_addr)) mmio_min_max_access: Option>, - latest_accesses: HashMap, - next_accesses: HashMap<(WordAddr, Cycle), Cycle>, + latest_accesses: FxHashMap, + next_accesses: NextCycleAccess, } impl Default for Tracer { @@ -363,8 +367,8 @@ impl Tracer { cycle: Self::SUBCYCLES_PER_INSN, ..StepRecord::default() }, - latest_accesses: HashMap::new(), - next_accesses: HashMap::new(), + latest_accesses: FxHashMap::default(), + next_accesses: NextCycleAccess::new(ACCESSED_CHUNK_SIZE), } } @@ -475,17 +479,19 @@ impl Tracer { pub fn track_access(&mut self, addr: WordAddr, subcycle: Cycle) -> Cycle { let cur_cycle = self.record.cycle + subcycle; let prev_cycle = self.latest_accesses.insert(addr, cur_cycle).unwrap_or(0); - self.next_accesses.insert((addr, prev_cycle), cur_cycle); + self.next_accesses + .get_or_create(prev_cycle as usize) + .push((addr, cur_cycle)); prev_cycle } /// Return all the addresses that were accessed and the cycle when they were last accessed. - pub fn final_accesses(&self) -> &HashMap { + pub fn final_accesses(&self) -> &FxHashMap { &self.latest_accesses } /// Return all the addresses that were accessed and the cycle when they were last accessed. - pub fn next_accesses(self) -> HashMap<(WordAddr, Cycle), Cycle> { + pub fn next_accesses(self) -> NextCycleAccess { self.next_accesses } diff --git a/ceno_emul/tests/test_vm_trace.rs b/ceno_emul/tests/test_vm_trace.rs index 74cc83d4e..14bf7a1fe 100644 --- a/ceno_emul/tests/test_vm_trace.rs +++ b/ceno_emul/tests/test_vm_trace.rs @@ -1,9 +1,7 @@ #![allow(clippy::unusual_byte_groupings)] use anyhow::Result; -use std::{ - collections::{BTreeMap, HashMap}, - sync::Arc, -}; +use rustc_hash::FxHashMap; +use std::{collections::BTreeMap, sync::Arc}; use ceno_emul::{ CENO_PLATFORM, Cycle, EmuContext, InsnKind, Instruction, Platform, Program, StepRecord, Tracer, @@ -111,8 +109,8 @@ fn expected_ops_fibonacci_20() -> Vec { } /// Reconstruct the last access of each register. -fn expected_final_accesses_fibonacci_20() -> HashMap { - let mut accesses = HashMap::new(); +fn expected_final_accesses_fibonacci_20() -> FxHashMap { + let mut accesses = FxHashMap::default(); let x = |i| WordAddr::from(Platform::register_vma(i)); const C: Cycle = Tracer::SUBCYCLES_PER_INSN; diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 3347b38cc..3c1c99ed4 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -34,6 +34,7 @@ witness.workspace = true itertools.workspace = true ndarray.workspace = true prettytable-rs.workspace = true +rustc-hash.workspace = true strum.workspace = true strum_macros.workspace = true tracing.workspace = true diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 912bcf658..712f3b7a1 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -16,8 +16,9 @@ use crate::{ tables::{MemFinalRecord, MemInitRecord, ProgramTableCircuit, ProgramTableConfig}, }; use ceno_emul::{ - Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, IterAddresses, Platform, Program, - StepRecord, Tracer, VMState, WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, + Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, InsnKind, IterAddresses, NextCycleAccess, + Platform, Program, StepRecord, Tracer, VMState, WORD_SIZE, Word, WordAddr, + host_utils::read_all_messages, }; use clap::ValueEnum; use either::Either; @@ -147,7 +148,7 @@ pub struct ShardContext<'a> { shards: Shards, max_cycle: Cycle, // TODO optimize this map as it's super huge - addr_future_accesses: Cow<'a, HashMap<(WordAddr, Cycle), Cycle>>, + addr_future_accesses: Cow<'a, NextCycleAccess>, read_thread_based_record_storage: Either>, &'a mut BTreeMap>, write_thread_based_record_storage: @@ -161,7 +162,7 @@ impl<'a> Default for ShardContext<'a> { Self { shards: Shards::default(), max_cycle: Cycle::default(), - addr_future_accesses: Cow::Owned(HashMap::new()), + addr_future_accesses: Cow::Owned(Default::default()), read_thread_based_record_storage: Either::Left( (0..max_threads) .into_par_iter() @@ -183,7 +184,7 @@ impl<'a> ShardContext<'a> { pub fn new( shards: Shards, executed_instructions: usize, - addr_future_accesses: HashMap<(WordAddr, Cycle), Cycle>, + addr_future_accesses: NextCycleAccess, ) -> Self { // current strategy: at least each shard deal with one instruction let max_num_shards = shards.max_num_shards.min(executed_instructions); @@ -329,8 +330,21 @@ impl<'a> ShardContext<'a> { } // check write to external mem bus - if let Some(future_touch_cycle) = self.addr_future_accesses.get(&(addr, cycle)) - && *future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle + if let Some(future_touch_cycle) = + self.addr_future_accesses + .get(cycle as usize) + .and_then(|res| { + if res.len() == 1 { + Some(res[0].1) + } else if res.len() > 1 { + res.iter() + .find(|(m_addr, _)| *m_addr == addr) + .map(|(_, cycle)| *cycle) + } else { + None + } + }) + && future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle && self.is_current_shard_cycle(cycle) { let ram_record = self From 9e5df7d363770cc8c813f06775f0ec38b5447708 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Thu, 23 Oct 2025 01:51:09 +0800 Subject: [PATCH 80/91] integrate ecc quark iop wip --- ceno_zkvm/src/instructions/global.rs | 74 ++++++++++++++++++++++------ ceno_zkvm/src/scheme.rs | 2 + ceno_zkvm/src/scheme/cpu/mod.rs | 74 +++++++++++++++------------- ceno_zkvm/src/scheme/hal.rs | 23 +++++++-- ceno_zkvm/src/scheme/prover.rs | 41 +++++++++++++-- ceno_zkvm/src/scheme/septic_curve.rs | 6 ++- gkr_iop/src/circuit_builder.rs | 8 ++- gkr_iop/src/selector.rs | 16 +++++- 8 files changed, 187 insertions(+), 57 deletions(-) diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 59f624216..806c81e32 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -22,14 +22,17 @@ use multilinear_extensions::{ }; use p3::{ field::{Field, FieldAlgebra}, + matrix::dense::RowMajorMatrix, symmetric::Permutation, + util::log2_ceil_usize, }; use rayon::{ - iter::{IndexedParallelIterator, ParallelIterator}, + iter::{IndexedParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator}, + prelude::ParallelSliceMut, slice::ParallelSlice, }; use std::ops::Deref; -use witness::{RowMajorMatrix, set_val}; +use witness::{InstancePaddingStrategy, next_pow2_instance_padding, set_val}; use crate::{ instructions::{Instruction, riscv::constants::UInt}, @@ -55,6 +58,7 @@ pub struct GlobalConfig { is_global_write: WitIn, x: Vec, y: Vec, + slope: Vec, perm_config: Poseidon2Config, perm: P, } @@ -72,6 +76,9 @@ impl GlobalConfig { let y: Vec = (0..SEPTIC_EXTENSION_DEGREE) .map(|i| cb.create_witin(|| format!("y{}", i))) .collect(); + let slope: Vec = (0..SEPTIC_EXTENSION_DEGREE) + .map(|i| cb.create_witin(|| format!("slope{}", i))) + .collect(); let addr = cb.create_witin(|| "addr"); let is_ram_register = cb.create_witin(|| "is_ram_register"); let value = UInt::new(|| "value", cb)?; @@ -134,6 +141,7 @@ impl GlobalConfig { cb.ec_sum( x.iter().map(|xi| xi.expr()).collect::>(), y.iter().map(|yi| yi.expr()).collect::>(), + slope.iter().map(|si| si.expr()).collect::>(), final_sum.into_iter().map(|x| x.expr()).collect::>(), ); @@ -156,6 +164,7 @@ impl GlobalConfig { Ok(GlobalConfig { x, y, + slope, addr, is_ram_register, value, @@ -422,16 +431,34 @@ impl } .max(1); let lk_multiplicity = LkMultiplicity::default(); - let mut raw_witin = - RowMajorMatrix::::new(steps.len(), num_witin, Self::padding_strategy()); - let mut raw_structual_witin = RowMajorMatrix::::new( - steps.len(), - num_structural_witin, - Self::padding_strategy(), - ); - let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); - let raw_structual_witin_iter = - raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); + // *2 because we need to store the internal nodes of binary tree for ec point summation + let num_rows_padded = next_pow2_instance_padding(steps.len()) * 2; + + let mut raw_witin = { + let matrix_size = num_rows_padded * num_witin; + let mut value = Vec::with_capacity(matrix_size); + value.par_extend( + (0..matrix_size) + .into_par_iter() + .map(|_| E::BaseField::default()), + ); + RowMajorMatrix::new(value, num_witin) + }; + let mut raw_structual_witin = { + let matrix_size = num_rows_padded * num_structural_witin; + let mut value = Vec::with_capacity(matrix_size); + value.par_extend( + (0..matrix_size) + .into_par_iter() + .map(|_| E::BaseField::default()), + ); + RowMajorMatrix::new(value, num_structural_witin) + }; + let raw_witin_iter = raw_witin.values[0..steps.len() * num_witin] + .par_chunks_mut(num_instance_per_batch * num_witin); + let raw_structual_witin_iter = raw_structual_witin.values + [0..steps.len() * num_structural_witin] + .par_chunks_mut(num_instance_per_batch * num_structural_witin); raw_witin_iter .zip_eq(raw_structual_witin_iter) @@ -458,8 +485,27 @@ impl }) .collect::>()?; - raw_witin.padding_by_strategy(); - raw_structual_witin.padding_by_strategy(); + // assign internal nodes in the binary tree for ec point summation + let half_witin_matrix_size = num_rows_padded / 2 * num_witin; + let raw_witin_iter = raw_witin.values + [half_witin_matrix_size..(2 * half_witin_matrix_size - 1)] + .par_chunks_mut(num_witin) + .for_each(|instance| { + for i in 0..SEPTIC_EXTENSION_DEGREE { + set_val!(instance, config.x[i], E::BaseField::default()); + set_val!(instance, config.y[i], E::BaseField::default()); + set_val!(instance, config.slope[i], E::BaseField::default()); + } + }); + + let raw_witin = witness::RowMajorMatrix::new_by_inner_matrix( + raw_witin, + InstancePaddingStrategy::Default, + ); + let raw_structual_witin = witness::RowMajorMatrix::new_by_inner_matrix( + raw_structual_witin, + InstancePaddingStrategy::Default, + ); Ok(( [raw_witin, raw_structual_witin], lk_multiplicity.into_finalize_result(), diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 1a27160bc..fb8db450e 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -1,3 +1,4 @@ +use crate::structs::EccQuarkProof; use ff_ext::ExtensionField; use gkr_iop::gkr::GKRProof; use itertools::Itertools; @@ -59,6 +60,7 @@ pub struct ZKVMChipProof { pub gkr_iop_proof: Option>, pub tower_proof: TowerProofs, + pub ecc_proof: Option>, pub num_read_instances: usize, pub num_write_instances: usize, diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 312c9fa32..e61d3bc69 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -6,8 +6,8 @@ use crate::{ error::ZKVMError, scheme::{ constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE}, - hal::{DeviceProvingKey, MainSumcheckEvals, ProofInput, TowerProverSpec}, - septic_curve::{SepticPoint, SymbolicSepticExtension}, + hal::{DeviceProvingKey, EccQuarkProver, MainSumcheckEvals, ProofInput, TowerProverSpec}, + septic_curve::{SepticExtension, SepticPoint, SymbolicSepticExtension}, utils::{ infer_tower_logup_witness, infer_tower_product_witness, masked_mle_split_to_chunks, wit_infer_by_expr, @@ -56,6 +56,7 @@ pub type TowerRelationOutput = ( // accumulate N=2^n EC points into one EC point using affine coordinates // in one layer which borrows ideas from the [Quark paper](https://eprint.iacr.org/2020/1275.pdf) +#[derive(Default)] pub struct CpuEccProver; impl CpuEccProver { @@ -65,10 +66,9 @@ impl CpuEccProver { pub fn create_ecc_proof<'a, E: ExtensionField>( &self, - mut xs: Vec>, - mut ys: Vec>, - invs: Vec>, - sum: SepticPoint, + xs: Vec>>, + ys: Vec>>, + invs: Vec>>, transcript: &mut impl Transcript, ) -> EccQuarkProof { assert_eq!(xs.len(), SEPTIC_EXTENSION_DEGREE); @@ -95,7 +95,7 @@ impl CpuEccProver { let mut exprs = vec![]; - let filter_bj = |v: &[MultilinearExtension<'_, E>], j: usize| { + let filter_bj = |v: &[Arc>], j: usize| { v.iter() .map(|v| { v.get_base_field_vec() @@ -115,14 +115,8 @@ impl CpuEccProver { let mut x1 = filter_bj(&xs, 1); let mut y1 = filter_bj(&ys, 1); // build x[1,b], y[1,b], s[0,b] - let mut x3 = xs - .iter_mut() - .map(|x| x.as_view_slice_mut(2, 1)) - .collect_vec(); - let mut y3 = ys - .iter_mut() - .map(|x| x.as_view_slice_mut(2, 1)) - .collect_vec(); + let mut x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + let mut y3 = ys.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); let mut s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec(); let s = SymbolicSepticExtension::new( @@ -209,6 +203,18 @@ impl CpuEccProver { // 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[0,rt] assert_eq!(evals.len(), 1 + SEPTIC_EXTENSION_DEGREE * 7); + let x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); + let y3 = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec(); + let final_sum_x: SepticExtension = (x3.iter()) + .map(|x| x.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0] + .collect_vec() + .into(); + let final_sum_y: SepticExtension = (y3.iter()) + .map(|y| y.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0] + .collect_vec() + .into(); + let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y); + #[cfg(feature = "sanity-check")] { let s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec(); @@ -216,19 +222,7 @@ impl CpuEccProver { let y0 = filter_bj(&ys, 0); let x1 = filter_bj(&xs, 1); let y1 = filter_bj(&ys, 1); - let x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); - let y3 = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec(); - let final_sum_x: SepticExtension = (x3.iter()) - .map(|x| x.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0] - .collect_vec() - .into(); - let final_sum_y: SepticExtension = (y3.iter()) - .map(|y| y.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0] - .collect_vec() - .into(); - let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y); - assert_eq!(final_sum, sum); // check evaluations assert_eq!( eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt), @@ -265,11 +259,25 @@ impl CpuEccProver { zerocheck_proof, num_vars: n, evals, - sum, + sum: final_sum, } } } +impl> EccQuarkProver> + for CpuProver> +{ + fn prove_ec_sum_quark<'a>( + &self, + xs: Vec>>, + ys: Vec>>, + invs: Vec>>, + transcript: &mut impl Transcript, + ) -> Result, ZKVMError> { + Ok(CpuEccProver::new().create_ecc_proof(xs, ys, invs, transcript)) + } +} + pub struct CpuTowerProver; impl CpuTowerProver { @@ -1090,7 +1098,7 @@ where #[cfg(test)] mod tests { - use std::iter::repeat; + use std::{iter::repeat, sync::Arc}; use ff_ext::BabyBearExt4; use itertools::Itertools; @@ -1183,13 +1191,13 @@ mod tests { let mut transcript = BasicTranscript::new(b"test"); let prover = CpuEccProver::new(); let quark_proof = prover.create_ecc_proof( - xs.to_vec(), - ys.to_vec(), - s.to_vec(), - final_sum, + xs.to_vec().into_iter().map(Arc::new).collect_vec(), + ys.to_vec().into_iter().map(Arc::new).collect_vec(), + s.to_vec().into_iter().map(Arc::new).collect_vec(), &mut transcript, ); + assert_eq!(quark_proof.sum, final_sum); let mut transcript = BasicTranscript::new(b"test"); let verifier = EccVerifier::new(); assert!( diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index 7fb5cd899..4d96043dc 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -3,8 +3,8 @@ use std::{collections::BTreeMap, sync::Arc}; use crate::{ circuit_builder::ConstraintSystem, error::ZKVMError, - scheme::cpu::TowerRelationOutput, - structs::{ComposedConstrainSystem, ZKVMProvingKey}, + scheme::{cpu::TowerRelationOutput, septic_curve::SepticPoint}, + structs::{ComposedConstrainSystem, EccQuarkProof, ZKVMProvingKey}, }; use ff_ext::ExtensionField; use gkr_iop::{ @@ -12,7 +12,7 @@ use gkr_iop::{ hal::{ProtocolWitnessGeneratorProver, ProverBackend}, }; use mpcs::{Point, PolynomialCommitmentScheme}; -use multilinear_extensions::{mle::MultilinearExtension, util::ceil_log2}; +use multilinear_extensions::{Expression, mle::MultilinearExtension, util::ceil_log2}; use sumcheck::structs::IOPProverMessage; use transcript::Transcript; use witness::next_pow2_instance_padding; @@ -24,6 +24,7 @@ pub trait ProverDevice: + OpeningProver + DeviceTransporter + ProtocolWitnessGeneratorProver + + EccQuarkProver // + FixedMLEPadder where PB: ProverBackend, @@ -68,6 +69,22 @@ pub trait TraceCommitter { ); } +/// Accumulate N (not necessarily power of 2) EC points into one EC point using affine coordinates +/// in one layer which borrows ideas from the [Quark paper](https://eprint.iacr.org/2020/1275.pdf) +/// Note that these points are defined over the septic extension field of BabyBear. +/// +/// The main constraint enforced in this quark layer is: +/// p[1,b] = affine_add(p[b,0], p[b,1]) for all b < N +pub trait EccQuarkProver { + fn prove_ec_sum_quark<'a>( + &self, + xs: Vec>>, + ys: Vec>>, + invs: Vec>>, + transcript: &mut impl Transcript, + ) -> Result, ZKVMError>; +} + pub trait TowerProver { // infer read/write/logup records from the read/write/logup expressions and then // build multiple complete binary trees (tower tree) to accumulate these records diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 6ba250ec3..708f1c39b 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -9,12 +9,12 @@ use std::{ sync::Arc, }; -use crate::scheme::hal::MainSumcheckEvals; +use crate::scheme::{constants::SEPTIC_EXTENSION_DEGREE, hal::MainSumcheckEvals}; use gkr_iop::hal::MultilinearPolynomial; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{ - Instance, + Expression, Instance, mle::{IntoMLE, MultilinearExtension}, }; use p3::field::FieldAlgebra; @@ -329,6 +329,33 @@ impl< let log2_num_instances = input.log2_num_instances(); let num_var_with_rotation = log2_num_instances + cs.rotation_vars().unwrap_or(0); + // run ecc quark prover + let ecc_proof = if !cs.zkvm_v1_css.ec_final_sum.is_empty() { + let ec_point_exprs = &cs.zkvm_v1_css.ec_point_exprs; + assert_eq!(ec_point_exprs.len(), SEPTIC_EXTENSION_DEGREE * 2); + let mut xs_ys = ec_point_exprs + .into_iter() + .map(|expr| match expr { + Expression::WitIn(id) => input.witness[*id as usize].clone(), + _ => unreachable!("ec point's expression must be WitIn"), + }) + .collect_vec(); + let ys = xs_ys.split_off(SEPTIC_EXTENSION_DEGREE); + let xs = xs_ys; + let invs = cs + .zkvm_v1_css + .ec_slope_exprs + .iter() + .map(|expr| match expr { + Expression::WitIn(id) => input.witness[*id as usize].clone(), + _ => unreachable!("ec inv's expression must be WitIn"), + }) + .collect_vec(); + Some(self.device.prove_ec_sum_quark(xs, ys, invs, transcript)?) + } else { + None + }; + // build main witness let (records, is_padded) = build_main_witness::(&self.device, cs, &input, challenges); @@ -346,7 +373,14 @@ impl< num_var_with_rotation, ); - // override cs.gkr_circuit.layers + // TODO: batch reduction into main sumcheck + // x[rt,0] = \sum_b eq([rt,0], b) * x[b] + // x[rt,1] = \sum_b eq([rt,1], b) * x[b] + // x[1,rt] = \sum_b eq([1,rt], b) * x[b] + // y[rt,0] = \sum_b eq([rt,0], b) * y[b] + // y[rt,1] = \sum_b eq([rt,1], b) * y[b] + // y[1,rt] = \sum_b eq([1,rt], b) * y[b] + // s[0,rt] = \sum_b eq([0,rt], b) * s[b] // 1. prove the main constraints among witness polynomials // 2. prove the relation between last layer in the tower and read/write/logup records @@ -380,6 +414,7 @@ impl< main_sumcheck_proofs, gkr_iop_proof, tower_proof, + ecc_proof, fixed_in_evals, wits_in_evals, num_read_instances: input.num_read_instances, diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 12b07fcaf..788d8b637 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -717,7 +717,11 @@ impl Mul> for SymbolicSepticExtension { impl SymbolicSepticExtension { pub fn new(exprs: Vec>) -> Self { - assert!(exprs.len() == 7); + assert!( + exprs.len() == 7, + "exprs length must be 7, but got {}", + exprs.len() + ); Self(exprs) } diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index eb1cb4d03..9e08b8b70 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -104,6 +104,7 @@ pub struct ConstraintSystem { pub instance_name_map: HashMap, pub ec_point_exprs: Vec>, + pub ec_slope_exprs: Vec>, pub ec_final_sum: Vec>, pub r_selector: Option>, @@ -171,6 +172,7 @@ impl ConstraintSystem { ns: NameSpace::new(root_name_fn), instance_name_map: HashMap::new(), ec_final_sum: vec![], + ec_slope_exprs: vec![], ec_point_exprs: vec![], r_selector: None, r_expressions: vec![], @@ -414,16 +416,19 @@ impl ConstraintSystem { &mut self, xs: Vec>, ys: Vec>, + slopes: Vec>, final_sum: Vec>, ) { assert_eq!(xs.len(), 7); assert_eq!(ys.len(), 7); + assert_eq!(slopes.len(), 7); assert_eq!(final_sum.len(), 7 * 2); assert_eq!(self.ec_point_exprs.len(), 0); self.ec_point_exprs.extend(xs.into_iter()); self.ec_point_exprs.extend(ys.into_iter()); + self.ec_slope_exprs = slopes; self.ec_final_sum = final_sum; } @@ -650,9 +655,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { &mut self, xs: Vec>, ys: Vec>, + slope: Vec>, final_sum: Vec>, ) { - self.cs.ec_sum(xs, ys, final_sum); + self.cs.ec_sum(xs, ys, slope, final_sum); } pub fn create_bit(&mut self, name_fn: N) -> Result diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index c6ffea80b..80e4df981 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -136,7 +136,13 @@ impl SelectorType { SelectorType::Prefix(_) => { let start = ctx.offset; let end = start + ctx.num_instances; - assert!(end <= (1 << ctx.num_vars), "start: {}, num_instances: {}, num_vars: {}", start, ctx.num_instances, ctx.num_vars); + assert!( + end <= (1 << ctx.num_vars), + "start: {}, num_instances: {}, num_vars: {}", + start, + ctx.num_instances, + ctx.num_vars + ); let mut sel = build_eq_x_r_vec(out_point); sel.splice(0..start, repeat_n(E::ZERO, start)); @@ -196,7 +202,13 @@ impl SelectorType { let end = start + ctx.num_instances; assert_eq!(in_point.len(), out_point.len()); - assert!(end <= (1 << out_point.len()), "start: {}, num_instances: {}, num_vars: {}", start, ctx.num_instances, ctx.num_vars); + assert!( + end <= (1 << out_point.len()), + "start: {}, num_instances: {}, num_vars: {}", + start, + ctx.num_instances, + ctx.num_vars + ); let eq_end = eq_eval_less_or_equal_than(end - 1, out_point, in_point); let sel = if start > 0 { From de8397f61b59a4b2a5cb5ab617d6fbcdbfcadf15 Mon Sep 17 00:00:00 2001 From: Ming Date: Mon, 27 Oct 2025 16:51:57 +0800 Subject: [PATCH 81/91] non-pow2 septic elliptic curve points add IOP (#1081) Co-authored-by: kunxian xia --- ceno_zkvm/src/scheme/cpu/mod.rs | 159 ++++++++++------ ceno_zkvm/src/scheme/hal.rs | 1 + ceno_zkvm/src/scheme/prover.rs | 5 +- ceno_zkvm/src/scheme/septic_curve.rs | 22 +++ ceno_zkvm/src/scheme/verifier.rs | 222 ++++++++++++++--------- ceno_zkvm/src/structs.rs | 2 +- gkr_iop/src/gkr/layer/zerocheck_layer.rs | 3 +- gkr_iop/src/selector.rs | 166 ++++++++++++++++- 8 files changed, 441 insertions(+), 139 deletions(-) diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index e61d3bc69..8e3554236 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -32,8 +32,10 @@ use multilinear_extensions::{ virtual_poly::build_eq_x_r_vec, virtual_polys::VirtualPolynomialsBuilder, }; -use p3::field::FieldAlgebra; -use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, + IntoParallelRefMutIterator, ParallelIterator, +}; use std::{collections::BTreeMap, sync::Arc}; use sumcheck::{ macros::{entered_span, exit_span}, @@ -66,6 +68,7 @@ impl CpuEccProver { pub fn create_ecc_proof<'a, E: ExtensionField>( &self, + num_instances: usize, xs: Vec>>, ys: Vec>>, invs: Vec>>, @@ -78,22 +81,44 @@ impl CpuEccProver { let out_rt = transcript.sample_and_append_vec(b"ecc", n); let num_threads = optimal_sumcheck_threads(out_rt.len()); - let alpha_pows = - transcript.sample_and_append_challenge_pows(SEPTIC_EXTENSION_DEGREE * 3, b"ecc_alpha"); + // expression with add (3 zero constrains) and bypass (2 zero constrains) + let alpha_pows = transcript.sample_and_append_challenge_pows( + SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2, + b"ecc_alpha", + ); + let mut alpha_pows_iter = alpha_pows.iter(); let mut expr_builder = VirtualPolynomialsBuilder::new(num_threads, out_rt.len()); - let sel = SelectorType::Prefix(0.into()); - let num_instances = (1 << n) - 1; - let sel_ctx = SelectorContext { + let sel_add = SelectorType::QuarkBinaryTreeLessThan(0.into()); + let sel_add_ctx = SelectorContext { offset: 0, num_instances, num_vars: n, }; - let mut sel_mle: MultilinearExtension<'_, E> = sel.compute(&out_rt, &sel_ctx).unwrap(); - let sel_expr = expr_builder.lift(sel_mle.to_either()); + let mut sel_add_mle: MultilinearExtension<'_, E> = + sel_add.compute(&out_rt, &sel_add_ctx).unwrap(); + // we construct sel_bypass witness here + // verifier can derive it via `sel_bypass = eq - sel_add - sel_last_onehot` + let mut sel_bypass_mle: Vec = build_eq_x_r_vec(&out_rt); + match sel_add_mle.evaluations() { + FieldType::Ext(sel_add_mle) => sel_add_mle + .par_iter() + .zip_eq(sel_bypass_mle.par_iter_mut()) + .for_each(|(sel_add, sel_bypass)| { + if *sel_add != E::ZERO { + *sel_bypass = E::ZERO; + } + }), + _ => unreachable!(), + } + *sel_bypass_mle.last_mut().unwrap() = E::ZERO; + let mut sel_bypass_mle = sel_bypass_mle.into_mle(); + let sel_add_expr = expr_builder.lift(sel_add_mle.to_either()); + let sel_bypass_expr = expr_builder.lift(sel_bypass_mle.to_either()); - let mut exprs = vec![]; + let mut exprs_add = vec![]; + let mut exprs_bypass = vec![]; let filter_bj = |v: &[Arc>], j: usize| { v.iter() @@ -156,43 +181,58 @@ impl CpuEccProver { ); // affine addition // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) with b != (1,...,1) - exprs.extend( + exprs_add.extend( (s.clone() * (&x0 - &x1) - (&y0 - &y1)) .to_exprs() .into_iter() - .zip(alpha_pows.iter().take(SEPTIC_EXTENSION_DEGREE)) + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] with b != (1,...,1) - exprs.extend( + exprs_add.extend( ((&s * &s) - &x0 - &x1 - &x3) .to_exprs() .into_iter() - .zip( - alpha_pows[SEPTIC_EXTENSION_DEGREE..] - .iter() - .take(SEPTIC_EXTENSION_DEGREE), - ) + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) with b != (1,...,1) - exprs.extend( + exprs_add.extend( (s.clone() * (&x0 - &x3) - (&y0 + &y3)) .to_exprs() .into_iter() - .zip( - alpha_pows[SEPTIC_EXTENSION_DEGREE * 2..] - .iter() - .take(SEPTIC_EXTENSION_DEGREE), - ) + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); + let exprs_add = exprs_add.into_iter().sum::>() * sel_add_expr; + + // deal with bypass + // 0 = (x[1,b] - x[b,0]) + exprs_bypass.extend( + (&x3 - &x0) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + + // 0 = (y[1,b] - y[b,0]) + exprs_bypass.extend( + (&y3 - &y0) + .to_exprs() + .into_iter() + .zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), + ); + assert!(alpha_pows_iter.next().is_none()); + + let exprs_bypass = exprs_bypass.into_iter().sum::>() * sel_bypass_expr; + let (zerocheck_proof, state) = IOPProverState::prove( - expr_builder - .to_virtual_polys(&[exprs.into_iter().sum::>() * sel_expr], &[]), + expr_builder.to_virtual_polys(&[exprs_add + exprs_bypass], &[]), transcript, ); @@ -201,16 +241,17 @@ impl CpuEccProver { assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); // 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[0,rt] - assert_eq!(evals.len(), 1 + SEPTIC_EXTENSION_DEGREE * 7); + assert_eq!(evals.len(), 2 + SEPTIC_EXTENSION_DEGREE * 7); + let last_evaluation_index = (1 << n) - 1; let x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); let y3 = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec(); let final_sum_x: SepticExtension = (x3.iter()) - .map(|x| x.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0] + .map(|x| x.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0] .collect_vec() .into(); let final_sum_y: SepticExtension = (y3.iter()) - .map(|y| y.get_base_field_vec()[num_instances - 1]) // x[1,...,1,0] + .map(|y| y.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0] .collect_vec() .into(); let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y); @@ -225,7 +266,7 @@ impl CpuEccProver { // check evaluations assert_eq!( - eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt), + eq_eval_less_or_equal_than(last_evaluation_index - 1, &out_rt, &rt), evals[0] ); for i in 0..SEPTIC_EXTENSION_DEGREE { @@ -257,7 +298,7 @@ impl CpuEccProver { // TODO: prove the validity of s[0,rt], x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt] EccQuarkProof { zerocheck_proof, - num_vars: n, + num_instances, evals, sum: final_sum, } @@ -269,12 +310,13 @@ impl> EccQuarkProver( &self, + num_instances: usize, xs: Vec>>, ys: Vec>>, invs: Vec>>, transcript: &mut impl Transcript, ) -> Result, ZKVMError> { - Ok(CpuEccProver::new().create_ecc_proof(xs, ys, invs, transcript)) + Ok(CpuEccProver::new().create_ecc_proof(num_instances, xs, ys, invs, transcript)) } } @@ -1098,8 +1140,12 @@ where #[cfg(test)] mod tests { - use std::{iter::repeat, sync::Arc}; - + use crate::scheme::{ + constants::SEPTIC_EXTENSION_DEGREE, + cpu::CpuEccProver, + septic_curve::{SepticExtension, SepticPoint}, + verifier::EccVerifier, + }; use ff_ext::BabyBearExt4; use itertools::Itertools; use multilinear_extensions::{ @@ -1107,22 +1153,22 @@ mod tests { util::transpose, }; use p3::babybear::BabyBear; + use std::{iter::repeat_n, sync::Arc}; use transcript::BasicTranscript; - - use crate::scheme::{ - constants::SEPTIC_EXTENSION_DEGREE, - cpu::CpuEccProver, - septic_curve::{SepticExtension, SepticPoint}, - verifier::EccVerifier, - }; + use witness::next_pow2_instance_padding; #[test] fn test_ecc_quark_prover() { + for n_points in 1..2 ^ 10 { + test_ecc_quark_prover_inner(n_points) + } + } + + fn test_ecc_quark_prover_inner(n_points: usize) { type E = BabyBearExt4; type F = BabyBear; - let log2_n = 6; - let n_points = 1 << log2_n; + let log2_n = next_pow2_instance_padding(n_points).ilog2(); let mut rng = rand::thread_rng(); let final_sum; @@ -1132,7 +1178,11 @@ mod tests { let mut points = (0..n_points) .map(|_| SepticPoint::::random(&mut rng)) .collect_vec(); - let mut s = Vec::with_capacity(n_points); + points.extend(repeat_n( + SepticPoint::point_at_infinity(), + (1 << log2_n) - points.len(), + )); + let mut s = Vec::with_capacity(1 << (log2_n + 1)); for layer in (1..=log2_n).rev() { let num_inputs = 1 << layer; @@ -1141,17 +1191,19 @@ mod tests { s.extend(inputs.chunks_exact(2).map(|chunk| { let p = &chunk[0]; let q = &chunk[1]; - - (&p.y - &q.y) * (&p.x - &q.x).inverse().unwrap() + if q.is_infinity { + SepticExtension::zero() + } else { + (&p.y - &q.y) * (&p.x - &q.x).inverse().unwrap() + } })); points.extend( - points[points.len() - num_inputs..] + inputs .chunks_exact(2) .map(|chunk| { let p = chunk[0].clone(); let q = chunk[1].clone(); - p + q }) .collect_vec(), @@ -1160,11 +1212,14 @@ mod tests { final_sum = points.last().cloned().unwrap(); // padding to 2*N - s.extend(repeat(SepticExtension::zero()).take(n_points + 1)); + s.extend(repeat_n( + SepticExtension::zero(), + (1 << (log2_n + 1)) - s.len(), + )); points.push(SepticPoint::point_at_infinity()); - assert_eq!(s.len(), 2 * n_points); - assert_eq!(points.len(), 2 * n_points); + assert_eq!(s.len(), 1 << (log2_n + 1)); + assert_eq!(points.len(), 1 << (log2_n + 1)); // transform points to row major matrix let trace = points @@ -1191,6 +1246,7 @@ mod tests { let mut transcript = BasicTranscript::new(b"test"); let prover = CpuEccProver::new(); let quark_proof = prover.create_ecc_proof( + n_points, xs.to_vec().into_iter().map(Arc::new).collect_vec(), ys.to_vec().into_iter().map(Arc::new).collect_vec(), s.to_vec().into_iter().map(Arc::new).collect_vec(), @@ -1203,6 +1259,7 @@ mod tests { assert!( verifier .verify_ecc_proof(&quark_proof, &mut transcript) + .inspect_err(|err| println!("err {:?}", err)) .is_ok() ); } diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index 4d96043dc..7f3ca1cd0 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -78,6 +78,7 @@ pub trait TraceCommitter { pub trait EccQuarkProver { fn prove_ec_sum_quark<'a>( &self, + num_instances: usize, xs: Vec>>, ys: Vec>>, invs: Vec>>, diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 708f1c39b..8fa0d72e9 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -351,7 +351,10 @@ impl< _ => unreachable!("ec inv's expression must be WitIn"), }) .collect_vec(); - Some(self.device.prove_ec_sum_quark(xs, ys, invs, transcript)?) + Some( + self.device + .prove_ec_sum_quark(input.num_instances, xs, ys, invs, transcript)?, + ) } else { None }; diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 788d8b637..fa288ac0f 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -594,6 +594,28 @@ impl MulAssign for SepticExtension { #[derive(Clone, Debug)] pub struct SymbolicSepticExtension(pub Vec>); +impl SymbolicSepticExtension { + pub fn mul_scalar(&self, scalar: Either) -> Self { + let res = self + .0 + .iter() + .map(|a| a.clone() * Expression::Constant(scalar)) + .collect(); + + SymbolicSepticExtension(res) + } + + pub fn add_scalar(&self, scalar: Either) -> Self { + let res = self + .0 + .iter() + .map(|a| a.clone() + Expression::Constant(scalar)) + .collect(); + + SymbolicSepticExtension(res) + } +} + impl Add for &SymbolicSepticExtension { type Output = SymbolicSepticExtension; diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index e3229971d..0d703beb3 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -5,24 +5,6 @@ use ff_ext::ExtensionField; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use gkr_iop::{gkr::GKRClaims, selector::SelectorContext, utils::eq_eval_less_or_equal_than}; -use itertools::{Itertools, chain, interleave, izip}; -use mpcs::{Point, PolynomialCommitmentScheme}; -use multilinear_extensions::{ - Instance, StructuralWitIn, StructuralWitInType, - mle::IntoMLE, - util::ceil_log2, - utils::eval_by_expr_with_instance, - virtual_poly::{VPAuxInfo, build_eq_x_r_vec_sequential, eq_eval}, -}; -use p3::field::FieldAlgebra; -use sumcheck::{ - structs::{IOPProof, IOPVerifierState}, - util::get_challenge_pows, -}; -use transcript::{ForkableTranscript, Transcript}; -use witness::next_pow2_instance_padding; - use crate::{ error::ZKVMError, scheme::{ @@ -38,6 +20,28 @@ use crate::{ eval_stacked_constant_vec, eval_stacked_wellform_address_vec, eval_wellform_address_vec, }, }; +use gkr_iop::{ + gkr::GKRClaims, + selector::{SelectorContext, SelectorType}, + utils::eq_eval_less_or_equal_than, +}; +use itertools::{Itertools, chain, interleave, izip}; +use mpcs::{Point, PolynomialCommitmentScheme}; +use multilinear_extensions::{ + Expression, Instance, StructuralWitIn, StructuralWitInType, + StructuralWitInType::StackedConstantSequence, + mle::IntoMLE, + util::ceil_log2, + utils::eval_by_expr_with_instance, + virtual_poly::{VPAuxInfo, build_eq_x_r_vec_sequential, eq_eval}, +}; +use p3::field::FieldAlgebra; +use sumcheck::{ + structs::{IOPProof, IOPVerifierState}, + util::get_challenge_pows, +}; +use transcript::{ForkableTranscript, Transcript}; +use witness::next_pow2_instance_padding; use super::{ZKVMChipProof, ZKVMProof}; @@ -842,7 +846,7 @@ impl TowerVerify { let max_num_variables = num_variables.iter().max().unwrap(); - let (next_rt, _) = (0..(max_num_variables-1)).try_fold( + let (next_rt, _) = (0..(max_num_variables - 1)).try_fold( ( PointAndEval { point: initial_rt, @@ -875,31 +879,31 @@ impl TowerVerify { // prod'[b] = prod[0,b] * prod[1,b] // prod'[out_rt] = \sum_b eq(out_rt,b) * prod'[b] = \sum_b eq(out_rt,b) * prod[0,b] * prod[1,b] eq * *alpha - * if round < *max_round-1 {tower_proofs.prod_specs_eval[spec_index][round].iter().copied().product()} else { - E::ZERO - } + * if round < *max_round - 1 { tower_proofs.prod_specs_eval[spec_index][round].iter().copied().product() } else { + E::ZERO + } }) .sum::() + (0..num_logup_spec) - .zip_eq(alpha_pows[num_prod_spec..].chunks(2)) - .zip_eq(num_variables[num_prod_spec..].iter()) - .map(|((spec_index, alpha), max_round)| { - // logup_q'[b] = logup_q[0,b] * logup_q[1,b] - // logup_p'[b] = logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b] - // logup_p'[out_rt] = \sum_b eq(out_rt,b) * (logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b]) - // logup_q'[out_rt] = \sum_b eq(out_rt,b) * logup_q[0,b] * logup_q[1,b] - let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); - eq * if round < *max_round-1 { - let evals = &tower_proofs.logup_specs_eval[spec_index][round]; - let (p1, p2, q1, q2) = - (evals[0], evals[1], evals[2], evals[3]); - *alpha_numerator * (p1 * q2 + p2 * q1) - + *alpha_denominator * (q1 * q2) - } else { - E::ZERO - } - }) - .sum::(); + .zip_eq(alpha_pows[num_prod_spec..].chunks(2)) + .zip_eq(num_variables[num_prod_spec..].iter()) + .map(|((spec_index, alpha), max_round)| { + // logup_q'[b] = logup_q[0,b] * logup_q[1,b] + // logup_p'[b] = logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b] + // logup_p'[out_rt] = \sum_b eq(out_rt,b) * (logup_p[0,b] * logup_q[1,b] + logup_p[1,b] * logup_q[0,b]) + // logup_q'[out_rt] = \sum_b eq(out_rt,b) * logup_q[0,b] * logup_q[1,b] + let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); + eq * if round < *max_round - 1 { + let evals = &tower_proofs.logup_specs_eval[spec_index][round]; + let (p1, p2, q1, q2) = + (evals[0], evals[1], evals[2], evals[3]); + *alpha_numerator * (p1 * q2 + p2 * q1) + + *alpha_denominator * (q1 * q2) + } else { + E::ZERO + } + }) + .sum::(); if expected_evaluation != sumcheck_claim.expected_evaluation { return Err(ZKVMError::VerifyError("mismatch tower evaluation".into())); @@ -908,7 +912,7 @@ impl TowerVerify { // derive single eval // rt' = r_merge || rt // r_merge.len() == ceil_log2(num_product_fanin) - let r_merge =transcript.sample_and_append_vec(b"merge", log2_num_fanin); + let r_merge = transcript.sample_and_append_vec(b"merge", log2_num_fanin); let coeffs = build_eq_x_r_vec_sequential(&r_merge); assert_eq!(coeffs.len(), num_fanin); let rt_prime = [rt, r_merge].concat(); @@ -924,17 +928,17 @@ impl TowerVerify { .zip(num_variables.iter()) .map(|((spec_index, alpha), max_round)| { // prod'[rt,r_merge] = \sum_b eq(r_merge, b) * prod'[b,rt] - if round < max_round -1 { + if round < max_round - 1 { // merged evaluation let evals = izip!( tower_proofs.prod_specs_eval[spec_index][round].iter(), coeffs.iter() ) - .map(|(a, b)| *a * *b) - .sum::(); + .map(|(a, b)| *a * *b) + .sum::(); // this will keep update until round > evaluation prod_spec_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), evals); - if next_round < max_round -1 { + if next_round < max_round - 1 { *alpha * evals } else { E::ZERO @@ -948,28 +952,28 @@ impl TowerVerify { .zip_eq(next_alpha_pows[num_prod_spec..].chunks(2)) .zip_eq(num_variables[num_prod_spec..].iter()) .map(|((spec_index, alpha), max_round)| { - if round < max_round -1 { + if round < max_round - 1 { let (alpha_numerator, alpha_denominator) = (&alpha[0], &alpha[1]); // merged evaluation let p_evals = izip!( tower_proofs.logup_specs_eval[spec_index][round][0..2].iter(), coeffs.iter() ) - .map(|(a, b)| *a * *b) - .sum::(); + .map(|(a, b)| *a * *b) + .sum::(); let q_evals = izip!( tower_proofs.logup_specs_eval[spec_index][round][2..4].iter(), coeffs.iter() ) - .map(|(a, b)| *a * *b) - .sum::(); + .map(|(a, b)| *a * *b) + .sum::(); // this will keep update until round > evaluation logup_spec_p_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), p_evals); logup_spec_q_point_n_eval[spec_index] = PointAndEval::new(rt_prime.clone(), q_evals); - if next_round < max_round -1 { + if next_round < max_round - 1 { *alpha_numerator * p_evals + *alpha_denominator * q_evals } else { E::ZERO @@ -1011,50 +1015,53 @@ impl EccVerifier { proof: &EccQuarkProof, transcript: &mut impl Transcript, ) -> Result<(), ZKVMError> { - let out_rt = transcript.sample_and_append_vec(b"ecc", proof.num_vars); - let alpha_pows = - transcript.sample_and_append_challenge_pows(SEPTIC_EXTENSION_DEGREE * 3, b"ecc_alpha"); + let num_vars = next_pow2_instance_padding(proof.num_instances).ilog2() as usize; + let out_rt = transcript.sample_and_append_vec(b"ecc", num_vars); + let alpha_pows = transcript.sample_and_append_challenge_pows( + SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2, + b"ecc_alpha", + ); + let mut alpha_pows_iter = alpha_pows.iter(); let sumcheck_claim = IOPVerifierState::verify( E::ZERO, &proof.zerocheck_proof, &VPAuxInfo { max_degree: 3, - max_num_variables: proof.num_vars, + max_num_variables: num_vars, phantom: PhantomData, }, transcript, ); - let s0: SepticExtension = proof.evals[1..][0..SEPTIC_EXTENSION_DEGREE] + let s0: SepticExtension = proof.evals[2..][0..][..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let x0: SepticExtension = proof.evals[1..] - [SEPTIC_EXTENSION_DEGREE..2 * SEPTIC_EXTENSION_DEGREE] + let x0: SepticExtension = proof.evals[2..][SEPTIC_EXTENSION_DEGREE..] + [..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let y0: SepticExtension = proof.evals[1..] - [2 * SEPTIC_EXTENSION_DEGREE..3 * SEPTIC_EXTENSION_DEGREE] + let y0: SepticExtension = proof.evals[2..][2 * SEPTIC_EXTENSION_DEGREE..] + [..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let x1: SepticExtension = proof.evals[1..] - [3 * SEPTIC_EXTENSION_DEGREE..4 * SEPTIC_EXTENSION_DEGREE] + let x1: SepticExtension = proof.evals[2..][3 * SEPTIC_EXTENSION_DEGREE..] + [..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let y1: SepticExtension = proof.evals[1..] - [4 * SEPTIC_EXTENSION_DEGREE..5 * SEPTIC_EXTENSION_DEGREE] + let y1: SepticExtension = proof.evals[2..][4 * SEPTIC_EXTENSION_DEGREE..] + [..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let x3: SepticExtension = proof.evals[1..] - [5 * SEPTIC_EXTENSION_DEGREE..6 * SEPTIC_EXTENSION_DEGREE] + let x3: SepticExtension = proof.evals[2..][5 * SEPTIC_EXTENSION_DEGREE..] + [..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let y3: SepticExtension = proof.evals[1..] - [6 * SEPTIC_EXTENSION_DEGREE..7 * SEPTIC_EXTENSION_DEGREE] + let y3: SepticExtension = proof.evals[2..][6 * SEPTIC_EXTENSION_DEGREE..] + [..SEPTIC_EXTENSION_DEGREE] .try_into() .unwrap(); - let num_instances = (1 << proof.num_vars) - 1; let rt = sumcheck_claim .point .iter() @@ -1064,6 +1071,8 @@ impl EccVerifier { // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) + // zerocheck: 0 = (x[1,b] - x[b,0]) + // zerocheck: 0 = (y[1,b] - y[b,0]) // // note that they are not septic extension field elements, // we just want to reuse the multiply/add/sub formulas @@ -1071,25 +1080,70 @@ impl EccVerifier { let v2: SepticExtension = s0.square() - &x0 - &x1 - &x3; let v3: SepticExtension = s0 * (&x0 - &x3) - (&y0 + &y3); - let v: E = vec![v1, v2, v3] - .into_iter() - .enumerate() - .flat_map(|(i, v)| { - let start = i * SEPTIC_EXTENSION_DEGREE; - let end = (i + 1) * SEPTIC_EXTENSION_DEGREE; - v.0.into_iter() - .zip(alpha_pows[start..end].iter()) - .map(|(c, alpha)| c * *alpha) - }) - .sum(); + let v4: SepticExtension = &x3 - &x0; + let v5: SepticExtension = &y3 - &y0; + + let [v1, v2, v3, v4, v5] = [v1, v2, v3, v4, v5].map(|v| { + v.0.into_iter() + .zip(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE)) + .map(|(c, alpha)| c * *alpha) + .collect_vec() + }); + + let sel_add_expr = SelectorType::::QuarkBinaryTreeLessThan(Expression::StructuralWitIn( + 0, + // this value doesn't matter, as we only need structural id + StackedConstantSequence { max_value: 0 }, + )); + let mut sel_evals = vec![E::ZERO]; + sel_add_expr.evaluate( + &mut sel_evals, + &out_rt, + &rt, + &SelectorContext { + offset: 0, + num_instances: proof.num_instances, + num_vars, + }, + 0, + ); + let expected_sel_add = sel_evals[0]; + + if proof.evals[0] != expected_sel_add { + return Err(ZKVMError::VerifyError( + (format!( + "sel_add evaluation mismatch, expected {}, got {}", + expected_sel_add, proof.evals[0] + )) + .into(), + )); + } + + // derive `sel_bypass = eq - sel_add - sel_last_onehot` + let expected_sel_bypass = eq_eval(&out_rt, &rt) + - expected_sel_add + - (out_rt.iter().copied().product::() * rt.iter().copied().product::()); - let sel = eq_eval_less_or_equal_than(num_instances - 1, &out_rt, &rt); - if sumcheck_claim.expected_evaluation != v * sel { + if proof.evals[1] != expected_sel_bypass { + return Err(ZKVMError::VerifyError( + (format!( + "sel_bypass evaluation mismatch, expected {}, got {}", + expected_sel_bypass, proof.evals[1] + )) + .into(), + )); + } + + let add_evaluations = vec![v1, v2, v3].into_iter().flatten().sum::(); + let bypass_evaluations = vec![v4, v5].into_iter().flatten().sum::(); + if sumcheck_claim.expected_evaluation + != add_evaluations * expected_sel_add + bypass_evaluations * expected_sel_bypass + { return Err(ZKVMError::VerifyError( (format!( "ecc zerocheck failed: mismatched evaluation, expected {}, got {}", sumcheck_claim.expected_evaluation, - v * sel + add_evaluations * expected_sel_add + bypass_evaluations * expected_sel_bypass )) .into(), )); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 47820680d..e400d3fdb 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -31,7 +31,7 @@ use witness::RowMajorMatrix; ))] pub struct EccQuarkProof { pub zerocheck_proof: IOPProof, - pub num_vars: usize, + pub num_instances: usize, pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[0,rt], y[0,rt], s[0,rt] pub sum: SepticPoint, } diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index cd4a8036a..0f38885be 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -456,7 +456,8 @@ pub fn extend_exprs_with_rotation( | SelectorType::Prefix(sel) | SelectorType::OrderedSparse32 { expression: sel, .. - } => match_expr(sel) * zero_check_expr, + } + | SelectorType::QuarkBinaryTreeLessThan(sel) => match_expr(sel) * zero_check_expr, }; zero_check_exprs.push(expr); } diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index 80e4df981..2330cfed7 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -53,6 +53,8 @@ pub enum SelectorType { indices: Vec, expression: Expression, }, + /// binary tree [`quark`] from paper + QuarkBinaryTreeLessThan(Expression), } impl SelectorType { @@ -119,6 +121,7 @@ impl SelectorType { .into_mle(), ) } + SelectorType::QuarkBinaryTreeLessThan(..) => unimplemented!(), } } @@ -149,6 +152,7 @@ impl SelectorType { sel.splice(end..sel.len(), repeat_n(E::ZERO, sel.len() - end)); Some(sel.into_mle()) } + // compute true and false mle eq(1; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (eq() - sel(y; b[5..])) SelectorType::OrderedSparse32 { indices, .. } => { assert_eq!(out_point.len(), ceil_log2(ctx.num_instances) + 5); @@ -176,10 +180,63 @@ impl SelectorType { }); Some(sel.into_mle()) } + // also see evaluate() function for more explanation + SelectorType::QuarkBinaryTreeLessThan(_) => { + assert_eq!(ctx.offset, 0); + // num_instances: number of prefix one in leaf layer + let mut sel: Vec = build_eq_x_r_vec(out_point); + let n = sel.len(); + + let num_instances_sequence = (0..out_point.len()) + // clean up sig bits + .scan(ctx.num_instances, |n_instance, _| { + // n points to sum means we have n/2 addition pairs + let cur = *n_instance / 2; + // the next layer has ceil(n/2) points to sum + *n_instance = (*n_instance).div_ceil(2); + Some(cur) + }) + .collect::>(); + + // split sel into different size of region, set tailing 0 of respective chunk size + // 1st round: take v = sel[0..sel.len()/2], zero out v[num_instances_sequence[0]..] + // 2nd round: take v = sel[sel.len()/2 .. sel.len()/4], zero out v[num_instances_sequence[1]..] + // ... + // each round: progressively smaller chunk + // example: round 0 uses first half, round 1 uses next quarter, etc. + // compute cumulative start indices: + // e.g. chunk = n/2, then start = 0, chunk, chunk + chunk/2, chunk + chunk/2 + chunk/4, ... + // compute disjoint start indices and lengths + let chunks: Vec<(usize, usize)> = { + let mut result = Vec::new(); + let mut start = 0; + let mut chunk_len = n / 2; + while chunk_len > 0 { + result.push((start, chunk_len)); + start += chunk_len; + chunk_len /= 2; + } + result + }; + + for (i, (start, len)) in chunks.into_iter().enumerate() { + let slice = &mut sel[start..start + len]; + + // determine from which index to zero + let zero_start = num_instances_sequence.get(i).copied().unwrap_or(0).min(len); + + for x in &mut slice[zero_start..] { + *x = E::ZERO; + } + } + + // zero out last bh evaluations + *sel.last_mut().unwrap() = E::ZERO; + Some(sel.into_mle()) + } } } - /// Evaluate true and false mle eq(CYCLIC_POW2_5[round]; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (1 - sel(y; b[5..])) pub fn evaluate( &self, evals: &mut Vec, @@ -219,6 +276,7 @@ impl SelectorType { }; (expression, sel) } + // evaluate true and false mle eq(CYCLIC_POW2_5[round]; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (1 - sel(y; b[5..])) SelectorType::OrderedSparse32 { indices, expression, @@ -236,6 +294,57 @@ impl SelectorType { ); (expression, eval * sel) } + SelectorType::QuarkBinaryTreeLessThan(expr) => { + // num_instances count on leaf layer + // where nodes size is 2^(N) / 2 + // out_point.len() is also log(2^(N)) - 1 + // so num_instances and 1 << out_point.len() are on same scaling + assert!(ctx.num_instances > 0); + assert!(ctx.num_instances <= (1 << out_point.len())); + assert!(!out_point.is_empty()); + assert_eq!(out_point.len(), in_point.len()); + + // we break down this special selector evaluation into recursive structure + // iterating through out_point and in_point, for each i + // next_eval = lhs * (1-out_point[i]) * (1 - in_point[i]) + prev_eval * out_point[i] * in_point[i] + // where the lhs is in consecutive prefix 1 follow by 0 + + // calculate prefix 1 length of each layer + let mut prefix_one_seq = (0..out_point.len()) + .scan(ctx.num_instances, |n_instance, _| { + // n points to sum means we have n/2 addition pairs + let cur = *n_instance / 2; + // next layer has ceil(n/2) points to sum + *n_instance = (*n_instance).div_ceil(2); + Some(cur) + }) + .collect::>(); + prefix_one_seq.reverse(); + + let mut res = if prefix_one_seq[0] == 0 { + E::ZERO + } else { + assert_eq!(prefix_one_seq[0], 1); + (E::ONE - out_point[0]) * (E::ONE - in_point[0]) + }; + for i in 1..out_point.len() { + let num_prefix_one_lhs = prefix_one_seq[i]; + let lhs_res = if num_prefix_one_lhs == 0 { + E::ZERO + } else { + (E::ONE - out_point[i]) + * (E::ONE - in_point[i]) + * eq_eval_less_or_equal_than( + num_prefix_one_lhs - 1, + &out_point[..i], + &in_point[..i], + ) + }; + let rhs_res = (out_point[i] * in_point[i]) * res; + res = lhs_res + rhs_res; + } + (expr, res) + } }; let Expression::StructuralWitIn(wit_id, _) = expr else { panic!("Wrong selector expression format"); @@ -264,3 +373,58 @@ impl SelectorType { } } } + +#[cfg(test)] +mod tests { + use ff_ext::{BabyBearExt4, FromUniformBytes}; + use multilinear_extensions::{ + StructuralWitIn, ToExpr, util::ceil_log2, virtual_poly::build_eq_x_r_vec, + }; + use p3::field::FieldAlgebra; + use rand::thread_rng; + + use crate::selector::{SelectorContext, SelectorType}; + + type E = BabyBearExt4; + + #[test] + fn test_quark_lt_selector() { + let mut rng = thread_rng(); + let n_points = 5; + let n_vars = ceil_log2(n_points); + let witin = StructuralWitIn { + id: 0, + witin_type: multilinear_extensions::StructuralWitInType::EqualDistanceSequence { + max_len: 0, + offset: 0, + multi_factor: 0, + descending: false, + }, + }; + let selector = SelectorType::QuarkBinaryTreeLessThan(witin.expr()); + let ctx = SelectorContext::new(0, n_points, n_vars); + let out_rt = E::random_vec(n_vars, &mut rng); + let sel_mle = selector.compute(&out_rt, &ctx).unwrap(); + + // if we have 5 points to sum, then + // in 1st layer: two additions p12 = p1 + p2, p34 = p3 + p4, p5 kept + // in 2nd layer: one addition p14 = p12 + p34, p5 kept + // in 3rd layer: one addition p15 = p14 + p5 + let eq = build_eq_x_r_vec(&out_rt); + let vec = sel_mle.get_ext_field_vec(); + assert_eq!(vec[0], eq[0]); // p1+p2 + assert_eq!(vec[1], eq[1]); // p3+p4 + assert_eq!(vec[2], E::ZERO); // p5 + assert_eq!(vec[3], E::ZERO); + assert_eq!(vec[4], eq[4]); // p1+p2+p3+p4 + assert_eq!(vec[5], E::ZERO); // p5 + assert_eq!(vec[6], eq[6]); // p1+p2+p3+p4+p5 + assert_eq!(vec[7], E::ZERO); + + let in_rt = E::random_vec(n_vars, &mut rng); + let mut evals = vec![]; + // TODO: avoid the param evals when we evaluate a selector + selector.evaluate(&mut evals, &out_rt, &in_rt, &ctx, 0); + assert_eq!(sel_mle.evaluate(&in_rt), evals[0]); + } +} From bcd3eb9e8426c5931cbf20c2871f4dd856095f48 Mon Sep 17 00:00:00 2001 From: xkx Date: Tue, 28 Oct 2025 08:38:10 +0800 Subject: [PATCH 82/91] Feat: integrate ecc quark prover into prover's and verifier's workflow (#1093) --- ceno_zkvm/src/chip_handler/general.rs | 2 - ceno_zkvm/src/gadgets/poseidon2.rs | 7 +- ceno_zkvm/src/gadgets/poseidon2_constants.rs | 2 + ceno_zkvm/src/instructions.rs | 1 - ceno_zkvm/src/instructions/global.rs | 372 +++++++++++------- .../src/instructions/riscv/memory/test.rs | 1 - ceno_zkvm/src/scheme/cpu/mod.rs | 32 +- ceno_zkvm/src/scheme/gpu/mod.rs | 4 +- ceno_zkvm/src/scheme/hal.rs | 14 +- ceno_zkvm/src/scheme/prover.rs | 9 +- ceno_zkvm/src/scheme/septic_curve.rs | 37 +- ceno_zkvm/src/scheme/tests.rs | 1 + ceno_zkvm/src/scheme/utils.rs | 4 +- ceno_zkvm/src/scheme/verifier.rs | 59 ++- ceno_zkvm/src/structs.rs | 5 + gkr_iop/src/circuit_builder.rs | 4 +- gkr_iop/src/cpu/mod.rs | 2 +- gkr_iop/src/gkr/layer/zerocheck_layer.rs | 2 +- 18 files changed, 322 insertions(+), 236 deletions(-) diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index a86d82dd0..8075bb9ba 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -72,14 +72,12 @@ impl<'a, E: ExtensionField> PublicIOQuery for CircuitBuilder<'a, E> { fn query_global_rw_sum(&mut self) -> Result, CircuitBuilderError> { let x = (0..SEPTIC_EXTENSION_DEGREE) - .into_iter() .map(|i| { self.cs .query_instance(|| format!("global_rw_sum_x_{}", i), GLOBAL_RW_SUM_IDX + i) }) .collect::, CircuitBuilderError>>()?; let y = (0..SEPTIC_EXTENSION_DEGREE) - .into_iter() .map(|i| { self.cs.query_instance( || format!("global_rw_sum_y_{}", i), diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index 80e6e6728..e59817bf3 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -12,7 +12,7 @@ use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; use num_bigint::BigUint; use p3::{ - babybear::{BabyBear, BabyBearInternalLayerParameters}, + babybear::BabyBearInternalLayerParameters, field::{Field, FieldAlgebra, PrimeField}, monty_31::InternalLayerBaseParameters, poseidon2::{GenericPoseidon2LinearLayers, MDSMat4, mds_light_permutation}, @@ -65,7 +65,7 @@ impl GenericPoseidon2LinearLayers let diag_m1_matrix: &[F; WIDTH] = unsafe { transmute(diag_m1_matrix) }; let sum = state.iter().cloned().sum::(); for (input, diag_m1) in state.iter_mut().zip(diag_m1_matrix) { - *input = sum.clone() + F::from_f(*diag_m1) * input.clone(); + *input = sum + F::from_f(*diag_m1) * *input; } } else { panic!("Unsupported field"); @@ -325,7 +325,6 @@ impl< ////////////////////////////////////////////////////////////////////////// /// The following routines are taken from poseidon2-air/src/generation.rs ////////////////////////////////////////////////////////////////////////// - fn generate_trace_rows_for_perm< F: PrimeField, LinearLayers: GenericPoseidon2LinearLayers, @@ -485,6 +484,6 @@ mod tests { let mut cb = CircuitBuilder::::new(&mut cs); let poseidon2_constants = horizen_round_consts(); - let poseidon2_config = Poseidon2BabyBearConfig::construct(&mut cb, poseidon2_constants); + let _ = Poseidon2BabyBearConfig::construct(&mut cb, poseidon2_constants); } } diff --git a/ceno_zkvm/src/gadgets/poseidon2_constants.rs b/ceno_zkvm/src/gadgets/poseidon2_constants.rs index cf807a56d..85ae2a1ae 100644 --- a/ceno_zkvm/src/gadgets/poseidon2_constants.rs +++ b/ceno_zkvm/src/gadgets/poseidon2_constants.rs @@ -12,10 +12,12 @@ const BABY_BEAR_POSEIDON2_WIDTH: usize = 16; const BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS: usize = 4; const BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS: usize = 13; +#[allow(dead_code)] pub(crate) fn horizen_to_p3_babybear(horizen_babybear: HorizenBabyBear) -> BabyBear { BabyBear::from_canonical_u64(horizen_babybear.into_bigint().0[0]) } +#[allow(dead_code)] pub(crate) fn horizen_round_consts() -> RoundConstants< BabyBear, BABY_BEAR_POSEIDON2_WIDTH, diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 2ad63d163..70df0b67f 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -2,7 +2,6 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, structs::ProgramParams, tables::RMMCollections, witness::LkMultiplicity, }; -use ceno_emul::StepRecord; use ff_ext::{ExtensionField, FieldInto}; use gkr_iop::{ chip::Chip, diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 806c81e32..4d2e1ef56 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -1,4 +1,4 @@ -use std::iter::repeat; +use std::iter::repeat_n; use crate::{ Value, @@ -16,7 +16,7 @@ use gkr_iop::{ chip::Chip, circuit_builder::CircuitBuilder, error::CircuitBuilderError, gkr::layer::Layer, selector::SelectorType, utils::lk_multiplicity::Multiplicity, }; -use itertools::Itertools; +use itertools::{Itertools, chain}; use multilinear_extensions::{ Expression, StructuralWitInType::EqualDistanceSequence, ToExpr, WitIn, util::max_usable_threads, }; @@ -24,7 +24,6 @@ use p3::{ field::{Field, FieldAlgebra}, matrix::dense::RowMajorMatrix, symmetric::Permutation, - util::log2_ceil_usize, }; use rayon::{ iter::{IndexedParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator}, @@ -39,11 +38,92 @@ use crate::{ scheme::constants::SEPTIC_EXTENSION_DEGREE, }; -// opcode circuit + mem init/final table + global chip: -// have read/write consistency for RAMType::Register and RAMType::Memory -// -// global chip: read from and write into a global set shared -// among multiple shards +/// A record for a read/write into the global set +#[derive(Default, Debug, Clone)] +pub struct GlobalRecord { + pub addr: u32, + pub ram_type: RAMType, + pub value: u32, + pub shard: u64, + pub local_clk: u64, + pub global_clk: u64, + pub is_write: bool, +} + +/// An EC point corresponding to a global read/write record +/// whose x-coordinate is derived from Poseidon2 hash of the record +#[derive(Clone, Debug)] +pub struct GlobalPoint { + pub nonce: u32, + pub point: SepticPoint, +} + +impl GlobalRecord { + pub fn to_ec_point< + E: ExtensionField, + P: Permutation<[E::BaseField; POSEIDON2_BABYBEAR_WIDTH]>, + >( + &self, + hasher: &P, + ) -> GlobalPoint { + let mut nonce = 0; + let mut input = [ + E::BaseField::from_canonical_u32(self.addr), + E::BaseField::from_canonical_u32(self.ram_type as u32), + E::BaseField::from_canonical_u32(self.value & 0xFFFF), // lower 16 bits + E::BaseField::from_canonical_u32((self.value >> 16) & 0xFFFF), // higher 16 bits + E::BaseField::from_canonical_u64(self.shard), + E::BaseField::from_canonical_u64(self.global_clk), + E::BaseField::from_canonical_u32(nonce), + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + E::BaseField::ZERO, + ]; + + let prime = E::BaseField::order().to_u64_digits()[0]; + loop { + let x: SepticExtension = + hasher.permute(input)[0..SEPTIC_EXTENSION_DEGREE].into(); + if let Some(p) = SepticPoint::from_x(x) { + let y6 = (p.y.0)[SEPTIC_EXTENSION_DEGREE - 1].to_canonical_u64(); + let is_y_in_2nd_half = y6 >= (prime / 2); + + // we negate y if needed + // to ensure read => y in [0, p/2) and write => y in [p/2, p) + let negate = match (self.is_write, is_y_in_2nd_half) { + (true, false) => true, // write, y in [0, p/2) + (false, true) => true, // read, y in [p/2, p) + _ => false, + }; + + let point = if negate { -p } else { p }; + + return GlobalPoint { nonce, point }; + } else { + // try again with different nonce + nonce += 1; + input[6] = E::BaseField::from_canonical_u32(nonce); + } + } + } +} +/// opcode circuit + mem init/final table + local finalize circuit + global chip +/// global chip is used to ensure the **local** reads and writes produced by +/// opcode circuits / memory init / memory finalize table / local finalize circuit +/// can balance out. +/// +/// 1. For a local memory read record whose previous write is not in the same shard, +/// the global chip will read it from the **global set** and insert a local write record. +/// 2. For a local memory write record which will **not** be read in the future, +/// the local finalize circuit will consume it by inserting a local read record. +/// 3. For a local memory write record which will be read in the future, +/// the global chip will insert a local read record and write it to the **global set**. pub struct GlobalConfig { addr: WitIn, is_ram_register: WitIn, @@ -103,21 +183,21 @@ impl GlobalConfig { input.push(global_clk.expr()); // add nonce to ensure poseidon2(input) always map to a valid ec point input.push(nonce.expr()); - input.extend(repeat(E::BaseField::ZERO.expr()).take(16 - input.len())); + input.extend(repeat_n(E::BaseField::ZERO.expr(), 16 - input.len())); let mut record = vec![]; record.push(addr.expr()); record.push(ram_type); record.extend(value.memory_expr()); - record.push(shard.expr()); record.push(local_clk.expr()); // if is_global_write = 1, then it means we are propagating a local write to global // so we need to insert a local read record to cancel out this local write - cb.assert_bit(|| "is_global_write must be boolean", is_global_write.expr())?; + // TODO: for all local reads, enforce they come to global writes + // TODO: for all local writes, enforce they come from global reads - // local read/write consistency + // global read => insert a local write with local_clk = 0 cb.condition_require_zero( || "is_global_read => local_clk = 0", 1 - is_global_write.expr(), @@ -127,12 +207,12 @@ impl GlobalConfig { cb.read_record( || "r_record", - gkr_iop::RAMType::Memory, // TODO fixme + gkr_iop::RAMType::Memory, // FIXME: should be dynamic, either Register or Memory record.clone(), )?; cb.write_record( || "w_record", - gkr_iop::RAMType::Memory, // TODO fixme + gkr_iop::RAMType::Memory, // FIXME: should be dynamic, either Register or Memory record.clone(), )?; @@ -179,86 +259,24 @@ impl GlobalConfig { } } -#[derive(Default)] -pub struct GlobalRecord { - pub addr: u32, - pub ram_type: RAMType, - pub value: u32, - pub shard: u64, - pub local_clk: u64, - pub global_clk: u64, - pub is_write: bool, -} - -impl GlobalRecord { - pub fn to_ec_point< - E: ExtensionField, - P: Permutation<[E::BaseField; POSEIDON2_BABYBEAR_WIDTH]>, - >( - &self, - hasher: &P, - ) -> (u32, SepticPoint) { - let mut nonce = 0; - let mut input = [ - E::BaseField::from_canonical_u32(self.addr), - E::BaseField::from_canonical_u32(self.ram_type as u32), - E::BaseField::from_canonical_u32(self.value & 0xFFFF), // lower 16 bits - E::BaseField::from_canonical_u32((self.value >> 16) & 0xFFFF), // higher 16 bits - E::BaseField::from_canonical_u64(self.shard), - E::BaseField::from_canonical_u64(self.global_clk), - E::BaseField::from_canonical_u32(nonce), - E::BaseField::ZERO, - E::BaseField::ZERO, - E::BaseField::ZERO, - E::BaseField::ZERO, - E::BaseField::ZERO, - E::BaseField::ZERO, - E::BaseField::ZERO, - E::BaseField::ZERO, - E::BaseField::ZERO, - ]; - - let prime = E::BaseField::order().to_u64_digits()[0]; - loop { - let x: SepticExtension = - hasher.permute(input)[0..SEPTIC_EXTENSION_DEGREE].into(); - if let Some(p) = SepticPoint::from_x(x) { - let y6 = (p.y.0)[SEPTIC_EXTENSION_DEGREE - 1].to_canonical_u64(); - let is_y_in_2nd_half = y6 >= (prime / 2); - - // we negate y if needed - let negate = match (self.is_write, is_y_in_2nd_half) { - (true, false) => true, // write, y in [0, p/2) - (false, true) => true, // read, y in [p/2, p) - _ => false, - }; - - if negate { - return (nonce, -p); - } else { - return (nonce, p); - } - } else { - // try again with different nonce - nonce += 1; - input[6] = E::BaseField::from_canonical_u32(nonce); - } - } - } -} - -// This chip is used to manage read/write into a global set -// shared among multiple shards +/// This chip is used to manage read/write into a global set +/// shared among multiple shards pub struct GlobalChip { rc: RoundConstants, perm: P, } +#[derive(Clone, Debug)] +pub struct GlobalChipInput { + pub record: GlobalRecord, + pub ec_point: Option>, // to be filled during instance assignment +} + impl + Send> Instruction for GlobalChip { type InstructionConfig = GlobalConfig; - type Record = GlobalRecord; + type Record = GlobalChipInput; fn name() -> String { "Global".to_string() @@ -280,14 +298,6 @@ impl param: &ProgramParams, ) -> Result<(Self::InstructionConfig, gkr_iop::gkr::GKRCircuit), crate::error::ZKVMError> { - let config = self.construct_circuit(cb, param)?; - - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - // create three selectors: selector_r, selector_w, selector_zero let selector_r = cb.create_structural_witin( || "selector_r", @@ -301,6 +311,7 @@ impl ); let selector_w = cb.create_structural_witin( || "selector_w", + // this is just a placeholder, the actural type is SelectorType::Prefix() EqualDistanceSequence { max_len: 0, offset: 0, @@ -310,6 +321,7 @@ impl ); let selector_zero = cb.create_structural_witin( || "selector_zero", + // this is just a placeholder, the actural type is SelectorType::Prefix() EqualDistanceSequence { max_len: 0, offset: 0, @@ -317,6 +329,15 @@ impl descending: false, }, ); + + let config = self.construct_circuit(cb, param)?; + + let w_len = cb.cs.w_expressions.len(); + let r_len = cb.cs.r_expressions.len(); + let lk_len = cb.cs.lk_expressions.len(); + let zero_len = + cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); + let selector_r = SelectorType::Prefix(selector_r.expr()); // note that the actual offset should be set by prover // depending on the number of local read instances @@ -354,10 +375,11 @@ impl fn assign_instance( config: &Self::InstructionConfig, instance: &mut [E::BaseField], - _lk_multiplicity: &mut LkMultiplicity, - record: &GlobalRecord, + lk_multiplicity: &mut LkMultiplicity, + input: &Self::Record, ) -> Result<(), crate::error::ZKVMError> { // assign basic fields + let record = &input.record; let is_ram_register = match record.ram_type { RAMType::Register => 1, RAMType::Memory => 0, @@ -365,7 +387,7 @@ impl }; set_val!(instance, config.addr, record.addr as u64); set_val!(instance, config.is_ram_register, is_ram_register as u64); - let value = Value::new_unchecked(record.value); + let value = Value::new(record.value, lk_multiplicity); config.value.assign_limbs(instance, value.as_u16_limbs()); set_val!(instance, config.shard, record.shard); set_val!(instance, config.global_clk, record.global_clk); @@ -373,15 +395,15 @@ impl set_val!(instance, config.is_global_write, record.is_write as u64); // assign (x, y) and nonce - let (nonce, point) = record.to_ec_point::(&config.perm); - set_val!(instance, config.nonce, nonce as u64); + let GlobalPoint { nonce, point } = input.ec_point.as_ref().unwrap(); + set_val!(instance, config.nonce, *nonce as u64); config .x .iter() .chain(config.y.iter()) .zip_eq((point.x.deref()).iter().chain((point.y.deref()).iter())) .for_each(|(witin, fe)| { - set_val!(instance, *witin, fe.to_canonical_u64()); + instance[witin.id as usize] = *fe; }); let ram_type = E::BaseField::from_canonical_u32(record.ram_type as u32); @@ -396,11 +418,12 @@ impl .for_each(|(i, v)| *i = E::BaseField::from_canonical_u16(*v)); input[2 + k] = E::BaseField::from_canonical_u64(record.shard); input[2 + k + 1] = E::BaseField::from_canonical_u64(record.global_clk); - input[2 + k + 2] = E::BaseField::from_canonical_u32(nonce); + input[2 + k + 2] = E::BaseField::from_canonical_u32(*nonce); config .perm_config - .assign_instance(&mut instance[21 + UINT_LIMBS..], input); + // TODO: remove hardcoded constant 28 + .assign_instance(&mut instance[28 + UINT_LIMBS..], input); Ok(()) } @@ -409,11 +432,12 @@ impl config: &Self::InstructionConfig, num_witin: usize, num_structural_witin: usize, - steps: Vec, + mut steps: Vec, ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { // FIXME selector is the only structural witness // this is workaround, as call `construct_circuit` will not initialized selector // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` + assert!(num_structural_witin == 3); let selector_r_witin = WitIn { id: 0 }; let selector_w_witin = WitIn { id: 1 }; @@ -421,8 +445,8 @@ impl let nthreads = max_usable_threads(); - // local read => global write - let num_local_reads = steps.iter().filter(|s| s.is_write).count(); + // local read iff it's global write + let num_local_reads = steps.iter().filter(|s| s.record.is_write).count(); let num_instance_per_batch = if steps.len() > 256 { steps.len().div_ceil(nthreads) @@ -430,9 +454,22 @@ impl steps.len() } .max(1); + + let n = next_pow2_instance_padding(steps.len()); + // compute the input for the binary tree for ec point summation + steps + .par_chunks_mut(num_instance_per_batch) + .for_each(|chunk| { + chunk.iter_mut().for_each(|step| { + let point = step.record.to_ec_point::(&config.perm); + + step.ec_point.replace(point); + }); + }); + let lk_multiplicity = LkMultiplicity::default(); // *2 because we need to store the internal nodes of binary tree for ec point summation - let num_rows_padded = next_pow2_instance_padding(steps.len()) * 2; + let num_rows_padded = 2 * n; let mut raw_witin = { let matrix_size = num_rows_padded * num_witin; @@ -463,7 +500,8 @@ impl raw_witin_iter .zip_eq(raw_structual_witin_iter) .zip_eq(steps.par_chunks(num_instance_per_batch)) - .flat_map(|((instances, structural_instance), steps)| { + .enumerate() + .flat_map(|(chunk_idx, ((instances, structural_instance), steps))| { let mut lk_multiplicity = lk_multiplicity.clone(); instances .chunks_mut(num_witin) @@ -471,7 +509,8 @@ impl .zip_eq(steps) .enumerate() .map(|(i, ((instance, structural_instance), step))| { - let (sel_r, sel_w) = if i < num_local_reads { + let row = chunk_idx * num_instance_per_batch + i; + let (sel_r, sel_w) = if row < num_local_reads { (E::BaseField::ONE, E::BaseField::ZERO) } else { (E::BaseField::ZERO, E::BaseField::ONE) @@ -486,17 +525,67 @@ impl .collect::>()?; // assign internal nodes in the binary tree for ec point summation - let half_witin_matrix_size = num_rows_padded / 2 * num_witin; - let raw_witin_iter = raw_witin.values - [half_witin_matrix_size..(2 * half_witin_matrix_size - 1)] - .par_chunks_mut(num_witin) - .for_each(|instance| { - for i in 0..SEPTIC_EXTENSION_DEGREE { - set_val!(instance, config.x[i], E::BaseField::default()); - set_val!(instance, config.y[i], E::BaseField::default()); - set_val!(instance, config.slope[i], E::BaseField::default()); - } - }); + let mut cur_layer_points = steps + .iter() + .map(|step| step.ec_point.as_ref().map(|p| p.point.clone()).unwrap()) + .enumerate() + .collect_vec(); + + // slope[1,b] = (input[b,0].y - input[b,1].y) / (input[b,0].x - input[b,1].x) + loop { + if cur_layer_points.len() <= 1 { + break; + } + // 2b -> b + 2^log_n + let next_layer_offset = cur_layer_points.first().map(|(i, _)| *i / 2 + n).unwrap(); + cur_layer_points = cur_layer_points + .par_chunks(2) + .zip(raw_witin.values[next_layer_offset * num_witin..].par_chunks_mut(num_witin)) + .with_min_len(64) + .map(|(pair, instance)| { + // input[1,b] = affine_add(input[b,0], input[b,1]) + // the left node is at index 2b, right node is at index 2b+1 + // the parent node is at index b + 2^n + let (o, slope, q) = match pair.len() { + 2 => { + // l = 2b, r = 2b+1 + let (l, p1) = &pair[0]; + let (r, p2) = &pair[1]; + assert_eq!(*r - *l, 1); + + // parent node idx = b + 2^log2_n + let o = n + l / 2; + let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap(); + let q = p1.clone() + p2.clone(); + + (o, slope, q) + } + 1 => { + let (l, p) = &pair[0]; + let o = n + l / 2; + (o, SepticExtension::zero(), p.clone()) + } + _ => unreachable!(), + }; + + config + .x + .iter() + .chain(config.y.iter()) + .chain(config.slope.iter()) + .zip_eq(chain!( + q.x.deref().iter(), + q.y.deref().iter(), + slope.deref().iter(), + )) + .for_each(|(witin, fe)| { + set_val!(instance, *witin, *fe); + }); + + (o, q) + }) + .collect::>(); + } let raw_witin = witness::RowMajorMatrix::new_by_inner_matrix( raw_witin, @@ -520,12 +609,10 @@ mod tests { use ff_ext::{BabyBearExt4, FromUniformBytes, PoseidonField}; use itertools::Itertools; use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; - use p3::{babybear::BabyBear, field::FieldAlgebra}; + use p3::babybear::BabyBear; use rand::thread_rng; use tracing_forest::{ForestLayer, util::LevelFilter}; - use tracing_subscriber::{ - EnvFilter, Registry, fmt, layer::SubscriberExt, util::SubscriberInitExt, - }; + use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; use transcript::BasicTranscript; use crate::{ @@ -533,7 +620,7 @@ mod tests { gadgets::horizen_round_consts, instructions::{ Instruction, - global::{GlobalChip, GlobalRecord}, + global::{GlobalChip, GlobalChipInput, GlobalRecord}, }, scheme::{ PublicValues, create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, @@ -546,8 +633,8 @@ mod tests { type E = BabyBearExt4; type F = BabyBear; - type PERM = ::P; - type PCS = BasefoldDefault; + type Perm = ::P; + type Pcs = BasefoldDefault; #[test] fn test_global_chip() { @@ -555,22 +642,16 @@ mod tests { let default_filter = EnvFilter::builder() .with_default_directive(LevelFilter::DEBUG.into()) .from_env_lossy(); - // let fmt_layer = fmt::layer() - // .compact() - // .with_thread_ids(false) - // .with_thread_names(false) - // .without_time(); Registry::default() .with(ForestLayer::default()) - // .with(fmt_layer) .with(default_filter) .init(); // init global chip with horizen_rc_consts let rc = horizen_round_consts(); let perm = ::get_default_perm(); - let global_chip = GlobalChip:: { rc, perm }; + let global_chip = GlobalChip:: { rc, perm }; let mut cs = ConstraintSystem::new(|| "global chip test"); let mut cb = CircuitBuilder::new(&mut cs); @@ -580,8 +661,8 @@ mod tests { .unwrap(); // create a bunch of random memory read/write records - let n_global_reads = 16; - let n_global_writes = 16; + let n_global_reads = 1700; + let n_global_writes = 1420; let global_reads = (0..n_global_reads) .map(|i| { let addr = i * 8; @@ -619,7 +700,7 @@ mod tests { let global_ec_sum: SepticPoint = global_reads .iter() .chain(global_writes.iter()) - .map(|record| record.to_ec_point::(&global_chip.perm).1) + .map(|record| record.to_ec_point::(&global_chip.perm).point) .sum(); let public_value = PublicValues::new( @@ -637,13 +718,17 @@ mod tests { .collect_vec(), ); // assign witness - let (witness, lk) = GlobalChip::assign_instances( + let (witness, _) = GlobalChip::assign_instances( &config, cs.num_witin as usize, cs.num_structural_witin as usize, global_writes // local reads .into_iter() - .chain(global_reads.into_iter()) // local writes + .chain(global_reads) // local writes + .map(|record| GlobalChipInput { + record, + ec_point: None, + }) .collect::>(), ) .unwrap(); @@ -655,13 +740,12 @@ mod tests { let pk = composed_cs.key_gen(); // create chip proof for global chip - let pcs_param = PCS::setup(1 << 20, SecurityLevel::Conjecture100bits).unwrap(); - let (pp, vp) = PCS::trim(pcs_param, 1 << 20).unwrap(); - let backend = create_backend::(20, SecurityLevel::Conjecture100bits); + let pcs_param = Pcs::setup(1 << 20, SecurityLevel::Conjecture100bits).unwrap(); + let (pp, vp) = Pcs::trim(pcs_param, 1 << 20).unwrap(); + let backend = create_backend::(20, SecurityLevel::Conjecture100bits); let pd = create_prover(backend); - // let pk = prover.create_chip_proof(); - let mut zkvm_pk = ZKVMProvingKey::new(pp, vp); + let zkvm_pk = ZKVMProvingKey::new(pp, vp); let zkvm_vk = zkvm_pk.get_vk_slow(); let zkvm_prover = ZKVMProver::new(zkvm_pk, pd); let mut transcript = BasicTranscript::new(b"global chip test"); @@ -679,6 +763,7 @@ mod tests { num_read_instances: n_global_writes as usize, num_write_instances: n_global_reads as usize, num_instances: (n_global_reads + n_global_writes) as usize, + has_ecc_ops: true, }; let mut rng = thread_rng(); let challenges = [E::random(&mut rng), E::random(&mut rng)]; @@ -698,7 +783,7 @@ mod tests { .iter() .map(|mle| mle.evaluate(&point[..mle.num_vars()])) .collect_vec(); - let opening_point = verifier + let vrf_point = verifier .verify_opcode_proof( "global", &pk.vk, @@ -710,5 +795,6 @@ mod tests { &challenges, ) .expect("verify global chip proof"); + assert_eq!(vrf_point, point); } } diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index da519027f..700bb2ead 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -21,7 +21,6 @@ use ff_ext::BabyBearExt4; use ff_ext::{ExtensionField, GoldilocksExt2}; use gkr_iop::circuit_builder::DebugIndex; use std::hash::Hash; -use tracing::span::Record; fn sb(prev: Word, rs2: Word, shift: u32) -> Word { let shift = (shift * 8) as usize; diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 8e3554236..69173609d 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -46,7 +46,7 @@ use transcript::Transcript; use witness::next_pow2_instance_padding; #[cfg(feature = "sanity-check")] -use {crate::scheme::septic_curve::SepticExtension, gkr_iop::utils::eq_eval_less_or_equal_than}; +use gkr_iop::utils::eq_eval_less_or_equal_than; pub type TowerRelationOutput = ( Point, @@ -139,10 +139,10 @@ impl CpuEccProver { let mut y0 = filter_bj(&ys, 0); let mut x1 = filter_bj(&xs, 1); let mut y1 = filter_bj(&ys, 1); - // build x[1,b], y[1,b], s[0,b] + // build x[1,b], y[1,b], s[1,b] let mut x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); let mut y3 = ys.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); - let mut s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec(); + let mut s = invs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); let s = SymbolicSepticExtension::new( s.iter_mut() @@ -180,7 +180,7 @@ impl CpuEccProver { .collect(), ); // affine addition - // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) with b != (1,...,1) + // zerocheck: 0 = s[1,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) with b != (1,...,1) exprs_add.extend( (s.clone() * (&x0 - &x1) - (&y0 - &y1)) .to_exprs() @@ -189,7 +189,7 @@ impl CpuEccProver { .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); - // zerocheck: 0 = s[0,b]^2 - x[b,0] - x[b,1] - x[1,b] with b != (1,...,1) + // zerocheck: 0 = s[1,b]^2 - x[b,0] - x[b,1] - x[1,b] with b != (1,...,1) exprs_add.extend( ((&s * &s) - &x0 - &x1 - &x3) .to_exprs() @@ -198,7 +198,7 @@ impl CpuEccProver { .map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))), ); - // zerocheck: 0 = s[0,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) with b != (1,...,1) + // zerocheck: 0 = s[1,b] * (x[b,0] - x[1,b]) - (y[b,0] + y[1,b]) with b != (1,...,1) exprs_add.extend( (s.clone() * (&x0 - &x3) - (&y0 + &y3)) .to_exprs() @@ -240,7 +240,7 @@ impl CpuEccProver { let evals = state.get_mle_flatten_final_evaluations(); assert_eq!(zerocheck_proof.extract_sum(), E::ZERO); - // 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[0,rt] + // 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[1,rt] assert_eq!(evals.len(), 2 + SEPTIC_EXTENSION_DEGREE * 7); let last_evaluation_index = (1 << n) - 1; @@ -295,11 +295,12 @@ impl CpuEccProver { } } - // TODO: prove the validity of s[0,rt], x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt] + // TODO: prove the validity of s[1,rt], x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt] EccQuarkProof { zerocheck_proof, num_instances, evals, + rt, sum: final_sum, } } @@ -887,8 +888,7 @@ impl> MainSumcheckProver> MainSumcheckProver { pub num_read_instances: usize, pub num_write_instances: usize, pub num_instances: usize, + pub has_ecc_ops: bool, } impl<'a, PB: ProverBackend> ProofInput<'a, PB> { #[inline] pub fn log2_num_instances(&self) -> usize { - ceil_log2(next_pow2_instance_padding(self.num_instances)) + let log2 = ceil_log2(next_pow2_instance_padding(self.num_instances)); + if self.has_ecc_ops { + // the mles have one extra variable to store + // the internal partial sums for ecc additions + log2 + 1 + } else { + log2 + } } } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 8fa0d72e9..65f76d365 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -237,6 +237,7 @@ impl< num_read_instances: num_instances, // TODO: fixme num_write_instances: num_instances, // TODO: fixme num_instances, + has_ecc_ops: cs.has_ecc_ops(), }; if cs.is_opcode_circuit() { @@ -334,7 +335,7 @@ impl< let ec_point_exprs = &cs.zkvm_v1_css.ec_point_exprs; assert_eq!(ec_point_exprs.len(), SEPTIC_EXTENSION_DEGREE * 2); let mut xs_ys = ec_point_exprs - .into_iter() + .iter() .map(|expr| match expr { Expression::WitIn(id) => input.witness[*id as usize].clone(), _ => unreachable!("ec point's expression must be WitIn"), @@ -342,18 +343,18 @@ impl< .collect_vec(); let ys = xs_ys.split_off(SEPTIC_EXTENSION_DEGREE); let xs = xs_ys; - let invs = cs + let slopes = cs .zkvm_v1_css .ec_slope_exprs .iter() .map(|expr| match expr { Expression::WitIn(id) => input.witness[*id as usize].clone(), - _ => unreachable!("ec inv's expression must be WitIn"), + _ => unreachable!("slope's expression must be WitIn"), }) .collect_vec(); Some( self.device - .prove_ec_sum_quark(input.num_instances, xs, ys, invs, transcript)?, + .prove_ec_sum_quark(input.num_instances, xs, ys, slopes, transcript)?, ) } else { None diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index fa288ac0f..6c6120050 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -412,8 +412,8 @@ impl QuadraticExtension { impl SepticExtension { pub fn random(mut rng: impl RngCore) -> Self { let mut arr = [F::ZERO; 7]; - for i in 0..7 { - arr[i] = F::random(&mut rng); + for item in arr.iter_mut() { + *item = F::random(&mut rng); } Self(arr) } @@ -434,8 +434,8 @@ impl Add<&Self> for SepticExtension { fn add(self, other: &Self) -> Self { let mut result = [F::ZERO; 7]; - for i in 0..7 { - result[i] = self.0[i] + other.0[i]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] + other.0[i]; } Self(result) } @@ -446,8 +446,8 @@ impl Add for &SepticExtension { fn add(self, other: Self) -> SepticExtension { let mut result = [F::ZERO; 7]; - for i in 0..7 { - result[i] = self.0[i] + other.0[i]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] + other.0[i]; } SepticExtension(result) } @@ -466,8 +466,8 @@ impl Neg for SepticExtension { fn neg(self) -> Self { let mut result = [F::ZERO; 7]; - for i in 0..7 { - result[i] = -self.0[i]; + for (res, src) in result.iter_mut().zip(self.0.iter()) { + *res = -(*src); } Self(result) } @@ -478,8 +478,8 @@ impl Sub<&Self> for SepticExtension { fn sub(self, other: &Self) -> Self { let mut result = [F::ZERO; 7]; - for i in 0..7 { - result[i] = self.0[i] - other.0[i]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] - other.0[i]; } Self(result) } @@ -490,8 +490,8 @@ impl Sub for &SepticExtension { fn sub(self, other: Self) -> SepticExtension { let mut result = [F::ZERO; 7]; - for i in 0..7 { - result[i] = self.0[i] - other.0[i]; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] - other.0[i]; } SepticExtension(result) } @@ -529,8 +529,8 @@ impl Mul for &SepticExtension { fn mul(self, other: F) -> Self::Output { let mut result = [F::ZERO; 7]; - for i in 0..7 { - result[i] = self.0[i] * other; + for (i, res) in result.iter_mut().enumerate() { + *res = self.0[i] * other; } SepticExtension(result) } @@ -787,11 +787,7 @@ impl SepticPoint { } pub fn from_affine(x: SepticExtension, y: SepticExtension) -> Self { - let is_infinity = if x.is_zero() && y.is_zero() { - true - } else { - false - }; + let is_infinity = x.is_zero() && y.is_zero(); Self { x, y, is_infinity } } @@ -907,9 +903,6 @@ impl SepticPoint { impl SepticPoint { pub fn random(mut rng: impl RngCore) -> Self { - let b: SepticExtension = [0, 0, 0, 0, 0, 26, 0].into(); - let a: F = F::from_canonical_u32(2); - loop { let x = SepticExtension::random(&mut rng); if let Some(point) = Self::from_x(x) { diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 99aeb9648..bcf4a0263 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -201,6 +201,7 @@ fn test_rw_lk_expression_combination() { num_read_instances: num_instances, num_write_instances: num_instances, num_instances, + has_ecc_ops: false, }; let (proof, _, _) = prover .create_chip_proof( diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 52887f9a2..419d499e9 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -275,7 +275,7 @@ pub fn infer_tower_product_witness( assert!(num_product_fanin.is_power_of_two()); let log2_num_product_fanin = log2_strict_usize(num_product_fanin); - assert!(num_vars % log2_num_product_fanin == 0); + assert!(num_vars.is_multiple_of(log2_num_product_fanin)); assert!( last_layer .iter() @@ -378,7 +378,7 @@ pub fn infer_septic_sum_witness( outputs.set_len(SEPTIC_JACOBIAN_NUM_MLES * 2 * output_len); } - (0..2).into_iter().for_each(|chunk| { + (0..2).for_each(|chunk| { (0..output_len) .into_par_iter() .with_min_len(MIN_PAR_SIZE) diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 0d703beb3..753b1ced2 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -23,7 +23,6 @@ use crate::{ use gkr_iop::{ gkr::GKRClaims, selector::{SelectorContext, SelectorType}, - utils::eq_eval_less_or_equal_than, }; use itertools::{Itertools, chain, interleave, izip}; use mpcs::{Point, PolynomialCommitmentScheme}; @@ -372,9 +371,22 @@ impl> ZKVMVerifier let num_batched = r_counts_per_instance + w_counts_per_instance + lk_counts_per_instance; let next_pow2_instance = next_pow2_instance_padding(num_instances); - let log2_num_instances = ceil_log2(next_pow2_instance); + let mut log2_num_instances = ceil_log2(next_pow2_instance); + if composed_cs.has_ecc_ops() { + // for opcode circuit with ecc ops, the mles have one extra variable + // to store the internal partial sums for ecc additions + log2_num_instances += 1; + } + println!("{log2_num_instances}"); let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0); + // verify ecc proof if exists + if composed_cs.has_ecc_ops() { + assert!(proof.ecc_proof.is_some()); + let ecc_proof = proof.ecc_proof.as_ref().unwrap(); + EccVerifier::new().verify_ecc_proof(ecc_proof, transcript)?; + } + // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; @@ -1003,6 +1015,7 @@ impl TowerVerify { } } +#[derive(Default)] pub struct EccVerifier; impl EccVerifier { @@ -1034,38 +1047,24 @@ impl EccVerifier { transcript, ); - let s0: SepticExtension = proof.evals[2..][0..][..SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - let x0: SepticExtension = proof.evals[2..][SEPTIC_EXTENSION_DEGREE..] - [..SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - let y0: SepticExtension = proof.evals[2..][2 * SEPTIC_EXTENSION_DEGREE..] - [..SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - let x1: SepticExtension = proof.evals[2..][3 * SEPTIC_EXTENSION_DEGREE..] - [..SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - let y1: SepticExtension = proof.evals[2..][4 * SEPTIC_EXTENSION_DEGREE..] - [..SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - let x3: SepticExtension = proof.evals[2..][5 * SEPTIC_EXTENSION_DEGREE..] - [..SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); - let y3: SepticExtension = proof.evals[2..][6 * SEPTIC_EXTENSION_DEGREE..] - [..SEPTIC_EXTENSION_DEGREE] - .try_into() - .unwrap(); + let s0: SepticExtension = proof.evals[2..][0..][..SEPTIC_EXTENSION_DEGREE].into(); + let x0: SepticExtension = + proof.evals[2..][SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let y0: SepticExtension = + proof.evals[2..][2 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let x1: SepticExtension = + proof.evals[2..][3 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let y1: SepticExtension = + proof.evals[2..][4 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let x3: SepticExtension = + proof.evals[2..][5 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); + let y3: SepticExtension = + proof.evals[2..][6 * SEPTIC_EXTENSION_DEGREE..][..SEPTIC_EXTENSION_DEGREE].into(); let rt = sumcheck_claim .point .iter() - .map(|c| c.elements.clone()) + .map(|c| c.elements) .collect_vec(); // zerocheck: 0 = s[0,b] * (x[b,0] - x[b,1]) - (y[b,0] - y[b,1]) diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index e400d3fdb..8296f9f76 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -33,6 +33,7 @@ pub struct EccQuarkProof { pub zerocheck_proof: IOPProof, pub num_instances: usize, pub evals: Vec, // x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[0,rt], y[0,rt], s[0,rt] + pub rt: Point, pub sum: SepticPoint, } @@ -137,6 +138,10 @@ impl ComposedConstrainSystem { self.zkvm_v1_css.w_expressions.len() + self.zkvm_v1_css.w_table_expressions.len() } + pub fn has_ecc_ops(&self) -> bool { + !self.zkvm_v1_css.ec_final_sum.is_empty() + } + pub fn instance_name_map(&self) -> &HashMap { &self.zkvm_v1_css.instance_name_map } diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 9e08b8b70..1364d1e6e 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -425,8 +425,8 @@ impl ConstraintSystem { assert_eq!(final_sum.len(), 7 * 2); assert_eq!(self.ec_point_exprs.len(), 0); - self.ec_point_exprs.extend(xs.into_iter()); - self.ec_point_exprs.extend(ys.into_iter()); + self.ec_point_exprs.extend(xs); + self.ec_point_exprs.extend(ys); self.ec_slope_exprs = slopes; self.ec_final_sum = final_sum; diff --git a/gkr_iop/src/cpu/mod.rs b/gkr_iop/src/cpu/mod.rs index 2e22fc5fc..32ccb77a0 100644 --- a/gkr_iop/src/cpu/mod.rs +++ b/gkr_iop/src/cpu/mod.rs @@ -4,7 +4,7 @@ use crate::{ hal::{MultilinearPolynomial, ProtocolWitnessGeneratorProver, ProverBackend, ProverDevice}, }; use ff_ext::ExtensionField; -use itertools::{Itertools, izip}; +use itertools::izip; use mpcs::{PolynomialCommitmentScheme, SecurityLevel, SecurityLevel::Conjecture100bits}; use multilinear_extensions::{ mle::{ArcMultilinearExtension, MultilinearExtension, Point}, diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index 0f38885be..d9f13a2a9 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -27,7 +27,7 @@ use crate::{ }, }, hal::{ProverBackend, ProverDevice}, - selector::{self, SelectorContext, SelectorType}, + selector::{SelectorContext, SelectorType}, utils::rotation_selector_eval, }; From 2d8d643565cb741f8f8565f2d590cbb6b5021abd Mon Sep 17 00:00:00 2001 From: Ming Date: Tue, 28 Oct 2025 15:26:55 +0800 Subject: [PATCH 83/91] chores: fix quick error in e2e (#1096) --- ceno_zkvm/src/e2e.rs | 2 +- ceno_zkvm/src/scheme/verifier.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 52b4eb035..167e35364 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -280,7 +280,7 @@ impl<'a> ShardContext<'a> { #[inline(always)] pub fn aligned_prev_ts(&self, prev_cycle: Cycle) -> Cycle { - let mut ts = prev_cycle - self.current_shard_offset_cycle(); + let mut ts = prev_cycle.saturating_sub(self.current_shard_offset_cycle()); if ts < Tracer::SUBCYCLES_PER_INSN { ts = 0 } diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 88aa6a1b5..2dfc21c8a 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -75,12 +75,12 @@ impl> ZKVMVerifier expect_halt: bool, ) -> Result { // require ecall/halt proof to exist, depending whether we expect a halt. - let has_halt = vm_proof.has_halt(&self.vk); - if has_halt != expect_halt { - return Err(ZKVMError::VerifyError( - format!("ecall/halt mismatch: expected {expect_halt} != {has_halt}",).into(), - )); - } + // let has_halt = vm_proof.has_halt(&self.vk); + // if has_halt != expect_halt { + // return Err(ZKVMError::VerifyError( + // format!("ecall/halt mismatch: expected {expect_halt} != {has_halt}",).into(), + // )); + // } self.verify_proof_validity(vm_proof, transcript) } From 9676da8124814cda40ec7f0a8fede67de478710f Mon Sep 17 00:00:00 2001 From: Ming Date: Wed, 29 Oct 2025 17:24:26 +0800 Subject: [PATCH 84/91] rw record ramtype expression (#1098) --- ceno_zkvm/src/instructions/global.rs | 15 ++++---- gkr_iop/src/circuit_builder.rs | 56 +++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 4a76bad78..f4a43b5f3 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -188,7 +188,7 @@ impl GlobalConfig { let mut record = vec![]; record.push(addr.expr()); - record.push(ram_type); + record.push(ram_type.clone()); record.extend(value.memory_expr()); record.push(local_clk.expr()); @@ -205,16 +205,17 @@ impl GlobalConfig { local_clk.expr(), )?; // TODO: enforce shard = shard_id in the public values - - cb.read_record( + cb.read_rlc_record( || "r_record", - gkr_iop::RAMType::Memory, // FIXME: should be dynamic, either Register or Memory + ram_type.clone(), record.clone(), + cb.rlc_chip_record(record.clone()), )?; - cb.write_record( + cb.write_rlc_record( || "w_record", - gkr_iop::RAMType::Memory, // FIXME: should be dynamic, either Register or Memory + ram_type, record.clone(), + cb.rlc_chip_record(record), )?; // enforces final_sum = \sum_i (x_i, y_i) using ecc quark protocol @@ -441,7 +442,7 @@ impl // this is workaround, as call `construct_circuit` will not initialized selector // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` - assert!(num_structural_witin == 3); + assert_eq!(num_structural_witin, 3); let selector_r_witin = WitIn { id: 0 }; let selector_w_witin = WitIn { id: 1 }; let selector_zero_witin = WitIn { id: 2 }; diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 9d1d153b5..70de7f171 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -419,12 +419,22 @@ impl ConstraintSystem { record: Vec>, ) -> Result<(), CircuitBuilderError> { let rlc_record = self.rlc_chip_record(record.clone()); + self.read_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) + } + + pub fn read_rlc_record, N: FnOnce() -> NR>( + &mut self, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> { self.r_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); self.r_expressions_namespace_map.push(path); // Since r_expression is RLC(record) and when we're debugging // it's helpful to recover the value of record itself. - self.r_ram_types.push(((ram_type as u64).into(), record)); + self.r_ram_types.push((ram_type, record)); Ok(()) } @@ -435,10 +445,22 @@ impl ConstraintSystem { record: Vec>, ) -> Result<(), CircuitBuilderError> { let rlc_record = self.rlc_chip_record(record.clone()); + self.write_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) + } + + pub fn write_rlc_record, N: FnOnce() -> NR>( + &mut self, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> { self.w_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); self.w_expressions_namespace_map.push(path); - self.w_ram_types.push(((ram_type as u64).into(), record)); + // Since w_expression is RLC(record) and when we're debugging + // it's helpful to recover the value of record itself. + self.w_ram_types.push((ram_type, record)); Ok(()) } @@ -696,6 +718,21 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.read_record(name_fn, ram_type, record) } + pub fn read_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .read_rlc_record(name_fn, ram_type, record, rlc_record) + } + pub fn write_record( &mut self, name_fn: N, @@ -709,6 +746,21 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.write_record(name_fn, ram_type, record) } + pub fn write_rlc_record( + &mut self, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> + where + NR: Into, + N: FnOnce() -> NR, + { + self.cs + .write_rlc_record(name_fn, ram_type, record, rlc_record) + } + pub fn rlc_chip_record(&self, records: Vec>) -> Expression { self.cs.rlc_chip_record(records) } From 2d5ed663b4c4f69fddaf25582aef80908527038f Mon Sep 17 00:00:00 2001 From: xkx Date: Thu, 30 Oct 2025 10:09:43 +0800 Subject: [PATCH 85/91] Feat: integrate `Global` chip into e2e workflow (#1099) --- Cargo.lock | 200 +---------- Cargo.toml | 22 +- ceno_zkvm/Cargo.toml | 1 - ceno_zkvm/benches/riscv_add.rs | 3 + ceno_zkvm/src/e2e.rs | 3 +- ceno_zkvm/src/gadgets/mod.rs | 3 - ceno_zkvm/src/gadgets/poseidon2.rs | 49 ++- ceno_zkvm/src/gadgets/poseidon2_constants.rs | 60 ---- ceno_zkvm/src/instructions/global.rs | 316 ++++++++---------- .../src/instructions/riscv/rv32im/mmu.rs | 54 ++- ceno_zkvm/src/scheme/verifier.rs | 27 +- ceno_zkvm/src/structs.rs | 6 +- 12 files changed, 289 insertions(+), 455 deletions(-) delete mode 100644 ceno_zkvm/src/gadgets/poseidon2_constants.rs diff --git a/Cargo.lock b/Cargo.lock index 1145ce390..2c8d916a7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -717,26 +717,6 @@ dependencies = [ "wyz", ] -[[package]] -name = "blake2" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" -dependencies = [ - "digest 0.10.7", -] - -[[package]] -name = "blake2b_simd" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06e903a20b159e944f91ec8499fe1e55651480c541ea0a584f5d967c49ad9d99" -dependencies = [ - "arrayref", - "arrayvec", - "constant_time_eq", -] - [[package]] name = "block-buffer" version = "0.9.0" @@ -755,19 +735,6 @@ dependencies = [ "generic-array 0.14.7", ] -[[package]] -name = "bls12_381" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3c196a77437e7cc2fb515ce413a6401291578b5afc8ecb29a3c7ab957f05941" -dependencies = [ - "ff 0.12.1", - "group 0.12.1", - "pairing", - "rand_core 0.6.4", - "subtle", -] - [[package]] name = "blst" version = "0.3.16" @@ -1104,7 +1071,6 @@ dependencies = [ "transcript", "whir", "witness", - "zkhash", ] [[package]] @@ -1249,12 +1215,6 @@ dependencies = [ "unicode-xid", ] -[[package]] -name = "constant_time_eq" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" - [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1830,9 +1790,9 @@ dependencies = [ "base16ct", "crypto-bigint", "digest 0.10.7", - "ff 0.13.1", + "ff", "generic-array 0.14.7", - "group 0.13.0", + "group", "hkdf", "pem-rfc7468", "pkcs8", @@ -1932,24 +1892,12 @@ dependencies = [ "bytes", ] -[[package]] -name = "ff" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d013fc25338cc558c5c2cfbad646908fb23591e2404481826742b651c9af7160" -dependencies = [ - "bitvec", - "rand_core 0.6.4", - "subtle", -] - [[package]] name = "ff" version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" dependencies = [ - "bitvec", "rand_core 0.6.4", "subtle", ] @@ -1957,7 +1905,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "once_cell", "p3", @@ -2148,25 +2096,13 @@ dependencies = [ "windows-sys 0.60.2", ] -[[package]] -name = "group" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfbfb3a6cfbd390d5c9564ab283a0349b9b9fcd46a706c1eb10e0db70bfbac7" -dependencies = [ - "ff 0.12.1", - "memuse", - "rand_core 0.6.4", - "subtle", -] - [[package]] name = "group" version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" dependencies = [ - "ff 0.13.1", + "ff", "rand_core 0.6.4", "subtle", ] @@ -2181,29 +2117,6 @@ dependencies = [ "crunchy", ] -[[package]] -name = "halo2" -version = "0.1.0-beta.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a23c779b38253fe1538102da44ad5bd5378495a61d2c4ee18d64eaa61ae5995" -dependencies = [ - "halo2_proofs", -] - -[[package]] -name = "halo2_proofs" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e925780549adee8364c7f2b685c753f6f3df23bde520c67416e93bf615933760" -dependencies = [ - "blake2b_simd", - "ff 0.12.1", - "group 0.12.1", - "pasta_curves 0.4.1", - "rand_core 0.6.4", - "rayon", -] - [[package]] name = "hashbrown" version = "0.12.3" @@ -2572,20 +2485,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "jubjub" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a575df5f985fe1cd5b2b05664ff6accfc46559032b954529fd225a2168d27b0f" -dependencies = [ - "bitvec", - "bls12_381", - "ff 0.12.1", - "group 0.12.1", - "rand_core 0.6.4", - "subtle", -] - [[package]] name = "k256" version = "0.13.4" @@ -2806,12 +2705,6 @@ dependencies = [ "libc", ] -[[package]] -name = "memuse" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d97bbf43eb4f088f8ca469930cde17fa036207c9a5e02ccc5107c4e8b17c964" - [[package]] name = "miniz_oxide" version = "0.8.8" @@ -2824,7 +2717,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "bincode", "clap", @@ -2848,7 +2741,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "either", "ff_ext", @@ -3169,7 +3062,7 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "p3-air", "p3-baby-bear", @@ -3449,15 +3342,6 @@ dependencies = [ "serde", ] -[[package]] -name = "pairing" -version = "0.22.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "135590d8bdba2b31346f9cd1fb2a912329f5135e832a4f422942eb6ead8b6b3b" -dependencies = [ - "group 0.12.1", -] - [[package]] name = "parity-scale-codec" version = "3.7.5" @@ -3515,36 +3399,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "487f2ccd1e17ce8c1bfab3a65c89525af41cfad4c8659021a1e9a2aacd73b89b" -[[package]] -name = "pasta_curves" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cc65faf8e7313b4b1fbaa9f7ca917a0eed499a9663be71477f87993604341d8" -dependencies = [ - "blake2b_simd", - "ff 0.12.1", - "group 0.12.1", - "lazy_static", - "rand 0.8.5", - "static_assertions", - "subtle", -] - -[[package]] -name = "pasta_curves" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3e57598f73cc7e1b2ac63c79c517b31a0877cd7c402cdcaa311b5208de7a095" -dependencies = [ - "blake2b_simd", - "ff 0.13.1", - "group 0.13.0", - "lazy_static", - "rand 0.8.5", - "static_assertions", - "subtle", -] - [[package]] name = "paste" version = "1.0.15" @@ -3645,7 +3499,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "ff_ext", "p3", @@ -4629,7 +4483,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "cfg-if", "dashu", @@ -4751,7 +4605,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "either", "ff_ext", @@ -4769,7 +4623,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "itertools 0.13.0", "p3", @@ -5164,7 +5018,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -5436,7 +5290,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "bincode", "clap", @@ -5723,7 +5577,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#c2580ace319e01bc8657dc92a6b5775348ce3133" +source = "git+https://github.com/scroll-tech/gkr-backend.git?branch=chore%2Fsw_curve_default#b155b40356762733fac4264444a5e8cef323607e" dependencies = [ "ff_ext", "multilinear_extensions", @@ -5880,29 +5734,3 @@ dependencies = [ "quote", "syn 2.0.101", ] - -[[package]] -name = "zkhash" -version = "0.2.0" -source = "git+https://github.com/HorizenLabs/poseidon2.git?rev=bb476b9#bb476b9ca38198cf5092487283c8b8c5d4317c4e" -dependencies = [ - "ark-ff 0.4.2", - "ark-std 0.4.0", - "bitvec", - "blake2", - "bls12_381", - "byteorder", - "cfg-if", - "group 0.12.1", - "group 0.13.0", - "halo2", - "hex", - "jubjub", - "lazy_static", - "pasta_curves 0.5.1", - "rand 0.8.5", - "serde", - "sha2 0.10.9", - "sha3", - "subtle", -] diff --git a/Cargo.toml b/Cargo.toml index 9cd889535..c5d5b3784 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,7 +80,6 @@ tracing = { version = "0.1", features = [ tracing-forest = { version = "0.1.6" } tracing-subscriber = { version = "0.3", features = ["env-filter"] } uint = "0.8" -zkhash = { git = "https://github.com/HorizenLabs/poseidon2.git", rev = "bb476b9" } lazy_static = "1.5.0" ceno_gpu = { path = "utils/cuda_hal", package = "cuda_hal" } @@ -101,13 +100,14 @@ lto = "thin" # [patch."ssh://git@github.com/scroll-tech/ceno-gpu.git"] # ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal" } -#[patch."https://github.com/scroll-tech/gkr-backend"] -#ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } -#mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } -#multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } -#p3 = { path = "../gkr-backend/crates/p3", package = "p3" } -#poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } -#sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } -#transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } -#whir = { path = "../gkr-backend/crates/whir", package = "whir" } -#witness = { path = "../gkr-backend/crates/witness", package = "witness" } +# [patch."https://github.com/scroll-tech/gkr-backend"] +# ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } +# mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } +# multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } +# p3 = { path = "../gkr-backend/crates/p3", package = "p3" } +# poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } +# sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" } +# sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } +# transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } +# whir = { path = "../gkr-backend/crates/whir", package = "whir" } +# witness = { path = "../gkr-backend/crates/witness", package = "witness" } diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 481d92024..80d69111b 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -30,7 +30,6 @@ sumcheck.workspace = true transcript.workspace = true whir.workspace = true witness.workspace = true -zkhash.workspace = true lazy_static.workspace = true itertools.workspace = true diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 028748058..b438244a3 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -111,7 +111,10 @@ fn bench_add(c: &mut Criterion) { witness: polys, structural_witness: vec![], public_input: vec![], + num_read_instances: num_instances, + num_write_instances: num_instances, num_instances, + has_ecc_ops: false, }; let _ = prover .create_chip_proof( diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 167e35364..c490bb597 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -3,6 +3,7 @@ use crate::{ instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig}, scheme::{ PublicValues, ZKVMProof, + constants::SEPTIC_EXTENSION_DEGREE, hal::ProverDevice, mock_prover::{LkMultiplicityKey, MockProver}, prover::ZKVMProver, @@ -433,7 +434,7 @@ pub fn emulate_program<'a>( end_cycle, shards.shard_id as u32, io_init.iter().map(|rec| rec.value).collect_vec(), - vec![0; 14], // point_at_infinity + vec![0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity ); // Find the final register values and cycles. diff --git a/ceno_zkvm/src/gadgets/mod.rs b/ceno_zkvm/src/gadgets/mod.rs index d0d8ed67d..a4d624568 100644 --- a/ceno_zkvm/src/gadgets/mod.rs +++ b/ceno_zkvm/src/gadgets/mod.rs @@ -2,7 +2,6 @@ mod div; mod field; mod is_lt; mod poseidon2; -mod poseidon2_constants; mod signed; mod signed_ext; mod signed_limbs; @@ -15,9 +14,7 @@ pub use gkr_iop::gadgets::{ AssertLtConfig, InnerLtConfig, IsEqualConfig, IsLtConfig, IsZeroConfig, cal_lt_diff, }; pub use is_lt::{AssertSignedLtConfig, SignedLtConfig}; -pub(crate) use poseidon2::RoundConstants; pub use poseidon2::{Poseidon2BabyBearConfig, Poseidon2Config}; -pub(crate) use poseidon2_constants::horizen_round_consts; pub use signed::Signed; pub use signed_ext::SignedExtendConfig; pub use signed_limbs::{UIntLimbsLT, UIntLimbsLTConfig}; diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index e59817bf3..c4e7b621c 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -35,6 +35,42 @@ pub struct RoundConstants< pub ending_full_round_constants: [[F; WIDTH]; HALF_FULL_ROUNDS], } +impl + From> for RoundConstants +{ + fn from(value: Vec) -> Self { + let mut iter = value.into_iter(); + let mut beginning_full_round_constants = [[F::ZERO; WIDTH]; HALF_FULL_ROUNDS]; + for round in 0..HALF_FULL_ROUNDS { + for i in 0..WIDTH { + beginning_full_round_constants[round][i] = + iter.next().expect("insufficient round constants"); + } + } + + let mut partial_round_constants = [F::ZERO; PARTIAL_ROUNDS]; + for round in 0..PARTIAL_ROUNDS { + partial_round_constants[round] = iter.next().expect("insufficient round constants"); + } + + let mut ending_full_round_constants = [[F::ZERO; WIDTH]; HALF_FULL_ROUNDS]; + for round in 0..HALF_FULL_ROUNDS { + for i in 0..WIDTH { + ending_full_round_constants[round][i] = + iter.next().expect("insufficient round constants"); + } + } + + assert!(iter.next().is_none(), "round constants are too many"); + + RoundConstants { + beginning_full_round_constants, + partial_round_constants, + ending_full_round_constants, + } + } +} + pub type Poseidon2BabyBearConfig = Poseidon2Config; pub struct Poseidon2Config< E: ExtensionField, @@ -471,19 +507,20 @@ fn generate_sbox( #[cfg(test)] mod tests { - use crate::gadgets::{ - poseidon2::Poseidon2BabyBearConfig, poseidon2_constants::horizen_round_consts, - }; - use ff_ext::BabyBearExt4; + use crate::gadgets::poseidon2::Poseidon2BabyBearConfig; + use ff_ext::{BabyBearExt4, PoseidonField}; use gkr_iop::circuit_builder::{CircuitBuilder, ConstraintSystem}; + use p3::babybear::BabyBear; type E = BabyBearExt4; + type F = BabyBear; #[test] fn test_poseidon2_gadget() { let mut cs = ConstraintSystem::new(|| "poseidon2 gadget test"); let mut cb = CircuitBuilder::::new(&mut cs); - let poseidon2_constants = horizen_round_consts(); - let _ = Poseidon2BabyBearConfig::construct(&mut cb, poseidon2_constants); + // let poseidon2_constants = horizen_round_consts(); + let rc = ::get_default_perm_rc().into(); + let _ = Poseidon2BabyBearConfig::construct(&mut cb, rc); } } diff --git a/ceno_zkvm/src/gadgets/poseidon2_constants.rs b/ceno_zkvm/src/gadgets/poseidon2_constants.rs deleted file mode 100644 index 85ae2a1ae..000000000 --- a/ceno_zkvm/src/gadgets/poseidon2_constants.rs +++ /dev/null @@ -1,60 +0,0 @@ -// taken from openvm/crates/circuits/poseidon2-air/src/babybear.rs -use super::poseidon2::RoundConstants; -use lazy_static::lazy_static; -use p3::{babybear::BabyBear, field::FieldAlgebra}; -use std::array::from_fn; -use zkhash::{ - ark_ff::PrimeField as _, fields::babybear::FpBabyBear as HorizenBabyBear, - poseidon2::poseidon2_instance_babybear::RC16, -}; - -const BABY_BEAR_POSEIDON2_WIDTH: usize = 16; -const BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS: usize = 4; -const BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS: usize = 13; - -#[allow(dead_code)] -pub(crate) fn horizen_to_p3_babybear(horizen_babybear: HorizenBabyBear) -> BabyBear { - BabyBear::from_canonical_u64(horizen_babybear.into_bigint().0[0]) -} - -#[allow(dead_code)] -pub(crate) fn horizen_round_consts() -> RoundConstants< - BabyBear, - BABY_BEAR_POSEIDON2_WIDTH, - BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS, - BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS, -> { - let p3_rc16: Vec> = RC16 - .iter() - .map(|round| { - round - .iter() - .map(|babybear| horizen_to_p3_babybear(*babybear)) - .collect() - }) - .collect(); - let p_end = BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS + BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS; - - let beginning_full_round_constants: [[BabyBear; BABY_BEAR_POSEIDON2_WIDTH]; - BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS] = from_fn(|i| p3_rc16[i].clone().try_into().unwrap()); - let partial_round_constants: [BabyBear; BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS] = - from_fn(|i| p3_rc16[i + BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS][0]); - let ending_full_round_constants: [[BabyBear; BABY_BEAR_POSEIDON2_WIDTH]; - BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS] = - from_fn(|i| p3_rc16[i + p_end].clone().try_into().unwrap()); - - RoundConstants { - beginning_full_round_constants, - partial_round_constants, - ending_full_round_constants, - } -} - -lazy_static! { - pub static ref BABYBEAR_BEGIN_EXT_CONSTS: [[BabyBear; BABY_BEAR_POSEIDON2_WIDTH]; BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS] = - horizen_round_consts().beginning_full_round_constants; - pub static ref BABYBEAR_PARTIAL_CONSTS: [BabyBear; BABY_BEAR_POSEIDON2_PARTIAL_ROUNDS] = - horizen_round_consts().partial_round_constants; - pub static ref BABYBEAR_END_EXT_CONSTS: [[BabyBear; BABY_BEAR_POSEIDON2_WIDTH]; BABY_BEAR_POSEIDON2_HALF_FULL_ROUNDS] = - horizen_round_consts().ending_full_round_constants; -} diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index f4a43b5f3..2bad3fbbf 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -1,21 +1,23 @@ -use std::iter::repeat_n; +use std::{collections::HashMap, iter::repeat_n, marker::PhantomData}; use crate::{ Value, chip_handler::general::PublicIOQuery, - e2e::ShardContext, error::ZKVMError, - gadgets::{Poseidon2Config, RoundConstants}, + gadgets::Poseidon2Config, instructions::riscv::constants::UINT_LIMBS, scheme::septic_curve::{SepticExtension, SepticPoint}, structs::{ProgramParams, RAMType}, - tables::RMMCollections, + tables::{RMMCollections, TableCircuit}, witness::LkMultiplicity, }; -use ff_ext::{ExtensionField, FieldInto, POSEIDON2_BABYBEAR_WIDTH, SmallField}; +use ff_ext::{ExtensionField, FieldInto, PoseidonField, SmallField}; use gkr_iop::{ - chip::Chip, circuit_builder::CircuitBuilder, error::CircuitBuilderError, gkr::layer::Layer, - selector::SelectorType, utils::lk_multiplicity::Multiplicity, + chip::Chip, + circuit_builder::CircuitBuilder, + error::CircuitBuilderError, + gkr::{GKRCircuit, layer::Layer}, + selector::SelectorType, }; use itertools::{Itertools, chain}; use multilinear_extensions::{ @@ -34,10 +36,7 @@ use rayon::{ use std::ops::Deref; use witness::{InstancePaddingStrategy, next_pow2_instance_padding, set_val}; -use crate::{ - instructions::{Instruction, riscv::constants::UInt}, - scheme::constants::SEPTIC_EXTENSION_DEGREE, -}; +use crate::{instructions::riscv::constants::UInt, scheme::constants::SEPTIC_EXTENSION_DEGREE}; /// A record for a read/write into the global set #[derive(Default, Debug, Clone)] @@ -60,15 +59,12 @@ pub struct GlobalPoint { } impl GlobalRecord { - pub fn to_ec_point< - E: ExtensionField, - P: Permutation<[E::BaseField; POSEIDON2_BABYBEAR_WIDTH]>, - >( + pub fn to_ec_point>>( &self, hasher: &P, ) -> GlobalPoint { let mut nonce = 0; - let mut input = [ + let mut input = vec![ E::BaseField::from_canonical_u32(self.addr), E::BaseField::from_canonical_u32(self.ram_type as u32), E::BaseField::from_canonical_u32(self.value & 0xFFFF), // lower 16 bits @@ -90,7 +86,7 @@ impl GlobalRecord { let prime = E::BaseField::order().to_u64_digits()[0]; loop { let x: SepticExtension = - hasher.permute(input)[0..SEPTIC_EXTENSION_DEGREE].into(); + hasher.permute(input.clone())[0..SEPTIC_EXTENSION_DEGREE].into(); if let Some(p) = SepticPoint::from_x(x) { let y6 = (p.y.0)[SEPTIC_EXTENSION_DEGREE - 1].to_canonical_u64(); let is_y_in_2nd_half = y6 >= (prime / 2); @@ -125,7 +121,7 @@ impl GlobalRecord { /// the local finalize circuit will consume it by inserting a local read record. /// 3. For a local memory write record which will be read in the future, /// the global chip will insert a local read record and write it to the **global set**. -pub struct GlobalConfig { +pub struct GlobalConfig { addr: WitIn, is_ram_register: WitIn, value: UInt, @@ -141,16 +137,11 @@ pub struct GlobalConfig { y: Vec, slope: Vec, perm_config: Poseidon2Config, - perm: P, } -impl GlobalConfig { +impl GlobalConfig { // TODO: make `WIDTH`, `HALF_FULL_ROUNDS`, `PARTIAL_ROUNDS` generic parameters - pub fn configure( - cb: &mut CircuitBuilder, - rc: RoundConstants, - perm: P, - ) -> Result { + pub fn configure(cb: &mut CircuitBuilder) -> Result { let x: Vec = (0..SEPTIC_EXTENSION_DEGREE) .map(|i| cb.create_witin(|| format!("x{}", i))) .collect(); @@ -173,6 +164,8 @@ impl GlobalConfig { let reg: Expression = RAMType::Register.into(); let mem: Expression = RAMType::Memory.into(); let ram_type: Expression = is_ram_reg.clone() * reg + (1 - is_ram_reg) * mem; + + let rc = ::get_default_perm_rc().into(); let perm_config = Poseidon2Config::construct(cb, rc); let mut input = vec![]; @@ -256,50 +249,103 @@ impl GlobalConfig { nonce, is_global_write, perm_config, - perm, }) } } /// This chip is used to manage read/write into a global set /// shared among multiple shards -pub struct GlobalChip { - rc: RoundConstants, - perm: P, +#[derive(Default)] +pub struct GlobalChip { + _marker: PhantomData, } #[derive(Clone, Debug)] pub struct GlobalChipInput { pub record: GlobalRecord, - pub ec_point: Option>, // to be filled during instance assignment + pub ec_point: GlobalPoint, +} + +impl GlobalChip { + fn assign_instance<'a>( + config: &GlobalConfig, + instance: &mut [E::BaseField], + _lk_multiplicity: &mut LkMultiplicity, + input: &GlobalChipInput, + ) -> Result<(), crate::error::ZKVMError> { + // assign basic fields + let record = &input.record; + let is_ram_register = match record.ram_type { + RAMType::Register => 1, + RAMType::Memory => 0, + _ => unreachable!(), + }; + set_val!(instance, config.addr, record.addr as u64); + set_val!(instance, config.is_ram_register, is_ram_register as u64); + let value = Value::new_unchecked(record.value); + config.value.assign_limbs(instance, value.as_u16_limbs()); + set_val!(instance, config.shard, record.shard); + set_val!(instance, config.global_clk, record.global_clk); + set_val!(instance, config.local_clk, record.local_clk); + set_val!(instance, config.is_global_write, record.is_write as u64); + + // assign (x, y) and nonce + let GlobalPoint { nonce, point } = &input.ec_point; + set_val!(instance, config.nonce, *nonce as u64); + config + .x + .iter() + .chain(config.y.iter()) + .zip_eq((point.x.deref()).iter().chain((point.y.deref()).iter())) + .for_each(|(witin, fe)| { + instance[witin.id as usize] = *fe; + }); + + let ram_type = E::BaseField::from_canonical_u32(record.ram_type as u32); + let mut input = [E::BaseField::ZERO; 16]; + + let k = UINT_LIMBS; + input[0] = E::BaseField::from_canonical_u32(record.addr); + input[1] = ram_type; + input[2..(k + 2)] + .iter_mut() + .zip(value.as_u16_limbs().iter()) + .for_each(|(i, v)| *i = E::BaseField::from_canonical_u16(*v)); + input[2 + k] = E::BaseField::from_canonical_u64(record.shard); + input[2 + k + 1] = E::BaseField::from_canonical_u64(record.global_clk); + input[2 + k + 2] = E::BaseField::from_canonical_u32(*nonce); + + config + .perm_config + // TODO: remove hardcoded constant 28 + .assign_instance(&mut instance[28 + UINT_LIMBS..], input); + + Ok(()) + } } -impl + Send> - Instruction for GlobalChip -{ - type InstructionConfig = GlobalConfig; - type Record = GlobalChipInput; +impl TableCircuit for GlobalChip { + type TableConfig = GlobalConfig; + type FixedInput = (); + type WitnessInput = Vec>; fn name() -> String { "Global".to_string() } fn construct_circuit( - &self, cb: &mut CircuitBuilder, _param: &ProgramParams, - ) -> Result { - let config = GlobalConfig::configure(cb, self.rc.clone(), self.perm.clone())?; + ) -> Result { + let config = GlobalConfig::configure(cb)?; Ok(config) } fn build_gkr_iop_circuit( - &self, cb: &mut CircuitBuilder, param: &ProgramParams, - ) -> Result<(Self::InstructionConfig, gkr_iop::gkr::GKRCircuit), crate::error::ZKVMError> - { + ) -> Result<(Self::TableConfig, Option>), crate::error::ZKVMError> { // create three selectors: selector_r, selector_w, selector_zero let selector_r = cb.create_structural_witin( || "selector_r", @@ -332,7 +378,7 @@ impl }, ); - let config = self.construct_circuit(cb, param)?; + let config = Self::construct_circuit(cb, param)?; let w_len = cb.cs.w_expressions.len(); let r_len = cb.cs.r_expressions.len(); @@ -371,73 +417,23 @@ impl let layer = Layer::from_circuit_builder(cb, format!("{}_main", Self::name()), 0, out_evals); chip.add_layer(layer); - Ok((config, chip.gkr_circuit())) + Ok((config, Some(chip.gkr_circuit()))) } - fn assign_instance<'a>( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext<'a>, - instance: &mut [E::BaseField], - lk_multiplicity: &mut LkMultiplicity, - input: &Self::Record, - ) -> Result<(), crate::error::ZKVMError> { - // assign basic fields - let record = &input.record; - let is_ram_register = match record.ram_type { - RAMType::Register => 1, - RAMType::Memory => 0, - _ => unreachable!(), - }; - set_val!(instance, config.addr, record.addr as u64); - set_val!(instance, config.is_ram_register, is_ram_register as u64); - let value = Value::new(record.value, lk_multiplicity); - config.value.assign_limbs(instance, value.as_u16_limbs()); - set_val!(instance, config.shard, record.shard); - set_val!(instance, config.global_clk, record.global_clk); - set_val!(instance, config.local_clk, record.local_clk); - set_val!(instance, config.is_global_write, record.is_write as u64); - - // assign (x, y) and nonce - let GlobalPoint { nonce, point } = input.ec_point.as_ref().unwrap(); - set_val!(instance, config.nonce, *nonce as u64); - config - .x - .iter() - .chain(config.y.iter()) - .zip_eq((point.x.deref()).iter().chain((point.y.deref()).iter())) - .for_each(|(witin, fe)| { - instance[witin.id as usize] = *fe; - }); - - let ram_type = E::BaseField::from_canonical_u32(record.ram_type as u32); - let mut input = [E::BaseField::ZERO; 16]; - - let k = UINT_LIMBS; - input[0] = E::BaseField::from_canonical_u32(record.addr); - input[1] = ram_type; - input[2..(k + 2)] - .iter_mut() - .zip(value.as_u16_limbs().iter()) - .for_each(|(i, v)| *i = E::BaseField::from_canonical_u16(*v)); - input[2 + k] = E::BaseField::from_canonical_u64(record.shard); - input[2 + k + 1] = E::BaseField::from_canonical_u64(record.global_clk); - input[2 + k + 2] = E::BaseField::from_canonical_u32(*nonce); - - config - .perm_config - // TODO: remove hardcoded constant 28 - .assign_instance(&mut instance[28 + UINT_LIMBS..], input); - - Ok(()) + fn generate_fixed_traces( + _config: &Self::TableConfig, + _num_fixed: usize, + _input: &Self::FixedInput, + ) -> witness::RowMajorMatrix<::BaseField> { + unimplemented!() } - fn assign_instances<'a>( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext<'a>, + config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - mut steps: Vec, - ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + _multiplicity: &[HashMap], + steps: &Self::WitnessInput, + ) -> Result, ZKVMError> { // FIXME selector is the only structural witness // this is workaround, as call `construct_circuit` will not initialized selector // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` @@ -461,15 +457,6 @@ impl let n = next_pow2_instance_padding(steps.len()); // compute the input for the binary tree for ec point summation - steps - .par_chunks_mut(num_instance_per_batch) - .for_each(|chunk| { - chunk.iter_mut().for_each(|step| { - let point = step.record.to_ec_point::(&config.perm); - - step.ec_point.replace(point); - }); - }); let lk_multiplicity = LkMultiplicity::default(); // *2 because we need to store the internal nodes of binary tree for ec point summation @@ -501,47 +488,37 @@ impl [0..steps.len() * num_structural_witin] .par_chunks_mut(num_instance_per_batch * num_structural_witin); - let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter .zip_eq(raw_structual_witin_iter) .zip_eq(steps.par_chunks(num_instance_per_batch)) - .zip(shard_ctx_vec) .enumerate() - .flat_map( - |(chunk_idx, (((instances, structural_instance), steps), mut shard_ctx))| { - let mut lk_multiplicity = lk_multiplicity.clone(); - instances - .chunks_mut(num_witin) - .zip_eq(structural_instance.chunks_mut(num_structural_witin)) - .zip_eq(steps) - .enumerate() - .map(|(i, ((instance, structural_instance), step))| { - let row = chunk_idx * num_instance_per_batch + i; - let (sel_r, sel_w) = if row < num_local_reads { - (E::BaseField::ONE, E::BaseField::ZERO) - } else { - (E::BaseField::ZERO, E::BaseField::ONE) - }; - set_val!(structural_instance, selector_r_witin, sel_r); - set_val!(structural_instance, selector_w_witin, sel_w); - set_val!(structural_instance, selector_zero_witin, E::BaseField::ONE); - Self::assign_instance( - config, - &mut shard_ctx, - instance, - &mut lk_multiplicity, - step, - ) - }) - .collect::>() - }, - ) + .flat_map(|(chunk_idx, ((instances, structural_instance), steps))| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip_eq(structural_instance.chunks_mut(num_structural_witin)) + .zip_eq(steps) + .enumerate() + .map(|(i, ((instance, structural_instance), step))| { + let row = chunk_idx * num_instance_per_batch + i; + let (sel_r, sel_w) = if row < num_local_reads { + (E::BaseField::ONE, E::BaseField::ZERO) + } else { + (E::BaseField::ZERO, E::BaseField::ONE) + }; + set_val!(structural_instance, selector_r_witin, sel_r); + set_val!(structural_instance, selector_w_witin, sel_w); + set_val!(structural_instance, selector_zero_witin, E::BaseField::ONE); + Self::assign_instance(config, instance, &mut lk_multiplicity, step) + }) + .collect::>() + }) .collect::>()?; // assign internal nodes in the binary tree for ec point summation let mut cur_layer_points = steps .iter() - .map(|step| step.ec_point.as_ref().map(|p| p.point.clone()).unwrap()) + .map(|step| step.ec_point.point.clone()) .enumerate() .collect_vec(); @@ -609,10 +586,7 @@ impl raw_structual_witin, InstancePaddingStrategy::Default, ); - Ok(( - [raw_witin, raw_structual_witin], - lk_multiplicity.into_finalize_result(), - )) + Ok([raw_witin, raw_structual_witin]) } } @@ -631,17 +605,13 @@ mod tests { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - e2e::ShardContext, - gadgets::horizen_round_consts, - instructions::{ - Instruction, - global::{GlobalChip, GlobalChipInput, GlobalRecord}, - }, + instructions::global::{GlobalChip, GlobalChipInput, GlobalRecord}, scheme::{ PublicValues, create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, septic_curve::SepticPoint, verifier::ZKVMVerifier, }, structs::{ComposedConstrainSystem, PointAndEval, ProgramParams, RAMType, ZKVMProvingKey}, + tables::TableCircuit, }; use multilinear_extensions::mle::IntoMLE; use p3::field::PrimeField32; @@ -664,16 +634,13 @@ mod tests { .init(); // init global chip with horizen_rc_consts - let rc = horizen_round_consts(); let perm = ::get_default_perm(); - let global_chip = GlobalChip:: { rc, perm }; let mut cs = ConstraintSystem::new(|| "global chip test"); let mut cb = CircuitBuilder::new(&mut cs); - let (config, gkr_circuit) = global_chip - .build_gkr_iop_circuit(&mut cb, &ProgramParams::default()) - .unwrap(); + let (config, gkr_circuit) = + GlobalChip::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()).unwrap(); // create a bunch of random memory read/write records let n_global_reads = 1700; @@ -712,10 +679,18 @@ mod tests { }) .collect::>(); - let global_ec_sum: SepticPoint = global_reads + let input = global_writes // local reads + .into_iter() + .chain(global_reads) // local writes + .map(|record| { + let ec_point = record.to_ec_point::(&perm); + GlobalChipInput { record, ec_point } + }) + .collect::>(); + + let global_ec_sum: SepticPoint = input .iter() - .chain(global_writes.iter()) - .map(|record| record.to_ec_point::(&global_chip.perm).point) + .map(|record| record.ec_point.point.clone()) .sum(); let public_value = PublicValues::new( @@ -733,27 +708,20 @@ mod tests { .map(|fe| fe.as_canonical_u32()) .collect_vec(), ); - let mut shard_context = ShardContext::default(); + // assign witness - let (witness, _) = GlobalChip::assign_instances( + let witness = GlobalChip::assign_instances( &config, - &mut shard_context, cs.num_witin as usize, cs.num_structural_witin as usize, - global_writes // local reads - .into_iter() - .chain(global_reads) // local writes - .map(|record| GlobalChipInput { - record, - ec_point: None, - }) - .collect::>(), + &[], + &input, ) .unwrap(); let composed_cs = ComposedConstrainSystem { zkvm_v1_css: cs, - gkr_circuit: Some(gkr_circuit), + gkr_circuit, }; let pk = composed_cs.key_gen(); diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 82a8d0c91..6e24f8f10 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -1,17 +1,19 @@ use crate::{ e2e::ShardContext, error::ZKVMError, + instructions::global::{GlobalChip, GlobalChipInput, GlobalRecord}, structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ DynVolatileRamTable, HeapInitCircuit, HeapTable, HintsCircuit, LocalFinalCircuit, - MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RBCircuit, - RegTable, RegTableInitCircuit, StackInitCircuit, StackTable, StaticMemInitCircuit, - StaticMemTable, TableCircuit, + MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RegTable, + RegTableInitCircuit, StackInitCircuit, StackTable, StaticMemInitCircuit, StaticMemTable, + TableCircuit, }, }; use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, PoseidonField}; use itertools::{Itertools, chain}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator as _}; use std::{collections::HashSet, iter::zip, ops::Range, sync::Arc}; use witness::InstancePaddingStrategy; @@ -31,7 +33,7 @@ pub struct MmuConfig<'a, E: ExtensionField> { /// finalized circuit for all MMIO pub local_final_circuit: as TableCircuit>::TableConfig, /// ram bus to deal with cross shard read/write - pub ram_bus_circuit: as TableCircuit>::TableConfig, + pub ram_bus_circuit: as TableCircuit>::TableConfig, pub params: ProgramParams, } @@ -47,7 +49,7 @@ impl MmuConfig<'_, E> { let stack_init_config = cs.register_table_circuit::>(); let heap_init_config = cs.register_table_circuit::>(); let local_final_circuit = cs.register_table_circuit::>(); - let ram_bus_circuit = cs.register_table_circuit::>(); + let ram_bus_circuit = cs.register_table_circuit::>(); Self { reg_init_config, @@ -94,7 +96,7 @@ impl MmuConfig<'_, E> { fixed.register_table_circuit::>(cs, &self.stack_init_config, &()); fixed.register_table_circuit::>(cs, &self.heap_init_config, &()); fixed.register_table_circuit::>(cs, &self.local_final_circuit, &()); - fixed.register_table_circuit::>(cs, &self.ram_bus_circuit, &()); + // fixed.register_table_circuit::>(cs, &self.ram_bus_circuit, &()); } #[allow(clippy::too_many_arguments)] @@ -163,7 +165,43 @@ impl MmuConfig<'_, E> { &(shard_ctx, all_records.as_slice()), )?; - witness.assign_table_circuit::>(cs, &self.ram_bus_circuit, shard_ctx)?; + let perm = ::get_default_perm(); + let global_input = shard_ctx + .read_records() + .par_iter() + .chain(shard_ctx.write_records().par_iter()) + .flat_map_iter(|records| { + records.iter().map(|(vma, record)| { + let addr = match record.ram_type { + gkr_iop::RAMType::Register => record.id as u32, + gkr_iop::RAMType::Memory => (*vma).into(), + _ => unreachable!(), + }; + let (is_write, local_clk, global_clk) = if record.prev_value.is_some() { + // global write + (true, record.cycle, record.cycle) + } else { + (false, 0, record.prev_cycle) + }; + let global_rw = GlobalRecord { + addr, + ram_type: record.ram_type, + value: record.value, + shard: record.prev_cycle, // TODO: extract shard id properly + local_clk: local_clk, + global_clk: global_clk, + is_write: is_write, + }; + + let ec_point = global_rw.to_ec_point(&perm); + GlobalChipInput { + record: global_rw, + ec_point, + } + }) + }) + .collect::>(); + witness.assign_table_circuit::>(cs, &self.ram_bus_circuit, &global_input)?; Ok(()) } diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 2dfc21c8a..1e1d738e6 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -9,7 +9,7 @@ use crate::{ error::ZKVMError, scheme::{ constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE}, - septic_curve::SepticExtension, + septic_curve::{SepticExtension, SepticPoint}, }, structs::{ ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VerifyingKey, @@ -372,13 +372,36 @@ impl> ZKVMVerifier // to store the internal partial sums for ecc additions log2_num_instances += 1; } - println!("{log2_num_instances}"); let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0); // verify ecc proof if exists if composed_cs.has_ecc_ops() { + tracing::debug!("verifying ecc proof..."); assert!(proof.ecc_proof.is_some()); let ecc_proof = proof.ecc_proof.as_ref().unwrap(); + + let xy = cs + .ec_final_sum + .iter() + .map(|expr| { + eval_by_expr_with_instance(&[], &[], &[], pi, challenges, &expr) + .right() + .and_then(|v| v.as_base()) + .unwrap() + }) + .collect_vec(); + let x: SepticExtension = xy[0..SEPTIC_EXTENSION_DEGREE].into(); + let y: SepticExtension = xy[SEPTIC_EXTENSION_DEGREE..].into(); + + assert_eq!( + SepticPoint { + x, + y, + is_infinity: false, + }, + ecc_proof.sum + ); + // assert ec sum in public input matches that in ecc proof EccVerifier::new().verify_ecc_proof(ecc_proof, transcript)?; } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 57202bd87..54ffaeb12 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -7,7 +7,7 @@ use crate::{ state::StateCircuit, tables::{RMMCollections, TableCircuit}, }; -use ceno_emul::{CENO_PLATFORM, Platform, StepRecord}; +use ceno_emul::{CENO_PLATFORM, Platform}; use ff_ext::ExtensionField; use gkr_iop::{gkr::GKRCircuit, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use itertools::Itertools; @@ -336,12 +336,12 @@ impl ZKVMWitnesses { self.lk_mlts.get(name) } - pub fn assign_opcode_circuit>( + pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, shard_ctx: &mut ShardContext, config: &OC::InstructionConfig, - records: Vec, + records: Vec, ) -> Result<(), ZKVMError> { assert!(self.combined_lk_mlt.is_none()); From a8a06a58181e583a2d43dd1bcd3a4c5c03bd9825 Mon Sep 17 00:00:00 2001 From: xkx Date: Fri, 31 Oct 2025 09:02:19 +0800 Subject: [PATCH 86/91] Fix integration bugs (#1102) Co-authored-by: Ming --- ceno_zkvm/src/instructions/global.rs | 45 +++++++- .../src/instructions/riscv/rv32im/mmu.rs | 46 +------- ceno_zkvm/src/scheme.rs | 6 +- ceno_zkvm/src/scheme/cpu/mod.rs | 59 ++++------ ceno_zkvm/src/scheme/gpu/mod.rs | 2 +- ceno_zkvm/src/scheme/hal.rs | 11 +- ceno_zkvm/src/scheme/prover.rs | 61 +++++----- ceno_zkvm/src/scheme/septic_curve.rs | 4 +- ceno_zkvm/src/scheme/tests.rs | 4 +- ceno_zkvm/src/scheme/utils.rs | 2 +- ceno_zkvm/src/scheme/verifier.rs | 79 +++++++------ ceno_zkvm/src/structs.rs | 105 +++++++++++++++++- gkr_iop/src/selector.rs | 18 +-- 13 files changed, 267 insertions(+), 175 deletions(-) diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 2bad3fbbf..8cef4d869 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -3,6 +3,7 @@ use std::{collections::HashMap, iter::repeat_n, marker::PhantomData}; use crate::{ Value, chip_handler::general::PublicIOQuery, + e2e::RAMRecord, error::ZKVMError, gadgets::Poseidon2Config, instructions::riscv::constants::UINT_LIMBS, @@ -11,6 +12,7 @@ use crate::{ tables::{RMMCollections, TableCircuit}, witness::LkMultiplicity, }; +use ceno_emul::WordAddr; use ff_ext::{ExtensionField, FieldInto, PoseidonField, SmallField}; use gkr_iop::{ chip::Chip, @@ -50,6 +52,32 @@ pub struct GlobalRecord { pub is_write: bool, } +impl From<(&WordAddr, &RAMRecord, bool)> for GlobalRecord { + fn from((vma, record, is_write): (&WordAddr, &RAMRecord, bool)) -> Self { + let addr = match record.ram_type { + RAMType::Register => record.id as u32, + RAMType::Memory => (*vma).into(), + _ => unreachable!(), + }; + let value = record.prev_value.map_or(record.value, |v| v); + let (shard, local_clk, global_clk) = if is_write { + // FIXME: extract shard id and local_clk from record.cycle + (record.cycle, record.cycle, record.cycle) + } else { + (record.prev_cycle, 0, record.prev_cycle) + }; + + GlobalRecord { + addr, + ram_type: record.ram_type, + value, + shard, + local_clk, + global_clk, + is_write, + } + } +} /// An EC point corresponding to a global read/write record /// whose x-coordinate is derived from Poseidon2 hash of the record #[derive(Clone, Debug)] @@ -153,7 +181,7 @@ impl GlobalConfig { .collect(); let addr = cb.create_witin(|| "addr"); let is_ram_register = cb.create_witin(|| "is_ram_register"); - let value = UInt::new(|| "value", cb)?; + let value = UInt::new_unchecked(|| "value", cb)?; let shard = cb.create_witin(|| "shard"); let global_clk = cb.create_witin(|| "global_clk"); let local_clk = cb.create_witin(|| "local_clk"); @@ -434,6 +462,12 @@ impl TableCircuit for GlobalChip { _multiplicity: &[HashMap], steps: &Self::WitnessInput, ) -> Result, ZKVMError> { + if steps.is_empty() { + return Ok([ + witness::RowMajorMatrix::empty(), + witness::RowMajorMatrix::empty(), + ]); + } // FIXME selector is the only structural witness // this is workaround, as call `construct_circuit` will not initialized selector // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` @@ -447,6 +481,11 @@ impl TableCircuit for GlobalChip { // local read iff it's global write let num_local_reads = steps.iter().filter(|s| s.record.is_write).count(); + tracing::debug!( + "{} local reads / {} local writes in global chip", + num_local_reads, + steps.len() - num_local_reads + ); let num_instance_per_batch = if steps.len() > 256 { steps.len().div_ceil(nthreads) @@ -746,9 +785,7 @@ mod tests { structural_witness: witness[1].to_mles().into_iter().map(Arc::new).collect(), fixed: vec![], public_input: public_input_mles.clone(), - num_read_instances: n_global_writes as usize, - num_write_instances: n_global_reads as usize, - num_instances: (n_global_reads + n_global_writes) as usize, + num_instances: vec![n_global_writes as usize, n_global_reads as usize], has_ecc_ops: true, }; let mut rng = thread_rng(); diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 6e24f8f10..1c4776cd0 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -1,7 +1,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, - instructions::global::{GlobalChip, GlobalChipInput, GlobalRecord}, + instructions::global::GlobalChip, structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::{ DynVolatileRamTable, HeapInitCircuit, HeapTable, HintsCircuit, LocalFinalCircuit, @@ -11,9 +11,8 @@ use crate::{ }, }; use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word}; -use ff_ext::{ExtensionField, PoseidonField}; +use ff_ext::ExtensionField; use itertools::{Itertools, chain}; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator as _}; use std::{collections::HashSet, iter::zip, ops::Range, sync::Arc}; use witness::InstancePaddingStrategy; @@ -158,50 +157,13 @@ impl MmuConfig<'_, E> { .into_iter() .filter(|(_, record)| !record.is_empty()) .collect_vec(); - // take all mem result and + witness.assign_table_circuit::>( cs, &self.local_final_circuit, &(shard_ctx, all_records.as_slice()), )?; - - let perm = ::get_default_perm(); - let global_input = shard_ctx - .read_records() - .par_iter() - .chain(shard_ctx.write_records().par_iter()) - .flat_map_iter(|records| { - records.iter().map(|(vma, record)| { - let addr = match record.ram_type { - gkr_iop::RAMType::Register => record.id as u32, - gkr_iop::RAMType::Memory => (*vma).into(), - _ => unreachable!(), - }; - let (is_write, local_clk, global_clk) = if record.prev_value.is_some() { - // global write - (true, record.cycle, record.cycle) - } else { - (false, 0, record.prev_cycle) - }; - let global_rw = GlobalRecord { - addr, - ram_type: record.ram_type, - value: record.value, - shard: record.prev_cycle, // TODO: extract shard id properly - local_clk: local_clk, - global_clk: global_clk, - is_write: is_write, - }; - - let ec_point = global_rw.to_ec_point(&perm); - GlobalChipInput { - record: global_rw, - ec_point, - } - }) - }) - .collect::>(); - witness.assign_table_circuit::>(cs, &self.ram_bus_circuit, &global_input)?; + witness.assign_global_chip_circuit(cs, &shard_ctx, &self.ram_bus_circuit)?; Ok(()) } diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 382c94de5..3fd2517c8 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -62,9 +62,7 @@ pub struct ZKVMChipProof { pub tower_proof: TowerProofs, pub ecc_proof: Option>, - pub num_read_instances: usize, - pub num_write_instances: usize, - pub num_instances: usize, + pub num_instances: Vec, pub fixed_in_evals: Vec, pub wits_in_evals: Vec, @@ -212,7 +210,7 @@ impl> ZKVMProof { let halt_instance_count = self .chip_proofs .get(&halt_circuit_index) - .map_or(0, |proof| proof.num_instances); + .map_or(0, |proof| proof.num_instances.iter().sum()); if halt_instance_count > 0 { assert_eq!( halt_instance_count, 1, diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index bbd7f97c1..e772fc563 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -45,9 +45,6 @@ use sumcheck::{ use transcript::Transcript; use witness::next_pow2_instance_padding; -#[cfg(feature = "sanity-check")] -use gkr_iop::utils::eq_eval_less_or_equal_than; - pub type TowerRelationOutput = ( Point, TowerProofs, @@ -78,6 +75,11 @@ impl CpuEccProver { assert_eq!(ys.len(), SEPTIC_EXTENSION_DEGREE); let n = xs[0].num_vars() - 1; + tracing::debug!( + "Creating EC Summation Quark proof with {} points in {n} variables", + num_instances + ); + let out_rt = transcript.sample_and_append_vec(b"ecc", n); let num_threads = optimal_sumcheck_threads(out_rt.len()); @@ -258,44 +260,25 @@ impl CpuEccProver { #[cfg(feature = "sanity-check")] { - let s = invs.iter().map(|x| x.as_view_slice(2, 0)).collect_vec(); + let s = invs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec(); let x0 = filter_bj(&xs, 0); let y0 = filter_bj(&ys, 0); let x1 = filter_bj(&xs, 1); let y1 = filter_bj(&ys, 1); + let evals = &evals[2..]; // check evaluations - assert_eq!( - eq_eval_less_or_equal_than(last_evaluation_index - 1, &out_rt, &rt), - evals[0] - ); for i in 0..SEPTIC_EXTENSION_DEGREE { - assert_eq!(s[i].evaluate(&rt), evals[1 + i]); - assert_eq!(x0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE + 1 + i]); - assert_eq!( - y0[i].evaluate(&rt), - evals[SEPTIC_EXTENSION_DEGREE * 2 + 1 + i] - ); - assert_eq!( - x1[i].evaluate(&rt), - evals[SEPTIC_EXTENSION_DEGREE * 3 + 1 + i] - ); - assert_eq!( - y1[i].evaluate(&rt), - evals[SEPTIC_EXTENSION_DEGREE * 4 + 1 + i] - ); - assert_eq!( - x3[i].evaluate(&rt), - evals[SEPTIC_EXTENSION_DEGREE * 5 + 1 + i] - ); - assert_eq!( - y3[i].evaluate(&rt), - evals[SEPTIC_EXTENSION_DEGREE * 6 + 1 + i] - ); + assert_eq!(s[i].evaluate(&rt), evals[i]); + assert_eq!(x0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE + i]); + assert_eq!(y0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 2 + i]); + assert_eq!(x1[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 3 + i]); + assert_eq!(y1[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 4 + i]); + assert_eq!(x3[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 5 + i]); + assert_eq!(y3[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 6 + i]); } } - // TODO: prove the validity of s[1,rt], x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt] EccQuarkProof { zerocheck_proof, num_instances, @@ -585,8 +568,8 @@ impl> TowerProver> MainSumcheckProver> MainSumcheckProver( zkvm_v1_css: cs, .. } = composed_cs; let num_instances_with_rotation = - input.num_instances << composed_cs.rotation_vars().unwrap_or(0); + input.num_instances() << composed_cs.rotation_vars().unwrap_or(0); let chip_record_alpha = challenges[0]; // TODO: safety ? diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index 408f3a921..44aa75c21 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -38,16 +38,19 @@ pub struct ProofInput<'a, PB: ProverBackend> { pub structural_witness: Vec>>, pub fixed: Vec>>, pub public_input: Vec>>, - pub num_read_instances: usize, - pub num_write_instances: usize, - pub num_instances: usize, + pub num_instances: Vec, pub has_ecc_ops: bool, } impl<'a, PB: ProverBackend> ProofInput<'a, PB> { + pub fn num_instances(&self) -> usize { + self.num_instances.iter().sum() + } + #[inline] pub fn log2_num_instances(&self) -> usize { - let log2 = ceil_log2(next_pow2_instance_padding(self.num_instances)); + let num_instance = self.num_instances(); + let log2 = ceil_log2(next_pow2_instance_padding(num_instance)); if self.has_ecc_ops { // the mles have one extra variable to store // the internal partial sums for ecc additions diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index bba5b8c9f..187d2a708 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -118,17 +118,11 @@ impl< { // num_instance from witness might include rotation if let Some(num_instance) = witnesses - .get_opcode_witness(circuit_name) - .or_else(|| witnesses.get_table_witness(circuit_name)) - .map(|rmms| { - if rmms[0].num_instances() == 0 { - rmms[1].num_instances() - } else { - rmms[0].num_instances() - } - }) + .num_instances + .get(circuit_name) + .cloned() .and_then(|num_instance| { - if num_instance > 0 { + if num_instance.iter().sum::() > 0 { Some(num_instance) } else { None @@ -140,26 +134,28 @@ impl< .circuit_index_fixed_num_instances .get(&index) .copied() - .unwrap_or(0) + .map(|num_instance| vec![num_instance]) + .unwrap_or(vec![]) }) }) { - num_instances.push(( - index, - num_instance >> vk.get_cs().rotation_vars().unwrap_or(0), - )); + let num_instance_exclude_rotation = num_instance + .iter() + .map(|num_instance| num_instance >> vk.get_cs().rotation_vars().unwrap_or(0)) + .collect_vec(); + num_instances.push((index, num_instance_exclude_rotation.clone())); + circuit_name_num_instances_mapping + .insert(circuit_name, num_instance_exclude_rotation); num_instances_with_rotation.push((index, num_instance)); - circuit_name_num_instances_mapping.insert( - circuit_name, - num_instance >> vk.get_cs().rotation_vars().unwrap_or(0), - ); } } // write (circuit_idx, num_var) to transcript for (circuit_idx, num_instance) in &num_instances { transcript.append_message(&circuit_idx.to_le_bytes()); - transcript.append_message(&num_instance.to_le_bytes()); + for num_instance in num_instance { + transcript.append_message(&num_instance.to_le_bytes()); + } } let commit_to_traces_span = entered_span!("batch commit to traces", profiling_1 = true); @@ -216,10 +212,10 @@ impl< |(mut points, mut evaluations), (index, (circuit_name, pk))| { let num_instances = circuit_name_num_instances_mapping .get(&circuit_name) - .copied() - .unwrap_or(0); + .cloned() + .unwrap_or_default(); let cs = pk.get_cs(); - if num_instances == 0 { + if num_instances.is_empty() { // we need to drain respective fixed when num_instances is 0 if cs.num_fixed() > 0 { let _ = fixed_mles.drain(..cs.num_fixed()).collect_vec(); @@ -249,9 +245,7 @@ impl< fixed, structural_witness, public_input: public_input.clone(), - num_read_instances: num_instances, // TODO: fixme - num_write_instances: num_instances, // TODO: fixme - num_instances, + num_instances: num_instances.clone(), has_ecc_ops: cs.has_ecc_ops(), }; @@ -264,7 +258,7 @@ impl< &challenges, )?; tracing::trace!( - "generated proof for opcode {} with num_instances={}", + "generated proof for opcode {} with num_instances={:?}", circuit_name, num_instances ); @@ -366,10 +360,13 @@ impl< _ => unreachable!("slope's expression must be WitIn"), }) .collect_vec(); - Some( - self.device - .prove_ec_sum_quark(input.num_instances, xs, ys, slopes, transcript)?, - ) + Some(self.device.prove_ec_sum_quark( + input.num_instances(), + xs, + ys, + slopes, + transcript, + )?) } else { None }; @@ -437,8 +434,6 @@ impl< ecc_proof, fixed_in_evals, wits_in_evals, - num_read_instances: input.num_read_instances, - num_write_instances: input.num_write_instances, num_instances: input.num_instances, }, pi_in_evals, diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 6c6120050..f9b6b4f76 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -41,7 +41,7 @@ pub struct SexticExtension([F; 6]); /// # check if f(x) is irreducible /// print(f.is_irreducible()) /// ``` -#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct SepticExtension(pub [F; 7]); impl From<&[F]> for SepticExtension { @@ -759,7 +759,7 @@ impl SymbolicSepticExtension { /// Note that /// 1. The curve's cofactor is 1 /// 2. The curve's order is a large prime number of 31x7 bits -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct SepticPoint { pub x: SepticExtension, pub y: SepticExtension, diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 65211b65b..a8dd0d015 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -201,9 +201,7 @@ fn test_rw_lk_expression_combination() { witness: wits_in, structural_witness: structural_in, public_input: vec![], - num_read_instances: num_instances, - num_write_instances: num_instances, - num_instances, + num_instances: vec![num_instances], has_ecc_ops: false, }; let (proof, _, _) = prover diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 41e178fa6..74314aca0 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -502,7 +502,7 @@ pub fn build_main_witness< } else { ( >::table_witness(device, input, cs, challenges), - input.num_instances > 1 && input.num_instances.is_power_of_two(), + input.num_instances() > 1 && input.num_instances().is_power_of_two(), ) } }; diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 1e1d738e6..e162da405 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -9,7 +9,7 @@ use crate::{ error::ZKVMError, scheme::{ constants::{NUM_FANIN, NUM_FANIN_LOGUP, SEPTIC_EXTENSION_DEGREE}, - septic_curve::{SepticExtension, SepticPoint}, + septic_curve::SepticExtension, }, structs::{ ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VerifyingKey, @@ -142,7 +142,9 @@ impl> ZKVMVerifier // write (circuit_idx, num_instance) to transcript for (circuit_idx, proof) in &vm_proof.chip_proofs { transcript.append_message(&circuit_idx.to_le_bytes()); - transcript.append_message(&proof.num_instances.to_le_bytes()); + for num_instance in &proof.num_instances { + transcript.append_message(&num_instance.to_le_bytes()); + } } // write witin commitment to transcript @@ -169,7 +171,8 @@ impl> ZKVMVerifier let mut witin_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); let mut fixed_openings = Vec::with_capacity(vm_proof.chip_proofs.len()); for (index, proof) in &vm_proof.chip_proofs { - assert!(proof.num_instances > 0); + let num_instance: usize = proof.num_instances.iter().sum(); + assert!(num_instance > 0); let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; @@ -227,11 +230,10 @@ impl> ZKVMVerifier // getting the number of dummy padding item that we used in this opcode circuit let num_lks = circuit_vk.get_cs().num_lks(); // each padding instance contribute to (2^rotation_vars) dummy lookup padding - let num_padded_instance = (next_pow2_instance_padding(proof.num_instances) - - proof.num_instances) + let num_padded_instance = (next_pow2_instance_padding(num_instance) - num_instance) * (1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0)); // each instance contribute to (2^rotation_vars - rotated) dummy lookup padding - let num_instance_non_selected = proof.num_instances + let num_instance_non_selected = num_instance * ((1 << circuit_vk.get_cs().rotation_vars().unwrap_or(0)) - (circuit_vk.get_cs().rotation_subgroup_size().unwrap_or(0) + 1)); dummy_table_item_multiplicity += @@ -357,7 +359,7 @@ impl> ZKVMVerifier zkvm_v1_css: cs, gkr_circuit, } = &composed_cs; - let num_instances = proof.num_instances; + let num_instances = proof.num_instances.iter().sum(); let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = ( cs.r_expressions.len() + cs.r_table_expressions.len(), cs.w_expressions.len() + cs.w_table_expressions.len(), @@ -380,29 +382,31 @@ impl> ZKVMVerifier assert!(proof.ecc_proof.is_some()); let ecc_proof = proof.ecc_proof.as_ref().unwrap(); - let xy = cs - .ec_final_sum - .iter() - .map(|expr| { - eval_by_expr_with_instance(&[], &[], &[], pi, challenges, &expr) - .right() - .and_then(|v| v.as_base()) - .unwrap() - }) - .collect_vec(); - let x: SepticExtension = xy[0..SEPTIC_EXTENSION_DEGREE].into(); - let y: SepticExtension = xy[SEPTIC_EXTENSION_DEGREE..].into(); - - assert_eq!( - SepticPoint { - x, - y, - is_infinity: false, - }, - ecc_proof.sum - ); + // TODO: enable this + // let xy = cs + // .ec_final_sum + // .iter() + // .map(|expr| { + // eval_by_expr_with_instance(&[], &[], &[], pi, challenges, &expr) + // .right() + // .and_then(|v| v.as_base()) + // .unwrap() + // }) + // .collect_vec(); + // let x: SepticExtension = xy[0..SEPTIC_EXTENSION_DEGREE].into(); + // let y: SepticExtension = xy[SEPTIC_EXTENSION_DEGREE..].into(); + + // assert_eq!( + // SepticPoint { + // x, + // y, + // is_infinity: false, + // }, + // ecc_proof.sum + // ); // assert ec sum in public input matches that in ecc proof EccVerifier::new().verify_ecc_proof(ecc_proof, transcript)?; + tracing::debug!("ecc proof verified."); } // verify and reduce product tower sumcheck @@ -450,6 +454,7 @@ impl> ZKVMVerifier let gkr_circuit = gkr_circuit.as_ref().unwrap(); let selector_ctxs = if cs.ec_final_sum.is_empty() { + assert_eq!(proof.num_instances.len(), 1); // it's not global chip vec![ SelectorContext::new(0, num_instances, num_var_with_rotation); @@ -460,21 +465,28 @@ impl> ZKVMVerifier .unwrap_or(0) ] } else { + assert_eq!(proof.num_instances.len(), 2); // it's global chip + tracing::debug!( + "num_reads: {}, num_writes: {}, total: {}", + proof.num_instances[0], + proof.num_instances[1], + proof.num_instances[0] + proof.num_instances[1], + ); vec![ SelectorContext { offset: 0, - num_instances: proof.num_read_instances, + num_instances: proof.num_instances[0], num_vars: num_var_with_rotation, }, SelectorContext { - offset: proof.num_read_instances, - num_instances: proof.num_write_instances, + offset: proof.num_instances[0], + num_instances: proof.num_instances[1], num_vars: num_var_with_rotation, }, SelectorContext { offset: 0, - num_instances: proof.num_instances, + num_instances: proof.num_instances[0] + proof.num_instances[1], num_vars: num_var_with_rotation, }, ] @@ -516,7 +528,8 @@ impl> ZKVMVerifier .all(|(r, w)| r.table_spec.len == w.table_spec.len) ); } - let log2_num_instances = next_pow2_instance_padding(proof.num_instances).ilog2() as usize; + let num_instances = proof.num_instances.iter().sum(); + let log2_num_instances = next_pow2_instance_padding(num_instances).ilog2() as usize; // verify and reduce product tower sumcheck let tower_proofs = &proof.tower_proof; diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 54ffaeb12..f82918cfb 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -2,17 +2,21 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, error::ZKVMError, - instructions::Instruction, + instructions::{ + Instruction, + global::{GlobalChip, GlobalChipInput, GlobalPoint, GlobalRecord}, + }, scheme::septic_curve::SepticPoint, state::StateCircuit, tables::{RMMCollections, TableCircuit}, }; use ceno_emul::{CENO_PLATFORM, Platform}; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, PoseidonField}; use gkr_iop::{gkr::GKRCircuit, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{Expression, Instance}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ collections::{BTreeMap, HashMap}, @@ -157,7 +161,8 @@ impl ComposedConstrainSystem { } pub fn is_opcode_circuit(&self) -> bool { - self.gkr_circuit.is_some() + // TODO: is global chip opcode circuit?? + self.gkr_circuit.is_some() || self.has_ecc_ops() } /// return number of lookup operation @@ -321,6 +326,8 @@ pub struct ZKVMWitnesses { witnesses_tables: BTreeMap>, lk_mlts: BTreeMap>, combined_lk_mlt: Option>>, + // in ram bus chip, num_instances length would be > 1 + pub num_instances: BTreeMap>, } impl ZKVMWitnesses { @@ -353,6 +360,11 @@ impl ZKVMWitnesses { cs.zkvm_v1_css.num_structural_witin as usize, records, )?; + assert!( + self.num_instances + .insert(OC::name(), vec![witness[0].num_instances()]) + .is_none() + ); assert!(self.witnesses_opcodes.insert(OC::name(), witness).is_none()); assert!(!self.witnesses_tables.contains_key(&OC::name())); assert!( @@ -406,12 +418,99 @@ impl ZKVMWitnesses { self.combined_lk_mlt.as_ref().unwrap(), input, )?; + let num_instances = std::cmp::max(witness[0].num_instances(), witness[1].num_instances()); + assert!( + self.num_instances + .insert(TC::name(), vec![num_instances]) + .is_none() + ); assert!(self.witnesses_tables.insert(TC::name(), witness).is_none()); assert!(!self.witnesses_opcodes.contains_key(&TC::name())); Ok(()) } + pub fn assign_global_chip_circuit( + &mut self, + cs: &ZKVMConstraintSystem, + shard_ctx: &ShardContext, + config: & as TableCircuit>::TableConfig, + ) -> Result<(), ZKVMError> { + let perm = ::get_default_perm(); + let global_input = shard_ctx + .read_records() + .par_iter() + .flat_map_iter(|records| { + records.iter().map(|(vma, record)| { + let global_read: GlobalRecord = (vma, record, false).into(); + let ec_point: GlobalPoint = global_read.to_ec_point(&perm); + GlobalChipInput { + record: global_read, + ec_point, + } + }) + }) + .chain( + shard_ctx + .write_records() + .par_iter() + .flat_map_iter(|records| { + records.iter().map(|(vma, record)| { + let global_write: GlobalRecord = (vma, record, true).into(); + let ec_point: GlobalPoint = global_write.to_ec_point(&perm); + GlobalChipInput { + record: global_write, + ec_point, + } + }) + }), + ) + .collect::>(); + assert!(self.combined_lk_mlt.is_some()); + let cs = cs.get_cs(&GlobalChip::::name()).unwrap(); + let witness = GlobalChip::assign_instances( + config, + cs.zkvm_v1_css.num_witin as usize, + cs.zkvm_v1_css.num_structural_witin as usize, + self.combined_lk_mlt.as_ref().unwrap(), + &global_input, + )?; + // set num_read, num_write as separate instance + assert!( + self.num_instances + .insert( + GlobalChip::::name(), + vec![ + // global write -> local read + shard_ctx + .write_records() + .iter() + .map(|records| records.len()) + .sum(), + // global read -> local write + shard_ctx + .read_records() + .iter() + .map(|records| records.len()) + .sum(), + ] + ) + .is_none() + ); + assert!( + self.witnesses_tables + .insert(GlobalChip::::name(), witness) + .is_none() + ); + assert!( + !self + .witnesses_opcodes + .contains_key(&GlobalChip::::name()) + ); + + Ok(()) + } + /// Iterate opcode/table circuits, sorted by alphabetical order. pub fn into_iter_sorted( self, diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index 2330cfed7..9f10d2249 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -267,14 +267,18 @@ impl SelectorType { ctx.num_vars ); - let eq_end = eq_eval_less_or_equal_than(end - 1, out_point, in_point); - let sel = if start > 0 { - let eq_start = eq_eval_less_or_equal_than(start - 1, out_point, in_point); - eq_end - eq_start + if end == 0 { + (expression, E::ZERO) } else { - eq_end - }; - (expression, sel) + let eq_end = eq_eval_less_or_equal_than(end - 1, out_point, in_point); + let sel = if start > 0 { + let eq_start = eq_eval_less_or_equal_than(start - 1, out_point, in_point); + eq_end - eq_start + } else { + eq_end + }; + (expression, sel) + } } // evaluate true and false mle eq(CYCLIC_POW2_5[round]; b[..5]) * sel(y; b[5..]), and eq(1; b[..5]) * (1 - sel(y; b[5..])) SelectorType::OrderedSparse32 { From 2214fdc1ab8341e4527cff51ef2827b896891871 Mon Sep 17 00:00:00 2001 From: Ming Date: Fri, 31 Oct 2025 11:28:32 +0800 Subject: [PATCH 87/91] extract shard id from cycle (#1101) --- ceno_zkvm/src/e2e.rs | 29 ++++++++++++++++++++++++++++ ceno_zkvm/src/instructions/global.rs | 8 ++++---- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index c490bb597..62f3e425f 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -106,10 +106,16 @@ pub struct RAMRecord { pub ram_type: RAMType, pub id: u64, pub addr: WordAddr, + // prev_cycle and cycle are global cycle pub prev_cycle: Cycle, pub cycle: Cycle, + // shard_cycle is cycle in current local shard, which already offset by start cycle + pub shard_cycle: Cycle, pub prev_value: Option, pub value: Word, + // for global reads, `shard_id` refers to the shard that previously produced this value. + // for global write, `shard_id` refers to current shard. + pub shard_id: usize, } #[derive(Clone, Debug)] @@ -155,6 +161,7 @@ pub struct ShardContext<'a> { write_thread_based_record_storage: Either>, &'a mut BTreeMap>, pub cur_shard_cycle_range: std::ops::Range, + pub expected_inst_per_shard: usize, } impl<'a> Default for ShardContext<'a> { @@ -177,6 +184,7 @@ impl<'a> Default for ShardContext<'a> { .collect::>(), ), cur_shard_cycle_range: Tracer::SUBCYCLES_PER_INSN as usize..usize::MAX, + expected_inst_per_shard: usize::MAX, } } } @@ -223,6 +231,7 @@ impl<'a> ShardContext<'a> { .collect::>(), ), cur_shard_cycle_range, + expected_inst_per_shard, } } @@ -244,6 +253,7 @@ impl<'a> ShardContext<'a> { read_thread_based_record_storage: Either::Right(read), write_thread_based_record_storage: Either::Right(write), cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), + expected_inst_per_shard: self.expected_inst_per_shard, }) .collect_vec(), _ => panic!("invalid type"), @@ -279,6 +289,14 @@ impl<'a> ShardContext<'a> { self.cur_shard_cycle_range.contains(&(cycle as usize)) } + #[inline(always)] + pub fn extract_prev_shard_id(&self, cycle: Cycle) -> usize { + let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN; + let per_shard_cycles = + (self.expected_inst_per_shard as u64).saturating_mul(subcycle_per_insn); + ((cycle.saturating_sub(subcycle_per_insn)) / per_shard_cycles) as usize + } + #[inline(always)] pub fn aligned_prev_ts(&self, prev_cycle: Cycle) -> Cycle { let mut ts = prev_cycle.saturating_sub(self.current_shard_offset_cycle()); @@ -288,6 +306,11 @@ impl<'a> ShardContext<'a> { ts } + #[inline(always)] + pub fn aligned_current_ts(&self, cycle: Cycle) -> Cycle { + cycle.saturating_sub(self.current_shard_offset_cycle()) + } + pub fn current_shard_offset_cycle(&self) -> Cycle { // cycle of each local shard start from Tracer::SUBCYCLES_PER_INSN (self.cur_shard_cycle_range.start as Cycle) - Tracer::SUBCYCLES_PER_INSN @@ -311,6 +334,7 @@ impl<'a> ShardContext<'a> { && self.is_current_shard_cycle(cycle) && !self.is_first_shard() { + let prev_shard_id = self.extract_prev_shard_id(prev_cycle); let ram_record = self .read_thread_based_record_storage .as_mut() @@ -324,8 +348,10 @@ impl<'a> ShardContext<'a> { addr, prev_cycle, cycle, + shard_cycle: 0, prev_value, value, + shard_id: prev_shard_id, }, ); } @@ -348,6 +374,7 @@ impl<'a> ShardContext<'a> { && future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle && self.is_current_shard_cycle(cycle) { + let shard_cycle = self.aligned_current_ts(cycle); let ram_record = self .write_thread_based_record_storage .as_mut() @@ -361,8 +388,10 @@ impl<'a> ShardContext<'a> { addr, prev_cycle, cycle, + shard_cycle, prev_value, value, + shard_id: self.shards.shard_id, }, ); } diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 8cef4d869..6d23a1594 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -61,17 +61,17 @@ impl From<(&WordAddr, &RAMRecord, bool)> for GlobalRecord { }; let value = record.prev_value.map_or(record.value, |v| v); let (shard, local_clk, global_clk) = if is_write { - // FIXME: extract shard id and local_clk from record.cycle - (record.cycle, record.cycle, record.cycle) + (record.shard_id, record.shard_cycle, record.cycle) } else { - (record.prev_cycle, 0, record.prev_cycle) + debug_assert_eq!(record.shard_cycle, 0); + (record.shard_id, 0, record.prev_cycle) }; GlobalRecord { addr, ram_type: record.ram_type, value, - shard, + shard: shard as u64, local_clk, global_clk, is_write, From 3860d14f3ebf40a7882bb411b1f9395df261a0fe Mon Sep 17 00:00:00 2001 From: xkx Date: Fri, 31 Oct 2025 12:56:42 +0800 Subject: [PATCH 88/91] revert stateful trait Instruction (#1105) --- ceno_emul/src/syscalls/bn254/bn254_fptower.rs | 4 --- ceno_emul/src/syscalls/keccak_permute.rs | 1 - ceno_emul/src/syscalls/secp256k1.rs | 3 -- ceno_emul/src/syscalls/sha256.rs | 1 - ceno_zkvm/benches/riscv_add.rs | 4 +-- ceno_zkvm/src/instructions.rs | 10 +++--- ceno_zkvm/src/instructions/global.rs | 2 +- ceno_zkvm/src/instructions/riscv.rs | 2 +- ceno_zkvm/src/instructions/riscv/arith.rs | 17 ++++------ ceno_zkvm/src/instructions/riscv/arith_imm.rs | 8 ++--- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 3 -- ceno_zkvm/src/instructions/riscv/auipc.rs | 9 ++---- ceno_zkvm/src/instructions/riscv/branch.rs | 11 ------- .../riscv/branch/branch_circuit.rs | 9 ++---- .../riscv/branch/branch_circuit_v2.rs | 3 -- .../src/instructions/riscv/branch/test.rs | 22 ++++++------- ceno_zkvm/src/instructions/riscv/div.rs | 21 ++---------- .../instructions/riscv/div/div_circuit_v2.rs | 3 -- .../instructions/riscv/dummy/dummy_circuit.rs | 3 -- .../instructions/riscv/dummy/dummy_ecall.rs | 3 -- .../src/instructions/riscv/dummy/test.rs | 12 +++---- ceno_zkvm/src/instructions/riscv/ecall.rs | 2 -- .../src/instructions/riscv/ecall/halt.rs | 3 -- .../src/instructions/riscv/ecall/keccak.rs | 4 --- .../riscv/ecall/weierstrass_add.rs | 4 --- .../riscv/ecall/weierstrass_decompress.rs | 4 --- .../riscv/ecall/weierstrass_double.rs | 4 --- .../src/instructions/riscv/jump/jal_v2.rs | 5 +-- ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 1 - .../src/instructions/riscv/jump/jalr_v2.rs | 5 +-- ceno_zkvm/src/instructions/riscv/jump/test.rs | 6 ++-- ceno_zkvm/src/instructions/riscv/logic.rs | 6 ---- .../instructions/riscv/logic/logic_circuit.rs | 3 -- .../src/instructions/riscv/logic/test.rs | 9 ++---- ceno_zkvm/src/instructions/riscv/logic_imm.rs | 6 ---- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 3 -- .../src/instructions/riscv/logic_imm/test.rs | 13 +++----- ceno_zkvm/src/instructions/riscv/lui.rs | 9 ++---- ceno_zkvm/src/instructions/riscv/memory.rs | 15 --------- .../src/instructions/riscv/memory/load_v2.rs | 5 +-- .../src/instructions/riscv/memory/store_v2.rs | 3 -- .../src/instructions/riscv/memory/test.rs | 22 +++---------- ceno_zkvm/src/instructions/riscv/mulh.rs | 32 +++++++++++-------- .../riscv/mulh/mulh_circuit_v2.rs | 5 +-- ceno_zkvm/src/instructions/riscv/shift.rs | 14 ++++---- .../riscv/shift/shift_circuit_v2.rs | 10 ++---- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 12 +++---- ceno_zkvm/src/instructions/riscv/slt.rs | 10 +++--- .../instructions/riscv/slt/slt_circuit_v2.rs | 3 -- ceno_zkvm/src/instructions/riscv/slti.rs | 12 +++---- .../riscv/slti/slti_circuit_v2.rs | 5 +-- ceno_zkvm/src/instructions/riscv/test.rs | 6 ++-- ceno_zkvm/src/scheme/cpu/mod.rs | 1 - ceno_zkvm/src/scheme/tests.rs | 3 -- ceno_zkvm/src/scheme/verifier.rs | 1 - ceno_zkvm/src/structs.rs | 14 +++----- gkr_iop/src/lib.rs | 5 +-- 57 files changed, 110 insertions(+), 306 deletions(-) diff --git a/ceno_emul/src/syscalls/bn254/bn254_fptower.rs b/ceno_emul/src/syscalls/bn254/bn254_fptower.rs index 7294693e0..3fa98f368 100644 --- a/ceno_emul/src/syscalls/bn254/bn254_fptower.rs +++ b/ceno_emul/src/syscalls/bn254/bn254_fptower.rs @@ -11,7 +11,6 @@ use crate::{ use super::types::{BN254_FP_WORDS, BN254_FP2_WORDS}; -#[derive(Default)] pub struct Bn254FpAddSpec; impl SyscallSpec for Bn254FpAddSpec { @@ -22,7 +21,6 @@ impl SyscallSpec for Bn254FpAddSpec { const CODE: u32 = ceno_syscall::BN254_FP_ADD; } -#[derive(Default)] pub struct Bn254Fp2AddSpec; impl SyscallSpec for Bn254Fp2AddSpec { const NAME: &'static str = "BN254_FP2_ADD"; @@ -32,7 +30,6 @@ impl SyscallSpec for Bn254Fp2AddSpec { const CODE: u32 = ceno_syscall::BN254_FP2_ADD; } -#[derive(Default)] pub struct Bn254FpMulSpec; impl SyscallSpec for Bn254FpMulSpec { const NAME: &'static str = "BN254_FP_MUL"; @@ -42,7 +39,6 @@ impl SyscallSpec for Bn254FpMulSpec { const CODE: u32 = ceno_syscall::BN254_FP_MUL; } -#[derive(Default)] pub struct Bn254Fp2MulSpec; impl SyscallSpec for Bn254Fp2MulSpec { const NAME: &'static str = "BN254_FP2_MUL"; diff --git a/ceno_emul/src/syscalls/keccak_permute.rs b/ceno_emul/src/syscalls/keccak_permute.rs index 20b66b033..c9fc2cac7 100644 --- a/ceno_emul/src/syscalls/keccak_permute.rs +++ b/ceno_emul/src/syscalls/keccak_permute.rs @@ -8,7 +8,6 @@ use super::{SyscallEffects, SyscallSpec, SyscallWitness}; const KECCAK_CELLS: usize = 25; // u64 cells pub const KECCAK_WORDS: usize = KECCAK_CELLS * 2; // u32 words -#[derive(Default)] pub struct KeccakSpec; impl SyscallSpec for KeccakSpec { diff --git a/ceno_emul/src/syscalls/secp256k1.rs b/ceno_emul/src/syscalls/secp256k1.rs index 288da1075..fafabe78c 100644 --- a/ceno_emul/src/syscalls/secp256k1.rs +++ b/ceno_emul/src/syscalls/secp256k1.rs @@ -5,13 +5,10 @@ use std::iter; use super::{SyscallEffects, SyscallSpec, SyscallWitness}; -#[derive(Default)] pub struct Secp256k1AddSpec; -#[derive(Default)] pub struct Secp256k1DoubleSpec; -#[derive(Default)] pub struct Secp256k1DecompressSpec; impl SyscallSpec for Secp256k1AddSpec { diff --git a/ceno_emul/src/syscalls/sha256.rs b/ceno_emul/src/syscalls/sha256.rs index 60c65d871..a6b1e404c 100644 --- a/ceno_emul/src/syscalls/sha256.rs +++ b/ceno_emul/src/syscalls/sha256.rs @@ -4,7 +4,6 @@ use super::{SyscallEffects, SyscallSpec, SyscallWitness}; pub const SHA_EXTEND_WORDS: usize = 64; // u64 cells -#[derive(Default)] pub struct Sha256ExtendSpec; impl SyscallSpec for Sha256ExtendSpec { diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index b438244a3..9d8cc22e8 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -111,9 +111,7 @@ fn bench_add(c: &mut Criterion) { witness: polys, structural_witness: vec![], public_input: vec![], - num_read_instances: num_instances, - num_write_instances: num_instances, - num_instances, + num_instances: vec![num_instances], has_ecc_ops: false, }; let _ = prover diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index dd9ae0907..12c137aa8 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -2,6 +2,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, structs::ProgramParams, tables::RMMCollections, witness::LkMultiplicity, }; +use ceno_emul::StepRecord; use ff_ext::{ExtensionField, FieldInto}; use gkr_iop::{ chip::Chip, @@ -23,7 +24,6 @@ pub mod riscv; pub trait Instruction { type InstructionConfig: Send + Sync; - type Record: Sync; fn padding_strategy() -> InstancePaddingStrategy { InstancePaddingStrategy::Default @@ -33,17 +33,15 @@ pub trait Instruction { /// construct circuit and manipulate circuit builder, then return the respective config fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, param: &ProgramParams, ) -> Result; fn build_gkr_iop_circuit( - &self, cb: &mut CircuitBuilder, param: &ProgramParams, ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { - let config = self.construct_circuit(cb, param)?; + let config = Self::construct_circuit(cb, param)?; let w_len = cb.cs.w_expressions.len(); let r_len = cb.cs.r_expressions.len(); let lk_len = cb.cs.lk_expressions.len(); @@ -101,7 +99,7 @@ pub trait Instruction { shard_ctx: &mut ShardContext<'a>, instance: &mut [E::BaseField], lk_multiplicity: &mut LkMultiplicity, - step: &Self::Record, + step: &StepRecord, ) -> Result<(), ZKVMError>; fn assign_instances( @@ -109,7 +107,7 @@ pub trait Instruction { shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec, + steps: Vec, ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { // FIXME selector is the only structural witness // this is workaround, as call `construct_circuit` will not initialized selector diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 6d23a1594..6de0c2d8b 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -41,7 +41,7 @@ use witness::{InstancePaddingStrategy, next_pow2_instance_padding, set_val}; use crate::{instructions::riscv::constants::UInt, scheme::constants::SEPTIC_EXTENSION_DEGREE}; /// A record for a read/write into the global set -#[derive(Default, Debug, Clone)] +#[derive(Debug, Clone)] pub struct GlobalRecord { pub addr: u32, pub ram_type: RAMType, diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 7b86cfca5..69c656148 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -44,7 +44,7 @@ mod test; #[cfg(test)] mod test_utils; -pub trait RIVInstruction: Default { +pub trait RIVInstruction { const INST_KIND: InsnKind; } diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index bce17c0ee..a94024b4a 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -18,20 +18,15 @@ pub struct ArithConfig { rd_written: UInt, } -#[derive(Default)] -pub struct ArithInstruction(PhantomData<(E, I)>); +pub struct ArithInstruction(PhantomData<(E, I)>); -#[derive(Default)] pub struct AddOp; - impl RIVInstruction for AddOp { const INST_KIND: InsnKind = InsnKind::ADD; } pub type AddInstruction = ArithInstruction; -#[derive(Default)] pub struct SubOp; - impl RIVInstruction for SubOp { const INST_KIND: InsnKind = InsnKind::SUB; } @@ -39,14 +34,12 @@ pub type SubInstruction = ArithInstruction; impl Instruction for ArithInstruction { type InstructionConfig = ArithConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) } fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { @@ -171,11 +164,15 @@ mod test { fn verify(name: &'static str, rs1: u32, rs2: u32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = ArithInstruction::::default(); let config = cb .namespace( || format!("{:?}_({name})", I::INST_KIND), - |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), + |cb| { + Ok(ArithInstruction::::construct_circuit( + cb, + &ProgramParams::default(), + )) + }, ) .unwrap() .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index dd96ca2df..4de4069d0 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -3,8 +3,6 @@ mod arith_imm_circuit; #[cfg(feature = "u16limb_circuit")] mod arith_imm_circuit_v2; -use ff_ext::ExtensionField; - #[cfg(feature = "u16limb_circuit")] pub use crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::AddiInstruction; @@ -13,7 +11,7 @@ pub use crate::instructions::riscv::arith_imm::arith_imm_circuit::AddiInstructio use super::RIVInstruction; -impl RIVInstruction for AddiInstruction { +impl RIVInstruction for AddiInstruction { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::ADDI; } @@ -51,12 +49,12 @@ mod test { fn test_opcode_addi_internal(rs1: u32, rd: u32, imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = AddiInstruction::::default(); let config = cb .namespace( || "addi", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = + AddiInstruction::::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 2cf5b4b6d..8ed175d58 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -18,7 +18,6 @@ use p3::field::FieldAlgebra; use std::marker::PhantomData; use witness::set_val; -#[derive(Default)] pub struct AddiInstruction(PhantomData); pub struct InstructionConfig { @@ -33,14 +32,12 @@ pub struct InstructionConfig { impl Instruction for AddiInstruction { type InstructionConfig = InstructionConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", Self::INST_KIND) } fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 560b99298..3244c5d60 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -18,7 +18,7 @@ use crate::{ utils::split_to_u8, witness::LkMultiplicity, }; -use ceno_emul::{InsnKind, StepRecord}; +use ceno_emul::InsnKind; use gkr_iop::tables::{LookupTable, ops::XorTable}; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; @@ -33,19 +33,16 @@ pub struct AuipcConfig { pub rd_written: UInt8, } -#[derive(Default)] pub struct AuipcInstruction(PhantomData); impl Instruction for AuipcInstruction { type InstructionConfig = AuipcConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", InsnKind::AUIPC) } fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { @@ -230,12 +227,12 @@ mod tests { fn test_opcode_auipc(rd: u32, imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = AuipcInstruction::default(); let config = cb .namespace( || "auipc", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = + AuipcInstruction::::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/branch.rs b/ceno_zkvm/src/instructions/riscv/branch.rs index 082c3f897..dc2c8c9e6 100644 --- a/ceno_zkvm/src/instructions/riscv/branch.rs +++ b/ceno_zkvm/src/instructions/riscv/branch.rs @@ -6,25 +6,19 @@ mod branch_circuit_v2; #[cfg(test)] mod test; -#[derive(Default)] pub struct BeqOp; - impl RIVInstruction for BeqOp { const INST_KIND: InsnKind = InsnKind::BEQ; } pub type BeqInstruction = branch_circuit::BranchCircuit; -#[derive(Default)] pub struct BneOp; - impl RIVInstruction for BneOp { const INST_KIND: InsnKind = InsnKind::BNE; } pub type BneInstruction = branch_circuit::BranchCircuit; -#[derive(Default)] pub struct BltuOp; - impl RIVInstruction for BltuOp { const INST_KIND: InsnKind = InsnKind::BLTU; } @@ -33,9 +27,7 @@ pub type BltuInstruction = branch_circuit_v2::BranchCircuit; #[cfg(not(feature = "u16limb_circuit"))] pub type BltuInstruction = branch_circuit::BranchCircuit; -#[derive(Default)] pub struct BgeuOp; - impl RIVInstruction for BgeuOp { const INST_KIND: InsnKind = InsnKind::BGEU; } @@ -44,9 +36,7 @@ pub type BgeuInstruction = branch_circuit_v2::BranchCircuit; #[cfg(not(feature = "u16limb_circuit"))] pub type BgeuInstruction = branch_circuit::BranchCircuit; -#[derive(Default)] pub struct BltOp; - impl RIVInstruction for BltOp { const INST_KIND: InsnKind = InsnKind::BLT; } @@ -55,7 +45,6 @@ pub type BltInstruction = branch_circuit_v2::BranchCircuit; #[cfg(not(feature = "u16limb_circuit"))] pub type BltInstruction = branch_circuit::BranchCircuit; -#[derive(Default)] pub struct BgeOp; impl RIVInstruction for BgeOp { const INST_KIND: InsnKind = InsnKind::BGE; diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs index a2bf9a572..2c97a12ee 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs @@ -23,8 +23,7 @@ use crate::{ use multilinear_extensions::Expression; pub use p3::field::FieldAlgebra; -#[derive(Default)] -pub struct BranchCircuit(PhantomData<(E, I)>); +pub struct BranchCircuit(PhantomData<(E, I)>); pub struct BranchConfig { pub b_insn: BInstructionConfig, @@ -36,15 +35,13 @@ pub struct BranchConfig { } impl Instruction for BranchCircuit { - type InstructionConfig = BranchConfig; - type Record = StepRecord; - fn name() -> String { format!("{:?}", I::INST_KIND) } + type InstructionConfig = BranchConfig; + fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 01cd1b9aa..386d2c286 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -16,7 +16,6 @@ use ff_ext::ExtensionField; use multilinear_extensions::Expression; use std::marker::PhantomData; -#[derive(Default)] pub struct BranchCircuit(PhantomData<(E, I)>); pub struct BranchConfig { @@ -30,14 +29,12 @@ pub struct BranchConfig { impl Instruction for BranchCircuit { type InstructionConfig = BranchConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) } fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, _param: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 7d1083751..82dbcffac 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -25,12 +25,11 @@ fn test_opcode_beq() { fn impl_opcode_beq(equal: bool) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = BeqInstruction::default(); let config = cb .namespace( || "beq", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = BeqInstruction::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -67,12 +66,11 @@ fn test_opcode_bne() { fn impl_opcode_bne(equal: bool) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = BneInstruction::default(); let config = cb .namespace( || "bne", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = BneInstruction::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -115,8 +113,8 @@ fn test_bltu_circuit() -> Result<(), ZKVMError> { fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { let mut cs = ConstraintSystem::new(|| "riscv"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let inst = BltuInstruction::default(); - let config = inst.construct_circuit(&mut circuit_builder, &ProgramParams::default())?; + let config = + BltuInstruction::construct_circuit(&mut circuit_builder, &ProgramParams::default())?; let pc_after = if taken { ByteAddr(MOCK_PC_START.0 - 8) @@ -160,8 +158,8 @@ fn test_bgeu_circuit() -> Result<(), ZKVMError> { fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { let mut cs = ConstraintSystem::new(|| "riscv"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let inst = BgeuInstruction::default(); - let config = inst.construct_circuit(&mut circuit_builder, &ProgramParams::default())?; + let config = + BgeuInstruction::construct_circuit(&mut circuit_builder, &ProgramParams::default())?; let pc_after = if taken { ByteAddr(MOCK_PC_START.0 - 8) @@ -212,8 +210,8 @@ fn test_blt_circuit() -> Result<(), ZKVMError> { fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { let mut cs = ConstraintSystem::new(|| "riscv"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let inst = BltInstruction::default(); - let config = inst.construct_circuit(&mut circuit_builder, &ProgramParams::default())?; + let config = + BltInstruction::construct_circuit(&mut circuit_builder, &ProgramParams::default())?; let pc_after = if taken { ByteAddr(MOCK_PC_START.0 - 8) @@ -264,8 +262,8 @@ fn test_bge_circuit() -> Result<(), ZKVMError> { fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<(), ZKVMError> { let mut cs = ConstraintSystem::new(|| "riscv"); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let inst = BgeInstruction::default(); - let config = inst.construct_circuit(&mut circuit_builder, &ProgramParams::default())?; + let config = + BgeInstruction::construct_circuit(&mut circuit_builder, &ProgramParams::default())?; let pc_after = if taken { ByteAddr(MOCK_PC_START.0 - 8) diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 91085076d..966320407 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -7,9 +7,7 @@ mod div_circuit_v2; use super::RIVInstruction; -#[derive(Default)] pub struct DivuOp; - impl RIVInstruction for DivuOp { const INST_KIND: InsnKind = InsnKind::DIVU; } @@ -18,9 +16,7 @@ pub type DivuInstruction = div_circuit_v2::ArithInstruction; #[cfg(not(feature = "u16limb_circuit"))] pub type DivuInstruction = div_circuit::ArithInstruction; -#[derive(Default)] pub struct RemuOp; - impl RIVInstruction for RemuOp { const INST_KIND: InsnKind = InsnKind::REMU; } @@ -29,9 +25,7 @@ pub type RemuInstruction = div_circuit_v2::ArithInstruction; #[cfg(not(feature = "u16limb_circuit"))] pub type RemuInstruction = div_circuit::ArithInstruction; -#[derive(Default)] pub struct RemOp; - impl RIVInstruction for RemOp { const INST_KIND: InsnKind = InsnKind::REM; } @@ -40,9 +34,7 @@ pub type RemInstruction = div_circuit_v2::ArithInstruction; #[cfg(not(feature = "u16limb_circuit"))] pub type RemInstruction = div_circuit::ArithInstruction; -#[derive(Default)] pub struct DivOp; - impl RIVInstruction for DivOp { const INST_KIND: InsnKind = InsnKind::DIV; } @@ -167,10 +159,7 @@ mod test { const INSN_KIND: InsnKind = InsnKind::REMU; } - fn verify< - E: ExtensionField, - Insn: Instruction + TestInstance + Default, - >( + fn verify + TestInstance>( name: &str, dividend: >::NumType, divisor: >::NumType, @@ -179,11 +168,10 @@ mod test { ) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = Insn::default(); let config = cb .namespace( || format!("{}_({})", Insn::name(), name), - |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), + |cb| Ok(Insn::construct_circuit(cb, &ProgramParams::default())), ) .unwrap() .unwrap(); @@ -234,10 +222,7 @@ mod test { } // shortcut to verify given pair produces correct output - fn verify_positive< - E: ExtensionField, - Insn: Instruction + TestInstance + Default, - >( + fn verify_positive + TestInstance>( name: &str, dividend: >::NumType, divisor: >::NumType, diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index 284a3ab22..f062ea949 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -44,19 +44,16 @@ pub struct DivRemConfig { lt_diff: WitIn, } -#[derive(Default)] pub struct ArithInstruction(PhantomData<(E, I)>); impl Instruction for ArithInstruction { type InstructionConfig = DivRemConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) } fn construct_circuit( - &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index b08a30c0a..1df279dd9 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -20,19 +20,16 @@ use p3::field::FieldAlgebra; use witness::set_val; /// DummyInstruction can handle any instruction and produce its side-effects. -#[derive(Default)] pub struct DummyInstruction(PhantomData<(E, I)>); impl Instruction for DummyInstruction { type InstructionConfig = DummyConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}_DUMMY", I::INST_KIND) } fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 8cc6abfd9..9cd5cb0f3 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -25,18 +25,15 @@ use witness::set_val; /// including multiple memory operations. /// /// Unsafe: The content is not constrained. -#[derive(Default)] pub struct LargeEcallDummy(PhantomData<(E, S)>); impl Instruction for LargeEcallDummy { type InstructionConfig = LargeEcallConfig; - type Record = StepRecord; fn name() -> String { S::NAME.to_owned() } fn construct_circuit( - &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index f07c9b9a5..c6f51d142 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -20,12 +20,11 @@ type BeqDummy = DummyInstruction; fn test_dummy_ecall() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = EcallDummy::default(); let config = cb .namespace( || "ecall_dummy", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = EcallDummy::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -52,12 +51,11 @@ fn test_dummy_keccak() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = KeccakDummy::default(); let config = cb .namespace( || "keccak_dummy", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = KeccakDummy::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -81,12 +79,11 @@ fn test_dummy_keccak() { fn test_dummy_r() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = AddDummy::default(); let config = cb .namespace( || "add_dummy", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = AddDummy::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -118,12 +115,11 @@ fn test_dummy_r() { fn test_dummy_b() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = BeqDummy::default(); let config = cb .namespace( || "beq_dummy", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = BeqDummy::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/ecall.rs b/ceno_zkvm/src/instructions/riscv/ecall.rs index ba3a9d00e..a25bbeeb6 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall.rs @@ -14,9 +14,7 @@ pub use halt::HaltInstruction; use super::{RIVInstruction, dummy::DummyInstruction}; -#[derive(Default)] pub struct EcallOp; - impl RIVInstruction for EcallOp { const INST_KIND: InsnKind = InsnKind::ECALL; } diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index 3ebddd8cc..bf38a67c4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -27,19 +27,16 @@ pub struct HaltConfig { lt_x10_cfg: AssertLtConfig, } -#[derive(Default)] pub struct HaltInstruction(PhantomData); impl Instruction for HaltInstruction { type InstructionConfig = HaltConfig; - type Record = StepRecord; fn name() -> String { "ECALL_HALT".into() } fn construct_circuit( - &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index 530ce90a3..dccdf34a2 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -50,19 +50,16 @@ pub struct EcallKeccakConfig { } /// KeccakInstruction can handle any instruction and produce its side-effects. -#[derive(Default)] pub struct KeccakInstruction(PhantomData); impl Instruction for KeccakInstruction { type InstructionConfig = EcallKeccakConfig; - type Record = StepRecord; fn name() -> String { "Ecall_Keccak".to_string() } fn construct_circuit( - &self, _circuit_builder: &mut CircuitBuilder, _param: &ProgramParams, ) -> Result { @@ -70,7 +67,6 @@ impl Instruction for KeccakInstruction { } fn build_gkr_iop_circuit( - &self, cb: &mut CircuitBuilder, _param: &ProgramParams, ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 3f392e130..adf52683f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -53,21 +53,18 @@ pub struct EcallWeierstrassAddAssignConfig } /// WeierstrassAddAssignInstruction can handle any instruction and produce its side-effects. -#[derive(Default)] pub struct WeierstrassAddAssignInstruction(PhantomData<(E, EC)>); impl Instruction for WeierstrassAddAssignInstruction { type InstructionConfig = EcallWeierstrassAddAssignConfig; - type Record = StepRecord; fn name() -> String { "Ecall_WeierstrassAddAssign_".to_string() + format!("{:?}", EC::CURVE_TYPE).as_str() } fn construct_circuit( - &self, _circuit_builder: &mut CircuitBuilder, _param: &ProgramParams, ) -> Result { @@ -75,7 +72,6 @@ impl Instruction } fn build_gkr_iop_circuit( - &self, cb: &mut CircuitBuilder, _param: &ProgramParams, ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index 33c8133ec..250141669 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -60,21 +60,18 @@ pub struct EcallWeierstrassDecompressConfig(PhantomData<(E, EC)>); impl Instruction for WeierstrassDecompressInstruction { type InstructionConfig = EcallWeierstrassDecompressConfig; - type Record = StepRecord; fn name() -> String { "Ecall_WeierstrassDecompress_".to_string() + format!("{:?}", EC::CURVE_TYPE).as_str() } fn construct_circuit( - &self, _circuit_builder: &mut CircuitBuilder, _param: &ProgramParams, ) -> Result { @@ -82,7 +79,6 @@ impl Instruction, _param: &ProgramParams, ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 96adb245a..c6922b8a5 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -55,21 +55,18 @@ pub struct EcallWeierstrassDoubleAssignConfig< } /// WeierstrassDoubleAssignInstruction can handle any instruction and produce its side-effects. -#[derive(Default)] pub struct WeierstrassDoubleAssignInstruction(PhantomData<(E, EC)>); impl Instruction for WeierstrassDoubleAssignInstruction { type InstructionConfig = EcallWeierstrassDoubleAssignConfig; - type Record = StepRecord; fn name() -> String { "Ecall_WeierstrassDoubleAssign_".to_string() + format!("{:?}", EC::CURVE_TYPE).as_str() } fn construct_circuit( - &self, _circuit_builder: &mut CircuitBuilder, _param: &ProgramParams, ) -> Result { @@ -77,7 +74,6 @@ impl Instruction, _param: &ProgramParams, ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 7c2074343..545adf275 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -17,7 +17,7 @@ use crate::{ utils::split_to_u8, witness::LkMultiplicity, }; -use ceno_emul::{InsnKind, PC_STEP_SIZE, StepRecord}; +use ceno_emul::{InsnKind, PC_STEP_SIZE}; use gkr_iop::tables::{LookupTable, ops::XorTable}; use multilinear_extensions::{Expression, ToExpr}; use p3::field::FieldAlgebra; @@ -27,7 +27,6 @@ pub struct JalConfig { pub rd_written: UInt8, } -#[derive(Default)] pub struct JalInstruction(PhantomData); /// JAL instruction circuit @@ -43,14 +42,12 @@ pub struct JalInstruction(PhantomData); /// of native WitIn values for address space arithmetic. impl Instruction for JalInstruction { type InstructionConfig = JalConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", InsnKind::JAL) } fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index 6bc9040eb..77f6ad1f8 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -29,7 +29,6 @@ pub struct JalrConfig { pub rd_written: UInt, } -#[derive(Default)] pub struct JalrInstruction(PhantomData); /// JALR instruction circuit diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index ba0c2c801..7f23ac9b6 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -20,7 +20,7 @@ use crate::{ utils::imm_sign_extend, witness::{LkMultiplicity, set_val}, }; -use ceno_emul::{InsnKind, PC_STEP_SIZE, StepRecord}; +use ceno_emul::{InsnKind, PC_STEP_SIZE}; use ff_ext::FieldInto; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; @@ -34,7 +34,6 @@ pub struct JalrConfig { pub rd_high: WitIn, } -#[derive(Default)] pub struct JalrInstruction(PhantomData); /// JALR instruction circuit @@ -43,14 +42,12 @@ pub struct JalrInstruction(PhantomData); /// the program table impl Instruction for JalrInstruction { type InstructionConfig = JalrConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", InsnKind::JALR) } fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index 68561d245..899e5a035 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -28,12 +28,11 @@ fn test_opcode_jal() { fn verify_test_opcode_jal(pc_offset: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = JalInstruction::default(); let config = cb .namespace( || "jal", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = JalInstruction::::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -89,12 +88,11 @@ fn test_opcode_jalr() { fn verify_test_opcode_jalr(rs1_read: Word, imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = JalrInstruction::default(); let config = cb .namespace( || "jalr", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = JalrInstruction::::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/logic.rs b/ceno_zkvm/src/instructions/riscv/logic.rs index 8749231ed..9ac2cd4c1 100644 --- a/ceno_zkvm/src/instructions/riscv/logic.rs +++ b/ceno_zkvm/src/instructions/riscv/logic.rs @@ -7,27 +7,21 @@ mod test; use ceno_emul::InsnKind; -#[derive(Default)] pub struct AndOp; - impl LogicOp for AndOp { const INST_KIND: InsnKind = InsnKind::AND; type OpsTable = AndTable; } pub type AndInstruction = LogicInstruction; -#[derive(Default)] pub struct OrOp; - impl LogicOp for OrOp { const INST_KIND: InsnKind = InsnKind::OR; type OpsTable = OrTable; } pub type OrInstruction = LogicInstruction; -#[derive(Default)] pub struct XorOp; - impl LogicOp for XorOp { const INST_KIND: InsnKind = InsnKind::XOR; type OpsTable = XorTable; diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index 77891758f..5a2d8e404 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -25,19 +25,16 @@ pub trait LogicOp { } /// The Instruction circuit for a given LogicOp. -#[derive(Default)] pub struct LogicInstruction(PhantomData<(E, I)>); impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) } fn construct_circuit( - &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/logic/test.rs b/ceno_zkvm/src/instructions/riscv/logic/test.rs index 47ad05b7d..f68135c72 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/test.rs @@ -18,12 +18,11 @@ const B: Word = 0xef552020; fn test_opcode_and() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = AndInstruction::default(); let config = cb .namespace( || "and", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = AndInstruction::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -62,12 +61,11 @@ fn test_opcode_and() { fn test_opcode_or() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = OrInstruction::default(); let config = cb .namespace( || "or", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = OrInstruction::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -106,12 +104,11 @@ fn test_opcode_or() { fn test_opcode_xor() { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = XorInstruction::default(); let config = cb .namespace( || "xor", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = XorInstruction::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm.rs b/ceno_zkvm/src/instructions/riscv/logic_imm.rs index 1e9dc25e1..a4b46edcc 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm.rs @@ -24,27 +24,21 @@ use gkr_iop::tables::ops::{AndTable, OrTable, XorTable}; use ceno_emul::InsnKind; use gkr_iop::tables::OpsTable; -#[derive(Default)] pub struct AndiOp; - impl LogicOp for AndiOp { const INST_KIND: InsnKind = InsnKind::ANDI; type OpsTable = AndTable; } pub type AndiInstruction = LogicInstruction; -#[derive(Default)] pub struct OriOp; - impl LogicOp for OriOp { const INST_KIND: InsnKind = InsnKind::ORI; type OpsTable = OrTable; } pub type OriInstruction = LogicInstruction; -#[derive(Default)] pub struct XoriOp; - impl LogicOp for XoriOp { const INST_KIND: InsnKind = InsnKind::XORI; type OpsTable = XorTable; diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 5394f48e3..b48af7f5f 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -27,19 +27,16 @@ use ceno_emul::{InsnKind, StepRecord}; use multilinear_extensions::ToExpr; /// The Instruction circuit for a given LogicOp. -#[derive(Default)] pub struct LogicInstruction(PhantomData<(E, I)>); impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) } fn construct_circuit( - &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs index 6a09f4de1..68032fd41 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs @@ -43,15 +43,9 @@ fn test_opcode_xori() { verify::("negative imm", TEST, NEG, TEST ^ NEG); } -fn verify( - name: &'static str, - rs1_read: u32, - imm: u32, - expected_rd_written: u32, -) { +fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_written: u32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = LogicInstruction::::default(); let (prefix, rd_written) = match I::INST_KIND { InsnKind::ANDI => ("ANDI", rs1_read & imm), @@ -64,7 +58,10 @@ fn verify( .namespace( || format!("{prefix}_({name})"), |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = LogicInstruction::::construct_circuit( + cb, + &ProgramParams::default(), + ); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 80d4c87a9..198bafbc5 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -18,7 +18,7 @@ use crate::{ utils::split_to_u8, witness::LkMultiplicity, }; -use ceno_emul::{InsnKind, StepRecord}; +use ceno_emul::InsnKind; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::FieldAlgebra; use witness::set_val; @@ -30,19 +30,16 @@ pub struct LuiConfig { pub rd_written: [WitIn; UINT_BYTE_LIMBS - 1], } -#[derive(Default)] pub struct LuiInstruction(PhantomData); impl Instruction for LuiInstruction { type InstructionConfig = LuiConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", InsnKind::LUI) } fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { @@ -144,12 +141,12 @@ mod tests { fn test_opcode_lui(rd: u32, imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = LuiInstruction::default(); let config = cb .namespace( || "lui", |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = + LuiInstruction::::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs index 612058667..bb29491f7 100644 --- a/ceno_zkvm/src/instructions/riscv/memory.rs +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -24,7 +24,6 @@ pub use crate::instructions::riscv::memory::store_v2::StoreInstruction; use ceno_emul::InsnKind; -#[derive(Default)] pub struct LwOp; impl RIVInstruction for LwOp { @@ -33,57 +32,43 @@ impl RIVInstruction for LwOp { pub type LwInstruction = LoadInstruction; -#[derive(Default)] pub struct LhOp; - impl RIVInstruction for LhOp { const INST_KIND: InsnKind = InsnKind::LH; } pub type LhInstruction = LoadInstruction; -#[derive(Default)] pub struct LhuOp; - impl RIVInstruction for LhuOp { const INST_KIND: InsnKind = InsnKind::LHU; } pub type LhuInstruction = LoadInstruction; -#[derive(Default)] pub struct LbOp; - impl RIVInstruction for LbOp { const INST_KIND: InsnKind = InsnKind::LB; } pub type LbInstruction = LoadInstruction; -#[derive(Default)] pub struct LbuOp; - impl RIVInstruction for LbuOp { const INST_KIND: InsnKind = InsnKind::LBU; } pub type LbuInstruction = LoadInstruction; -#[derive(Default)] pub struct SWOp; - impl RIVInstruction for SWOp { const INST_KIND: InsnKind = InsnKind::SW; } pub type SwInstruction = StoreInstruction; -#[derive(Default)] pub struct SHOp; - impl RIVInstruction for SHOp { const INST_KIND: InsnKind = InsnKind::SH; } pub type ShInstruction = StoreInstruction; -#[derive(Default)] pub struct SBOp; - impl RIVInstruction for SBOp { const INST_KIND: InsnKind = InsnKind::SB; } diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 2104a87b6..812e4020a 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -38,19 +38,16 @@ pub struct LoadConfig { signed_extend_config: Option>, } -#[derive(Default)] -pub struct LoadInstruction(PhantomData<(E, I)>); +pub struct LoadInstruction(PhantomData<(E, I)>); impl Instruction for LoadInstruction { type InstructionConfig = LoadConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) } fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index ac425ee9f..cb512975b 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -36,21 +36,18 @@ pub struct StoreConfig { next_memory_value: Option>, } -#[derive(Default)] pub struct StoreInstruction(PhantomData<(E, I)>); impl Instruction for StoreInstruction { type InstructionConfig = StoreConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) } fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index ed7f20531..b2a04326b 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -76,21 +76,14 @@ fn load(mem_value: Word, insn: InsnKind, shift: u32) -> Word { } } -fn impl_opcode_store< - E: ExtensionField + Hash, - I: RIVInstruction, - Inst: Instruction + Default, ->( - imm: i32, -) { +fn impl_opcode_store>(imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = Inst::default(); let config = cb .namespace( || Inst::name(), |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = Inst::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) @@ -146,21 +139,14 @@ fn impl_opcode_store< MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } -fn impl_opcode_load< - E: ExtensionField + Hash, - I: RIVInstruction, - Inst: Instruction + Default, ->( - imm: i32, -) { +fn impl_opcode_load>(imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = Inst::default(); let config = cb .namespace( || Inst::name(), |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = Inst::construct_circuit(cb, &ProgramParams::default()); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/mulh.rs b/ceno_zkvm/src/instructions/riscv/mulh.rs index 1bff6c574..2b7b0217e 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh.rs @@ -11,34 +11,26 @@ use mulh_circuit::MulhInstructionBase; #[cfg(feature = "u16limb_circuit")] use mulh_circuit_v2::MulhInstructionBase; -#[derive(Default)] pub struct MulOp; - impl RIVInstruction for MulOp { const INST_KIND: InsnKind = InsnKind::MUL; } pub type MulInstruction = MulhInstructionBase; -#[derive(Default)] pub struct MulhOp; - impl RIVInstruction for MulhOp { const INST_KIND: InsnKind = InsnKind::MULH; } pub type MulhInstruction = MulhInstructionBase; -#[derive(Default)] pub struct MulhuOp; - impl RIVInstruction for MulhuOp { const INST_KIND: InsnKind = InsnKind::MULHU; } pub type MulhuInstruction = MulhInstructionBase; -#[derive(Default)] pub struct MulhsuOp; - impl RIVInstruction for MulhsuOp { const INST_KIND: InsnKind = InsnKind::MULHSU; } @@ -120,11 +112,15 @@ mod test { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = MulhInstructionBase::::default(); let config = cb .namespace( || format!("{:?}_({name})", I::INST_KIND), - |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), + |cb| { + Ok(MulhInstructionBase::::construct_circuit( + cb, + &ProgramParams::default(), + )) + }, ) .unwrap() .unwrap(); @@ -204,11 +200,15 @@ mod test { fn verify_mulh(rs1: i32, rs2: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = MulhInstruction::::default(); let config = cb .namespace( || "mulh", - |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), + |cb| { + Ok(MulhInstruction::construct_circuit( + cb, + &ProgramParams::default(), + )) + }, ) .unwrap() .unwrap(); @@ -284,11 +284,15 @@ mod test { fn verify_mulhsu(rs1: i32, rs2: u32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = MulhsuInstruction::::default(); let config = cb .namespace( || "mulhsu", - |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), + |cb| { + Ok(MulhsuInstruction::construct_circuit( + cb, + &ProgramParams::default(), + )) + }, ) .unwrap() .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index bd3cdb19c..a94f63e74 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -23,8 +23,7 @@ use crate::e2e::ShardContext; use itertools::Itertools; use std::{array, marker::PhantomData}; -#[derive(Default)] -pub struct MulhInstructionBase(PhantomData<(E, I)>); +pub struct MulhInstructionBase(PhantomData<(E, I)>); pub struct MulhConfig { rs1_read: UInt, @@ -39,14 +38,12 @@ pub struct MulhConfig { impl Instruction for MulhInstructionBase { type InstructionConfig = MulhConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) } fn construct_circuit( - &self, circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index f3122acec..d09b98c89 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -11,25 +11,19 @@ use crate::instructions::riscv::shift::shift_circuit::ShiftLogicalInstruction; #[cfg(feature = "u16limb_circuit")] use crate::instructions::riscv::shift::shift_circuit_v2::ShiftLogicalInstruction; -#[derive(Default)] pub struct SllOp; - impl RIVInstruction for SllOp { const INST_KIND: InsnKind = InsnKind::SLL; } pub type SllInstruction = ShiftLogicalInstruction; -#[derive(Default)] pub struct SrlOp; - impl RIVInstruction for SrlOp { const INST_KIND: InsnKind = InsnKind::SRL; } pub type SrlInstruction = ShiftLogicalInstruction; -#[derive(Default)] pub struct SraOp; - impl RIVInstruction for SraOp { const INST_KIND: InsnKind = InsnKind::SRA; } @@ -129,7 +123,6 @@ mod tests { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = ShiftLogicalInstruction::::default(); let shift = rs2_read & 0b11111; let (prefix, insn_code, rd_written) = match I::INST_KIND { InsnKind::SLL => ( @@ -153,7 +146,12 @@ mod tests { let config = cb .namespace( || format!("{prefix}_({name})"), - |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), + |cb| { + Ok(ShiftLogicalInstruction::::construct_circuit( + cb, + &ProgramParams::default(), + )) + }, ) .unwrap() .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 3650f48df..fac05279e 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -13,7 +13,7 @@ use crate::{ structs::ProgramParams, utils::{split_to_limb, split_to_u8}, }; -use ceno_emul::{InsnKind, StepRecord}; +use ceno_emul::InsnKind; use ff_ext::{ExtensionField, FieldInto}; use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; @@ -272,19 +272,16 @@ pub struct ShiftRTypeConfig { r_insn: RInstructionConfig, } -#[derive(Default)] pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); impl Instruction for ShiftLogicalInstruction { type InstructionConfig = ShiftRTypeConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) } fn construct_circuit( - &self, circuit_builder: &mut crate::circuit_builder::CircuitBuilder, _params: &ProgramParams, ) -> Result { @@ -371,19 +368,16 @@ pub struct ShiftImmConfig { imm: WitIn, } -#[derive(Default)] -pub struct ShiftImmInstruction(PhantomData<(E, I)>); +pub struct ShiftImmInstruction(PhantomData<(E, I)>); impl Instruction for ShiftImmInstruction { type InstructionConfig = ShiftImmConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) } fn construct_circuit( - &self, circuit_builder: &mut crate::circuit_builder::CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 40dad7eb6..1757a0fc7 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -9,25 +9,19 @@ use crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmInstruction; #[cfg(not(feature = "u16limb_circuit"))] use crate::instructions::riscv::shift_imm::shift_imm_circuit::ShiftImmInstruction; -#[derive(Default)] pub struct SlliOp; - impl RIVInstruction for SlliOp { const INST_KIND: InsnKind = InsnKind::SLLI; } pub type SlliInstruction = ShiftImmInstruction; -#[derive(Default)] pub struct SraiOp; - impl RIVInstruction for SraiOp { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SRAI; } pub type SraiInstruction = ShiftImmInstruction; -#[derive(Default)] pub struct SrliOp; - impl RIVInstruction for SrliOp { const INST_KIND: ceno_emul::InsnKind = InsnKind::SRLI; } @@ -149,8 +143,10 @@ mod test { .namespace( || format!("{prefix}_({name})"), |cb| { - let inst = ShiftImmInstruction::::default(); - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = ShiftImmInstruction::::construct_circuit( + cb, + &ProgramParams::default(), + ); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index 5b3269ea2..3ba12bb39 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -7,9 +7,7 @@ use ceno_emul::InsnKind; use super::RIVInstruction; -#[derive(Default)] pub struct SltOp; - impl RIVInstruction for SltOp { const INST_KIND: InsnKind = InsnKind::SLT; } @@ -18,9 +16,7 @@ pub type SltInstruction = slt_circuit_v2::SetLessThanInstruction; #[cfg(not(feature = "u16limb_circuit"))] pub type SltInstruction = slt_circuit::SetLessThanInstruction; -#[derive(Default)] pub struct SltuOp; - impl RIVInstruction for SltuOp { const INST_KIND: InsnKind = InsnKind::SLTU; } @@ -60,12 +56,14 @@ mod test { ) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let inst = SetLessThanInstruction::::default(); let config = cb .namespace( || format!("{}/{name}", I::INST_KIND), |cb| { - let config = inst.construct_circuit(cb, &ProgramParams::default()); + let config = SetLessThanInstruction::<_, I>::construct_circuit( + cb, + &ProgramParams::default(), + ); Ok(config) }, ) diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index d24121e18..cd0b97ce4 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -15,7 +15,6 @@ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use std::marker::PhantomData; -#[derive(Default)] pub struct SetLessThanInstruction(PhantomData<(E, I)>); /// This config handles R-Instructions that represent registers values as 2 * u16. @@ -31,14 +30,12 @@ pub struct SetLessThanConfig { } impl Instruction for SetLessThanInstruction { type InstructionConfig = SetLessThanConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) } fn construct_circuit( - &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index dc76c3225..ff3a78043 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -12,17 +12,13 @@ use crate::instructions::riscv::slti::slti_circuit::SetLessThanImmInstruction; use super::RIVInstruction; -#[derive(Default)] pub struct SltiOp; - impl RIVInstruction for SltiOp { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SLTI; } pub type SltiInstruction = SetLessThanImmInstruction; -#[derive(Default)] pub struct SltiuOp; - impl RIVInstruction for SltiuOp { const INST_KIND: ceno_emul::InsnKind = ceno_emul::InsnKind::SLTIU; } @@ -174,12 +170,16 @@ mod test { let mut cb = CircuitBuilder::new(&mut cs); let insn_code = encode_rv32(I::INST_KIND, 2, 0, 4, imm); - let inst = SetLessThanImmInstruction::::default(); let config = cb .namespace( || format!("{:?}_({name})", I::INST_KIND), - |cb| Ok(inst.construct_circuit(cb, &ProgramParams::default())), + |cb| { + Ok(SetLessThanImmInstruction::::construct_circuit( + cb, + &ProgramParams::default(), + )) + }, ) .unwrap() .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 47039b588..914424247 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -37,19 +37,16 @@ pub struct SetLessThanImmConfig { uint_lt_config: UIntLimbsLTConfig, } -#[derive(Default)] -pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); +pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); impl Instruction for SetLessThanImmInstruction { type InstructionConfig = SetLessThanImmConfig; - type Record = StepRecord; fn name() -> String { format!("{:?}", I::INST_KIND) } fn construct_circuit( - &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/instructions/riscv/test.rs b/ceno_zkvm/src/instructions/riscv/test.rs index 41a3877fe..47c0ba178 100644 --- a/ceno_zkvm/src/instructions/riscv/test.rs +++ b/ceno_zkvm/src/instructions/riscv/test.rs @@ -17,15 +17,13 @@ fn test_multiple_opcode() { let params = ProgramParams::default(); let mut cs = ConstraintSystem::new(|| "riscv"); - let add_inst = AddInstruction::::default(); let _add_config = cs.namespace( || "add", - |cs| add_inst.construct_circuit(&mut CircuitBuilder::::new(cs), ¶ms), + |cs| AddInstruction::construct_circuit(&mut CircuitBuilder::::new(cs), ¶ms), ); - let sub_inst = SubInstruction::::default(); let _sub_config = cs.namespace( || "sub", - |cs| sub_inst.construct_circuit(&mut CircuitBuilder::::new(cs), ¶ms), + |cs| SubInstruction::construct_circuit(&mut CircuitBuilder::::new(cs), ¶ms), ); let param = Pcs::setup(1 << 10, SecurityLevel::default()).unwrap(); let (_, _) = Pcs::trim(param, 1 << 10).unwrap(); diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index e772fc563..89fc41449 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -55,7 +55,6 @@ pub type TowerRelationOutput = ( // accumulate N=2^n EC points into one EC point using affine coordinates // in one layer which borrows ideas from the [Quark paper](https://eprint.iacr.org/2020/1275.pdf) -#[derive(Default)] pub struct CpuEccProver; impl CpuEccProver { diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index a8dd0d015..73355017c 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -56,21 +56,18 @@ struct TestConfig { pub(crate) reg_id: WitIn, } -#[derive(Default)] struct TestCircuit { phantom: PhantomData, } impl Instruction for TestCircuit { type InstructionConfig = TestConfig; - type Record = StepRecord; fn name() -> String { "TEST".into() } fn construct_circuit( - &self, cb: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result { diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index e162da405..df669f416 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -969,7 +969,6 @@ impl TowerVerify { } } -#[derive(Default)] pub struct EccVerifier; impl EccVerifier { diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index f82918cfb..85bb61634 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -10,7 +10,7 @@ use crate::{ state::StateCircuit, tables::{RMMCollections, TableCircuit}, }; -use ceno_emul::{CENO_PLATFORM, Platform}; +use ceno_emul::{CENO_PLATFORM, Platform, StepRecord}; use ff_ext::{ExtensionField, PoseidonField}; use gkr_iop::{gkr::GKRCircuit, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use itertools::Itertools; @@ -216,15 +216,11 @@ impl ZKVMConstraintSystem { } } - pub fn register_opcode_circuit + Default>( - &mut self, - ) -> OC::InstructionConfig { + pub fn register_opcode_circuit>(&mut self) -> OC::InstructionConfig { let mut cs = ConstraintSystem::new(|| format!("riscv_opcode/{}", OC::name())); let mut circuit_builder = CircuitBuilder::::new(&mut cs); - let op_circuit = OC::default(); - let (config, gkr_iop_circuit) = op_circuit - .build_gkr_iop_circuit(&mut circuit_builder, &self.params) - .unwrap(); + let (config, gkr_iop_circuit) = + OC::build_gkr_iop_circuit(&mut circuit_builder, &self.params).unwrap(); let cs = ComposedConstrainSystem { zkvm_v1_css: cs, gkr_circuit: Some(gkr_iop_circuit), @@ -348,7 +344,7 @@ impl ZKVMWitnesses { cs: &ZKVMConstraintSystem, shard_ctx: &mut ShardContext, config: &OC::InstructionConfig, - records: Vec, + records: Vec, ) -> Result<(), ZKVMError> { assert!(self.combined_lk_mlt.is_none()); diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index 95063f1c4..a5e20f704 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -78,12 +78,9 @@ pub struct ProtocolVerifier, PCS>( PhantomData<(E, Trans, PCS)>, ); -#[derive( - Clone, Debug, Copy, Default, EnumIter, PartialEq, Eq, serde::Serialize, serde::Deserialize, -)] +#[derive(Clone, Debug, Copy, EnumIter, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[repr(usize)] pub enum RAMType { - #[default] GlobalState = 0, Register, Memory, From 5af8392c934dae9028bfd1270e2b2399fa9aae0c Mon Sep 17 00:00:00 2001 From: Ming Date: Fri, 31 Oct 2025 13:29:24 +0800 Subject: [PATCH 89/91] refactor ec points witness assignments (#1100) - simplify index computation of affine_add(left, right) - allocate buffer once and parallel initialization --- ceno_zkvm/src/instructions/global.rs | 83 ++++++++++++++-------------- 1 file changed, 41 insertions(+), 42 deletions(-) diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 6de0c2d8b..45f8ee29d 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -31,7 +31,10 @@ use p3::{ symmetric::Permutation, }; use rayon::{ - iter::{IndexedParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator}, + iter::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelExtend, + ParallelIterator, + }, prelude::ParallelSliceMut, slice::ParallelSlice, }; @@ -554,50 +557,45 @@ impl TableCircuit for GlobalChip { }) .collect::>()?; - // assign internal nodes in the binary tree for ec point summation - let mut cur_layer_points = steps - .iter() - .map(|step| step.ec_point.point.clone()) - .enumerate() - .collect_vec(); + // allocate num_rows_padded size, fill points on first half + let mut cur_layer_points_buffer: Vec<_> = (0..num_rows_padded) + .into_par_iter() + .map(|i| { + steps + .get(i) + .map(|step| step.ec_point.point.clone()) + .unwrap_or_else(SepticPoint::default) + }) + .collect(); + // raw_witin offset start from n. + // left node is at b, right node is at b + 1 + // op(left node, right node) = offset + b / 2 + let mut offset = num_rows_padded / 2; + let mut current_layer_len = cur_layer_points_buffer.len() / 2; // slope[1,b] = (input[b,0].y - input[b,1].y) / (input[b,0].x - input[b,1].x) loop { - if cur_layer_points.len() <= 1 { + if current_layer_len <= 1 { break; } - // 2b -> b + 2^log_n - let next_layer_offset = cur_layer_points.first().map(|(i, _)| *i / 2 + n).unwrap(); - cur_layer_points = cur_layer_points + let (current_layer, next_layer) = + cur_layer_points_buffer.split_at_mut(current_layer_len); + current_layer .par_chunks(2) - .zip(raw_witin.values[next_layer_offset * num_witin..].par_chunks_mut(num_witin)) - .with_min_len(64) - .map(|(pair, instance)| { - // input[1,b] = affine_add(input[b,0], input[b,1]) - // the left node is at index 2b, right node is at index 2b+1 - // the parent node is at index b + 2^n - let (o, slope, q) = match pair.len() { - 2 => { - // l = 2b, r = 2b+1 - let (l, p1) = &pair[0]; - let (r, p2) = &pair[1]; - assert_eq!(*r - *l, 1); - - // parent node idx = b + 2^log2_n - let o = n + l / 2; - let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap(); - let q = p1.clone() + p2.clone(); - - (o, slope, q) - } - 1 => { - let (l, p) = &pair[0]; - let o = n + l / 2; - (o, SepticExtension::zero(), p.clone()) - } - _ => unreachable!(), + .zip_eq(next_layer[..current_layer_len / 2].par_iter_mut()) + .zip(raw_witin.values[offset * num_witin..].par_chunks_mut(num_witin)) + .for_each(|((pair, parent), instance)| { + let p1 = &pair[0]; + let p2 = &pair[1]; + let (slope, q) = if p2.is_infinity { + // input[1,b] = bypass_left(input[b,0], input[b,1]) + (SepticExtension::zero(), p1.clone()) + } else { + // input[1,b] = affine_add(input[b,0], input[b,1]) + let slope = (&p1.y - &p2.y) * (&p1.x - &p2.x).inverse().unwrap(); + let q = p1.clone() + p2.clone(); + (slope, q) }; - config .x .iter() @@ -611,10 +609,11 @@ impl TableCircuit for GlobalChip { .for_each(|(witin, fe)| { set_val!(instance, *witin, *fe); }); - - (o, q) - }) - .collect::>(); + *parent = q.clone(); + }); + cur_layer_points_buffer = cur_layer_points_buffer.split_off(current_layer_len); + current_layer_len /= 2; + offset += current_layer_len; } let raw_witin = witness::RowMajorMatrix::new_by_inner_matrix( From 9917b619266153ffe8a80ff36ea0ed79f3df2350 Mon Sep 17 00:00:00 2001 From: xkx Date: Fri, 31 Oct 2025 17:07:06 +0800 Subject: [PATCH 90/91] #1061 cleanup (#1106) #1061 Cleanup - [x] clippy fixes - [x] remove old unused routines for generating witness for the ecc summation tower tree. Note this approach is abandoned. We prefer to pack $log_2(N)$ layers in one layer using the relation `p[1, b] = ec_add(p[0,b], p[1,b])` (first proposed in Quark paper. --- Cargo.lock | 1 - Cargo.toml | 1 - ceno_zkvm/Cargo.toml | 3 +- ceno_zkvm/src/gadgets/poseidon2.rs | 28 ++- ceno_zkvm/src/instructions/global.rs | 2 +- .../src/instructions/riscv/rv32im/mmu.rs | 2 +- ceno_zkvm/src/scheme.rs | 1 + ceno_zkvm/src/scheme/cpu/mod.rs | 20 +- ceno_zkvm/src/scheme/utils.rs | 189 ++---------------- ceno_zkvm/src/scheme/verifier.rs | 9 +- 10 files changed, 44 insertions(+), 212 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2c8d916a7..499d14457 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1040,7 +1040,6 @@ dependencies = [ "gkr_iop", "glob", "itertools 0.13.0", - "lazy_static", "mpcs", "multilinear_extensions", "ndarray", diff --git a/Cargo.toml b/Cargo.toml index c5d5b3784..a29a1807a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,7 +80,6 @@ tracing = { version = "0.1", features = [ tracing-forest = { version = "0.1.6" } tracing-subscriber = { version = "0.3", features = ["env-filter"] } uint = "0.8" -lazy_static = "1.5.0" ceno_gpu = { path = "utils/cuda_hal", package = "cuda_hal" } diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 80d69111b..07d1394ac 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -30,7 +30,6 @@ sumcheck.workspace = true transcript.workspace = true whir.workspace = true witness.workspace = true -lazy_static.workspace = true itertools.workspace = true ndarray.workspace = true @@ -49,12 +48,12 @@ derive = { path = "../derive" } generic-array.workspace = true generic_static = "0.2" num.workspace = true +num-bigint = "0.4.6" parse-size = "1.1" rand.workspace = true sp1-curves.workspace = true tempfile = "3.14" tiny-keccak.workspace = true -num-bigint = "0.4.6" [target.'cfg(unix)'.dependencies] tikv-jemalloc-ctl = { version = "0.6", features = ["stats"], optional = true } diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index c4e7b621c..0eca74c50 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -41,25 +41,23 @@ impl) -> Self { let mut iter = value.into_iter(); let mut beginning_full_round_constants = [[F::ZERO; WIDTH]; HALF_FULL_ROUNDS]; - for round in 0..HALF_FULL_ROUNDS { - for i in 0..WIDTH { - beginning_full_round_constants[round][i] = - iter.next().expect("insufficient round constants"); - } - } + + beginning_full_round_constants.iter_mut().for_each(|arr| { + arr.iter_mut() + .for_each(|c| *c = iter.next().expect("insufficient round constants")) + }); let mut partial_round_constants = [F::ZERO; PARTIAL_ROUNDS]; - for round in 0..PARTIAL_ROUNDS { - partial_round_constants[round] = iter.next().expect("insufficient round constants"); - } + + partial_round_constants + .iter_mut() + .for_each(|arr| *arr = iter.next().expect("insufficient round constants")); let mut ending_full_round_constants = [[F::ZERO; WIDTH]; HALF_FULL_ROUNDS]; - for round in 0..HALF_FULL_ROUNDS { - for i in 0..WIDTH { - ending_full_round_constants[round][i] = - iter.next().expect("insufficient round constants"); - } - } + ending_full_round_constants.iter_mut().for_each(|arr| { + arr.iter_mut() + .for_each(|c| *c = iter.next().expect("insufficient round constants")) + }); assert!(iter.next().is_none(), "round constants are too many"); diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index 45f8ee29d..ce24fff92 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -298,7 +298,7 @@ pub struct GlobalChipInput { } impl GlobalChip { - fn assign_instance<'a>( + fn assign_instance( config: &GlobalConfig, instance: &mut [E::BaseField], _lk_multiplicity: &mut LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index 1c4776cd0..900672a3d 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -163,7 +163,7 @@ impl MmuConfig<'_, E> { &self.local_final_circuit, &(shard_ctx, all_records.as_slice()), )?; - witness.assign_global_chip_circuit(cs, &shard_ctx, &self.ram_bus_circuit)?; + witness.assign_global_chip_circuit(cs, shard_ctx, &self.ram_bus_circuit)?; Ok(()) } diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 3fd2517c8..aa3928153 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -82,6 +82,7 @@ pub struct PublicValues { } impl PublicValues { + #[allow(clippy::too_many_arguments)] pub fn new( exit_code: u32, init_pc: u32, diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 89fc41449..cebe79899 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -58,12 +58,7 @@ pub type TowerRelationOutput = ( pub struct CpuEccProver; impl CpuEccProver { - pub fn new() -> Self { - Self {} - } - pub fn create_ecc_proof<'a, E: ExtensionField>( - &self, num_instances: usize, xs: Vec>>, ys: Vec>>, @@ -299,7 +294,13 @@ impl> EccQuarkProver>>, transcript: &mut impl Transcript, ) -> Result, ZKVMError> { - Ok(CpuEccProver::new().create_ecc_proof(num_instances, xs, ys, invs, transcript)) + Ok(CpuEccProver::create_ecc_proof( + num_instances, + xs, + ys, + invs, + transcript, + )) } } @@ -1224,8 +1225,7 @@ mod tests { let (ys, s) = rest.split_at(SEPTIC_EXTENSION_DEGREE); let mut transcript = BasicTranscript::new(b"test"); - let prover = CpuEccProver::new(); - let quark_proof = prover.create_ecc_proof( + let quark_proof = CpuEccProver::create_ecc_proof( n_points, xs.iter().cloned().map(Arc::new).collect_vec(), ys.iter().cloned().map(Arc::new).collect_vec(), @@ -1235,10 +1235,8 @@ mod tests { assert_eq!(quark_proof.sum, final_sum); let mut transcript = BasicTranscript::new(b"test"); - let verifier = EccVerifier::new(); assert!( - verifier - .verify_ecc_proof(&quark_proof, &mut transcript) + EccVerifier::verify_ecc_proof(&quark_proof, &mut transcript) .inspect_err(|err| println!("err {:?}", err)) .is_ok() ); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 74314aca0..cfa88175f 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -1,8 +1,7 @@ use crate::{ scheme::{ - constants::{MIN_PAR_SIZE, SEPTIC_JACOBIAN_NUM_MLES}, + constants::MIN_PAR_SIZE, hal::{MainSumcheckProver, ProofInput, ProverDevice}, - septic_curve::{SepticExtension, SepticJacobianPoint}, }, structs::ComposedConstrainSystem, }; @@ -27,7 +26,7 @@ use rayon::{ }, prelude::ParallelSliceMut, }; -use std::{array::from_fn, iter, sync::Arc}; +use std::{iter, sync::Arc}; use witness::next_pow2_instance_padding; // first computes the masked mle'[j] = mle[j] if j < num_instance, else default @@ -294,10 +293,7 @@ pub fn infer_tower_product_witness( .map(|index| { // avoid the overhead of vector initialization let mut evaluations: Vec = Vec::with_capacity(output_len); - unsafe { - // will be filled immediately - evaluations.set_len(output_len); - } + let remaining = evaluations.spare_capacity_mut(); input_layer.chunks_exact(2).enumerate().for_each(|(i, f)| { match (f[0].evaluations(), f[1].evaluations()) { @@ -307,24 +303,28 @@ pub fn infer_tower_product_witness( if i == 0 { (start..(start + output_len)) .into_par_iter() - .zip(evaluations.par_iter_mut()) + .zip(remaining.par_iter_mut()) .with_min_len(MIN_PAR_SIZE) .for_each(|(index, evaluations)| { - *evaluations = f1[index] * f2[index] + evaluations.write(f1[index] * f2[index]); }); } else { (start..(start + output_len)) .into_par_iter() - .zip(evaluations.par_iter_mut()) + .zip(remaining.par_iter_mut()) .with_min_len(MIN_PAR_SIZE) .for_each(|(index, evaluations)| { - *evaluations *= f1[index] * f2[index] + evaluations.write(f1[index] * f2[index]); }); } } _ => unreachable!("must be extension field"), } }); + + unsafe { + evaluations.set_len(output_len); + } evaluations.into_mle() }) .collect_vec(); @@ -336,94 +336,6 @@ pub fn infer_tower_product_witness( wit_layers } -/// Infer from input layer (layer n) to the output layer (layer 0) -/// Note that each layer has 3 * 7 * 2 multilinear polynomials since we use jacobian coordinates. -/// -/// The relation between layer i and layer i+1 is as follows: -/// (x1', y1', z1')[b] = jacobian_add( (x1, y1, z1)[0,b], (x2, y2, z2)[1,b] ) -/// (x2', y2', z2')[b] = jacobian_add( (x3, y3, z3)[0,b], (x4, y4, z4)[1,b] ) -/// -/// TODO handle jacobian_add & jacobian_double at the same time -pub fn infer_septic_sum_witness( - p_mles: Vec>, -) -> Vec>> { - assert_eq!(p_mles.len(), SEPTIC_JACOBIAN_NUM_MLES * 2); - assert!(p_mles.iter().map(|p| p.num_vars()).all_equal()); - - // +1 as the input layer has 2*N points where N = 2^num_vars - // and the output layer has 2 points - let num_layers = p_mles[0].num_vars() + 1; - println!("{num_layers} layers in total"); - - let mut layers = Vec::with_capacity(num_layers); - layers.push(p_mles); - - for layer in (0..num_layers - 1).rev() { - let input_layer = layers.last().unwrap(); - let p = input_layer[0..SEPTIC_JACOBIAN_NUM_MLES] - .iter() - .map(|mle| mle.get_base_field_vec()) - .collect_vec(); - let q = input_layer[SEPTIC_JACOBIAN_NUM_MLES..] - .iter() - .map(|mle| mle.get_base_field_vec()) - .collect_vec(); - - let output_len = p[0].len() / 2; - let mut outputs: Vec = - Vec::with_capacity(SEPTIC_JACOBIAN_NUM_MLES * 2 * output_len); - unsafe { - // will be filled immediately - outputs.set_len(SEPTIC_JACOBIAN_NUM_MLES * 2 * output_len); - } - - (0..2).for_each(|chunk| { - (0..output_len) - .into_par_iter() - .with_min_len(MIN_PAR_SIZE) - .zip_eq(outputs.par_chunks_mut(SEPTIC_JACOBIAN_NUM_MLES * 2)) - .for_each(|(idx, output)| { - let row = chunk * output_len + idx; - let offset = chunk * SEPTIC_JACOBIAN_NUM_MLES; - - let p1 = SepticJacobianPoint { - x: SepticExtension(from_fn(|i| p[i][row])), - y: SepticExtension(from_fn(|i| p[i + 7][row])), - z: SepticExtension(from_fn(|i| p[i + 14][row])), - }; - let p2 = SepticJacobianPoint { - x: SepticExtension(from_fn(|i| q[i][row])), - y: SepticExtension(from_fn(|i| q[i + 7][row])), - z: SepticExtension(from_fn(|i| q[i + 14][row])), - }; - assert!(p1.is_on_curve(), "{layer}, {row}"); - assert!(p2.is_on_curve(), "{layer}, {row}"); - - let p3 = &p1 + &p2; - - output[offset..offset + 7].clone_from_slice(&p3.x); - output[offset + 7..offset + 14].clone_from_slice(&p3.y); - output[offset + 14..offset + 21].clone_from_slice(&p3.z); - }); - }); - - // transpose - let output_mles = (0..SEPTIC_JACOBIAN_NUM_MLES * 2) - .map(|i| { - (0..output_len) - .into_par_iter() - .map(|j| outputs[j * SEPTIC_JACOBIAN_NUM_MLES * 2 + i]) - .collect::>() - .into_mle() - }) - .collect_vec(); - layers.push(output_mles); - } - - layers.reverse(); - layers -} - #[tracing::instrument( skip_all, name = "build_main_witness", @@ -641,23 +553,18 @@ pub fn gkr_witness< #[cfg(test)] mod tests { - use ff_ext::{BabyBearExt4, FieldInto, GoldilocksExt2}; + use ff_ext::{FieldInto, GoldilocksExt2}; use itertools::Itertools; use multilinear_extensions::{ commutative_op_mle_pair, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, smart_slice::SmartSlice, - util::{ceil_log2, transpose}, + util::ceil_log2, }; - use p3::{babybear::BabyBear, field::FieldAlgebra}; + use p3::field::FieldAlgebra; - use crate::scheme::{ - constants::SEPTIC_JACOBIAN_NUM_MLES, - septic_curve::{SepticExtension, SepticJacobianPoint, SepticPoint}, - utils::{ - infer_septic_sum_witness, infer_tower_logup_witness, infer_tower_product_witness, - interleaving_mles_to_mles, - }, + use crate::scheme::utils::{ + infer_tower_logup_witness, infer_tower_product_witness, interleaving_mles_to_mles, }; #[test] @@ -964,68 +871,4 @@ mod tests { ])) ); } - - #[test] - fn test_infer_septic_addition_witness() { - type F = BabyBear; - type E = BabyBearExt4; - - let n_points = 1 << 6; - let mut rng = rand::thread_rng(); - // sample n points - let points: Vec> = (0..n_points) - .map(|_| SepticJacobianPoint::::random(&mut rng)) - .collect_vec(); - - // transform points to row major matrix - let trace = points[0..n_points / 2] - .iter() - .zip(points[n_points / 2..n_points].iter()) - .map(|(p, q)| { - [p, q] - .iter() - .flat_map(|p| p.x.0.iter().chain(p.y.0.iter()).chain(p.z.0.iter())) - .copied() - .collect_vec() - }) - .collect_vec(); - - // transpose row major matrix to column major matrix - let p_mles: Vec> = transpose(trace) - .into_iter() - .map(|v| v.into_mle()) - .collect_vec(); - - let layers = infer_septic_sum_witness(p_mles); - let output_layer = &layers[0]; - assert!(output_layer.iter().all(|mle| mle.num_vars() == 0)); - assert!(output_layer.len() == SEPTIC_JACOBIAN_NUM_MLES * 2); - - // recover points from output layer - let output_points: Vec> = output_layer - .chunks_exact(SEPTIC_JACOBIAN_NUM_MLES) - .map(|mles| { - mles.iter() - .map(|mle| mle.get_base_field_vec()[0]) - .collect_vec() - }) - .map(|chunk| SepticJacobianPoint { - x: SepticExtension(chunk[0..7].try_into().unwrap()), - y: SepticExtension(chunk[7..14].try_into().unwrap()), - z: SepticExtension(chunk[14..21].try_into().unwrap()), - }) - .collect_vec(); - assert!(output_points.iter().all(|p| p.is_on_curve())); - assert_eq!(output_points.len(), 2); - - let point_acc: SepticPoint = output_points - .into_iter() - .sum::>() - .into_affine(); - let expected_acc: SepticPoint = points - .into_iter() - .sum::>() - .into_affine(); - assert_eq!(point_acc, expected_acc); - } } diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index df669f416..4ed5a89e9 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -72,7 +72,7 @@ impl> ZKVMVerifier &self, vm_proof: ZKVMProof, transcript: impl ForkableTranscript, - expect_halt: bool, + _expect_halt: bool, ) -> Result { // require ecall/halt proof to exist, depending whether we expect a halt. // let has_halt = vm_proof.has_halt(&self.vk); @@ -405,7 +405,7 @@ impl> ZKVMVerifier // ecc_proof.sum // ); // assert ec sum in public input matches that in ecc proof - EccVerifier::new().verify_ecc_proof(ecc_proof, transcript)?; + EccVerifier::verify_ecc_proof(ecc_proof, transcript)?; tracing::debug!("ecc proof verified."); } @@ -972,12 +972,7 @@ impl TowerVerify { pub struct EccVerifier; impl EccVerifier { - pub fn new() -> Self { - Self {} - } - pub fn verify_ecc_proof( - &self, proof: &EccQuarkProof, transcript: &mut impl Transcript, ) -> Result<(), ZKVMError> { From 60aa5f541de49fafb2cf9eed02c8694fbf7258f8 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Sat, 1 Nov 2025 00:05:43 +0800 Subject: [PATCH 91/91] enforce global.shard < cur_shard --- ceno_zkvm/src/e2e.rs | 5 ++ ceno_zkvm/src/gadgets/poseidon2.rs | 5 ++ ceno_zkvm/src/instructions/global.rs | 78 ++++++++++++++++++++++------ ceno_zkvm/src/scheme/hal.rs | 1 - ceno_zkvm/src/scheme/utils.rs | 1 + ceno_zkvm/src/structs.rs | 2 +- gkr_iop/src/circuit_builder.rs | 9 ++-- 7 files changed, 80 insertions(+), 21 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 62f3e425f..8421137f1 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -274,6 +274,11 @@ impl<'a> ShardContext<'a> { } } + #[inline(always)] + pub fn cur_shard(&self) -> usize { + self.shards.shard_id + } + #[inline(always)] pub fn is_first_shard(&self) -> bool { self.shards.shard_id == 0 diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index 0eca74c50..322ad675f 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -296,6 +296,11 @@ impl< } } + #[inline(always)] + pub fn num_polys(&self) -> usize { + self.cols.len() + } + pub fn inputs(&self) -> Vec> { let col_exprs = self.cols.iter().map(|c| c.expr()).collect::>(); diff --git a/ceno_zkvm/src/instructions/global.rs b/ceno_zkvm/src/instructions/global.rs index ce24fff92..c7a4fddb9 100644 --- a/ceno_zkvm/src/instructions/global.rs +++ b/ceno_zkvm/src/instructions/global.rs @@ -18,6 +18,7 @@ use gkr_iop::{ chip::Chip, circuit_builder::CircuitBuilder, error::CircuitBuilderError, + gadgets::IsLtConfig, gkr::{GKRCircuit, layer::Layer}, selector::SelectorType, }; @@ -160,6 +161,7 @@ pub struct GlobalConfig { global_clk: WitIn, local_clk: WitIn, nonce: WitIn, + is_shard_lt_cur: IsLtConfig, // if it's a write to global set, then insert a local read record // s.t. local offline memory checking can cancel out // this serves as propagating local write to global. @@ -183,7 +185,7 @@ impl GlobalConfig { .map(|i| cb.create_witin(|| format!("slope{}", i))) .collect(); let addr = cb.create_witin(|| "addr"); - let is_ram_register = cb.create_witin(|| "is_ram_register"); + let is_ram_register = cb.create_bit(|| "is_ram_register")?; let value = UInt::new_unchecked(|| "value", cb)?; let shard = cb.create_witin(|| "shard"); let global_clk = cb.create_witin(|| "global_clk"); @@ -196,9 +198,6 @@ impl GlobalConfig { let mem: Expression = RAMType::Memory.into(); let ram_type: Expression = is_ram_reg.clone() * reg + (1 - is_ram_reg) * mem; - let rc = ::get_default_perm_rc().into(); - let perm_config = Poseidon2Config::construct(cb, rc); - let mut input = vec![]; input.push(addr.expr()); input.push(ram_type.clone()); @@ -228,7 +227,31 @@ impl GlobalConfig { 1 - is_global_write.expr(), local_clk.expr(), )?; - // TODO: enforce shard = shard_id in the public values + + // if it's global write => shard == cur_shard + let cur_shard = cb.query_shard_id()?; + cb.condition_require_zero( + || "global_write = true => shard = instance.shard", + is_global_write.expr(), + shard.expr() - Expression::Instance(cur_shard), + )?; + + // global read => shard < cur_shard + let is_shard_lt_cur = IsLtConfig::construct_circuit( + cb, + || "shard < cur_shard", + shard.expr(), + Expression::Instance(cur_shard), + 16, + )?; + cb.condition_require_equal( + || "global read => shard < cur_shard", + is_global_write.expr(), + is_shard_lt_cur.expr(), + E::BaseField::ONE.expr(), // true + E::BaseField::ZERO.expr(), // false + )?; + cb.read_rlc_record( || "r_record", ram_type.clone(), @@ -251,6 +274,8 @@ impl GlobalConfig { final_sum.into_iter().map(|x| x.expr()).collect::>(), ); + let rc = ::get_default_perm_rc().into(); + let perm_config = Poseidon2Config::construct(cb, rc); // enforces x = poseidon2([addr, ram_type, value[0], value[1], shard, global_clk, nonce, 0, ..., 0]) for (input_expr, hasher_input) in input.into_iter().zip_eq(perm_config.inputs().into_iter()) { @@ -275,6 +300,7 @@ impl GlobalConfig { is_ram_register, value, shard, + is_shard_lt_cur, global_clk, local_clk, nonce, @@ -301,8 +327,9 @@ impl GlobalChip { fn assign_instance( config: &GlobalConfig, instance: &mut [E::BaseField], - _lk_multiplicity: &mut LkMultiplicity, + lk_multiplicity: &mut LkMultiplicity, input: &GlobalChipInput, + cur_shard: usize, ) -> Result<(), crate::error::ZKVMError> { // assign basic fields let record = &input.record; @@ -320,6 +347,13 @@ impl GlobalChip { set_val!(instance, config.local_clk, record.local_clk); set_val!(instance, config.is_global_write, record.is_write as u64); + config.is_shard_lt_cur.assign_instance( + instance, + lk_multiplicity, + record.shard, + cur_shard as u64, + )?; + // assign (x, y) and nonce let GlobalPoint { nonce, point } = &input.ec_point; set_val!(instance, config.nonce, *nonce as u64); @@ -346,10 +380,11 @@ impl GlobalChip { input[2 + k + 1] = E::BaseField::from_canonical_u64(record.global_clk); input[2 + k + 2] = E::BaseField::from_canonical_u32(*nonce); + let num_perm_polys = config.perm_config.num_polys(); + let offset = instance.len() - num_perm_polys; config .perm_config - // TODO: remove hardcoded constant 28 - .assign_instance(&mut instance[28 + UINT_LIMBS..], input); + .assign_instance(&mut instance[offset..], input); Ok(()) } @@ -358,7 +393,7 @@ impl GlobalChip { impl TableCircuit for GlobalChip { type TableConfig = GlobalConfig; type FixedInput = (); - type WitnessInput = Vec>; + type WitnessInput = (Vec>, usize); fn name() -> String { "Global".to_string() @@ -463,8 +498,10 @@ impl TableCircuit for GlobalChip { num_witin: usize, num_structural_witin: usize, _multiplicity: &[HashMap], - steps: &Self::WitnessInput, + input: &Self::WitnessInput, ) -> Result, ZKVMError> { + let steps = &input.0; + let cur_shard = input.1; if steps.is_empty() { return Ok([ witness::RowMajorMatrix::empty(), @@ -483,6 +520,8 @@ impl TableCircuit for GlobalChip { let nthreads = max_usable_threads(); // local read iff it's global write + // local reads are placed before local writes + // i.e. global writes are placed before global reads let num_local_reads = steps.iter().filter(|s| s.record.is_write).count(); tracing::debug!( "{} local reads / {} local writes in global chip", @@ -551,7 +590,13 @@ impl TableCircuit for GlobalChip { set_val!(structural_instance, selector_r_witin, sel_r); set_val!(structural_instance, selector_w_witin, sel_w); set_val!(structural_instance, selector_zero_witin, E::BaseField::ONE); - Self::assign_instance(config, instance, &mut lk_multiplicity, step) + Self::assign_instance( + config, + instance, + &mut lk_multiplicity, + step, + cur_shard, + ) }) .collect::>() }) @@ -683,6 +728,8 @@ mod tests { // create a bunch of random memory read/write records let n_global_reads = 1700; let n_global_writes = 1420; + let prev_shard = 0; + let cur_shard = 1; let global_reads = (0..n_global_reads) .map(|i| { let addr = i * 8; @@ -692,7 +739,7 @@ mod tests { addr: addr as u32, ram_type: RAMType::Memory, value: value as u32, - shard: 0, + shard: prev_shard, local_clk: 0, global_clk: i, is_write: false, @@ -709,7 +756,7 @@ mod tests { addr: addr as u32, ram_type: RAMType::Memory, value: value as u32, - shard: 1, + shard: cur_shard, local_clk: i, global_clk: i, is_write: true, @@ -737,7 +784,7 @@ mod tests { 0, 0, 0, - 0, + cur_shard as u32, vec![0], // dummy global_ec_sum .x @@ -747,13 +794,14 @@ mod tests { .collect_vec(), ); + tracing::debug!("num_witin: {}", cs.num_witin); // assign witness let witness = GlobalChip::assign_instances( &config, cs.num_witin as usize, cs.num_structural_witin as usize, &[], - &input, + &(input, cur_shard as usize), ) .unwrap(); diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index 44aa75c21..c57966217 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -61,7 +61,6 @@ impl<'a, PB: ProverBackend> ProofInput<'a, PB> { } } -#[derive(Clone)] pub struct TowerProverSpec<'a, PB: ProverBackend> { pub witness: Vec>>, } diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index cfa88175f..4bafca1c1 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -506,6 +506,7 @@ pub fn gkr_witness< Either::Right(iter::empty()) }) .chain(fixed.iter().cloned()) + .chain(pub_io.iter().cloned()) .collect_vec(); // infer current layer output diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 85bb61634..b0e971ffe 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -469,7 +469,7 @@ impl ZKVMWitnesses { cs.zkvm_v1_css.num_witin as usize, cs.zkvm_v1_css.num_structural_witin as usize, self.combined_lk_mlt.as_ref().unwrap(), - &global_input, + &(global_input, shard_ctx.cur_shard()), )?; // set num_read, num_write as separate instance assert!( diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 70de7f171..e26e3682e 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -471,10 +471,11 @@ impl ConstraintSystem { slopes: Vec>, final_sum: Vec>, ) { - assert_eq!(xs.len(), 7); - assert_eq!(ys.len(), 7); - assert_eq!(slopes.len(), 7); - assert_eq!(final_sum.len(), 7 * 2); + const SEPTIC_EXTENSION_DEGREE: usize = 7; + assert_eq!(xs.len(), SEPTIC_EXTENSION_DEGREE); + assert_eq!(ys.len(), SEPTIC_EXTENSION_DEGREE); + assert_eq!(slopes.len(), SEPTIC_EXTENSION_DEGREE); + assert_eq!(final_sum.len(), SEPTIC_EXTENSION_DEGREE * 2); assert_eq!(self.ec_point_exprs.len(), 0); self.ec_point_exprs.extend(xs);