Skip to content

Commit d5849da

Browse files
committed
aligned step cycle and prev_cycle to local version
1 parent 03092e9 commit d5849da

File tree

7 files changed

+95
-57
lines changed

7 files changed

+95
-57
lines changed

ceno_emul/src/shards.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
pub struct Shards {
22
pub shard_id: usize,
3-
pub num_shards: usize,
3+
pub max_num_shards: usize,
44
}
55

66
impl Shards {
7-
pub fn new(shard_id: usize, num_shards: usize) -> Self {
8-
assert!(shard_id < num_shards);
7+
pub fn new(shard_id: usize, max_num_shards: usize) -> Self {
8+
assert!(shard_id < max_num_shards);
99
Self {
1010
shard_id,
11-
num_shards,
11+
max_num_shards,
1212
}
1313
}
1414

@@ -17,6 +17,6 @@ impl Shards {
1717
}
1818

1919
pub fn is_last_shard(&self) -> bool {
20-
self.shard_id == self.num_shards - 1
20+
self.shard_id == self.max_num_shards - 1
2121
}
2222
}

ceno_emul/src/tracer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::{
2525
/// - Any pair of `rs1 / rs2 / rd` **may be the same**. Then, one op will point to the other op in the same instruction but a different subcycle. The circuits may follow the operations **without special handling** of repeated registers.
2626
#[derive(Clone, Debug, Default, PartialEq, Eq)]
2727
pub struct StepRecord {
28-
cycle: Cycle,
28+
pub cycle: Cycle,
2929
pc: Change<ByteAddr>,
3030
pub insn: Instruction,
3131

ceno_zkvm/src/bin/e2e.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ struct Args {
115115

116116
// number of total shards
117117
#[arg(long, default_value = "1")]
118-
num_shards: u32,
118+
max_num_shards: u32,
119119
}
120120

121121
fn main() {
@@ -248,7 +248,7 @@ fn main() {
248248
.unwrap_or_default();
249249

250250
let max_steps = args.max_steps.unwrap_or(usize::MAX);
251-
let shards = Shards::new(args.shard_id as usize, args.num_shards as usize);
251+
let shards = Shards::new(args.shard_id as usize, args.max_num_shards as usize);
252252

253253
match (args.pcs, args.field) {
254254
(PcsKind::Basefold, FieldType::Goldilocks) => {

ceno_zkvm/src/e2e.rs

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ pub struct RAMRecord {
112112

113113
pub struct ShardContext<'a> {
114114
shard_id: usize,
115-
num_shards: usize,
115+
max_num_shards: usize,
116116
max_cycle: Cycle,
117117
addr_future_accesses: Cow<'a, HashMap<(WordAddr, Cycle), Cycle>>,
118118
read_thread_based_record_storage: Either<
@@ -131,7 +131,7 @@ impl<'a> Default for ShardContext<'a> {
131131
let max_threads = max_usable_threads();
132132
Self {
133133
shard_id: 0,
134-
num_shards: 1,
134+
max_num_shards: 1,
135135
max_cycle: Cycle::default(),
136136
addr_future_accesses: Cow::Owned(HashMap::new()),
137137
read_thread_based_record_storage: Either::Left(
@@ -154,20 +154,27 @@ impl<'a> Default for ShardContext<'a> {
154154
impl<'a> ShardContext<'a> {
155155
pub fn new(
156156
shard_id: usize,
157-
num_shards: usize,
157+
max_num_shards: usize,
158158
executed_instructions: usize,
159159
addr_future_accesses: HashMap<(WordAddr, Cycle), Cycle>,
160160
) -> Self {
161+
// current strategy: at least each shard deal with one instruction
162+
let max_num_shards = max_num_shards.min(executed_instructions);
163+
assert!(
164+
shard_id < max_num_shards,
165+
"implement mechanism to skip current shard proof"
166+
);
167+
161168
let max_threads = max_usable_threads();
162169
// let max_record_per_thread = max_insts.div_ceil(max_threads as u64);
163-
let expected_inst_per_shard = executed_instructions.div_ceil(num_shards) as usize;
170+
let expected_inst_per_shard = executed_instructions.div_ceil(max_num_shards) as usize;
164171
let max_cycle = (executed_instructions + 1) * 4; // cycle start from 4
165-
let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * 4).max(4)
166-
..((shard_id + 1) * expected_inst_per_shard * 4).min(max_cycle);
172+
let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * 4 + 4)
173+
..((shard_id + 1) * expected_inst_per_shard * 4 + 4).min(max_cycle);
167174

168175
ShardContext {
169176
shard_id,
170-
num_shards,
177+
max_num_shards,
171178
max_cycle: max_cycle as Cycle,
172179
addr_future_accesses: Cow::Owned(addr_future_accesses),
173180
// TODO with_capacity optimisation
@@ -201,7 +208,7 @@ impl<'a> ShardContext<'a> {
201208
.zip(write_thread_based_record_storage.iter_mut())
202209
.map(|(read, write)| ShardContext {
203210
shard_id: self.shard_id,
204-
num_shards: self.num_shards,
211+
max_num_shards: self.max_num_shards,
205212
max_cycle: self.max_cycle,
206213
addr_future_accesses: Cow::Borrowed(self.addr_future_accesses.as_ref()),
207214
read_thread_based_record_storage: Either::Right(read),
@@ -220,14 +227,28 @@ impl<'a> ShardContext<'a> {
220227

221228
#[inline(always)]
222229
pub fn is_last_shard(&self) -> bool {
223-
self.shard_id == self.num_shards - 1
230+
self.shard_id == self.max_num_shards - 1
224231
}
225232

226233
#[inline(always)]
227234
pub fn is_current_shard_cycle(&self, cycle: Cycle) -> bool {
228235
self.cur_shard_cycle_range.contains(&(cycle as usize))
229236
}
230237

238+
#[inline(always)]
239+
pub fn aligned_prev_ts(&self, prev_cycle: Cycle) -> Cycle {
240+
let mut ts = prev_cycle.saturating_sub(self.cur_shard_cycle_range.start as Cycle);
241+
if ts < 4 {
242+
ts = 0
243+
}
244+
ts
245+
}
246+
247+
pub fn current_shard_offset_cycle(&self) -> Cycle {
248+
// `-4` as cycle of each local shard start from 4
249+
(self.cur_shard_cycle_range.start as Cycle) - 4
250+
}
251+
231252
#[inline(always)]
232253
pub fn send(
233254
&mut self,
@@ -475,7 +496,7 @@ pub fn emulate_program<'a>(
475496

476497
let shard_ctx = ShardContext::new(
477498
shards.shard_id,
478-
shards.num_shards,
499+
shards.max_num_shards,
479500
insts,
480501
vm.take_tracer().next_accesses(),
481502
);

ceno_zkvm/src/instructions/riscv/ecall_base.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ impl<E: ExtensionField, const REG_ID: usize, const RW: bool> OpFixedRS<E, REG_ID
6969
cycle: Cycle,
7070
op: &WriteOp,
7171
) -> Result<(), ZKVMError> {
72-
set_val!(instance, self.prev_ts, op.previous_cycle);
72+
let shard_prev_ts = shard_ctx.aligned_prev_ts(op.previous_cycle);
73+
set_val!(instance, self.prev_ts, shard_prev_ts);
7374

7475
// Register state
7576
if let Some(prev_value) = self.prev_value.as_ref() {
@@ -86,7 +87,7 @@ impl<E: ExtensionField, const REG_ID: usize, const RW: bool> OpFixedRS<E, REG_ID
8687
};
8788
// Register write
8889
self.lt_cfg
89-
.assign_instance(instance, lk_multiplicity, op.previous_cycle, cycle)?;
90+
.assign_instance(instance, lk_multiplicity, shard_prev_ts, cycle)?;
9091

9192
shard_ctx.send(
9293
RAMType::Register,

ceno_zkvm/src/instructions/riscv/insn_base.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,15 @@ impl<E: ExtensionField> ReadRS1<E> {
113113
step: &StepRecord,
114114
) -> Result<(), ZKVMError> {
115115
let op = step.rs1().expect("rs1 op");
116+
let shard_prev_ts = shard_ctx.aligned_prev_ts(op.previous_cycle);
116117
set_val!(instance, self.id, op.register_index() as u64);
117-
set_val!(instance, self.prev_ts, op.previous_cycle);
118+
set_val!(instance, self.prev_ts, shard_prev_ts);
118119

119120
// Register read
120121
self.lt_cfg.assign_instance(
121122
instance,
122123
lk_multiplicity,
123-
op.previous_cycle,
124+
shard_prev_ts,
124125
step.cycle() + Tracer::SUBCYCLE_RS1,
125126
)?;
126127
shard_ctx.send(
@@ -177,14 +178,15 @@ impl<E: ExtensionField> ReadRS2<E> {
177178
step: &StepRecord,
178179
) -> Result<(), ZKVMError> {
179180
let op = step.rs2().expect("rs2 op");
181+
let shard_prev_ts = shard_ctx.aligned_prev_ts(op.previous_cycle);
180182
set_val!(instance, self.id, op.register_index() as u64);
181-
set_val!(instance, self.prev_ts, op.previous_cycle);
183+
set_val!(instance, self.prev_ts, shard_prev_ts);
182184

183185
// Register read
184186
self.lt_cfg.assign_instance(
185187
instance,
186188
lk_multiplicity,
187-
op.previous_cycle,
189+
shard_prev_ts,
188190
step.cycle() + Tracer::SUBCYCLE_RS2,
189191
)?;
190192

@@ -255,8 +257,9 @@ impl<E: ExtensionField> WriteRD<E> {
255257
cycle: Cycle,
256258
op: &WriteOp,
257259
) -> Result<(), ZKVMError> {
260+
let shard_prev_ts = shard_ctx.aligned_prev_ts(op.previous_cycle);
258261
set_val!(instance, self.id, op.register_index() as u64);
259-
set_val!(instance, self.prev_ts, op.previous_cycle);
262+
set_val!(instance, self.prev_ts, shard_prev_ts);
260263

261264
// Register state
262265
self.prev_value.assign_limbs(
@@ -268,7 +271,7 @@ impl<E: ExtensionField> WriteRD<E> {
268271
self.lt_cfg.assign_instance(
269272
instance,
270273
lk_multiplicity,
271-
op.previous_cycle,
274+
shard_prev_ts,
272275
cycle + Tracer::SUBCYCLE_RD,
273276
)?;
274277
shard_ctx.send(
@@ -323,14 +326,15 @@ impl<E: ExtensionField> ReadMEM<E> {
323326
step: &StepRecord,
324327
) -> Result<(), ZKVMError> {
325328
let op = step.memory_op().unwrap();
329+
let shard_prev_ts = shard_ctx.aligned_prev_ts(op.previous_cycle);
326330
// Memory state
327-
set_val!(instance, self.prev_ts, op.previous_cycle);
331+
set_val!(instance, self.prev_ts, shard_prev_ts);
328332

329333
// Memory read
330334
self.lt_cfg.assign_instance(
331335
instance,
332336
lk_multiplicity,
333-
op.previous_cycle,
337+
shard_prev_ts,
334338
step.cycle() + Tracer::SUBCYCLE_MEM,
335339
)?;
336340

@@ -395,12 +399,13 @@ impl WriteMEM {
395399
cycle: Cycle,
396400
op: &WriteOp,
397401
) -> Result<(), ZKVMError> {
398-
set_val!(instance, self.prev_ts, op.previous_cycle);
402+
let shard_prev_ts = shard_ctx.aligned_prev_ts(op.previous_cycle);
403+
set_val!(instance, self.prev_ts, shard_prev_ts);
399404

400405
self.lt_cfg.assign_instance(
401406
instance,
402407
lk_multiplicity,
403-
op.previous_cycle,
408+
shard_prev_ts,
404409
cycle + Tracer::SUBCYCLE_MEM,
405410
)?;
406411

ceno_zkvm/src/instructions/riscv/rv32im.rs

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -414,35 +414,46 @@ impl<E: ExtensionField> Rv32imConfig<E> {
414414
let mut bn254_double_records = Vec::new();
415415
let mut secp256k1_add_records = Vec::new();
416416
let mut secp256k1_double_records = Vec::new();
417-
steps.into_iter().for_each(|record| {
418-
let insn_kind = record.insn.kind;
419-
match insn_kind {
420-
// ecall / halt
421-
InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => {
422-
halt_records.push(record);
417+
let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle();
418+
steps
419+
.into_iter()
420+
.filter_map(|mut step| {
421+
if shard_ctx.is_current_shard_cycle(step.cycle()) {
422+
step.cycle = step.cycle() - current_shard_offset_cycle;
423+
Some(step)
424+
} else {
425+
None
423426
}
424-
InsnKind::ECALL if record.rs1().unwrap().value == KeccakSpec::CODE => {
425-
keccak_records.push(record);
427+
})
428+
.for_each(|record| {
429+
let insn_kind = record.insn.kind;
430+
match insn_kind {
431+
// ecall / halt
432+
InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => {
433+
halt_records.push(record);
434+
}
435+
InsnKind::ECALL if record.rs1().unwrap().value == KeccakSpec::CODE => {
436+
keccak_records.push(record);
437+
}
438+
InsnKind::ECALL if record.rs1().unwrap().value == Bn254AddSpec::CODE => {
439+
bn254_add_records.push(record);
440+
}
441+
InsnKind::ECALL if record.rs1().unwrap().value == Bn254DoubleSpec::CODE => {
442+
bn254_double_records.push(record);
443+
}
444+
InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1AddSpec::CODE => {
445+
secp256k1_add_records.push(record);
446+
}
447+
InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DoubleSpec::CODE => {
448+
secp256k1_double_records.push(record);
449+
}
450+
// other type of ecalls are handled by dummy ecall instruction
451+
_ => {
452+
// it's safe to unwrap as all_records are initialized with Vec::new()
453+
all_records.get_mut(&insn_kind).unwrap().push(record);
454+
}
426455
}
427-
InsnKind::ECALL if record.rs1().unwrap().value == Bn254AddSpec::CODE => {
428-
bn254_add_records.push(record);
429-
}
430-
InsnKind::ECALL if record.rs1().unwrap().value == Bn254DoubleSpec::CODE => {
431-
bn254_double_records.push(record);
432-
}
433-
InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1AddSpec::CODE => {
434-
secp256k1_add_records.push(record);
435-
}
436-
InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DoubleSpec::CODE => {
437-
secp256k1_double_records.push(record);
438-
}
439-
// other type of ecalls are handled by dummy ecall instruction
440-
_ => {
441-
// it's safe to unwrap as all_records are initialized with Vec::new()
442-
all_records.get_mut(&insn_kind).unwrap().push(record);
443-
}
444-
}
445-
});
456+
});
446457

447458
for (insn_kind, (_, records)) in
448459
izip!(InsnKind::iter(), &all_records).sorted_by_key(|(_, (_, a))| Reverse(a.len()))

0 commit comments

Comments
 (0)