Skip to content

Commit 69dd886

Browse files
authored
Fix: ShardRamCircuit runs too slow (#1117)
# Issue There is recurring pattern in Poseidon2 gadget which incurs a lot of monomial terms. I.e. 1. `state'[i] = \sum_j M[i][j] * state[j]`; 2. `state''[i] = state'[i]^3` => this will create a lot of terms due to the nature of distributive law. ## Walk around The temporary solution to this issue is to allocate polynomials to store the value after each linear layer operation (1). | before/after| num_polys | num monomial terms of zerocheck | |------------|------------|------------------------------------| | before | 329 | 7139 | | after | 374 (16*2 + 13 ⬆️) | 957 |
1 parent 584d9d9 commit 69dd886

File tree

3 files changed

+122
-20
lines changed

3 files changed

+122
-20
lines changed

ceno_zkvm/src/gadgets/poseidon2.rs

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ pub struct Poseidon2Config<
7878
const HALF_FULL_ROUNDS: usize,
7979
const PARTIAL_ROUNDS: usize,
8080
> {
81-
cols: Vec<WitIn>,
81+
p3_cols: Vec<WitIn>, // columns in the plonky3-air
82+
post_linear_layer_cols: Vec<WitIn>, /* additional columns to hold the state after linear layers */
8283
constants: RoundConstants<E::BaseField, STATE_WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
8384
}
8485

@@ -141,10 +142,6 @@ impl<
141142
}
142143
(7, 1) => {
143144
let committed_x3: Expression<E> = sbox.0[0].clone();
144-
// TODO: avoid x^3 as x may have ~STATE_WIDTH terms after the linear layer
145-
// we can allocate one more column to store x^2 (which has ~STATE_WIDTH^2 terms)
146-
// then x^3 = x * x^2
147-
// but this will increase the number of columns (by FULL_ROUNDS * STATE_WIDTH + PARTIAL_ROUNDS)
148145
cb.require_zero(|| "x3 = x.cube()", committed_x3.clone() - x.cube())?;
149146
committed_x3.square() * x.clone()
150147
}
@@ -169,7 +166,7 @@ impl<
169166
}
170167
Self::external_linear_layer(state);
171168
for (state_i, post_i) in state.iter_mut().zip_eq(full_round.post.iter()) {
172-
cb.require_zero(|| "post_i = state_i", state_i.clone() - post_i)?;
169+
cb.require_equal(|| "post_i = state_i", state_i.clone(), post_i.clone())?;
173170
*state_i = post_i.clone();
174171
}
175172

@@ -178,11 +175,18 @@ impl<
178175

