Skip to content

Commit c542201

Browse files
committed
wip config as trait
1 parent 4963848 commit c542201

File tree

4 files changed

+204
-25
lines changed

4 files changed

+204
-25
lines changed

ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,16 @@ use crate::{
99
structs::{ProgramParams, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses},
1010
tables::{
1111
HeapCircuit, HintsCircuit, MemFinalRecord, MemInitRecord, NonVolatileTable, PubIOCircuit,
12-
PubIOTable, RegTable, RegTableCircuit, StackCircuit, StaticMemCircuit, StaticMemTable,
13-
TableCircuit,
12+
PubIOTable, RegTable, RegTableCircuit, RegTableInitCircuit, StackCircuit, StaticMemCircuit,
13+
StaticMemTable, TableCircuit,
1414
},
1515
};
1616

17+
pub struct RegConfigs<E: ExtensionField> {
18+
pub reg_init_config: <RegTableInitCircuit<E> as TableCircuit<E>>::TableConfig,
19+
pub reg_mem_bus: <RegTableInitCircuit<E> as TableCircuit<E>>::TableConfig,
20+
}
21+
1722
pub struct MmuConfig<E: ExtensionField> {
1823
/// Initialization of registers.
1924
pub reg_config: <RegTableCircuit<E> as TableCircuit<E>>::TableConfig,

ceno_zkvm/src/tables/ram.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::{
88

99
mod ram_circuit;
1010
mod ram_impl;
11+
use crate::tables::ram::ram_circuit::NonVolatileInitRamCircuit;
1112
pub use ram_circuit::{DynVolatileRamTable, MemFinalRecord, MemInitRecord, NonVolatileTable};
1213

1314
#[derive(Clone)]
@@ -108,7 +109,8 @@ impl NonVolatileTable for RegTable {
108109
}
109110
}
110111

111-
pub type RegTableCircuit<E> = NonVolatileRamCircuit<E, RegTable>;
112+
// pub type RegTableCircuit<E> = NonVolatileRamCircuit<E, RegTable>;
113+
pub type RegTableInitCircuit<E> = NonVolatileInitRamCircuit<E, RegTable>;
112114

113115
#[derive(Clone)]
114116
pub struct StaticMemTable;

ceno_zkvm/src/tables/ram/ram_circuit.rs

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
use std::{collections::HashMap, marker::PhantomData};
22

3-
use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word};
4-
use ff_ext::ExtensionField;
5-
use witness::{InstancePaddingStrategy, RowMajorMatrix};
6-
73
use crate::{
84
circuit_builder::CircuitBuilder,
95
error::ZKVMError,
106
structs::{ProgramParams, RAMType},
117
tables::{RMMCollections, TableCircuit},
128
};
9+
use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word};
10+
use ff_ext::{ExtensionField, SmallField};
11+
use gkr_iop::error::CircuitBuilderError;
12+
use witness::{InstancePaddingStrategy, RowMajorMatrix};
1313

14-
use super::ram_impl::{DynVolatileRamTableConfig, NonVolatileTableConfig, PubIOTableConfig};
14+
use super::ram_impl::{
15+
DynVolatileRamTableConfig, NonVolatileInitTableConfig, NonVolatileTableConfig, PubIOTableConfig,
16+
};
1517

1618
#[derive(Clone, Debug)]
1719
pub struct MemInitRecord {
@@ -104,6 +106,55 @@ impl<E: ExtensionField, NVRAM: NonVolatileTable + Send + Sync + Clone> TableCirc
104106
}
105107
}
106108

