Skip to content
9 changes: 9 additions & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ jobs:
RUSTFLAGS: "-C opt-level=3"
run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno examples/target/riscv32im-ceno-zkvm-elf/release/examples/bn254_curve_syscalls

- name: Run fibonacci (release) in 3 shards with CENO_CROSS_SHARD_LIMIT
env:
RUST_LOG: debug
RUSTFLAGS: "-C opt-level=3"
MOCK_PROVING: 1
CENO_CROSS_SHARD_LIMIT: 32
run: cargo run --release --package ceno_zkvm --features sanity-check --bin e2e -- --platform=ceno --min-cycle-per-shard=10 --max-cycle-per-shard=20000 --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/fibonacci


- name: Install cargo make
run: |
cargo make --version || cargo install cargo-make
Expand Down
45 changes: 34 additions & 11 deletions ceno_zkvm/src/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::{
hal::ProverDevice,
mock_prover::{LkMultiplicityKey, MockProver},
prover::ZKVMProver,
septic_curve::SepticPoint,
verifier::ZKVMVerifier,
},
state::GlobalState,
Expand Down Expand Up @@ -44,6 +45,7 @@ use witness::next_pow2_instance_padding;

pub const DEFAULT_MIN_CYCLE_PER_SHARDS: Cycle = 1 << 24;
pub const DEFAULT_MAX_CYCLE_PER_SHARDS: Cycle = 1 << 27;
pub const DEFAULT_CROSS_SHARD_ACCESS_LIMIT: usize = 1 << 20;

/// The polynomial commitment scheme kind
#[derive(
Expand Down Expand Up @@ -175,11 +177,16 @@ pub struct ShardContext<'a> {
Either<Vec<BTreeMap<WordAddr, RAMRecord>>, &'a mut BTreeMap<WordAddr, RAMRecord>>,
pub cur_shard_cycle_range: std::ops::Range<usize>,
pub expected_inst_per_shard: usize,
pub max_num_cross_shard_accesses: usize,
}

impl<'a> Default for ShardContext<'a> {
fn default() -> Self {
let max_threads = max_usable_threads();
let max_num_cross_shard_accesses = std::env::var("CENO_CROSS_SHARD_LIMIT")
.map(|v| v.parse().unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT))
.unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT);

Self {
shard_id: 0,
num_shards: 1,
Expand All @@ -202,6 +209,7 @@ impl<'a> Default for ShardContext<'a> {
),
cur_shard_cycle_range: Tracer::SUBCYCLES_PER_INSN as usize..usize::MAX,
expected_inst_per_shard: usize::MAX,
max_num_cross_shard_accesses,
}
}
}
Expand Down Expand Up @@ -231,6 +239,10 @@ impl<'a> ShardContext<'a> {
let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN as usize;
let max_threads = max_usable_threads();

let max_num_cross_shard_accesses = std::env::var("CENO_CROSS_SHARD_LIMIT")
.map(|v| v.parse().unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT))
.unwrap_or(DEFAULT_CROSS_SHARD_ACCESS_LIMIT);