179176
fn eval_partial_round(
180177
state: &mut [Expression<E>; STATE_WIDTH],
178+
post_linear_layer: &WitIn,
181179
partial_round: &PartialRound<Expression<E>, STATE_WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
182180
round_constant: &E::BaseField,
183181
cb: &mut CircuitBuilder<E>,
184182
) -> Result<(), CircuitBuilderError> {
185-
state[0] = state[0].clone() + round_constant.expr();
183+
cb.require_equal(
184+
|| "post_linear_layer[0] = state[0]",
185+
post_linear_layer.expr(),
186+
state[0].clone() + round_constant.expr(),
187+
)?;
188+
state[0] = post_linear_layer.expr();
189+
186190
Self::eval_sbox(&partial_round.sbox, &mut state[0], cb)?;
187191

188192
cb.require_zero(
@@ -231,17 +235,40 @@ impl<
231235
PARTIAL_ROUNDS,
232236
>,
233237
) -> Self {
234-
let num_cols =
238+
let num_p3_cols =
235239
num_cols::<STATE_WIDTH, SBOX_DEGREE, SBOX_REGISTERS, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>(
236240
);
237-
let cols = from_fn(|| Some(cb.create_witin(|| "poseidon2 col")))
238-
.take(num_cols)
241+
let p3_cols = from_fn(|| Some(cb.create_witin(|| "poseidon2 col")))
242+
.take(num_p3_cols)
239243
.collect::<Vec<_>>();
240-
let mut col_exprs = cols
244+
let mut col_exprs = p3_cols
241245
.iter()
242246
.map(|c| c.expr())
243247
.collect::<Vec<Expression<E>>>();
244248

249+
// allocate columns to cache the state after each linear layer
250+
// 1. before 0th full round
251+
let mut post_linear_layer_cols = (0..STATE_WIDTH)
252+
.map(|j| {
253+
cb.create_witin(|| format!("[before 0th full round] post linear layer col[{j}]"))
254+
})
255+
.collect::<Vec<WitIn>>();
256+
// 2. before each partial round
257+
for i in 0..PARTIAL_ROUNDS {
258+
post_linear_layer_cols.push(cb.create_witin(|| {
259+
format!("[round {}] post linear layer col", i + HALF_FULL_ROUNDS)
260+
}));
261+
}
262+
// 3. before HALF_FULL_ROUNDS-th full round
263+
post_linear_layer_cols.extend((0..STATE_WIDTH).map(|j| {
264+
cb.create_witin(|| {
265+
format!(
266+
"[before {}th full round] post linear layer col[{j}]",
267+
HALF_FULL_ROUNDS
268+
)
269+
})
270+
}));
271+
245272
let poseidon2_cols: &mut Poseidon2Cols<
246273
Expression<E>,
247274
STATE_WIDTH,
@@ -254,6 +281,23 @@ impl<
254281
// external linear layer
255282
Self::external_linear_layer(&mut poseidon2_cols.inputs);
256283

284+
// after linear layer, each state_i has ~STATE_WIDTH terms
285+
// therefore, we want to reduce that to one as the number of terms
286+
// after sbox(state_i + rc_i) = (state_i + rc_i)^d will explode
287+
poseidon2_cols
288+
.inputs
289+
.iter_mut()
290+
.zip_eq(post_linear_layer_cols[0..STATE_WIDTH].iter())
291+
.for_each(|(input, post_linear)| {
292+
cb.require_equal(
293+
|| "post_linear_layer = input",
294+
post_linear.expr(),
295+
input.clone(),
296+
)
297+
.unwrap();
298+
*input = post_linear.expr();
299+
});
300+
257301
// eval full round
258302
for round in 0..HALF_FULL_ROUNDS {
259303
Self::eval_full_round(
@@ -269,15 +313,27 @@ impl<
269313
for round in 0..PARTIAL_ROUNDS {
270314
Self::eval_partial_round(
271315
&mut poseidon2_cols.inputs,
316+
&post_linear_layer_cols[STATE_WIDTH + round],
272317
&poseidon2_cols.partial_rounds[round],
273318
&round_constants.partial_round_constants[round],
274319
cb,
275320
)
276321
.unwrap();
277322
}
278323

279-
// TODO: after the last partial round, each state_i has ~STATE_WIDTH terms
280-
// which will make the next full round to have many terms
324+
poseidon2_cols
325+
.inputs
326+
.iter_mut()
327+
.zip_eq(post_linear_layer_cols[STATE_WIDTH + PARTIAL_ROUNDS..].iter())
328+
.for_each(|(input, post_linear)| {
329+
cb.require_equal(
330+
|| "post_linear_layer = input",
331+
post_linear.expr(),
332+
input.clone(),
333+
)
334+
.unwrap();
335+
*input = post_linear.expr();
336+
});
281337

282338
// eval full round
283339
for round in 0..HALF_FULL_ROUNDS {
@@ -291,13 +347,14 @@ impl<
291347
}
292348

293349
Poseidon2Config {
294-
cols,
350+
p3_cols,
351+
post_linear_layer_cols,
295352
constants: round_constants,
296353
}
297354
}
298355

299356
pub fn inputs(&self) -> Vec<Expression<E>> {
300-
let col_exprs = self.cols.iter().map(|c| c.expr()).collect::<Vec<_>>();
357+
let col_exprs = self.p3_cols.iter().map(|c| c.expr()).collect::<Vec<_>>();
301358

302359
let poseidon2_cols: &Poseidon2Cols<
303360
Expression<E>,
@@ -312,7 +369,7 @@ impl<
312369
}
313370

314371
pub fn output(&self) -> Vec<Expression<E>> {
315-
let col_exprs = self.cols.iter().map(|c| c.expr()).collect::<Vec<_>>();
372+
let col_exprs = self.p3_cols.iter().map(|c| c.expr()).collect::<Vec<_>>();
316373

317374
let poseidon2_cols: &Poseidon2Cols<
318375
Expression<E>,
@@ -330,19 +387,29 @@ impl<
330387
.unwrap()
331388
}
332389

390+
fn num_p3_cols(&self) -> usize {
391+
self.p3_cols.len()
392+
}
393+
394+
pub fn num_cols(&self) -> usize {
395+
self.p3_cols.len() + self.post_linear_layer_cols.len()
396+
}
397+
333398
pub fn assign_instance(
334399
&self,
335400
instance: &mut [E::BaseField],
336401
state: [E::BaseField; STATE_WIDTH],
337402
) {
403+
let (p3_cols, post_linear_layer_cols) = instance.split_at_mut(self.num_p3_cols());
404+
338405
let poseidon2_cols: &mut Poseidon2Cols<
339406
E::BaseField,
340407
STATE_WIDTH,
341408
SBOX_DEGREE,
342409
SBOX_REGISTERS,
343410
HALF_FULL_ROUNDS,
344411
PARTIAL_ROUNDS,
345-
> = instance.borrow_mut();
412+
> = p3_cols.borrow_mut();
346413

347414
generate_trace_rows_for_perm::<
348415
E::BaseField,
@@ -352,7 +419,12 @@ impl<
352419
SBOX_REGISTERS,
353420
HALF_FULL_ROUNDS,
354421
PARTIAL_ROUNDS,
355-
>(poseidon2_cols, state, &self.constants);
422+
>(
423+
poseidon2_cols,
424+
post_linear_layer_cols,
425+
state,
426+
&self.constants,
427+
);
356428
}
357429
}
358430

@@ -376,6 +448,7 @@ fn generate_trace_rows_for_perm<
376448
HALF_FULL_ROUNDS,
377449
PARTIAL_ROUNDS,
378450
>,
451+
post_linear_layers: &mut [F],
379452
mut state: [F; WIDTH],
380453
constants: &RoundConstants<F, WIDTH, HALF_FULL_ROUNDS, PARTIAL_ROUNDS>,
381454
) {
@@ -389,6 +462,15 @@ fn generate_trace_rows_for_perm<
389462

390463
LinearLayers::external_linear_layer(&mut state);
391464

465+
// 1. before 0th full round
466+
// post_linear_layer[i] = state[i]
467+
post_linear_layers[0..WIDTH]
468+
.iter_mut()
469+
.zip(state.iter())
470+
.for_each(|(post, &x)| {
471+
*post = x;
472+
});
473+
392474
for (full_round, constants) in perm
393475
.beginning_full_rounds
394476
.iter_mut()
@@ -399,18 +481,29 @@ fn generate_trace_rows_for_perm<
399481
);
400482
}
401483

