Skip to content

Commit 03092e9

Browse files
committed
complete local finalized mem chip logic
1 parent b84f74e commit 03092e9

File tree

8 files changed

+403
-433
lines changed

8 files changed

+403
-433
lines changed

ceno_zkvm/src/chip_handler/general.rs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
circuit_builder::CircuitBuilder,
66
instructions::riscv::constants::{
77
END_CYCLE_IDX, END_PC_IDX, END_SHARD_ID_IDX, EXIT_CODE_IDX, INIT_CYCLE_IDX, INIT_PC_IDX,
8-
MEM_BUS_WITH_READ_IDX, MEM_BUS_WITH_WRITE_IDX, PUBLIC_IO_IDX, UINT_LIMBS,
8+
PUBLIC_IO_IDX, UINT_LIMBS,
99
},
1010
tables::InsnRecord,
1111
};
@@ -23,8 +23,6 @@ pub trait PublicIOQuery {
2323
fn query_end_cycle(&mut self) -> Result<Instance, CircuitBuilderError>;
2424
fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>;
2525
fn query_shard_id(&mut self) -> Result<Instance, CircuitBuilderError>;
26-
fn query_mem_bus_with_read(&mut self) -> Result<Instance, CircuitBuilderError>;
27-
fn query_mem_bus_with_write(&mut self) -> Result<Instance, CircuitBuilderError>;
2826
}
2927