// strategies
// 0. set cur_num_shards = num_provers
// 1. split instructions evenly by cur_num_shards
Expand Down Expand Up @@ -323,6 +335,7 @@ impl<'a> ShardContext<'a> {
),
cur_shard_cycle_range,
expected_inst_per_shard,
max_num_cross_shard_accesses,
}
})
.collect_vec()
Expand Down Expand Up @@ -355,6 +368,7 @@ impl<'a> ShardContext<'a> {
write_records_tbs: Either::Right(write),
cur_shard_cycle_range: self.cur_shard_cycle_range.clone(),
expected_inst_per_shard: self.expected_inst_per_shard,
max_num_cross_shard_accesses: self.max_num_cross_shard_accesses,
},
)
.collect_vec(),
Expand Down Expand Up @@ -1125,17 +1139,26 @@ pub fn generate_witness<'a, E: ExtensionField>(
pi.end_pc = current_shard_end_pc;
pi.end_cycle = current_shard_end_cycle;
// set shard ram bus expected output to pi
let shard_ram_witness = zkvm_witness.get_table_witness(&ShardRamCircuit::<E>::name());
if let Some(shard_ram_witness) = shard_ram_witness
&& shard_ram_witness[0].num_instances() > 0
{
for (f, v) in ShardRamCircuit::<E>::extract_ec_sum(
&system_config.mmu_config.ram_bus_circuit,
&shard_ram_witness[0],
)
.into_iter()
.zip_eq(pi.shard_rw_sum.as_mut_slice())
{
let shard_ram_witnesses = zkvm_witness.get_witness(&ShardRamCircuit::<E>::name());

if let Some(shard_ram_witnesses) = shard_ram_witnesses {
let shard_ram_ec_sum: SepticPoint<E::BaseField> = shard_ram_witnesses
.iter()
.filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0)
.map(|shard_ram_witness| {
ShardRamCircuit::<E>::extract_ec_sum(
&system_config.mmu_config.ram_bus_circuit,
&shard_ram_witness.witness_rmms[0],
)
})
.sum();

let xy = shard_ram_ec_sum
.x
.0
.iter()
.chain(shard_ram_ec_sum.y.0.iter());
for (f, v) in xy.zip_eq(pi.shard_rw_sum.as_mut_slice()) {
*v = f.to_canonical_u64() as u32;
}
}
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ impl<E: ExtensionField> MmuConfig<'_, E> {
&self.local_final_circuit,
&(shard_ctx, all_records.as_slice()),
)?;
witness.assign_global_chip_circuit(
witness.assign_shared_circuit(
cs,
&(shard_ctx, all_records.as_slice()),
&self.ram_bus_circuit,
Expand Down
3 changes: 3 additions & 0 deletions ceno_zkvm/src/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ impl<E: ExtensionField> ZKVMConstraintSystem<E> {
fixed_traces.insert(circuit_index, fixed_trace_rmm);
}

vm_pk
.circuit_name_to_index
.insert(c_name.clone(), circuit_index);
let circuit_pk = cs.key_gen();
assert!(vm_pk.circuit_pks.insert(c_name, circuit_pk).is_none());
}
Expand Down
20 changes: 17 additions & 3 deletions ceno_zkvm/src/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::{
collections::{BTreeMap, HashMap},
fmt::{self, Debug},
iter,
ops::Div,
rc::Rc,
};
Expand Down Expand Up @@ -156,7 +157,8 @@ pub struct ZKVMProof<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> {
pub raw_pi: Vec<Vec<E::BaseField>>,
// the evaluation of raw_pi.
pub pi_evals: Vec<E>,
pub chip_proofs: BTreeMap<usize, ZKVMChipProof<E>>,
// each circuit may have multiple proof instances
pub chip_proofs: BTreeMap<usize, Vec<ZKVMChipProof<E>>>,
pub witin_commit: <PCS as PolynomialCommitmentScheme<E>>::Commitment,
pub opening_proof: PCS::Proof,
}
Expand All @@ -165,7 +167,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProof<E, PCS> {
pub fn new(
raw_pi: Vec<Vec<E::BaseField>>,
pi_evals: Vec<E>,
chip_proofs: BTreeMap<usize, ZKVMChipProof<E>>,
chip_proofs: BTreeMap<usize, Vec<ZKVMChipProof<E>>>,
witin_commit: <PCS as PolynomialCommitmentScheme<E>>::Commitment,
opening_proof: PCS::Proof,
) -> Self {
Expand Down Expand Up @@ -211,7 +213,13 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProof<E, PCS> {
let halt_instance_count = self
.chip_proofs
.get(&halt_circuit_index)
.map_or(0, |proof| proof.num_instances.iter().sum());
.map_or(0, |proofs| {
proofs
.iter()
.flat_map(|proof| &proof.num_instances)
.copied()
.sum()
});
if halt_instance_count > 0 {
assert_eq!(
halt_instance_count, 1,
Expand Down Expand Up @@ -240,6 +248,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E> + Serialize> fmt::Dis
let tower_proof = self
.chip_proofs
.iter()
.flat_map(|(circuit_index, proofs)| {
iter::repeat_n(circuit_index, proofs.len()).zip(proofs)
})
.map(|(circuit_index, proof)| {
let size = bincode::serialized_size(&proof.tower_proof);
size.inspect(|size| {
Expand All @@ -254,6 +265,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E> + Serialize> fmt::Dis
let main_sumcheck = self
.chip_proofs
.iter()
.flat_map(|(circuit_index, proofs)| {
iter::repeat_n(circuit_index, proofs.len()).zip(proofs)
})
.map(|(circuit_index, proof)| {
let size = bincode::serialized_size(&proof.main_sumcheck_proofs);
size.inspect(|size| {
Expand Down
77 changes: 53 additions & 24 deletions ceno_zkvm/src/scheme/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,21 @@ use gkr_iop::{
use itertools::{Itertools, chain};
use mpcs::{Point, PolynomialCommitmentScheme};
use multilinear_extensions::{
Expression,
Expression, ToExpr,
mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension},
util::ceil_log2,
virtual_poly::build_eq_x_r_vec,
virtual_poly::{build_eq_x_r_vec, eq_eval},
virtual_polys::VirtualPolynomialsBuilder,
};
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
};
use std::{collections::BTreeMap, sync::Arc};
use std::{
collections::BTreeMap,
iter::{once, repeat_n},
sync::Arc,
};
use sumcheck::{
macros::{entered_span, exit_span},
structs::{IOPProverMessage, IOPProverState},
Expand Down Expand Up @@ -75,9 +79,9 @@ impl CpuEccProver {
let out_rt = transcript.sample_and_append_vec(b"ecc", n);
let num_threads = optimal_sumcheck_threads(out_rt.len());

// expression with add (3 zero constrains) and bypass (2 zero constrains)
// expression with add (3 zero constraints), bypass (2 zero constraints), export (2 zero constraints)
let alpha_pows = transcript.sample_and_append_challenge_pows(
SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2,
SEPTIC_EXTENSION_DEGREE * 3 + SEPTIC_EXTENSION_DEGREE * 2 + SEPTIC_EXTENSION_DEGREE * 2,
b"ecc_alpha",
);
let mut alpha_pows_iter = alpha_pows.iter();
Expand All @@ -92,6 +96,17 @@ impl CpuEccProver {
};
let mut sel_add_mle: MultilinearExtension<'_, E> =
sel_add.compute(&out_rt, &sel_add_ctx).unwrap();

// [1,1,...,1,0]
let last_evaluation_index = (1 << n) - 2;
let lsi_on_hypercube = repeat_n(E::ONE, n - 1).chain(once(E::ZERO)).collect_vec();
let mut sel_export = (0..(1 << n))
.into_par_iter()
.map(|_| E::ZERO)
.collect::<Vec<_>>();
sel_export[last_evaluation_index] = eq_eval(&out_rt, lsi_on_hypercube.as_slice());
let mut sel_export_mle = sel_export.into_mle();

// we construct sel_bypass witness here
// verifier can derive it via `sel_bypass = eq - sel_add - sel_last_onehot`
let mut sel_bypass_mle: Vec<E> = build_eq_x_r_vec(&out_rt);
Expand All @@ -110,6 +125,7 @@ impl CpuEccProver {
let mut sel_bypass_mle = sel_bypass_mle.into_mle();
let sel_add_expr = expr_builder.lift(sel_add_mle.to_either());
let sel_bypass_expr = expr_builder.lift(sel_bypass_mle.to_either());
let sel_export_expr = expr_builder.lift(sel_export_mle.to_either());

let mut exprs_add = vec![];
let mut exprs_bypass = vec![];
Expand Down Expand Up @@ -219,33 +235,43 @@ impl CpuEccProver {
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha))),
);
assert!(alpha_pows_iter.next().is_none());

// export x[1,...,1,0], y[1,...,1,0] for final result
let xp = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec();
let yp = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec();
let final_sum_x: SepticExtension<E::BaseField> = (xp.iter())
.map(|x| x.get_base_field_vec()[last_evaluation_index]) // x[1,...,1,0]
.collect_vec()
.into();
let final_sum_y: SepticExtension<E::BaseField> = (yp.iter())
.map(|y| y.get_base_field_vec()[last_evaluation_index]) // x[1,...,1,0]
.collect_vec()
.into();
// 0 = sel_export * (x[1,b] - final_sum.x)
// 0 = sel_export * (y[1,b] - final_sum.y)
let export_expr =
x3.0.iter()
.zip_eq(final_sum_x.0.iter())
// .chain(y3.0.iter().zip_eq(final_sum_y.0.iter()))
.map(|(x, final_x)| x - final_x.expr())
.zip_eq(alpha_pows_iter.by_ref().take(SEPTIC_EXTENSION_DEGREE))
.map(|(e, alpha)| e * Expression::Constant(Either::Right(*alpha)))
.sum::<Expression<E>>()
* sel_export_expr;
// assert!(alpha_pows_iter.next().is_none());

let exprs_bypass = exprs_bypass.into_iter().sum::<Expression<E>>() * sel_bypass_expr;

let (zerocheck_proof, state) = IOPProverState::prove(
expr_builder.to_virtual_polys(&[exprs_add + exprs_bypass], &[]),
expr_builder.to_virtual_polys(&[exprs_add + exprs_bypass + export_expr], &[]),
transcript,
);

let rt = state.collect_raw_challenges();
let evals = state.get_mle_flatten_final_evaluations();

// 7 for x[rt,0], x[rt,1], y[rt,0], y[rt,1], x[1,rt], y[1,rt], s[1,rt]
assert_eq!(evals.len(), 2 + SEPTIC_EXTENSION_DEGREE * 7);

let last_evaluation_index = (1 << n) - 1;
let x3 = xs.iter().map(|x| x.as_view_slice(2, 1)).collect_vec();
let y3 = ys.iter().map(|y| y.as_view_slice(2, 1)).collect_vec();
let final_sum_x: SepticExtension<E::BaseField> = (x3.iter())
.map(|x| x.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0]
.collect_vec()
.into();
let final_sum_y: SepticExtension<E::BaseField> = (y3.iter())
.map(|y| y.get_base_field_vec()[last_evaluation_index - 1]) // x[1,...,1,0]
.collect_vec()
.into();
let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y);
assert_eq!(evals.len(), 3 + SEPTIC_EXTENSION_DEGREE * 7);

#[cfg(feature = "sanity-check")]
{
Expand All @@ -254,19 +280,22 @@ impl CpuEccProver {
let y0 = filter_bj(&ys, 0);
let x1 = filter_bj(&xs, 1);
let y1 = filter_bj(&ys, 1);
let sel_export = eq_eval(&out_rt, &lsi_on_hypercube) * eq_eval(&rt, &lsi_on_hypercube);
assert_eq!(sel_export, evals[2]);

let evals = &evals[2..];
let evals = &evals[3..];
// check evaluations
for i in 0..SEPTIC_EXTENSION_DEGREE {
assert_eq!(s[i].evaluate(&rt), evals[i]);
assert_eq!(x0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE + i]);
assert_eq!(y0[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 2 + i]);
assert_eq!(x1[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 3 + i]);
assert_eq!(y1[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 4 + i]);
assert_eq!(x3[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 5 + i]);
assert_eq!(y3[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 6 + i]);
assert_eq!(xp[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 5 + i]);
assert_eq!(yp[i].evaluate(&rt), evals[SEPTIC_EXTENSION_DEGREE * 6 + i]);
}
}
let final_sum = SepticPoint::from_affine(final_sum_x, final_sum_y);
assert_eq!(zerocheck_proof.extract_sum(), E::ZERO);

EccQuarkProof {
Expand Down
Loading
Loading