109+
/// NonVolatileRamCircuit initializes and finalizes memory
110+
/// - at fixed addresses,
111+
/// - with fixed initial content,
112+
/// - with witnessed final content that the program wrote, if WRITABLE,
113+
/// - or final content equal to initial content, if not WRITABLE.
114+
pub struct NonVolatileInitRamCircuit<E, R>(PhantomData<(E, R)>);
115+
116+
impl<E: ExtensionField, NVRAM: NonVolatileTable + Send + Sync + Clone> TableCircuit<E>
117+
for NonVolatileInitRamCircuit<E, NVRAM>
118+
{
119+
type TableConfig = NonVolatileInitTableConfig<NVRAM>;
120+
type FixedInput = [MemInitRecord];
121+
type WitnessInput = [MemFinalRecord];
122+
123+
fn name() -> String {
124+
format!("RAM_{:?}_{}", NVRAM::RAM_TYPE, NVRAM::name())
125+
}
126+
127+
fn construct_circuit(
128+
cb: &mut CircuitBuilder<E>,
129+
params: &ProgramParams,
130+
) -> Result<Self::TableConfig, ZKVMError> {
131+
Ok(cb.namespace(
132+
|| Self::name(),
133+
|cb| Self::TableConfig::construct_circuit(cb, params),
134+
)?)
135+
}
136+
137+
fn generate_fixed_traces(
138+
config: &Self::TableConfig,
139+
num_fixed: usize,
140+
init_v: &Self::FixedInput,
141+
) -> RowMajorMatrix<E::BaseField> {
142+
// assume returned table is well-formed include padding
143+
config.gen_init_state(num_fixed, init_v)
144+
}
145+
146+
fn assign_instances(
147+
config: &Self::TableConfig,
148+
num_witin: usize,
149+
num_structural_witin: usize,
150+
_multiplicity: &[HashMap<u64, usize>],
151+
final_v: &Self::WitnessInput,
152+
) -> Result<RMMCollections<E::BaseField>, ZKVMError> {
153+
// assume returned table is well-formed include padding
154+
Ok(config.assign_instances(num_witin, num_structural_witin, final_v)?)
155+
}
156+
}
157+
107158
/// PubIORamCircuit initializes and finalizes memory
108159
/// - at fixed addresses,
109160
/// - with content from the public input of proofs.
@@ -189,6 +240,20 @@ pub trait DynVolatileRamTable {
189240
}
190241
}
191242

