@@ -78,7 +78,8 @@ pub struct Poseidon2Config<
7878 const HALF_FULL_ROUNDS : usize ,
7979 const PARTIAL_ROUNDS : usize ,
8080> {
81- cols : Vec < WitIn > ,
81+ p3_cols : Vec < WitIn > , // columns in the plonky3-air
82+ post_linear_layer_cols : Vec < WitIn > , /* additional columns to hold the state after linear layers */
8283 constants : RoundConstants < E :: BaseField , STATE_WIDTH , HALF_FULL_ROUNDS , PARTIAL_ROUNDS > ,
8384}
8485
@@ -141,10 +142,6 @@ impl<
141142 }
142143 ( 7 , 1 ) => {
143144 let committed_x3: Expression < E > = sbox. 0 [ 0 ] . clone ( ) ;
144- // TODO: avoid x^3 as x may have ~STATE_WIDTH terms after the linear layer
145- // we can allocate one more column to store x^2 (which has ~STATE_WIDTH^2 terms)
146- // then x^3 = x * x^2
147- // but this will increase the number of columns (by FULL_ROUNDS * STATE_WIDTH + PARTIAL_ROUNDS)
148145 cb. require_zero ( || "x3 = x.cube()" , committed_x3. clone ( ) - x. cube ( ) ) ?;
149146 committed_x3. square ( ) * x. clone ( )
150147 }
@@ -169,7 +166,7 @@ impl<
169166 }
170167 Self :: external_linear_layer ( state) ;
171168 for ( state_i, post_i) in state. iter_mut ( ) . zip_eq ( full_round. post . iter ( ) ) {
172- cb. require_zero ( || "post_i = state_i" , state_i. clone ( ) - post_i) ?;
169+ cb. require_equal ( || "post_i = state_i" , state_i. clone ( ) , post_i. clone ( ) ) ?;
173170 * state_i = post_i. clone ( ) ;
174171 }
175172
@@ -178,11 +175,18 @@ impl<
178175
179176 fn eval_partial_round (
180177 state : & mut [ Expression < E > ; STATE_WIDTH ] ,
178+ post_linear_layer : & WitIn ,
181179 partial_round : & PartialRound < Expression < E > , STATE_WIDTH , SBOX_DEGREE , SBOX_REGISTERS > ,
182180 round_constant : & E :: BaseField ,
183181 cb : & mut CircuitBuilder < E > ,
184182 ) -> Result < ( ) , CircuitBuilderError > {
185- state[ 0 ] = state[ 0 ] . clone ( ) + round_constant. expr ( ) ;
183+ cb. require_equal (
184+ || "post_linear_layer[0] = state[0]" ,
185+ post_linear_layer. expr ( ) ,
186+ state[ 0 ] . clone ( ) + round_constant. expr ( ) ,
187+ ) ?;
188+ state[ 0 ] = post_linear_layer. expr ( ) ;
189+
186190 Self :: eval_sbox ( & partial_round. sbox , & mut state[ 0 ] , cb) ?;
187191
188192 cb. require_zero (
@@ -231,17 +235,40 @@ impl<
231235 PARTIAL_ROUNDS ,
232236 > ,
233237 ) -> Self {
234- let num_cols =
238+ let num_p3_cols =
235239 num_cols :: < STATE_WIDTH , SBOX_DEGREE , SBOX_REGISTERS , HALF_FULL_ROUNDS , PARTIAL_ROUNDS > (
236240 ) ;
237- let cols = from_fn ( || Some ( cb. create_witin ( || "poseidon2 col" ) ) )
238- . take ( num_cols )
241+ let p3_cols = from_fn ( || Some ( cb. create_witin ( || "poseidon2 col" ) ) )
242+ . take ( num_p3_cols )
239243 . collect :: < Vec < _ > > ( ) ;
240- let mut col_exprs = cols
244+ let mut col_exprs = p3_cols
241245 . iter ( )
242246 . map ( |c| c. expr ( ) )
243247 . collect :: < Vec < Expression < E > > > ( ) ;
244248
249+ // allocate columns to cache the state after each linear layer
250+ // 1. before 0th full round
251+ let mut post_linear_layer_cols = ( 0 ..STATE_WIDTH )
252+ . map ( |j| {
253+ cb. create_witin ( || format ! ( "[before 0th full round] post linear layer col[{j}]" ) )
254+ } )
255+ . collect :: < Vec < WitIn > > ( ) ;
256+ // 2. before each partial round
257+ for i in 0 ..PARTIAL_ROUNDS {
258+ post_linear_layer_cols. push ( cb. create_witin ( || {
259+ format ! ( "[round {}] post linear layer col" , i + HALF_FULL_ROUNDS )
260+ } ) ) ;
261+ }
262+ // 3. before HALF_FULL_ROUNDS-th full round
263+ post_linear_layer_cols. extend ( ( 0 ..STATE_WIDTH ) . map ( |j| {
264+ cb. create_witin ( || {
265+ format ! (
266+ "[before {}th full round] post linear layer col[{j}]" ,
267+ HALF_FULL_ROUNDS
268+ )
269+ } )
270+ } ) ) ;
271+
245272 let poseidon2_cols: & mut Poseidon2Cols <
246273 Expression < E > ,
247274 STATE_WIDTH ,
@@ -254,6 +281,23 @@ impl<
254281 // external linear layer
255282 Self :: external_linear_layer ( & mut poseidon2_cols. inputs ) ;
256283
284+ // after linear layer, each state_i has ~STATE_WIDTH terms
285+ // therefore, we want to reduce that to one as the number of terms
286+ // after sbox(state_i + rc_i) = (state_i + rc_i)^d will explode
287+ poseidon2_cols
288+ . inputs
289+ . iter_mut ( )
290+ . zip_eq ( post_linear_layer_cols[ 0 ..STATE_WIDTH ] . iter ( ) )
291+ . for_each ( |( input, post_linear) | {
292+ cb. require_equal (
293+ || "post_linear_layer = input" ,
294+ post_linear. expr ( ) ,
295+ input. clone ( ) ,
296+ )
297+ . unwrap ( ) ;
298+ * input = post_linear. expr ( ) ;
299+ } ) ;
300+
257301 // eval full round
258302 for round in 0 ..HALF_FULL_ROUNDS {
259303 Self :: eval_full_round (
@@ -269,15 +313,27 @@ impl<
269313 for round in 0 ..PARTIAL_ROUNDS {
270314 Self :: eval_partial_round (
271315 & mut poseidon2_cols. inputs ,
316+ & post_linear_layer_cols[ STATE_WIDTH + round] ,
272317 & poseidon2_cols. partial_rounds [ round] ,
273318 & round_constants. partial_round_constants [ round] ,
274319 cb,
275320 )
276321 . unwrap ( ) ;
277322 }
278323
279- // TODO: after the last partial round, each state_i has ~STATE_WIDTH terms
280- // which will make the next full round to have many terms
324+ poseidon2_cols
325+ . inputs
326+ . iter_mut ( )
327+ . zip_eq ( post_linear_layer_cols[ STATE_WIDTH + PARTIAL_ROUNDS ..] . iter ( ) )
328+ . for_each ( |( input, post_linear) | {
329+ cb. require_equal (
330+ || "post_linear_layer = input" ,
331+ post_linear. expr ( ) ,
332+ input. clone ( ) ,
333+ )
334+ . unwrap ( ) ;
335+ * input = post_linear. expr ( ) ;
336+ } ) ;
281337
282338 // eval full round
283339 for round in 0 ..HALF_FULL_ROUNDS {
@@ -291,13 +347,14 @@ impl<
291347 }
292348
293349 Poseidon2Config {
294- cols,
350+ p3_cols,
351+ post_linear_layer_cols,
295352 constants : round_constants,
296353 }
297354 }
298355
299356 pub fn inputs ( & self ) -> Vec < Expression < E > > {
300- let col_exprs = self . cols . iter ( ) . map ( |c| c. expr ( ) ) . collect :: < Vec < _ > > ( ) ;
357+ let col_exprs = self . p3_cols . iter ( ) . map ( |c| c. expr ( ) ) . collect :: < Vec < _ > > ( ) ;
301358
302359 let poseidon2_cols: & Poseidon2Cols <
303360 Expression < E > ,
@@ -312,7 +369,7 @@ impl<
312369 }
313370
314371 pub fn output ( & self ) -> Vec < Expression < E > > {
315- let col_exprs = self . cols . iter ( ) . map ( |c| c. expr ( ) ) . collect :: < Vec < _ > > ( ) ;
372+ let col_exprs = self . p3_cols . iter ( ) . map ( |c| c. expr ( ) ) . collect :: < Vec < _ > > ( ) ;
316373
317374 let poseidon2_cols: & Poseidon2Cols <
318375 Expression < E > ,
@@ -330,19 +387,29 @@ impl<
330387 . unwrap ( )
331388 }
332389
390+ fn num_p3_cols ( & self ) -> usize {
391+ self . p3_cols . len ( )
392+ }
393+
394+ pub fn num_cols ( & self ) -> usize {
395+ self . p3_cols . len ( ) + self . post_linear_layer_cols . len ( )
396+ }
397+
333398 pub fn assign_instance (
334399 & self ,
335400 instance : & mut [ E :: BaseField ] ,
336401 state : [ E :: BaseField ; STATE_WIDTH ] ,
337402 ) {
403+ let ( p3_cols, post_linear_layer_cols) = instance. split_at_mut ( self . num_p3_cols ( ) ) ;
404+
338405 let poseidon2_cols: & mut Poseidon2Cols <
339406 E :: BaseField ,
340407 STATE_WIDTH ,
341408 SBOX_DEGREE ,
342409 SBOX_REGISTERS ,
343410 HALF_FULL_ROUNDS ,
344411 PARTIAL_ROUNDS ,
345- > = instance . borrow_mut ( ) ;
412+ > = p3_cols . borrow_mut ( ) ;
346413
347414 generate_trace_rows_for_perm :: <
348415 E :: BaseField ,
@@ -352,7 +419,12 @@ impl<
352419 SBOX_REGISTERS ,
353420 HALF_FULL_ROUNDS ,
354421 PARTIAL_ROUNDS ,
355- > ( poseidon2_cols, state, & self . constants ) ;
422+ > (
423+ poseidon2_cols,
424+ post_linear_layer_cols,
425+ state,
426+ & self . constants ,
427+ ) ;
356428 }
357429}
358430
@@ -376,6 +448,7 @@ fn generate_trace_rows_for_perm<
376448 HALF_FULL_ROUNDS ,
377449 PARTIAL_ROUNDS ,
378450 > ,
451+ post_linear_layers : & mut [ F ] ,
379452 mut state : [ F ; WIDTH ] ,
380453 constants : & RoundConstants < F , WIDTH , HALF_FULL_ROUNDS , PARTIAL_ROUNDS > ,
381454) {
@@ -389,6 +462,15 @@ fn generate_trace_rows_for_perm<
389462
390463 LinearLayers :: external_linear_layer ( & mut state) ;
391464
465+ // 1. before 0th full round
466+ // post_linear_layer[i] = state[i]
467+ post_linear_layers[ 0 ..WIDTH ]
468+ . iter_mut ( )
469+ . zip ( state. iter ( ) )
470+ . for_each ( |( post, & x) | {
471+ * post = x;
472+ } ) ;
473+
392474 for ( full_round, constants) in perm
393475 . beginning_full_rounds
394476 . iter_mut ( )
@@ -399,18 +481,29 @@ fn generate_trace_rows_for_perm<
399481 ) ;
400482 }
401483
402- for ( partial_round, constant) in perm
484+ for ( i , ( partial_round, constant) ) in perm
403485 . partial_rounds
404486 . iter_mut ( )
405487 . zip ( & constants. partial_round_constants )
488+ . enumerate ( )
406489 {
407490 generate_partial_round :: < F , LinearLayers , WIDTH , SBOX_DEGREE , SBOX_REGISTERS > (
408491 & mut state,
492+ & mut post_linear_layers[ WIDTH + i] ,
409493 partial_round,
410494 * constant,
411495 ) ;
412496 }
413497
498+ // 3. before HALF_FULL_ROUNDS-th full round
499+ // post_linear_layer[i] = state[i]
500+ post_linear_layers[ WIDTH + PARTIAL_ROUNDS ..]
501+ . iter_mut ( )
502+ . zip ( state. iter ( ) )
503+ . for_each ( |( post, & x) | {
504+ * post = x;
505+ } ) ;
506+
414507 for ( full_round, constants) in perm
415508 . ending_full_rounds
416509 . iter_mut ( )
@@ -459,10 +552,12 @@ fn generate_partial_round<
459552 const SBOX_REGISTERS : usize ,
460553> (
461554 state : & mut [ F ; WIDTH ] ,
555+ post_linear_layer : & mut F ,
462556 partial_round : & mut PartialRound < F , WIDTH , SBOX_DEGREE , SBOX_REGISTERS > ,
463557 round_constant : F ,
464558) {
465559 state[ 0 ] += round_constant;
560+ * post_linear_layer = state[ 0 ] ;
466561 generate_sbox ( & mut partial_round. sbox , & mut state[ 0 ] ) ;
467562 partial_round. post_sbox = state[ 0 ] ;
468563 LinearLayers :: internal_linear_layer ( state) ;
0 commit comments