402-
for (partial_round, constant) in perm
484+
for (i, (partial_round, constant)) in perm
403485
.partial_rounds
404486
.iter_mut()
405487
.zip(&constants.partial_round_constants)
488+
.enumerate()
406489
{
407490
generate_partial_round::<F, LinearLayers, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>(
408491
&mut state,
492+
&mut post_linear_layers[WIDTH + i],
409493
partial_round,
410494
*constant,
411495
);
412496
}
413497

498+
// 3. before HALF_FULL_ROUNDS-th full round
499+
// post_linear_layer[i] = state[i]
500+
post_linear_layers[WIDTH + PARTIAL_ROUNDS..]
501+
.iter_mut()
502+
.zip(state.iter())
503+
.for_each(|(post, &x)| {
504+
*post = x;
505+
});
506+
414507
for (full_round, constants) in perm
415508
.ending_full_rounds
416509
.iter_mut()
@@ -459,10 +552,12 @@ fn generate_partial_round<
459552
const SBOX_REGISTERS: usize,
460553
>(
461554
state: &mut [F; WIDTH],
555+
post_linear_layer: &mut F,
462556
partial_round: &mut PartialRound<F, WIDTH, SBOX_DEGREE, SBOX_REGISTERS>,
463557
round_constant: F,
464558
) {
465559
state[0] += round_constant;
560+
*post_linear_layer = state[0];
466561
generate_sbox(&mut partial_round.sbox, &mut state[0]);
467562
partial_round.post_sbox = state[0];
468563
LinearLayers::internal_linear_layer(state);

ceno_zkvm/src/tables/shard_ram.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ mod tests {
690690
ShardRamCircuit::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()).unwrap();
691691

692692
// create a bunch of random memory read/write records
693-
let n_global_reads = 1700;
693+
let n_global_reads = 170000;
694694
let n_global_writes = 1420;
695695
let global_reads = (0..n_global_reads)
696696
.map(|i| {

gkr_iop/src/gkr/layer/zerocheck_layer.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,18 @@ impl<E: ExtensionField> ZerocheckLayer<E> for Layer<E> {
165165
self.n_fixed as WitnessId,
166166
self.n_instance,
167167
);
168+
tracing::debug!("main sumcheck degree: {}", zero_expr.degree());
168169
self.main_sumcheck_expression = Some(zero_expr);
169170
self.main_sumcheck_expression_monomial_terms = self
170171
.main_sumcheck_expression
171172
.as_ref()
172173
.map(|expr| expr.get_monomial_terms());
174+
tracing::debug!(
175+
"main sumcheck monomial terms count: {}",
176+
self.main_sumcheck_expression_monomial_terms
177+
.as_ref()
178+
.map_or(0, |terms| terms.len()),
179+
);
173180
exit_span!(span);
174181
}
175182

0 commit comments

Comments
 (0)