Skip to content

Commit bf3ef29

Browse files
add the program segment to the preprocessed trace
1 parent d155537 commit bf3ef29

9 files changed

Lines changed: 224 additions & 9 deletions

File tree

stwo_cairo_prover/crates/cairo-air/src/lib.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,19 @@ pub enum PreProcessedTraceVariant {
2323
Canonical,
2424
CanonicalWithoutPedersen,
2525
CanonicalSmall,
26+
CanonicalWithProgram,
2627
}
2728
impl PreProcessedTraceVariant {
28-
pub fn to_preprocessed_trace(&self) -> PreProcessedTrace {
29+
pub fn to_preprocessed_trace(&self, program: &[(u32, [u32; 8])]) -> PreProcessedTrace {
2930
match self {
3031
PreProcessedTraceVariant::Canonical => PreProcessedTrace::canonical(),
3132
PreProcessedTraceVariant::CanonicalWithoutPedersen => {
3233
PreProcessedTrace::canonical_without_pedersen()
3334
}
3435
PreProcessedTraceVariant::CanonicalSmall => PreProcessedTrace::canonical_small(),
36+
PreProcessedTraceVariant::CanonicalWithProgram => {
37+
PreProcessedTrace::canonical_with_program(program)
38+
}
3539
}
3640
}
3741
}

stwo_cairo_prover/crates/cairo-air/src/verifier.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,10 +346,11 @@ pub fn verify_cairo_ex<MC: MerkleChannel>(
346346
pcs_config.mix_into(channel);
347347
let commitment_scheme_verifier = &mut CommitmentSchemeVerifier::<MC>::new(pcs_config);
348348

349+
let preprocessed_trace =
350+
preprocessed_trace_variant.to_preprocessed_trace(&claim.public_data.public_memory.program);
351+
349352
let mut log_sizes = claim.log_sizes();
350-
log_sizes[PREPROCESSED_TRACE_IDX] = preprocessed_trace_variant
351-
.to_preprocessed_trace()
352-
.log_sizes();
353+
log_sizes[PREPROCESSED_TRACE_IDX] = preprocessed_trace.log_sizes();
353354

354355
// Preproccessed trace.
355356
commitment_scheme_verifier.commit(stark_proof.commitments[0], &log_sizes[0], channel);
@@ -375,7 +376,7 @@ pub fn verify_cairo_ex<MC: MerkleChannel>(
375376
&claim,
376377
&interaction_elements,
377378
&interaction_claim,
378-
&preprocessed_trace_variant.to_preprocessed_trace().ids(),
379+
&preprocessed_trace.ids(),
379380
);
380381
let components = component_generator.components();
381382

stwo_cairo_prover/crates/common/src/preprocessed_columns/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ pub mod poseidon;
66
pub mod poseidon_round_keys;
77
pub mod preprocessed_trace;
88
pub mod preprocessed_utils;
9+
pub mod program;
910
#[cfg(feature = "prover")]
1011
pub mod simd_prelude;

stwo_cairo_prover/crates/common/src/preprocessed_columns/preprocessed_trace.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use super::pedersen::{PedersenPoints, PEDERSEN_TABLE_N_COLUMNS};
1111
use super::poseidon::{PoseidonRoundKeys, N_WORDS as POSEIDON_N_WORDS};
1212
#[cfg(feature = "prover")]
1313
use super::simd_prelude::*;
14+
use crate::preprocessed_columns::program;
1415

1516
// Size to initialize the preprocessed trace with for `PreprocessedColumn::BitwiseXor`.
1617
const XOR_N_BITS: [u32; 5] = [4, 7, 8, 9, 10];
@@ -151,6 +152,23 @@ impl PreProcessedTrace {
151152
Self::from_columns(columns)
152153
}
153154

155+
pub fn canonical_with_program(program: &[(u32, [u32; 8])]) -> Self {
156+
let canonical_without_program = Self::canonical().columns;
157+
let curr_program_columns = (0..program::PROGRAM_N_COLUMNS).map(|x| {
158+
Box::new(program::ProgramColumn::new(x, program)) as Box<dyn PreProcessedColumn>
159+
});
160+
let columns = chain!(canonical_without_program, curr_program_columns)
161+
.sorted_by_key(|column| column.log_size())
162+
.collect_vec();
163+
164+
assert!(
165+
columns.iter().map(|col| 1 << col.log_size()).sum::<u32>() == CANONICAL_SIZE,
166+
"Canonical preprocessed trace has unexpected size"
167+
);
168+
169+
Self::from_columns(columns)
170+
}
171+
154172
pub fn log_sizes(&self) -> Vec<u32> {
155173
self.columns.iter().map(|c| c.log_size()).collect()
156174
}
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
use stwo::core::fields::m31::M31;
2+
use stwo_constraint_framework::preprocessed_columns::PreProcessedColumnId;
3+
4+
use super::preprocessed_trace::PreProcessedColumn;
5+
#[cfg(feature = "prover")]
6+
use super::simd_prelude::*;
7+
use crate::prover_types::cpu::FELT252_N_WORDS;
8+
9+
pub const PROGRAM_N_COLUMNS: usize = FELT252_N_WORDS;
10+
11+
/// Extracts a 9-bit limb from a 256-bit value stored as `[u32; 8]` little-endian limbs.
12+
fn extract_9bit_limb(value: &[u32; 8], limb_index: usize) -> u32 {
13+
let bit_offset = limb_index * 9;
14+
let word_index = bit_offset / 32;
15+
let bit_shift = bit_offset % 32;
16+
let mut result = value[word_index] >> bit_shift;
17+
if bit_shift + 9 > 32 {
18+
result |= value[word_index + 1] << (32 - bit_shift);
19+
}
20+
result & 0x1FF
21+
}
22+
23+
#[derive(Debug)]
24+
pub struct ProgramColumn {
25+
col_index: usize,
26+
column_data: Vec<M31>,
27+
}
28+
impl ProgramColumn {
29+
pub fn new(col_index: usize, program: &[(u32, [u32; 8])]) -> Self {
30+
let padded_len = program.len().next_power_of_two();
31+
let column_data = (0..padded_len)
32+
.map(|i| {
33+
if i < program.len() {
34+
M31::from_u32_unchecked(extract_9bit_limb(&program[i].1, col_index))
35+
} else {
36+
M31(0)
37+
}
38+
})
39+
.collect();
40+
Self {
41+
col_index,
42+
column_data,
43+
}
44+
}
45+
46+
pub fn get_data(&self) -> &Vec<M31> {
47+
&self.column_data
48+
}
49+
}
50+
51+
impl PreProcessedColumn for ProgramColumn {
52+
fn log_size(&self) -> u32 {
53+
self.column_data.len().ilog2()
54+
}
55+
56+
fn id(&self) -> PreProcessedColumnId {
57+
PreProcessedColumnId {
58+
id: format!("curr_program_{}", self.col_index),
59+
}
60+
}
61+
62+
#[cfg(feature = "prover")]
63+
fn packed_at(&self, vec_row: usize) -> PackedM31 {
64+
let array = self.get_data()[(vec_row * N_LANES)..((vec_row + 1) * N_LANES)]
65+
.try_into()
66+
.unwrap();
67+
PackedM31::from_array(array)
68+
}
69+
70+
#[cfg(feature = "prover")]
71+
fn gen_column_simd(&self) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
72+
CircleEvaluation::new(
73+
CanonicCoset::new(self.log_size()).circle_domain(),
74+
BaseColumn::from_cpu(self.get_data()),
75+
)
76+
}
77+
}
78+
79+
#[cfg(test)]
80+
pub mod tests {
81+
use super::*;
82+
83+
#[test]
84+
fn test_extract_9bit_limb() {
85+
// Value: 0b_000_111_010_110_001_011_101 = 0x1D6B5 (in 7 limbs of 9 bits)
86+
// limb 0: 101 = 0b_000_101_101 ... let's use a concrete example.
87+
// value[0] = 0x01FF_03FE = bits: limb0=0x1FE(510), limb1=0x1FF(511), ...
88+
// Actually let's just test specific values.
89+
90+
// All zeros.
91+
let value = [0u32; 8];
92+
for i in 0..7 {
93+
assert_eq!(extract_9bit_limb(&value, i), 0);
94+
}
95+
96+
// limb 0 = lower 9 bits of value[0].
97+
let value = [0x1AB, 0, 0, 0, 0, 0, 0, 0]; // 0x1AB = 0b110101011 = 427
98+
assert_eq!(extract_9bit_limb(&value, 0), 0x1AB);
99+
100+
// limb 3 spans value[0] bits 27-31 and value[1] bits 0-3.
101+
// bit_offset = 27, word_index = 0, bit_shift = 27.
102+
// Place 0b_101010111 at bits 27-35:
103+
// value[0] bits 27-31 = lower 5 bits of limb = 0b10111 => value[0] |= 0b10111 << 27
104+
// value[1] bits 0-3 = upper 4 bits of limb = 0b1010 => value[1] |= 0b1010
105+
let value = [0b10111 << 27, 0b1010, 0, 0, 0, 0, 0, 0];
106+
assert_eq!(extract_9bit_limb(&value, 3), 0b101010111);
107+
}
108+
109+
#[test]
110+
fn test_program_column_new() {
111+
let program = vec![
112+
(0u32, [0x1FFu32, 0, 0, 0, 0, 0, 0, 0]), // limb 0 = 0x1FF (511)
113+
(1, [0, 0, 0, 0, 0, 0, 0, 0]), // all zeros
114+
];
115+
116+
let col0 = ProgramColumn::new(0, &program);
117+
// Padded to next power of 2 = 2.
118+
assert_eq!(col0.column_data.len(), 2);
119+
assert_eq!(col0.column_data[0], M31(511));
120+
assert_eq!(col0.column_data[1], M31(0));
121+
122+
let col1 = ProgramColumn::new(1, &program);
123+
// 0x1FF in bits 0-8, limb 1 = bits 9-17 = 0.
124+
assert_eq!(col1.column_data[0], M31(0));
125+
}
126+
127+
#[test]
128+
fn test_program_column_pads_to_power_of_two() {
129+
let program = vec![
130+
(0u32, [1u32, 0, 0, 0, 0, 0, 0, 0]),
131+
(1, [2, 0, 0, 0, 0, 0, 0, 0]),
132+
(2, [3, 0, 0, 0, 0, 0, 0, 0]),
133+
];
134+
135+
let col = ProgramColumn::new(0, &program);
136+
// 3 entries padded to 4.
137+
assert_eq!(col.column_data.len(), 4);
138+
assert_eq!(col.column_data[0], M31(1));
139+
assert_eq!(col.column_data[1], M31(2));
140+
assert_eq!(col.column_data[2], M31(3));
141+
assert_eq!(col.column_data[3], M31(0));
142+
}
143+
144+
#[test]
145+
fn test_program_column_log_size() {
146+
let program = vec![(0u32, [0u32; 8]), (1, [0; 8]), (2, [0; 8])];
147+
let col = ProgramColumn::new(0, &program);
148+
// 3 entries padded to 4 = 2^2.
149+
assert_eq!(col.log_size(), 2);
150+
}
151+
152+
#[test]
153+
fn test_program_column_id() {
154+
let program = vec![(0u32, [0u32; 8])];
155+
let col = ProgramColumn::new(3, &program);
156+
assert_eq!(col.id().id, "curr_program_3");
157+
}
158+
159+
#[cfg(feature = "prover")]
160+
#[test]
161+
fn test_packed_at() {
162+
// Create a program with N_LANES entries so packed_at(0) covers them all.
163+
let program: Vec<(u32, [u32; 8])> = (0..N_LANES)
164+
.map(|i| (i as u32, [i as u32, 0, 0, 0, 0, 0, 0, 0]))
165+
.collect();
166+
167+
let col = ProgramColumn::new(0, &program);
168+
let packed = col.packed_at(0);
169+
let array = packed.to_array();
170+
171+
for (i, val) in array.iter().enumerate() {
172+
assert_eq!(*val, M31(i as u32));
173+
}
174+
}
175+
}

stwo_cairo_prover/crates/prover/src/prover.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ where
9292
.half_coset,
9393
);
9494

