Skip to content

Commit 119c1e9

Browse files
committed
fix mock prover
1 parent b4e7499 commit 119c1e9

File tree

1 file changed

+40
-25
lines changed

1 file changed

+40
-25
lines changed

ceno_zkvm/src/scheme/mock_prover.rs

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
523523
structural_witin,
524524
&[],
525525
&[],
526+
&[],
526527
Some(challenge),
527528
lkm,
528529
)
@@ -534,7 +535,7 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
534535
program: &[ceno_emul::Instruction],
535536
lkm: Option<Multiplicity<u64>>,
536537
) -> Result<(), Vec<MockProverError<E, u64>>> {
537-
Self::run_maybe_challenge(cb, &[], wits_in, &[], program, &[], None, lkm)
538+
Self::run_maybe_challenge(cb, &[], wits_in, &[], program, &[], &[], None, lkm)
538539
}
539540

540541
#[allow(clippy::too_many_arguments)]
@@ -544,7 +545,8 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
544545
wits_in: &[ArcMultilinearExtension<'a, E>],
545546
structural_witin: &[ArcMultilinearExtension<'a, E>],
546547
program: &[ceno_emul::Instruction],
547-
pi: &[ArcMultilinearExtension<'a, E>],
548+
pi_mles: &[ArcMultilinearExtension<'a, E>],
549+
pub_io_evals: &[Either<E::BaseField, E>],
548550
challenge: Option<[E; 2]>,
549551
lkm: Option<Multiplicity<u64>>,
550552
) -> Result<(), Vec<MockProverError<E, u64>>> {
@@ -557,7 +559,8 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
557559
fixed,
558560
wits_in,
559561
structural_witin,
560-
pi,
562+
pi_mles,
563+
pub_io_evals,
561564
1,
562565
challenge,
563566
lkm,
@@ -572,22 +575,15 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
572575
fixed: &[ArcMultilinearExtension<'a, E>],
573576
wits_in: &[ArcMultilinearExtension<'a, E>],
574577
structural_witin: &[ArcMultilinearExtension<'a, E>],
575-
pi: &[ArcMultilinearExtension<'a, E>],
578+
pi_mles: &[ArcMultilinearExtension<'a, E>],
579+
pub_io_evals: &[Either<E::BaseField, E>],
576580
num_instances: usize,
577581
challenge: [E; 2],
578582
expected_lkm: Option<Multiplicity<u64>>,
579583
) -> Result<LkMultiplicityRaw<E>, Vec<MockProverError<E, u64>>> {
580584
let mut shared_lkm = LkMultiplicityRaw::<E>::default();
581585
let mut errors = vec![];
582586

583-
let pub_io_evals = pi.iter().map(|v| v.index(0)).collect_vec();
584-
585-
let pi_mles = cs
586-
.instance_openings
587-
.iter()
588-
.map(|instance| pi[instance.0].clone())
589-
.collect_vec();
590-
591587
// Assert zero expressions
592588
for (expr, name) in cs
593589
.assert_zero_expressions
@@ -611,10 +607,14 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
611607
if let Some(zero_selector) = &cs.zero_selector {
612608
structural_witin[zero_selector.selector_expr().id()].clone()
613609
} else {
610+
let num_instance_padded = next_pow2_instance_padding(num_instances);
614611
let mut selector = vec![E::BaseField::ONE; num_instances];
615-
selector.resize(wits_in[0].evaluations().len(), E::BaseField::ZERO);
616-
MultilinearExtension::from_evaluation_vec_smart(wits_in[0].num_vars(), selector)
617-
.into()
612+
selector.resize(num_instance_padded, E::BaseField::ZERO);
613+
MultilinearExtension::from_evaluation_vec_smart(
614+
ceil_log2(num_instance_padded),
615+
selector,
616+
)
617+
.into()
618618
};
619619

620620
// require_equal does not always have the form of Expr::Sum as
@@ -702,18 +702,22 @@ impl<'a, E: ExtensionField + Hash> MockProver<E> {
702702
let lk_selector: ArcMultilinearExtension<_> = if let Some(lk_selector) = &cs.lk_selector {
703703
structural_witin[lk_selector.selector_expr().id()].clone()
704704
} else {
705+
let num_instance_padded = next_pow2_instance_padding(num_instances);
705706
let mut selector = vec![E::BaseField::ONE; num_instances];
706-
selector.resize(wits_in[0].evaluations().len(), E::BaseField::ZERO);
707-
MultilinearExtension::from_evaluation_vec_smart(wits_in[0].num_vars(), selector).into()
707+
selector.resize(num_instance_padded, E::BaseField::ZERO);
708+
MultilinearExtension::from_evaluation_vec_smart(
709+
ceil_log2(num_instance_padded),
710+
selector,
711+
)
712+
.into()
708713
};
709714

710715
// Lookup expressions
711-
for ((expr, name), (rom_type, _)) in cs
712-
.lk_expressions
713-
.iter()
714-
.zip_eq(cs.lk_expressions_namespace_map.iter())
715-
.zip_eq(cs.lk_expressions_items_map.iter())
716-
{
716+
for (expr, (name, (rom_type, _))) in cs.lk_expressions.iter().zip(
717+
cs.lk_expressions_namespace_map
718+
.iter()
719+
.zip_eq(cs.lk_expressions_items_map.iter()),
720+
) {
717721
let expr_evaluated = wit_infer_by_expr(
718722
expr,
719723
cs.num_witin,
@@ -1016,7 +1020,6 @@ Hints:
10161020
.iter()
10171021
.map(|instance| pi_mles[instance.0].clone())
10181022
.collect_vec();
1019-
let is_opcode = gkr_circuit.is_some();
10201023
let [witness, structural_witness] = witnesses
10211024
.get_opcode_witness(circuit_name)
10221025
.or_else(|| witnesses.get_table_witness(circuit_name))
@@ -1055,7 +1058,8 @@ Hints:
10551058
.map_or(vec![], |fixed| {
10561059
fixed.to_mles().into_iter().map(|f| f.into()).collect_vec()
10571060
});
1058-
if is_opcode {
1061+
// not lookup table
1062+
if cs.lk_table_expressions.is_empty() {
10591063
tracing::info!(
10601064
"Mock proving opcode {} with {} entries",
10611065
circuit_name,
@@ -1072,6 +1076,7 @@ Hints:
10721076
&witness,
10731077
&structural_witness,
10741078
&pi_mles,
1079+
&pub_io_evals,
10751080
num_rows,
10761081
challenges,
10771082
lkm_from_assignments,
@@ -1171,6 +1176,11 @@ Hints:
11711176
let fixed = fixed_mles.get(circuit_name).unwrap();
11721177
let witness = wit_mles.get(circuit_name).unwrap();
11731178
let structural_witness = structural_wit_mles.get(circuit_name).unwrap();
1179+
let pi_mles = cs
1180+
.instance_openings
1181+
.iter()
1182+
.map(|instance| pi_mles[instance.0].clone())
1183+
.collect_vec();
11741184

11751185
let num_rows = num_instances.get(circuit_name).unwrap();
11761186
if *num_rows == 0 {
@@ -1272,6 +1282,11 @@ Hints:
12721282
let fixed = fixed_mles.get(circuit_name).unwrap();
12731283
let witness = wit_mles.get(circuit_name).unwrap();
12741284
let structural_witness = structural_wit_mles.get(circuit_name).unwrap();
1285+
let pi_mles = cs
1286+
.instance_openings
1287+
.iter()
1288+
.map(|instance| pi_mles[instance.0].clone())
1289+
.collect_vec();
12751290
let num_rows = num_instances.get(circuit_name).unwrap();
12761291
if *num_rows == 0 {
12771292
continue;

0 commit comments

Comments
 (0)