Skip to content

Commit 517f8d5

Browse files
committed
cleanup
1 parent f347310 commit 517f8d5

File tree

4 files changed

+22
-18
lines changed

4 files changed

+22
-18
lines changed

ceno_cli/src/commands/common_args/ceno.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use ceno_zkvm::{
1313
use clap::Args;
1414
use ff_ext::{BabyBearExt4, ExtensionField, GoldilocksExt2};
1515

16+
use ceno_emul::shards::Shards;
1617
use mpcs::{
1718
Basefold, BasefoldRSParams, PolynomialCommitmentScheme, SecurityLevel, Whir, WhirDefaultSpec,
1819
};
@@ -78,6 +79,14 @@ pub struct CenoOptions {
7879
#[arg(long)]
7980
pub out_vk: Option<PathBuf>,
8081

82+
/// shard id
83+
#[arg(long, default_value = "0")]
84+
shard_id: u32,
85+
86+
/// number of total shards.
87+
#[arg(long, default_value = "1")]
88+
max_num_shards: u32,
89+
8190
/// Profiling granularity.
8291
/// Setting any value restricts logs to profiling information
8392
#[arg(long)]
@@ -337,6 +346,7 @@ fn run_elf_inner<
337346
std::fs::read(elf_path).context(format!("failed to read {}", elf_path.display()))?;
338347
let program = Program::load_elf(&elf_bytes, u32::MAX).context("failed to load elf")?;
339348
print_cargo_message("Loaded", format_args!("{}", elf_path.display()));
349+
let shards = Shards::new(options.shard_id as usize, options.max_num_shards as usize);
340350

341351
let public_io = options
342352
.read_public_io()
@@ -385,6 +395,7 @@ fn run_elf_inner<
385395
create_prover(backend.clone()),
386396
program,
387397
platform,
398+
shards,
388399
&hints,
389400
&public_io,
390401
options.max_steps,

ceno_zkvm/src/e2e.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ impl<'a> Default for ShardContext<'a> {
142142
.map(|_| BTreeMap::new())
143143
.collect::<Vec<_>>(),
144144
),
145-
cur_shard_cycle_range: 0..usize::MAX,
145+
cur_shard_cycle_range: Tracer::SUBCYCLES_PER_INSN as usize..usize::MAX,
146146
}
147147
}
148148
}
@@ -161,12 +161,15 @@ impl<'a> ShardContext<'a> {
161161
"implement mechanism to skip current shard proof"
162162
);
163163

164+
let subcycle_per_insn = Tracer::SUBCYCLES_PER_INSN as usize;
164165
let max_threads = max_usable_threads();
165166
// let max_record_per_thread = max_insts.div_ceil(max_threads as u64);
166167
let expected_inst_per_shard = executed_instructions.div_ceil(max_num_shards) as usize;
167-
let max_cycle = (executed_instructions + 1) * 4; // cycle start from 4
168-
let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * 4 + 4)
169-
..((shard_id + 1) * expected_inst_per_shard * 4 + 4).min(max_cycle);
168+
let max_cycle = (executed_instructions + 1) * subcycle_per_insn; // cycle start from subcycle_per_insn
169+
let cur_shard_cycle_range = (shard_id * expected_inst_per_shard * subcycle_per_insn
170+
+ subcycle_per_insn)
171+
..((shard_id + 1) * expected_inst_per_shard * subcycle_per_insn + subcycle_per_insn)
172+
.min(max_cycle);
170173

171174
ShardContext {
172175
shard_id,
@@ -248,15 +251,15 @@ impl<'a> ShardContext<'a> {
248251
#[inline(always)]
249252
pub fn aligned_prev_ts(&self, prev_cycle: Cycle) -> Cycle {
250253
let mut ts = prev_cycle.saturating_sub(self.cur_shard_cycle_range.start as Cycle);
251-
if ts < 4 {
254+
if ts < Tracer::SUBCYCLES_PER_INSN {
252255
ts = 0
253256
}
254257
ts
255258
}
256259

257260
pub fn current_shard_offset_cycle(&self) -> Cycle {
258-
// `-4` as cycle of each local shard start from 4
259-
(self.cur_shard_cycle_range.start as Cycle) - 4
261+
// cycle of each local shard start from Tracer::SUBCYCLES_PER_INSN
262+
(self.cur_shard_cycle_range.start as Cycle) - Tracer::SUBCYCLES_PER_INSN
260263
}
261264

262265
#[inline(always)]
@@ -383,8 +386,6 @@ pub fn emulate_program<'a>(
383386
vm.get_pc().into(),
384387
end_cycle,
385388
shards.shard_id as u32,
386-
!shards.is_first_shard(), // first shard disable global read
387-
!shards.is_last_shard(), // last shard disable global write
388389
io_init.iter().map(|rec| rec.value).collect_vec(),
389390
);
390391

ceno_zkvm/src/scheme.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ pub struct PublicValues {
7373
end_pc: u32,
7474
end_cycle: u64,
7575
shard_id: u32,
76-
mem_bus_with_read: bool,
77-
mem_bus_with_write: bool,
7876
public_io: Vec<u32>,
7977
}
8078

@@ -86,8 +84,6 @@ impl PublicValues {
8684
end_pc: u32,
8785
end_cycle: u64,
8886
shard_id: u32,
89-
mem_bus_with_read: bool,
90-
mem_bus_with_write: bool,
9187
public_io: Vec<u32>,
9288
) -> Self {
9389
Self {
@@ -97,8 +93,6 @@ impl PublicValues {
9793
end_pc,
9894
end_cycle,
9995
shard_id,
100-
mem_bus_with_read,
101-
mem_bus_with_write,
10296
public_io,
10397
}
10498
}
@@ -113,8 +107,6 @@ impl PublicValues {
113107
vec![E::BaseField::from_canonical_u32(self.end_pc)],
114108
vec![E::BaseField::from_canonical_u64(self.end_cycle)],
115109
vec![E::BaseField::from_canonical_u32(self.shard_id)],
116-
vec![E::BaseField::from_bool(self.mem_bus_with_read)],
117-
vec![E::BaseField::from_bool(self.mem_bus_with_write)],
118110
]
119111
.into_iter()
120112
.chain(

ceno_zkvm/src/scheme/tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ fn test_single_add_instance_e2e() {
370370
.assign_table_circuit::<ProgramTableCircuit<E>>(&zkvm_cs, &prog_config, &program)
371371
.unwrap();
372372

373-
let pi = PublicValues::new(0, 0, 0, 0, 0, vec![0]);
373+
let pi = PublicValues::new(0, 0, 0, 0, 0, 0, vec![0]);
374374
let transcript = BasicTranscript::new(b"riscv");
375375
let zkvm_proof = prover
376376
.create_proof(zkvm_witness, pi, transcript)

0 commit comments

Comments
 (0)