95-
let preprocessed_trace = Arc::new(prover_params.preprocessed_trace.to_preprocessed_trace());
95+
let preprocessed_trace = Arc::new(
96+
prover_params
97+
.preprocessed_trace
98+
.to_preprocessed_trace(&input.program),
99+
);
96100
let preprocessed_trace_polys =
97101
SimdBackend::interpolate_columns(gen_trace(preprocessed_trace.clone()), &twiddles);
98102

stwo_cairo_prover/crates/prover/src/witness/components/memory_id_to_big.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ mod tests {
672672
// Preprocessed trace.
673673
let mut tree_builder = commitment_scheme.tree_builder();
674674
tree_builder.extend_evals(gen_trace(Arc::new(
675-
PreProcessedTraceVariant::CanonicalWithoutPedersen.to_preprocessed_trace(),
675+
PreProcessedTraceVariant::CanonicalWithoutPedersen.to_preprocessed_trace(&[]),
676676
)));
677677
tree_builder.finalize_interaction();
678678

stwo_cairo_prover/crates/prover/src/witness/preprocessed_trace.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ pub fn generate_preprocessed_commitment_root<MC: MerkleChannel>(
2222
log_blowup_factor: u32,
2323
preprocessed_trace: PreProcessedTraceVariant,
2424
lifting_log_size: Option<u32>,
25+
program: &[(u32, [u32; 8])],
2526
) -> <<MC as MerkleChannel>::H as MerkleHasherLifted>::Hash
2627
where
2728
SimdBackend: BackendForChannel<MC>,
2829
{
29-
let preprocessed_trace = Arc::new(preprocessed_trace.to_preprocessed_trace());
30+
let preprocessed_trace = Arc::new(preprocessed_trace.to_preprocessed_trace(program));
3031

3132
// Precompute twiddles for the commitment scheme.
3233
let mut max_log_size = preprocessed_trace.log_sizes().into_iter().max().unwrap();
@@ -80,6 +81,7 @@ fn test_canonical_preprocessed_root_regression() {
8081
log_blowup_factor,
8182
PreProcessedTraceVariant::Canonical,
8283
None,
84+
&[],
8385
);
8486

8587
assert_eq!(root, expected);
@@ -101,6 +103,7 @@ fn test_small_canonical_preprocessed_root_regression() {
101103
log_blowup_factor,
102104
PreProcessedTraceVariant::CanonicalSmall,
103105
None,
106+
&[],
104107
);
105108

106109
assert_eq!(root, expected);

stwo_cairo_prover/crates/prover/src/witness/utils.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,19 @@ fn get_preprocessed_roots<MC: MerkleChannel>(
136136
max_log_blowup_factor: u32,
137137
preprocessed_trace: PreProcessedTraceVariant,
138138
lifting_log_size: Option<u32>,
139+
program: &[(u32, [u32; 8])],
139140
) -> Vec<<MC::H as MerkleHasherLifted>::Hash>
140141
where
141142
stwo::prover::backend::simd::SimdBackend: BackendForChannel<MC>,
142143
{
143144
(1..=max_log_blowup_factor)
144145
.map(|i| {
145-
generate_preprocessed_commitment_root::<MC>(i, preprocessed_trace, lifting_log_size)
146+
generate_preprocessed_commitment_root::<MC>(
147+
i,
148+
preprocessed_trace,
149+
lifting_log_size,
150+
program,
151+
)
146152
})
147153
.collect_vec()
148154
}
@@ -158,6 +164,7 @@ pub fn export_preprocessed_roots() {
158164
max_log_blowup_factor,
159165
PreProcessedTraceVariant::Canonical,
160166
None,
167+
&[],
161168
);
162169
blake_roots.iter().enumerate().for_each(|(i, root)| {
163170
let root_bytes = root.0;
@@ -179,6 +186,7 @@ pub fn export_preprocessed_roots() {
179186
max_log_blowup_factor,
180187
PreProcessedTraceVariant::CanonicalWithoutPedersen,
181188
None,
189+
&[],
182190
)
183191
.into_iter()
184192
.enumerate()
@@ -203,6 +211,7 @@ pub fn export_circuit_cairo_verifier_preprocessed_roots() {
203211
max_log_blowup_factor,
204212
PreProcessedTraceVariant::CanonicalSmall,
205213
Some(lifting_log_size),
214+
&[],
206215
);
207216
blake_m31_small_roots
208217
.iter()

0 commit comments

Comments
 (0)