Skip to content

Commit d32c71f

Browse files
committed
aligned step cycle and prev_cycle to local version
1 parent 03092e9 commit d32c71f

File tree

20 files changed

+142
-76
lines changed

20 files changed

+142
-76
lines changed

ceno_emul/src/shards.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
pub struct Shards {
22
pub shard_id: usize,
3-
pub num_shards: usize,
3+
pub max_num_shards: usize,
44
}
55

66
impl Shards {
7-
pub fn new(shard_id: usize, num_shards: usize) -> Self {
8-
assert!(shard_id < num_shards);
7+
pub fn new(shard_id: usize, max_num_shards: usize) -> Self {
8+
assert!(shard_id < max_num_shards);
99
Self {
1010
shard_id,
11-
num_shards,
11+
max_num_shards,
1212
}
1313
}
1414

@@ -17,6 +17,6 @@ impl Shards {
1717
}
1818

1919
pub fn is_last_shard(&self) -> bool {
20-
self.shard_id == self.num_shards - 1
20+
self.shard_id == self.max_num_shards - 1
2121
}
2222
}

ceno_emul/src/tracer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::{
2525
/// - Any pair of `rs1 / rs2 / rd` **may be the same**. Then, one op will point to the other op in the same instruction but a different subcycle. The circuits may follow the operations **without special handling** of repeated registers.
2626
#[derive(Clone, Debug, Default, PartialEq, Eq)]
2727
pub struct StepRecord {
28-
cycle: Cycle,
28+
pub cycle: Cycle,
2929
pc: Change<ByteAddr>,
3030
pub insn: Instruction,
3131

ceno_zkvm/src/bin/e2e.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ struct Args {
115115

116116
// number of total shards
117117
#[arg(long, default_value = "1")]
118-
num_shards: u32,
118+
max_num_shards: u32,
119119
}
120120

121121
fn main() {
@@ -248,7 +248,7 @@ fn main() {
248248
.unwrap_or_default();
249249

250250
let max_steps = args.max_steps.unwrap_or(usize::MAX);
251-
let shards = Shards::new(args.shard_id as usize, args.num_shards as usize);
251+
let shards = Shards::new(args.shard_id as usize, args.max_num_shards as usize);
252252

253253
match (args.pcs, args.field) {
254254
(PcsKind::Basefold, FieldType::Goldilocks) => {

ceno_zkvm/src/e2e.rs

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ pub struct RAMRecord {
112112

113113
pub struct ShardContext<'a> {
114114
shard_id: usize,
115-
num_shards: usize,
115+
max_num_shards: usize,
116116
max_cycle: Cycle,
117117
addr_future_accesses: Cow<'a, HashMap<(WordAddr, Cycle), Cycle>>,
118118
read_thread_based_record_storage: Either<
@@ -131,7 +131,7 @@ impl<'a> Default for ShardContext<'a> {
131131
let max_threads = max_usable_threads();
132132
Self {
133133
shard_id: 0,
134-
num_shards: 1,
134+
max_num_shards: 1,
135135
max_cycle: Cycle::default(),
136136
addr_future_accesses: Cow::Owned(HashMap::new()),
137137
read_thread_based_record_storage: Either::Left(
@@ -154,20 +154,27 @@ impl<'a> Default for ShardContext<'a> {
154154
impl<'a> ShardContext<'a> {
155155
pub fn new(
156156
shard_id: usize,
157-
num_shards: usize,
157+
max_num_shards: usize,
158158
executed_instructions: usize,
159159
addr_future_accesses: HashMap<(WordAddr, Cycle), Cycle>,
160160
) -> Self {
161+
// current strategy: at least each shard deal with one instruction
162+
let max_num_shards = max_num_shards.min(executed_instructions);
163+
assert!(
164+
shard_id < max_num_shards,
165+
"implement mechanism to skip current shard proof"
166+
);
167+
161168
let max_threads = max_usable_threads();
162169
// let max_record_per_thread = max_insts.div_ceil(max_threads as u64);
163-
let expected_inst_per_shard = executed_instructions.div_ceil(num_shards) as usize;
170+
let expected_inst_per_shard = executed_instructions.div_ceil(max_num_shards) as usize;
164171
let max_cycle = (executed_instructions + 1) * 4; // cycle start from 4
165-
let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * 4).max(4)
166-
..((shard_id + 1) * expected_inst_per_shard * 4).min(max_cycle);
172+
let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * 4 + 4)
173+
..((shard_id + 1) * expected_inst_per_shard * 4 + 4).min(max_cycle);
167174

168175
ShardContext {
169176
shard_id,
170-
num_shards,
177+
max_num_shards,
171178
max_cycle: max_cycle as Cycle,
172179
addr_future_accesses: Cow::Owned(addr_future_accesses),
173180
// TODO with_capacity optimisation
@@ -201,7 +208,7 @@ impl<'a> ShardContext<'a> {
201208
.zip(write_thread_based_record_storage.iter_mut())
202209
.map(|(read, write)| ShardContext {
203210
shard_id: self.shard_id,
204-
num_shards: self.num_shards,
211+
max_num_shards: self.max_num_shards,
205212
max_cycle: self.max_cycle,
206213
addr_future_accesses: Cow::Borrowed(self.addr_future_accesses.as_ref()),
207214
read_thread_based_record_storage: Either::Right(read),
@@ -220,14 +227,28 @@ impl<'a> ShardContext<'a> {
220227

221228
#[inline(always)]
222229
pub fn is_last_shard(&self) -> bool {
223-
self.shard_id == self.num_shards - 1
230+
self.shard_id == self.max_num_shards - 1
224231
}
225232

226233
#[inline(always)]
227234
pub fn is_current_shard_cycle(&self, cycle: Cycle) -> bool {
228235
self.cur_shard_cycle_range.contains(&(cycle as usize))
229236
}
230237

238+
#[inline(always)]
239+
pub fn aligned_prev_ts(&self, prev_cycle: Cycle) -> Cycle {
240+
let mut ts = prev_cycle.saturating_sub(self.cur_shard_cycle_range.start as Cycle);
241+
if ts < 4 {
242+
ts = 0
243+
}
244+
ts
245+
}
246+
247+
pub fn current_shard_offset_cycle(&self) -> Cycle {
248+
// `-4` as cycle of each local shard start from 4
249+
(self.cur_shard_cycle_range.start as Cycle) - 4
250+
}
251+
231252
#[inline(always)]
232253
pub fn send(
233254
&mut self,
@@ -475,7 +496,7 @@ pub fn emulate_program<'a>(
475496

476497
let shard_ctx = ShardContext::new(
477498
shards.shard_id,
478-
shards.num_shards,
499+
shards.max_num_shards,
479500
insts,
480501
vm.take_tracer().next_accesses(),
481502
);

ceno_zkvm/src/instructions/riscv/b_insn.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ impl<E: ExtensionField> BInstructionConfig<E> {
9393
lk_multiplicity: &mut LkMultiplicity,
9494
step: &StepRecord,
9595
) -> Result<(), ZKVMError> {
96-
self.vm_state.assign_instance(instance, step)?;
96+
self.vm_state.assign_instance(instance, shard_ctx, step)?;
9797
self.rs1
9898
.assign_instance(instance, shard_ctx, lk_multiplicity, step)?;
9999
self.rs2

ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ impl<E: ExtensionField> DummyConfig<E> {
248248
step: &StepRecord,
249249
) -> Result<(), ZKVMError> {
250250
// State in and out
251-
self.vm_state.assign_instance(instance, step)?;
251+
self.vm_state.assign_instance(instance, shard_ctx, step)?;
252252

253253
// Fetch instruction
254254
lk_multiplicity.fetch(step.pc().before.0);

ceno_zkvm/src/instructions/riscv/ecall/keccak.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ impl<E: ExtensionField> Instruction<E> for KeccakInstruction<E> {
223223
[round_index as usize * num_witin..][..num_witin];
224224

225225
// vm_state
226-
config.vm_state.assign_instance(instance, step)?;
226+
config
227+
.vm_state
228+
.assign_instance(instance, &shard_ctx, step)?;
227229

228230
config.ecall_id.assign_op(
229231
instance,

ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ impl<E: ExtensionField, EC: EllipticCurve> Instruction<E>
274274
let ops = &step.syscall().expect("syscall step");
275275

276276
// vm_state
277-
config.vm_state.assign_instance(instance, step)?;
277+
config
278+
.vm_state
279+
.assign_instance(instance, &shard_ctx, step)?;
278280

279281
config.ecall_id.assign_op(
280282
instance,

ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,9 @@ impl<E: ExtensionField, EC: EllipticCurve + WeierstrassParameters> Instruction<E
246246
let ops = &step.syscall().expect("syscall step");
247247

248248
// vm_state
249-
config.vm_state.assign_instance(instance, step)?;
249+
config
250+
.vm_state
251+
.assign_instance(instance, &shard_ctx, step)?;
250252

251253
config.ecall_id.assign_op(
252254
instance,

ceno_zkvm/src/instructions/riscv/ecall_base.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ impl<E: ExtensionField, const REG_ID: usize, const RW: bool> OpFixedRS<E, REG_ID
6969
cycle: Cycle,
7070
op: &WriteOp,
7171
) -> Result<(), ZKVMError> {
72-
set_val!(instance, self.prev_ts, op.previous_cycle);
72+
let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle);
73+
let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle();
74+
let shard_cycle = cycle - current_shard_offset_cycle;
75+
set_val!(instance, self.prev_ts, shard_prev_cycle);
7376

7477
// Register state
7578
if let Some(prev_value) = self.prev_value.as_ref() {
@@ -79,14 +82,20 @@ impl<E: ExtensionField, const REG_ID: usize, const RW: bool> OpFixedRS<E, REG_ID
7982
);
8083
}
8184

82-
let cycle = if RW {
83-
cycle + Tracer::SUBCYCLE_RD
85+
let (shard_cycle, cycle) = if RW {
86+
(
87+
shard_cycle + Tracer::SUBCYCLE_RD,
88+
cycle + Tracer::SUBCYCLE_RD,
89+
)
8490
} else {
85-
cycle + Tracer::SUBCYCLE_RS1
91+
(
92+
shard_cycle + Tracer::SUBCYCLE_RS1,
93+
cycle + Tracer::SUBCYCLE_RS1,
94+
)
8695
};
8796
// Register write
8897
self.lt_cfg
89-
.assign_instance(instance, lk_multiplicity, op.previous_cycle, cycle)?;
98+
.assign_instance(instance, lk_multiplicity, shard_prev_cycle, shard_cycle)?;
9099

91100
shard_ctx.send(
92101
RAMType::Register,

0 commit comments

Comments
 (0)