3028
impl<'a, E: ExtensionField> InstFetch<E> for CircuitBuilder<'a, E> {
@@ -67,16 +65,6 @@ impl<'a, E: ExtensionField> PublicIOQuery for CircuitBuilder<'a, E> {
6765
self.cs.query_instance(|| "shard_id", END_SHARD_ID_IDX)
6866
}
6967

70-
fn query_mem_bus_with_read(&mut self) -> Result<Instance, CircuitBuilderError> {
71-
self.cs
72-
.query_instance(|| "mem_bus_with_read", MEM_BUS_WITH_READ_IDX)
73-
}
74-
75-
fn query_mem_bus_with_write(&mut self) -> Result<Instance, CircuitBuilderError> {
76-
self.cs
77-
.query_instance(|| "mem_bus_with_write", MEM_BUS_WITH_WRITE_IDX)
78-
}
79-
8068
fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError> {
8169
Ok([
8270
self.cs.query_instance(|| "public_io_low", PUBLIC_IO_IDX)?,

ceno_zkvm/src/e2e.rs

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,21 @@ impl<'a> ShardContext<'a> {
213213
}
214214
}
215215

216+
#[inline(always)]
217+
pub fn is_first_shard(&self) -> bool {
218+
self.shard_id == 0
219+
}
220+
221+
#[inline(always)]
222+
pub fn is_last_shard(&self) -> bool {
223+
self.shard_id == self.num_shards - 1
224+
}
225+
226+
#[inline(always)]
227+
pub fn is_current_shard_cycle(&self, cycle: Cycle) -> bool {
228+
self.cur_shard_cycle_range.contains(&(cycle as usize))
229+
}
230+
216231
#[inline(always)]
217232
pub fn send(
218233
&mut self,
@@ -226,7 +241,7 @@ impl<'a> ShardContext<'a> {
226241
) {
227242
// check read from external mem bus
228243
if prev_cycle < self.cur_shard_cycle_range.start as Cycle
229-
&& self.cur_shard_cycle_range.contains(&(cycle as usize))
244+
&& self.is_current_shard_cycle(cycle)
230245
{
231246
let ram_record = self
232247
.read_thread_based_record_storage
@@ -248,7 +263,7 @@ impl<'a> ShardContext<'a> {
248263
// check write to external mem bus
249264
if let Some(future_touch_cycle) = self.addr_future_accesses.get(&(addr, cycle)) {
250265
if *future_touch_cycle >= self.cur_shard_cycle_range.end as Cycle
251-
&& self.cur_shard_cycle_range.contains(&(cycle as usize))
266+
&& self.is_current_shard_cycle(cycle)
252267
{
253268
let ram_record = self
254269
.write_thread_based_record_storage
@@ -348,13 +363,15 @@ pub fn emulate_program<'a>(
348363
if index < VMState::REG_COUNT {
349364
let vma: WordAddr = Platform::register_vma(index).into();
350365
MemFinalRecord {
366+
ram_type: RAMType::Memory,
351367
addr: rec.addr,
352368
value: vm.peek_register(index),
353369
cycle: *final_access.get(&vma).unwrap_or(&0),
354370
}
355371
} else {
356372
// The table is padded beyond the number of registers.
357373
MemFinalRecord {
374+
ram_type: RAMType::Memory,
358375
addr: rec.addr,
359376
value: 0,
360377
cycle: 0,
@@ -369,6 +386,7 @@ pub fn emulate_program<'a>(
369386
.map(|rec| {
370387
let vma: WordAddr = rec.addr.into();
371388
MemFinalRecord {
389+
ram_type: RAMType::Memory,
372390
addr: rec.addr,
373391
value: vm.peek_memory(vma),
374392
cycle: *final_access.get(&vma).unwrap_or(&0),
@@ -380,6 +398,7 @@ pub fn emulate_program<'a>(
380398
let io_final = io_init
381399
.iter()
382400
.map(|rec| MemFinalRecord {
401+
ram_type: RAMType::Memory,
383402
addr: rec.addr,
384403
value: rec.value,
385404
cycle: *final_access.get(&rec.addr.into()).unwrap_or(&0),
@@ -390,6 +409,7 @@ pub fn emulate_program<'a>(
390409
let hints_final = hints_init
391410
.iter()
392411
.map(|rec| MemFinalRecord {
412+
ram_type: RAMType::Memory,
393413
addr: rec.addr,
394414
value: rec.value,
395415
cycle: *final_access.get(&rec.addr.into()).unwrap_or(&0),
@@ -407,6 +427,7 @@ pub fn emulate_program<'a>(
407427
.map(|vma| {
408428
let byte_addr = vma.baddr();
409429
MemFinalRecord {
430+
ram_type: RAMType::Memory,
410431
addr: byte_addr.0,
411432
value: vm.peek_memory(vma),
412433
cycle: *final_access.get(&vma).unwrap_or(&0),
@@ -430,6 +451,7 @@ pub fn emulate_program<'a>(
430451
.map(|vma| {
431452
let byte_addr = vma.baddr();
432453
MemFinalRecord {
454+
ram_type: RAMType::Memory,
433455
addr: byte_addr.0,
434456
value: vm.peek_memory(vma),
435457
cycle: *final_access.get(&vma).unwrap_or(&0),
@@ -578,17 +600,17 @@ pub fn init_static_addrs(program: &Program) -> Vec<MemInitRecord> {
578600
program_addrs
579601
}
580602

581-
pub struct ConstraintSystemConfig<E: ExtensionField> {
603+
pub struct ConstraintSystemConfig<'a, E: ExtensionField> {
582604
pub zkvm_cs: ZKVMConstraintSystem<E>,
583605
pub config: Rv32imConfig<E>,
584-
pub mmu_config: MmuConfig<E>,
606+
pub mmu_config: MmuConfig<'a, E>,
585607
pub dummy_config: DummyExtraConfig<E>,
586608
pub prog_config: ProgramTableConfig,
587609
}
588610

589-
pub fn construct_configs<E: ExtensionField>(
611+
pub fn construct_configs<'a, E: ExtensionField>(
590612
program_params: ProgramParams,
591-
) -> ConstraintSystemConfig<E> {
613+
) -> ConstraintSystemConfig<'a, E> {
592614
let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params);
593615

594616
let config = Rv32imConfig::<E>::construct_circuits(&mut zkvm_cs);
@@ -673,6 +695,7 @@ pub fn generate_witness<E: ExtensionField>(
673695
.mmu_config
674696
.assign_table_circuit(
675697
&system_config.zkvm_cs,
698+
&emul_result.shard_ctx,
676699
&mut zkvm_witness,
677700
&emul_result.final_mem_state.reg,
678701
&emul_result.final_mem_state.mem,
@@ -714,13 +737,13 @@ pub enum Checkpoint {
714737
pub type IntermediateState<E, PCS> = (Option<ZKVMProof<E, PCS>>, Option<ZKVMVerifyingKey<E, PCS>>);
715738

716739
/// Context construct from a program and given platform
717-
pub struct E2EProgramCtx<E: ExtensionField> {
740+
pub struct E2EProgramCtx<'a, E: ExtensionField> {
718741
pub program: Arc<Program>,
719742
pub platform: Platform,
720743
pub shards: Shards,
721744
pub static_addrs: Vec<MemInitRecord>,
722745
pub pubio_len: usize,
723-
pub system_config: ConstraintSystemConfig<E>,
746+
pub system_config: ConstraintSystemConfig<'a, E>,
724747
pub reg_init: Vec<MemInitRecord>,
725748
pub io_init: Vec<MemInitRecord>,
726749
pub zkvm_fixed_traces: ZKVMFixedTraces<E>,
@@ -745,11 +768,11 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> E2ECheckpointResult<
745768
}
746769

747770
/// Set up a program with the given platform
748-
pub fn setup_program<E: ExtensionField>(
771+
pub fn setup_program<'a, E: ExtensionField>(
749772
program: Program,
750773
platform: Platform,
751774
shards: Shards,
752-
) -> E2EProgramCtx<E> {
775+
) -> E2EProgramCtx<'a, E> {
753776
let static_addrs = init_static_addrs(&program);
754777
let pubio_len = platform.public_io.iter_addresses().len();
755778
let program_params = ProgramParams {
@@ -784,7 +807,7 @@ pub fn setup_program<E: ExtensionField>(
784807
}
785808
}
786809

787-
impl<E: ExtensionField> E2EProgramCtx<E> {
810+
impl<E: ExtensionField> E2EProgramCtx<'_, E> {
788811
pub fn keygen<PCS: PolynomialCommitmentScheme<E> + 'static>(
789812
&self,
790813
max_num_variables: usize,

ceno_zkvm/src/instructions/riscv/constants.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ pub const INIT_CYCLE_IDX: usize = 3;
1010
pub const END_PC_IDX: usize = 4;
1111
pub const END_CYCLE_IDX: usize = 5;
1212
pub const END_SHARD_ID_IDX: usize = 6;
13-
pub const MEM_BUS_WITH_READ_IDX: usize = 7;
14-
pub const MEM_BUS_WITH_WRITE_IDX: usize = 8;
15-
pub const PUBLIC_IO_IDX: usize = 9;
13+
pub const PUBLIC_IO_IDX: usize = 7;
1614

1715
pub const LIMB_BITS: usize = 16;
1816
pub const LIMB_MASK: u32 = 0xFFFF;

ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs

Lines changed: 83 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,63 @@
1-
use std::{collections::HashSet, iter::zip, ops::Range};
2-
3-
use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word};
4-
use ff_ext::ExtensionField;
5-
use itertools::{Itertools, chain};
6-
71
use crate::{
2+
e2e::ShardContext,
83
error::ZKVMError,
94
structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
105
tables::{
11-
HeapCircuit, HintsCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit,
12-
PubIOTable, RegTable, RegTableCircuit, RegTableInitCircuit, StackCircuit, StaticMemCircuit,
6+
DynVolatileRamTable, HeapInitCircuit, HeapTable, HintsCircuit, LocalFinalCircuit,
7+
MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit, PubIOTable, RBCircuit,
8+
RegTable, RegTableInitCircuit, StackInitCircuit, StackTable, StaticMemInitCircuit,
139
StaticMemTable, TableCircuit,
1410
},
1511
};
12+
use ceno_emul::{Addr, Cycle, IterAddresses, WORD_SIZE, Word};
13+
use ff_ext::ExtensionField;
14+
use itertools::{Itertools, chain};
15+
use std::{collections::HashSet, iter::zip, ops::Range, sync::Arc};
16+
use witness::InstancePaddingStrategy;
1617

17-
pub struct RegConfigs<E: ExtensionField> {
18-
pub reg_init_config: <RegTableInitCircuit<E> as TableCircuit<E>>::TableConfig,
19-
pub reg_final_config: <RegTableFinalCircuit<E> as TableCircuit<E>>::TableConfig,
20-
pub reg_mem_bus: <RegTableInitCircuit<E> as TableCircuit<E>>::TableConfig,
21-
}
22-
23-
pub struct MmuConfig<E: ExtensionField> {
18+
pub struct MmuConfig<'a, E: ExtensionField> {
2419
/// Initialization of registers.
25-
pub reg_config: <RegTableCircuit<E> as TableCircuit<E>>::TableConfig,
20+
pub reg_init_config: <RegTableInitCircuit<E> as TableCircuit<E>>::TableConfig,
2621
/// Initialization of memory with static addresses.
27-
pub static_mem_config: <StaticMemCircuit<E> as TableCircuit<E>>::TableConfig,
22+
pub static_mem_init_config: <StaticMemInitCircuit<E> as TableCircuit<E>>::TableConfig,
2823
/// Initialization of public IO.
2924
pub public_io_config: <PubIOCircuit<E> as TableCircuit<E>>::TableConfig,
3025
/// Initialization of hints.
3126
pub hints_config: <HintsCircuit<E> as TableCircuit<E>>::TableConfig,
3227
/// Initialization of heap.
33-
pub heap_config: <HeapCircuit<E> as TableCircuit<E>>::TableConfig,
28+
pub heap_init_config: <HeapInitCircuit<E> as TableCircuit<E>>::TableConfig,
3429
/// Initialization of stack.
35-
pub stack_config: <StackCircuit<E> as TableCircuit<E>>::TableConfig,
30+
pub stack_init_config: <StackInitCircuit<E> as TableCircuit<E>>::TableConfig,
31+
/// finalized circuit for all MMIO
32+
pub local_final_circuit: <LocalFinalCircuit<'a, E> as TableCircuit<E>>::TableConfig,
33+
/// ram bus to deal with cross shard read/write
34+
pub ram_bus_circuit: <RBCircuit<E> as TableCircuit<E>>::TableConfig,
3635
pub params: ProgramParams,
3736
}
3837

39-
impl<E: ExtensionField> MmuConfig<E> {
38+
impl<E: ExtensionField> MmuConfig<'_, E> {
4039
pub fn construct_circuits(cs: &mut ZKVMConstraintSystem<E>) -> Self {
41-
let reg_config = cs.register_table_circuit::<RegTableCircuit<E>>();
40+
let reg_init_config = cs.register_table_circuit::<RegTableInitCircuit<E>>();
4241

43-
let static_mem_config = cs.register_table_circuit::<StaticMemCircuit<E>>();
42+
let static_mem_init_config = cs.register_table_circuit::<StaticMemInitCircuit<E>>();
4443

4544
let public_io_config = cs.register_table_circuit::<PubIOCircuit<E>>();
4645

4746
let hints_config = cs.register_table_circuit::<HintsCircuit<E>>();
48-
let stack_config = cs.register_table_circuit::<StackCircuit<E>>();
49-
let heap_config = cs.register_table_circuit::<HeapCircuit<E>>();
47+
let stack_init_config = cs.register_table_circuit::<StackInitCircuit<E>>();
48+
let heap_init_config = cs.register_table_circuit::<HeapInitCircuit<E>>();
49+
let local_final_circuit = cs.register_table_circuit::<LocalFinalCircuit<E>>();
50+
let ram_bus_circuit = cs.register_table_circuit::<RBCircuit<E>>();
5051

5152
Self {
52-
reg_config,
53-
static_mem_config,
53+
reg_init_config,
54+
static_mem_init_config,
5455
public_io_config,
5556
hints_config,
56-
stack_config,
57-
heap_config,
57+
stack_init_config,
58+
heap_init_config,
59+
local_final_circuit,
60+
ram_bus_circuit,
5861
params: cs.params.clone(),
5962
}
6063
}
@@ -78,24 +81,27 @@ impl<E: ExtensionField> MmuConfig<E> {
7881
"memory addresses must be unique"
7982
);
8083

81-
fixed.register_table_circuit::<RegTableCircuit<E>>(cs, &self.reg_config, reg_init);
84+
fixed.register_table_circuit::<RegTableInitCircuit<E>>(cs, &self.reg_init_config, reg_init);
8285

83-
fixed.register_table_circuit::<StaticMemCircuit<E>>(
86+
fixed.register_table_circuit::<StaticMemInitCircuit<E>>(
8487
cs,
85-
&self.static_mem_config,
88+
&self.static_mem_init_config,
8689
static_mem_init,
8790
);
8891

8992
fixed.register_table_circuit::<PubIOCircuit<E>>(cs, &self.public_io_config, io_addrs);
9093
fixed.register_table_circuit::<HintsCircuit<E>>(cs, &self.hints_config, &());
91-
fixed.register_table_circuit::<StackCircuit<E>>(cs, &self.stack_config, &());
92-
fixed.register_table_circuit::<HeapCircuit<E>>(cs, &self.heap_config, &());
94+
fixed.register_table_circuit::<StackInitCircuit<E>>(cs, &self.stack_init_config, &());
95+
fixed.register_table_circuit::<HeapInitCircuit<E>>(cs, &self.heap_init_config, &());
96+
fixed.register_table_circuit::<LocalFinalCircuit<E>>(cs, &self.local_final_circuit, &());
97+
fixed.register_table_circuit::<RBCircuit<E>>(cs, &self.ram_bus_circuit, &());
9398
}
9499

95100
#[allow(clippy::too_many_arguments)]
96101
pub fn assign_table_circuit(
97102
&self,
98103
cs: &ZKVMConstraintSystem<E>,
104+
shard_ctx: &ShardContext,
99105
witness: &mut ZKVMWitnesses<E>,
100106
reg_final: &[MemFinalRecord],
101107
static_mem_final: &[MemFinalRecord],
@@ -104,18 +110,57 @@ impl<E: ExtensionField> MmuConfig<E> {
104110
stack_final: &[MemFinalRecord],
105111
heap_final: &[MemFinalRecord],
106112
) -> Result<(), ZKVMError> {
107-
witness.assign_table_circuit::<RegTableCircuit<E>>(cs, &self.reg_config, reg_final)?;
113+
witness.assign_table_circuit::<RegTableInitCircuit<E>>(
114+
cs,
115+
&self.reg_init_config,
116+
reg_final,
117+
)?;
108118

109-
witness.assign_table_circuit::<StaticMemCircuit<E>>(
119+
witness.assign_table_circuit::<StaticMemInitCircuit<E>>(
110120
cs,
111-
&self.static_mem_config,
121+
&self.static_mem_init_config,
112122
static_mem_final,
113123
)?;
114124

115125
witness.assign_table_circuit::<PubIOCircuit<E>>(cs, &self.public_io_config, io_cycles)?;
116126
witness.assign_table_circuit::<HintsCircuit<E>>(cs, &self.hints_config, hints_final)?;
117-
witness.assign_table_circuit::<StackCircuit<E>>(cs, &self.stack_config, stack_final)?;
118-
witness.assign_table_circuit::<HeapCircuit<E>>(cs, &self.heap_config, heap_final)?;
127+
witness.assign_table_circuit::<StackInitCircuit<E>>(
128+
cs,
129+
&self.stack_init_config,
130+
stack_final,
131+
)?;
132+
witness.assign_table_circuit::<HeapInitCircuit<E>>(
133+
cs,
134+
&self.heap_init_config,
135+
heap_final,
136+
)?;
137+
138+
let all_records = vec![
139+
(InstancePaddingStrategy::Default, reg_final),
140+
(InstancePaddingStrategy::Default, static_mem_final),
141+
(
142+
InstancePaddingStrategy::Custom({
143+
let params = cs.params.clone();
144+
Arc::new(move |row: u64, _: u64| StackTable::addr(&params, row as usize) as u64)
145+
}),
146+
stack_final,
147+
),
148+
(
149+
InstancePaddingStrategy::Custom({
150+
let params = cs.params.clone();
151+
Arc::new(move |row: u64, _: u64| HeapTable::addr(&params, row as usize) as u64)
152+
}),
153+
heap_final,
154+
),
155+
];
156+
// take all mem result and
157+
witness.assign_table_circuit::<LocalFinalCircuit<E>>(
158+
cs,
159+
&self.local_final_circuit,
160+
&(shard_ctx, all_records.as_slice()),
161+
)?;
162+
163+
witness.assign_table_circuit::<RBCircuit<E>>(cs, &self.ram_bus_circuit, todo!())?;
119164

120165
Ok(())
121166
}

0 commit comments

Comments
 (0)