Skip to content

Commit df0807f

Browse files
committed
Merge branch 'parallel-runner'
2 parents e7d8b16 + e4bba53 commit df0807f

File tree

37 files changed

+795
-761
lines changed

37 files changed

+795
-761
lines changed

crates/lean_compiler/snark_lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ def inline(fn):
1414
return fn
1515

1616

17-
# unroll(a, b) returns range(a, b) for Python execution
1817
def unroll(a: int, b: int):
1918
return range(a, b)
2019

20+
def parallel_range(a: int, b: int):
21+
return range(a, b)
2122

2223
# dynamic_unroll(start, end, n_bits) returns range(start, end) for Python execution
2324
def dynamic_unroll(start: int, end: int, n_bits: int):

crates/lean_compiler/src/a_simplify_lang.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,9 +1196,10 @@ fn transform_mutable_in_loops_in_lines(
11961196
start,
11971197
end,
11981198
body,
1199-
loop_kind: LoopKind::Range,
1199+
loop_kind: loop_kind @ (LoopKind::Range | LoopKind::ParallelRange),
12001200
location,
12011201
} => {
1202+
let loop_kind = loop_kind.clone();
12021203
transform_mutable_in_loops_in_lines(body, const_arrays, counter);
12031204
let modified_vars = find_modified_external_vars(body, const_arrays);
12041205

@@ -1344,7 +1345,7 @@ fn transform_mutable_in_loops_in_lines(
13441345
start: start.clone(),
13451346
end: end.clone(),
13461347
body: new_body,
1347-
loop_kind: LoopKind::Range,
1348+
loop_kind,
13481349
location,
13491350
});
13501351