243+
pub trait DynVolatileRamTableConfigTrait<DVRAM>: Sized + Send + Sync {
244+
type Output: Sized + Send + Sync;
245+
fn construct_circuit<E: ExtensionField>(
246+
cb: &mut CircuitBuilder<E>,
247+
params: &ProgramParams,
248+
) -> Result<Self::Output, CircuitBuilderError>;
249+
fn assign_instances<F: SmallField>(
250+
&self,
251+
num_witin: usize,
252+
num_structural_witin: usize,
253+
final_mem: &[MemFinalRecord],
254+
) -> Result<[RowMajorMatrix<F>; 2], CircuitBuilderError>;
255+
}
256+
192257
/// DynVolatileRamCircuit initializes and finalizes memory
193258
/// - at witnessed addresses, in a contiguous range chosen by the prover,
194259
/// - with zeros as initial content if ZERO_INIT,
@@ -197,12 +262,15 @@ pub trait DynVolatileRamTable {
197262
/// If not ZERO_INIT:
198263
/// - The initial content is an unconstrained prover hint.
199264
/// - The final content is equal to this initial content.
200-
pub struct DynVolatileRamCircuit<E, R>(PhantomData<(E, R)>);
265+
pub struct DynVolatileRamCircuit<E, R, C>(PhantomData<(E, R, C)>);
201266

202-
impl<E: ExtensionField, DVRAM: DynVolatileRamTable + Send + Sync + Clone> TableCircuit<E>
203-
for DynVolatileRamCircuit<E, DVRAM>
267+
impl<
268+
E: ExtensionField,
269+
DVRAM: DynVolatileRamTable + Send + Sync + Clone,
270+
C: DynVolatileRamTableConfigTrait<DVRAM>,
271+
> TableCircuit<E> for DynVolatileRamCircuit<E, DVRAM, C>
204272
{
205-
type TableConfig = DynVolatileRamTableConfig<DVRAM>;
273+
type TableConfig = C::Output;
206274
type FixedInput = ();
207275
type WitnessInput = [MemFinalRecord];
208276

@@ -214,10 +282,7 @@ impl<E: ExtensionField, DVRAM: DynVolatileRamTable + Send + Sync + Clone> TableC
214282
cb: &mut CircuitBuilder<E>,
215283
params: &ProgramParams,
216284
) -> Result<Self::TableConfig, ZKVMError> {
217-
Ok(cb.namespace(
218-
|| Self::name(),
219-
|cb| Self::TableConfig::construct_circuit(cb, params),
220-
)?)
285+
Ok(cb.namespace(|| Self::name(), |cb| C::construct_circuit(cb, params))?)
221286
}
222287

223288
fn generate_fixed_traces(

ceno_zkvm/src/tables/ram/ram_impl.rs

Lines changed: 116 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use crate::{
1919
e2e::RAMRecord,
2020
instructions::riscv::constants::{LIMB_BITS, LIMB_MASK},
2121
structs::ProgramParams,
22+
tables::ram::ram_circuit::DynVolatileRamTableConfigTrait,
2223
};
2324
use ff_ext::FieldInto;
2425
use multilinear_extensions::{
@@ -186,6 +187,105 @@ impl<NVRAM: NonVolatileTable + Send + Sync + Clone> NonVolatileTableConfig<NVRAM
186187
}
187188
}
188189

190+
/// define a non-volatile memory with init value
191+
#[derive(Clone, Debug)]
192+
pub struct NonVolatileInitTableConfig<NVRAM: NonVolatileTable + Send + Sync + Clone> {
193+
init_v: Vec<Fixed>,
194+
addr: Fixed,
195+
196+
phantom: PhantomData<NVRAM>,
197+
params: ProgramParams,
198+
}
199+
200+
impl<NVRAM: NonVolatileTable + Send + Sync + Clone> NonVolatileInitTableConfig<NVRAM> {
201+
pub fn construct_circuit<E: ExtensionField>(
202+
cb: &mut CircuitBuilder<E>,
203+
params: &ProgramParams,
204+
) -> Result<Self, CircuitBuilderError> {
205+
let init_v = (0..NVRAM::V_LIMBS)
206+
.map(|i| cb.create_fixed(|| format!("init_v_limb_{i}")))
207+
.collect_vec();
208+
let addr = cb.create_fixed(|| "addr");
209+
210+
let init_table = [
211+
vec![(NVRAM::RAM_TYPE as usize).into()],
212+
vec![Expression::Fixed(addr)],
213+
init_v.iter().map(|v| v.expr()).collect_vec(),
214+
vec![Expression::ZERO], // Initial cycle.
215+
]
216+
.concat();
217+
218+
cb.w_table_record(
219+
|| "init_table",
220+
NVRAM::RAM_TYPE,
221+
SetTableSpec {
222+
len: Some(NVRAM::len(params)),
223+
structural_witins: vec![],
224+
},
225+
init_table,
226+
)?;
227+
228+
Ok(Self {
229+
init_v,
230+
addr,
231+
phantom: PhantomData,
232+
params: params.clone(),
233+
})
234+
}
235+
236+
pub fn gen_init_state<F: SmallField>(
237+
&self,
238+
num_fixed: usize,
239+
init_mem: &[MemInitRecord],
240+
) -> RowMajorMatrix<F> {
241+
assert!(
242+
NVRAM::len(&self.params).is_power_of_two(),
243+
"{} len {} must be a power of 2",
244+
NVRAM::name(),
245+
NVRAM::len(&self.params)
246+
);
247+
248+
let mut init_table = RowMajorMatrix::<F>::new(
249+
NVRAM::len(&self.params),
250+
num_fixed,
251+
InstancePaddingStrategy::Default,
252+
);
253+
assert_eq!(init_table.num_padding_instances(), 0);
254+
255+
init_table
256+
.par_rows_mut()
257+
.zip_eq(init_mem)
258+
.for_each(|(row, rec)| {
259+
if self.init_v.len() == 1 {
260+
// Assign value directly.
261+
set_fixed_val!(row, self.init_v[0], (rec.value as u64).into_f());
262+
} else {
263+
// Assign value limbs.
264+
self.init_v.iter().enumerate().for_each(|(l, limb)| {
265+
let val = (rec.value >> (l * LIMB_BITS)) & LIMB_MASK;
266+
set_fixed_val!(row, limb, (val as u64).into_f());
267+
});
268+
}
269+
set_fixed_val!(row, self.addr, (rec.addr as u64).into_f());
270+
});
271+
272+
init_table
273+
}
274+
275+
/// TODO consider taking RowMajorMatrix as argument to save allocations.
276+
pub fn assign_instances<F: SmallField>(
277+
&self,
278+
num_witin: usize,
279+
num_structural_witin: usize,
280+
_final_mem: &[MemFinalRecord],
281+
) -> Result<[RowMajorMatrix<F>; 2], CircuitBuilderError> {
282+
assert_eq!(num_structural_witin, 0);
283+
assert!(_final_mem.is_empty());
284+
285+
Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()])
286+
}
287+
}
288+
189289
/// define public io
190290
/// init value set by instance
191291
#[derive(Clone, Debug)]
@@ -315,8 +415,11 @@ pub struct DynVolatileRamTableConfig<DVRAM: DynVolatileRamTable + Send + Sync +
315415
params: ProgramParams,
316416
}
317417

