@@ -18,7 +18,6 @@ use crate::{
1818use ceno_emul:: {
1919 Addr , ByteAddr , CENO_PLATFORM , Cycle , EmuContext , InsnKind , IterAddresses , Platform , Program ,
2020 StepRecord , Tracer , VMState , WORD_SIZE , Word , WordAddr , host_utils:: read_all_messages,
21- shards:: Shards ,
2221} ;
2322use clap:: ValueEnum ;
2423use either:: Either ;
@@ -111,9 +110,41 @@ pub struct RAMRecord {
111110 pub value : Word ,
112111}
113112
113+ #[ derive( Clone , Debug ) ]
114+ pub struct Shards {
115+ pub shard_id : usize ,
116+ pub max_num_shards : usize ,
117+ }
118+
119+ impl Shards {
120+ pub fn new ( shard_id : usize , max_num_shards : usize ) -> Self {
121+ assert ! ( shard_id < max_num_shards) ;
122+ Self {
123+ shard_id,
124+ max_num_shards,
125+ }
126+ }
127+
128+ pub fn is_first_shard ( & self ) -> bool {
129+ self . shard_id == 0
130+ }
131+
132+ pub fn is_last_shard ( & self ) -> bool {
133+ self . shard_id == self . max_num_shards - 1
134+ }
135+ }
136+
137+ impl Default for Shards {
138+ fn default ( ) -> Self {
139+ Self {
140+ shard_id : 0 ,
141+ max_num_shards : 1 ,
142+ }
143+ }
144+ }
145+
114146pub struct ShardContext < ' a > {
115- shard_id : usize ,
116- max_num_shards : usize ,
147+ shards : Shards ,
117148 max_cycle : Cycle ,
118149 // TODO optimize this map as it's super huge
119150 addr_future_accesses : Cow < ' a , HashMap < ( WordAddr , Cycle ) , Cycle > > ,
@@ -128,8 +159,7 @@ impl<'a> Default for ShardContext<'a> {
128159 fn default ( ) -> Self {
129160 let max_threads = max_usable_threads ( ) ;
130161 Self {
131- shard_id : 0 ,
132- max_num_shards : 1 ,
162+ shards : Shards :: default ( ) ,
133163 max_cycle : Cycle :: default ( ) ,
134164 addr_future_accesses : Cow :: Owned ( HashMap :: new ( ) ) ,
135165 read_thread_based_record_storage : Either :: Left (
@@ -151,30 +181,29 @@ impl<'a> Default for ShardContext<'a> {
151181
152182impl < ' a > ShardContext < ' a > {
153183 pub fn new (
154- shard_id : usize ,
155- max_num_shards : usize ,
184+ shards : Shards ,
156185 executed_instructions : usize ,
157186 addr_future_accesses : HashMap < ( WordAddr , Cycle ) , Cycle > ,
158187 ) -> Self {
159188 // current strategy: at least each shard deal with one instruction
160- let max_num_shards = max_num_shards. min ( executed_instructions) ;
189+ let max_num_shards = shards . max_num_shards . min ( executed_instructions) ;
161190 assert ! (
162- shard_id < max_num_shards,
191+ shards . shard_id < max_num_shards,
163192 "implement mechanism to skip current shard proof"
164193 ) ;
165194
166195 let subcycle_per_insn = Tracer :: SUBCYCLES_PER_INSN as usize ;
167196 let max_threads = max_usable_threads ( ) ;
168197 let expected_inst_per_shard = executed_instructions. div_ceil ( max_num_shards) ;
169198 let max_cycle = ( executed_instructions + 1 ) * subcycle_per_insn; // cycle start from subcycle_per_insn
170- let cur_shard_cycle_range = ( shard_id * expected_inst_per_shard * subcycle_per_insn
199+ let cur_shard_cycle_range = ( shards . shard_id * expected_inst_per_shard * subcycle_per_insn
171200 + subcycle_per_insn)
172- ..( ( shard_id + 1 ) * expected_inst_per_shard * subcycle_per_insn + subcycle_per_insn)
201+ ..( ( shards. shard_id + 1 ) * expected_inst_per_shard * subcycle_per_insn
202+ + subcycle_per_insn)
173203 . min ( max_cycle) ;
174204
175205 ShardContext {
176- shard_id,
177- max_num_shards,
206+ shards,
178207 max_cycle : max_cycle as Cycle ,
179208 addr_future_accesses : Cow :: Owned ( addr_future_accesses) ,
180209 // TODO with_capacity optimisation
@@ -207,8 +236,7 @@ impl<'a> ShardContext<'a> {
207236 . iter_mut ( )
208237 . zip ( write_thread_based_record_storage. iter_mut ( ) )
209238 . map ( |( read, write) | ShardContext {
210- shard_id : self . shard_id ,
211- max_num_shards : self . max_num_shards ,
239+ shards : self . shards . clone ( ) ,
212240 max_cycle : self . max_cycle ,
213241 addr_future_accesses : Cow :: Borrowed ( self . addr_future_accesses . as_ref ( ) ) ,
214242 read_thread_based_record_storage : Either :: Right ( read) ,
@@ -236,12 +264,12 @@ impl<'a> ShardContext<'a> {
236264
237265 #[ inline( always) ]
238266 pub fn is_first_shard ( & self ) -> bool {
239- self . shard_id == 0
267+ self . shards . shard_id == 0
240268 }
241269
242270 #[ inline( always) ]
243271 pub fn is_last_shard ( & self ) -> bool {
244- self . shard_id == self . max_num_shards - 1
272+ self . shards . shard_id == self . shards . max_num_shards - 1
245273 }
246274
247275 #[ inline( always) ]
@@ -511,12 +539,7 @@ pub fn emulate_program<'a>(
511539 ) ,
512540 ) ;
513541
514- let shard_ctx = ShardContext :: new (
515- shards. shard_id ,
516- shards. max_num_shards ,
517- insts,
518- vm. take_tracer ( ) . next_accesses ( ) ,
519- ) ;
542+ let shard_ctx = ShardContext :: new ( shards. clone ( ) , insts, vm. take_tracer ( ) . next_accesses ( ) ) ;
520543
521544 EmulationResult {
522545 pi,
0 commit comments