@@ -2633,10 +2634,12 @@ fn simplify_lines(
26332634
location,
26342635
} => {
26352636
assert!(
2636-
matches!(loop_kind, LoopKind::Range),
2637+
matches!(loop_kind, LoopKind::Range | LoopKind::ParallelRange),
26372638
"Unrolled/dynamic_unroll loops should have been handled already"
26382639
);
26392640

2641+
let is_parallel = loop_kind.is_parallel();
2642+
26402643
let mut loop_const_malloc = ConstMalloc {
26412644
counter: const_malloc.counter,
26422645
..ConstMalloc::default()
@@ -2653,7 +2656,8 @@ fn simplify_lines(
26532656
const_malloc.counter = loop_const_malloc.counter;
26542657
state.array_manager.valid = valid_aux_vars_in_array_manager_before; // restore the valid aux vars
26552658

2656-
let func_name = format!("@loop_{}_{}", state.counters.loops.get_next(), location);
2659+
let loop_prefix = if is_parallel { "@parallel_loop" } else { "@loop" };
2660+
let func_name = format!("{}_{}_{}", loop_prefix, state.counters.loops.get_next(), location);
26572661

26582662
// Find variables used inside loop but defined outside
26592663
let (_, mut external_vars) = find_variable_usage(body, ctx.const_arrays);

crates/lean_compiler/src/b_compile_intermediate.rs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -149,16 +149,28 @@ fn compile_function(
149149
compiler.const_malloc_vars.clear();
150150
(compiler.dead_fp_relative_vars, compiler.dead_store_vars) = compute_dead_vars(&function.instructions);
151151

152-
compile_lines(
153-
&Label::function(function.name.clone()),
154-
&function.instructions,
155-
compiler,
156-
None,
157-
)
152+
let mut instructions = Vec::new();
153+
154+
// Emit ParallelBatchStart for parallel loop functions.
155+
// The first instruction is always `diff = iterator - end`. The end value is either
156+
// a function argument (runtime) or a compile-time constant.
157+
if function.name.starts_with("@parallel_loop_") {
158+
let SimpleLine::Assignment { arg1, .. } = &function.instructions[0] else {
159+
panic!("parallel loop: expected first instruction to be `diff = i - end`");
160+
};
161+
let end_value = IntermediateValue::from_simple_expr(arg1, compiler);
162+
instructions.push(IntermediateInstruction::ParallelBatchStart {
163+
n_args: function.arguments.len(),
164+
end_value,
165+
});
166+
}
167+
168+
instructions.extend(compile_lines(&function.instructions, compiler, None)?);
169+
170+
Ok(instructions)
158171
}
159172

160173
fn compile_lines(
161-
function_name: &Label,
162174
lines: &[SimpleLine],
163175
compiler: &mut Compiler,
164176
final_jump: Option<Label>,
@@ -263,14 +275,13 @@ fn compile_lines(
263275
for arm in arms.iter() {
264276
compiler.stack_pos = saved_stack_pos;
265277
compiler.stack_frame_layout.scopes.push(ScopeLayout::default());
266-
let arm_instructions = compile_lines(function_name, arm, compiler, Some(end_label.clone()))?;
278+
let arm_instructions = compile_lines(arm, compiler, Some(end_label.clone()))?;
267279
compiled_arms.push(arm_instructions);
268280
compiler.stack_frame_layout.scopes.pop();
269281
new_stack_pos = new_stack_pos.max(compiler.stack_pos);
270282
}
271283
compiler.stack_pos = new_stack_pos;
272284
compiler.match_blocks.push(MatchBlock {
273-
function_name: function_name.clone(),
274285
match_cases: compiled_arms,
275286
});
276287
// Get the actual index AFTER pushing (nested matches may have pushed their blocks first)
@@ -317,7 +328,7 @@ fn compile_lines(
317328
updated_fp: None,
318329
});
319330

320-
let remaining = compile_lines(function_name, &lines[i + 1..], compiler, final_jump)?;
331+
let remaining = compile_lines(&lines[i + 1..], compiler, final_jump)?;
321332
compiler.bytecode.insert(end_label, remaining);
322333

323334
compiler.stack_frame_layout.scopes.pop();
@@ -409,22 +420,22 @@ fn compile_lines(
409420
let saved_stack_pos = compiler.stack_pos;
410421

411422
compiler.stack_frame_layout.scopes.push(ScopeLayout::default());
412-
let then_instructions = compile_lines(function_name, then_branch, compiler, Some(end_label.clone()))?;
423+
let then_instructions = compile_lines(then_branch, compiler, Some(end_label.clone()))?;
413424

414425
let then_stack_pos = compiler.stack_pos;
415426
compiler.stack_pos = saved_stack_pos;
416427
compiler.stack_frame_layout.scopes.pop();
417428
compiler.stack_frame_layout.scopes.push(ScopeLayout::default());
418429

419-
let else_instructions = compile_lines(function_name, else_branch, compiler, Some(end_label.clone()))?;
430+
let else_instructions = compile_lines(else_branch, compiler, Some(end_label.clone()))?;
420431

421432
compiler.bytecode.insert(if_label, then_instructions);
422433
compiler.bytecode.insert(else_label, else_instructions);
423434

424435
compiler.stack_frame_layout.scopes.pop();
425436
compiler.stack_pos = compiler.stack_pos.max(then_stack_pos);
426437

427-
let remaining = compile_lines(function_name, &lines[i + 1..], compiler, final_jump)?;
438+
let remaining = compile_lines(&lines[i + 1..], compiler, final_jump)?;
428439
compiler.bytecode.insert(end_label, remaining);
429440
// It is not necessary to update compiler.stack_size here because the preceding call to
430441
// compile_lines should have done so.
@@ -485,7 +496,7 @@ fn compile_lines(
485496
});
486497
}
487498

488-
instructions.extend(compile_lines(function_name, &lines[i + 1..], compiler, final_jump)?);
499+
instructions.extend(compile_lines(&lines[i + 1..], compiler, final_jump)?);
489500

490501
instructions
491502
};

crates/lean_compiler/src/c_compile_final.rs

Lines changed: 15 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ impl IntermediateInstruction {
1414
| Self::LocationReport { .. }
1515
| Self::DebugAssert { .. }
1616
| Self::DerefHint { .. }
17-
| Self::PanicHint { .. } => true,
17+
| Self::PanicHint { .. }
18+
| Self::ParallelBatchStart { .. } => true,
1819
Self::Computation { .. }
1920
| Self::Panic
2021
| Self::Deref { .. }
@@ -66,42 +67,19 @@ pub fn compile_to_low_level_bytecode(
6667
.bytecode
6768
.remove(&Label::function("main"))
6869
.ok_or("No main function found in the compiled program")?;
69-
hints.insert(
70-
STARTING_PC,
71-
vec![Hint::StackFrame {
72-
label: Label::function("main"),
73-
size: starting_frame_memory,
74-
}],
75-
);
7670

7771
let mut pc = count_real_instructions(&exit_point) + count_real_instructions(&entrypoint);
78-
let mut code_blocks = vec![
79-
(Label::EndProgram, ENDING_PC, exit_point),
80-
(Label::function("main"), STARTING_PC, entrypoint),
81-
];
72+
let mut code_blocks = vec![(ENDING_PC, exit_point), (STARTING_PC, entrypoint)];
8273

8374
for (label, instructions) in &intermediate_bytecode.bytecode {
8475
label_to_pc.insert(label.clone(), pc);
85-
if let Label::Function(function_name) = label {
86-
hints.entry(pc).or_insert_with(Vec::new).push(Hint::StackFrame {
87-
label: label.clone(),
88-
size: *intermediate_bytecode
89-
.memory_size_per_function
90-
.get(function_name)
91-
.unwrap(),
92-
});
93-
}
94-
code_blocks.push((label.clone(), pc, instructions.clone()));
76+
code_blocks.push((pc, instructions.clone()));
9577
pc += count_real_instructions(instructions);
9678
}
9779

9880
let mut match_block_sizes = Vec::new();
9981
let mut match_first_block_starts = Vec::new();
100-
for MatchBlock {
101-
function_name,
102-
match_cases,
103-
} in intermediate_bytecode.match_blocks
104-
{
82+
for MatchBlock { match_cases } in intermediate_bytecode.match_blocks {
10583
let max_block_size = match_cases
10684
.iter()
10785
.map(|block| count_real_instructions(block))
@@ -116,7 +94,7 @@ pub fn compile_to_low_level_bytecode(
11694
IntermediateInstruction::Panic;
11795
max_block_size - count_real_instructions(&block)
11896
]);
119-
code_blocks.push((function_name.clone(), pc, block));
97+
code_blocks.push((pc, block));
12098
pc += max_block_size;
12199
}
122100
}
@@ -134,15 +112,8 @@ pub fn compile_to_low_level_bytecode(
134112

135113
let mut instructions = Vec::new();
136114

137-
for (function_name, pc_start, block) in code_blocks {
138-
compile_block(
139-
&compiler,
140-
&function_name,
141-
&block,
142-
pc_start,
143-
&mut instructions,
144-
&mut hints,
145-
);
115+
for (pc_start, block) in code_blocks {
116+
compile_block(&compiler, &block, pc_start, &mut instructions, &mut hints);
146117
}
147118
let instructions_encoded = instructions.par_iter().map(field_representation).collect::<Vec<_>>();
148119

@@ -195,7 +166,6 @@ pub fn compile_to_low_level_bytecode(
195166

196167
fn compile_block(
197168
compiler: &Compiler,
198-
function_name: &Label,
199169
block: &[IntermediateInstruction],
200170
pc_start: CodeAddress,
201171
low_level_bytecode: &mut Vec<Instruction>,
@@ -341,7 +311,6 @@ fn compile_block(
341311
IntermediateInstruction::RequestMemory { offset, size } => {
342312
let size = try_as_mem_or_constant(&size).unwrap();
343313
let hint = Hint::RequestMemory {
344-
function_name: function_name.clone(),
345314
offset: eval_const_expression_usize(&offset, compiler),
346315
size,
347316
};
@@ -386,6 +355,13 @@ fn compile_block(
386355
let hint = Hint::Panic { message };
387356
hints.entry(pc).or_default().push(hint);
388357
}
358+
IntermediateInstruction::ParallelBatchStart { n_args, end_value } => {
359+
let end_value = try_as_mem_or_constant(&end_value).expect("parallel loop end value");
360+
hints
361+
.entry(pc)
362+
.or_default()
363+
.push(Hint::ParallelBatchStart { n_args, end_value });
364+
}
389365
}
390366

391367
if !instruction.is_hint() {

crates/lean_compiler/src/grammar.pest

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ elif_clause = { "elif" ~ condition ~ ":" ~ newline ~ statement* ~ end_block }
9090

9191
else_clause = { "else" ~ ":" ~ newline ~ statement* ~ end_block }
9292

93-
for_statement = { "for" ~ identifier ~ "in" ~ (dynamic_unroll_range | unroll_range | range) ~ ":" ~ newline ~ statement* ~ end_block }
93+
for_statement = { "for" ~ identifier ~ "in" ~ (dynamic_unroll_range | unroll_range | parallel_range | range) ~ ":" ~ newline ~ statement* ~ end_block }
9494
range = { "range" ~ "(" ~ expression ~ "," ~ expression ~ ")" }
95+
parallel_range = { "parallel_range" ~ "(" ~ expression ~ "," ~ expression ~ ")" }
9596
unroll_range = { "unroll" ~ "(" ~ expression ~ "," ~ expression ~ ")" }
9697
dynamic_unroll_range = { "dynamic_unroll" ~ "(" ~ expression ~ "," ~ expression ~ "," ~ expression ~ ")" }
9798

crates/lean_compiler/src/ir/bytecode.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@ use std::fmt::{Display, Formatter};
66
/// A match statement bytecode block
77
#[derive(Debug, Clone)]
88
pub struct MatchBlock {
9-
/// Name of the function containing the match block
10-
pub function_name: Label,
11-
129
/// Cases of the match block
1310
pub match_cases: Vec<Vec<IntermediateInstruction>>,
1411
}

crates/lean_compiler/src/ir/instruction.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ pub enum IntermediateInstruction {
6565
PanicHint {
6666
message: Option<String>,
6767
},
68+
/// Marks the start of a parallelizable loop
69+
ParallelBatchStart {
70+
n_args: usize,
71+
end_value: IntermediateValue,
72+
},
6873
}
6974

7075
impl IntermediateInstruction {
@@ -196,6 +201,9 @@ impl Display for IntermediateInstruction {
196201
Some(msg) => write!(f, "panic hint: \"{msg}\""),
197202
None => write!(f, "panic hint"),
198203
},
204+
Self::ParallelBatchStart { n_args, end_value } => {
205+
write!(f, "parallel_batch_start(n_args={n_args}, end={end_value})")
206+
}
199207
}
200208
}
201209
}

crates/lean_compiler/src/lang.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ impl VecLiteral {
555555
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
556556
pub enum LoopKind {
557557
Range,
558+
ParallelRange,
558559
Unroll,
559560
/// `for i in dynamic_unroll(0, a, n_bits): body` — unrolls over runtime-bounded range
560561
/// using bit decomposition. `n_bits` must be compile-time known.
@@ -567,6 +568,10 @@ impl LoopKind {
567568
pub fn is_unroll(&self) -> bool {
568569
matches!(self, Self::Unroll | Self::DynamicUnroll { .. })
569570
}
571+
572+
pub fn is_parallel(&self) -> bool {
573+
matches!(self, Self::ParallelRange)
574+
}
570575
}
571576

572577
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
@@ -798,7 +803,13 @@ impl Line {
798803
iterator, start, end, n_bits, body_str, spaces
799804
),
800805
_ => {
801-
let range_fn = if loop_kind.is_unroll() { "unroll" } else { "range" };
806+
let range_fn = if loop_kind.is_unroll() {
807+
"unroll"
808+
} else if loop_kind.is_parallel() {
809+
"parallel_range"
810+
} else {
811+
"range"
812+
};
802813
format!(
803814
"for {} in {}({}, {}) {{\n{}\n{}}}",
804815
iterator, range_fn, start, end, body_str, spaces

crates/lean_compiler/src/parser/parsers/function.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub const RESERVED_FUNCTION_NAMES: &[&str] = &[
2424
"next_multiple_of",
2525
"saturating_sub",
2626
"range",
27+
"parallel_range",
2728
"match_range",
2829
];
2930

crates/lean_compiler/src/parser/parsers/statement.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ impl Parse<Line> for ForStatementParser {
203203
let n_bits = ExpressionParser.parse(next_inner_pair(&mut range_inner, "n_bits")?, ctx)?;
204204
LoopKind::DynamicUnroll { n_bits }
205205
}
206+
Rule::parallel_range => LoopKind::ParallelRange,
206207
_ => LoopKind::Range,
207208
};
208209

0 commit comments

Comments
 (0)