318-
impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableConfig<DVRAM> {
319-
pub fn construct_circuit<E: ExtensionField>(
418+
impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableConfigTrait<DVRAM>
419+
for DynVolatileRamTableConfig<DVRAM>
420+
{
421+
type Output = DynVolatileRamTableConfig<DVRAM>;
422+
fn construct_circuit<E: ExtensionField>(
320423
cb: &mut CircuitBuilder<E>,
321424
params: &ProgramParams,
322425
) -> Result<Self, CircuitBuilderError> {
@@ -389,7 +492,7 @@ impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableConfig
389492
}
390493

391494
/// TODO consider taking RowMajorMatrix as argument to save allocations.
392-
pub fn assign_instances<F: SmallField>(
495+
fn assign_instances<F: SmallField>(
393496
&self,
394497
num_witin: usize,
395498
num_structural_witin: usize,
@@ -457,8 +560,10 @@ pub struct DynVolatileRamTableInitConfig<DVRAM: DynVolatileRamTable + Send + Syn
457560
params: ProgramParams,
458561
}
459562

460-
impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableInitConfig<DVRAM> {
461-
pub fn construct_circuit<E: ExtensionField>(
563+
impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableConfigTrait<DVRAM>
564+
for DynVolatileRamTableInitConfig<DVRAM>
565+
{
566+
fn construct_circuit<E: ExtensionField>(
462567
cb: &mut CircuitBuilder<E>,
463568
params: &ProgramParams,
464569
) -> Result<Self, CircuitBuilderError> {
@@ -503,7 +608,7 @@ impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableInitCo
503608
}
504609

505610
/// TODO consider taking RowMajorMatrix as argument to save allocations.
506-
pub fn assign_instances<F: SmallField>(
611+
fn assign_instances<F: SmallField>(
507612
&self,
508613
num_witin: usize,
509614
num_structural_witin: usize,
@@ -564,8 +669,10 @@ pub struct DynVolatileRamTableFinalConfig<DVRAM: DynVolatileRamTable + Send + Sy
564669
params: ProgramParams,
565670
}
566671

567-
impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableFinalConfig<DVRAM> {
568-
pub fn construct_circuit<E: ExtensionField>(
672+
impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableConfigTrait<DVRAM>
673+
for DynVolatileRamTableFinalConfig<DVRAM>
674+
{
675+
fn construct_circuit<E: ExtensionField>(
569676
cb: &mut CircuitBuilder<E>,
570677
params: &ProgramParams,
571678
) -> Result<Self, CircuitBuilderError> {
@@ -620,7 +727,7 @@ impl<DVRAM: DynVolatileRamTable + Send + Sync + Clone> DynVolatileRamTableFinalC
620727
}
621728

622729
/// TODO consider taking RowMajorMatrix as argument to save allocations.
623-
pub fn assign_instances<F: SmallField>(
730+
fn assign_instances<F: SmallField>(
624731
&self,
625732
num_witin: usize,
626733
num_structural_witin: usize,

0 commit comments

Comments
 (0)