diff --git a/Cargo.lock b/Cargo.lock index 40b84807..0e38d744 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -544,6 +544,7 @@ name = "mt-air" version = "0.1.0" dependencies = [ "mt-field", + "mt-koala-bear", "mt-poly", ] diff --git a/crates/backend/air/Cargo.toml b/crates/backend/air/Cargo.toml index 3868ecf2..2494d529 100644 --- a/crates/backend/air/Cargo.toml +++ b/crates/backend/air/Cargo.toml @@ -6,3 +6,6 @@ edition.workspace = true [dependencies] field = { path = "../field", package = "mt-field" } poly = { path = "../poly", package = "mt-poly" } + +[dev-dependencies] +koala-bear = { path = "../koala-bear", package = "mt-koala-bear" } diff --git a/crates/backend/air/src/symbolic.rs b/crates/backend/air/src/symbolic.rs index 218e477d..f32c2182 100644 --- a/crates/backend/air/src/symbolic.rs +++ b/crates/backend/air/src/symbolic.rs @@ -5,6 +5,7 @@ use core::iter::{Product, Sum}; use core::marker::PhantomData; use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::cell::RefCell; +use std::sync::atomic::{AtomicU32, Ordering}; use field::{Algebra, Field, InjectiveMonomial, PrimeCharacteristicRing}; @@ -73,37 +74,142 @@ pub struct SymbolicNode { pub rhs: SymbolicExpression, // dummy (ZERO) for Neg } -// We use an arena as a trick to allow SymbolicExpression to be Copy -// (ugly trick but fine in practice since SymbolicExpression is only used once at the start of the program) +/// Opaque handle into the thread-local symbolic arena. +/// +/// Handles are scoped to a specific arena (thread) and generation (clear cycle). +/// Using a handle from a different thread or after the arena has been cleared will +/// produce a deterministic error instead of undefined behaviour. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct SymbolicNodeRef { + arena_id: u32, + generation: u32, + offset: u32, + _phantom: PhantomData F>, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum SymbolicNodeAccessError { + WrongArena, + StaleGeneration, + OutOfBounds, +} + +impl core::fmt::Display for SymbolicNodeAccessError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::WrongArena => { + write!(f, "symbolic node handle belongs to a different thread's arena") + } + Self::StaleGeneration => { + write!(f, "symbolic node handle is stale (arena was cleared)") + } + Self::OutOfBounds => write!(f, "symbolic node handle offset is out of bounds"), + } + } +} + +impl std::error::Error for SymbolicNodeAccessError {} + +#[derive(Debug)] +struct ArenaState { + arena_id: u32, + generation: u32, + bytes: Vec, +} + +impl ArenaState { + fn new() -> Self { + Self { + arena_id: NEXT_ARENA_ID.fetch_add(1, Ordering::Relaxed), + generation: 0, + bytes: Vec::new(), + } + } +} + +static NEXT_ARENA_ID: AtomicU32 = AtomicU32::new(1); + +// We use an arena as a trick to allow SymbolicExpression to be Copy. +// Handles carry arena_id + generation so that stale or cross-thread use +// is caught deterministically instead of reading garbage bytes. thread_local! { - static ARENA: RefCell> = const { RefCell::new(Vec::new()) }; + static ARENA: RefCell = RefCell::new(ArenaState::new()); } -fn alloc_node(node: SymbolicNode) -> u32 { +fn clear_arena() { ARENA.with(|arena| { - let mut bytes = arena.borrow_mut(); + let mut state = arena.borrow_mut(); + state.generation = state + .generation + .checked_add(1) + .expect("symbolic arena generation overflow"); + state.bytes.clear(); + }); +} + +fn alloc_node(node: SymbolicNode) -> SymbolicNodeRef { + ARENA.with(|arena| { + let mut state = arena.borrow_mut(); let node_size = std::mem::size_of::>(); - let idx = bytes.len(); - bytes.resize(idx + node_size, 0); + let offset = state.bytes.len(); + let offset_u32 = u32::try_from(offset).expect("symbolic arena exceeded u32::MAX bytes"); + let end = offset + .checked_add(node_size) + .expect("symbolic arena allocation overflow"); + state.bytes.resize(end, 0); + // SAFETY: We just resized the buffer to `end` bytes, so `offset..end` is valid. unsafe { - std::ptr::write_unaligned(bytes.as_mut_ptr().add(idx) as *mut SymbolicNode, node); + std::ptr::write_unaligned( + state.bytes.as_mut_ptr().add(offset).cast::>(), + node, + ); + } + SymbolicNodeRef { + arena_id: state.arena_id, + generation: state.generation, + offset: offset_u32, + _phantom: PhantomData, } - idx as u32 }) } -pub fn get_node(idx: u32) -> SymbolicNode { +pub fn try_get_node( + handle: SymbolicNodeRef, +) -> Result, SymbolicNodeAccessError> { ARENA.with(|arena| { - let bytes = arena.borrow(); - unsafe { std::ptr::read_unaligned(bytes.as_ptr().add(idx as usize) as *const SymbolicNode) } + let state = arena.borrow(); + if state.arena_id != handle.arena_id { + return Err(SymbolicNodeAccessError::WrongArena); + } + if state.generation != handle.generation { + return Err(SymbolicNodeAccessError::StaleGeneration); + } + let offset = handle.offset as usize; + let node_size = std::mem::size_of::>(); + let end = offset + .checked_add(node_size) + .ok_or(SymbolicNodeAccessError::OutOfBounds)?; + if end > state.bytes.len() { + return Err(SymbolicNodeAccessError::OutOfBounds); + } + // SAFETY: We verified that `offset..end` is within the arena buffer. + Ok(unsafe { + std::ptr::read_unaligned( + state.bytes.as_ptr().add(offset).cast::>(), + ) + }) }) } +pub fn get_node(handle: SymbolicNodeRef) -> SymbolicNode { + try_get_node(handle).expect("invalid or stale symbolic node handle") +} + #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum SymbolicExpression { Variable(SymbolicVariable), Constant(F), - Operation(u32), // index into thread-local arena + Operation(SymbolicNodeRef), } impl Default for SymbolicExpression { @@ -325,8 +431,7 @@ pub fn get_symbolic_constraints_and_bus_data_values( where A::ExtraData: Default, { - // Clear the arena before building constraints - ARENA.with(|arena| arena.borrow_mut().clear()); + clear_arena(); let mut builder = SymbolicAirBuilder::::new(air.n_columns(), air.n_down_columns()); air.eval(&mut builder, &Default::default()); @@ -336,3 +441,129 @@ where builder.bus_data_values.unwrap(), ) } + +#[cfg(test)] +mod tests { + use super::*; + use koala_bear::KoalaBear; + + type F = KoalaBear; + + const _: () = { + const fn assert_copy() {} + assert_copy::>(); + assert_copy::>(); + }; + + #[test] + fn roundtrip_alloc_get() { + clear_arena(); + let a = SymbolicExpression::::Constant(F::ONE); + let b = SymbolicExpression::::Constant(F::TWO); + let handle = alloc_node(SymbolicNode { + op: SymbolicOperation::Add, + lhs: a, + rhs: b, + }); + let node = get_node::(handle); + assert_eq!(node.op, SymbolicOperation::Add); + assert_eq!(node.lhs, a); + assert_eq!(node.rhs, b); + } + + #[test] + fn stale_handle_rejected_after_clear() { + clear_arena(); + let handle = alloc_node(SymbolicNode { + op: SymbolicOperation::Mul, + lhs: SymbolicExpression::::ONE, + rhs: SymbolicExpression::::TWO, + }); + assert!(try_get_node::(handle).is_ok()); + clear_arena(); + assert!( + matches!(try_get_node::(handle), Err(SymbolicNodeAccessError::StaleGeneration)) + ); + } + + #[test] + fn old_handle_cannot_read_new_generation_bytes() { + clear_arena(); + let old_handle = alloc_node(SymbolicNode { + op: SymbolicOperation::Add, + lhs: SymbolicExpression::::ONE, + rhs: SymbolicExpression::::TWO, + }); + clear_arena(); + let _new_handle = alloc_node(SymbolicNode { + op: SymbolicOperation::Sub, + lhs: SymbolicExpression::::ZERO, + rhs: SymbolicExpression::::ONE, + }); + assert!( + matches!(try_get_node::(old_handle), Err(SymbolicNodeAccessError::StaleGeneration)) + ); + } + + #[test] + fn wrong_thread_handle_rejected() { + clear_arena(); + let handle = alloc_node(SymbolicNode { + op: SymbolicOperation::Neg, + lhs: SymbolicExpression::::ONE, + rhs: SymbolicExpression::::ZERO, + }); + let result = std::thread::spawn(move || try_get_node::(handle)) + .join() + .unwrap(); + assert!(matches!(result, Err(SymbolicNodeAccessError::WrongArena))); + } + + #[test] + fn out_of_bounds_handle_rejected() { + clear_arena(); + let bogus = SymbolicNodeRef:: { + arena_id: ARENA.with(|a| a.borrow().arena_id), + generation: ARENA.with(|a| a.borrow().generation), + offset: 999_999, + _phantom: PhantomData, + }; + assert!(matches!( + try_get_node::(bogus), + Err(SymbolicNodeAccessError::OutOfBounds) + )); + } + + #[test] + fn offset_truncation_detected() { + fn checked_offset(len: usize) -> u32 { + u32::try_from(len).expect("symbolic arena exceeded u32::MAX bytes") + } + assert!(std::panic::catch_unwind(|| checked_offset(u32::MAX as usize + 1)).is_err()); + } + + #[test] + fn arithmetic_produces_valid_handles() { + clear_arena(); + let var = SymbolicExpression::::Variable(SymbolicVariable::new(0)); + let c = SymbolicExpression::::Constant(F::TWO); + let sum = var + c; + if let SymbolicExpression::Operation(handle) = sum { + let node = get_node::(handle); + assert_eq!(node.op, SymbolicOperation::Add); + assert_eq!(node.lhs, var); + assert_eq!(node.rhs, c); + } else { + panic!("expected Operation variant from variable + constant"); + } + + let neg = -var; + if let SymbolicExpression::Operation(handle) = neg { + let node = get_node::(handle); + assert_eq!(node.op, SymbolicOperation::Neg); + assert_eq!(node.lhs, var); + } else { + panic!("expected Operation variant from neg(variable)"); + } + } +} diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index 04cfa60a..b75fa01f 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -380,7 +380,7 @@ where { let (constraints, bus_flag, bus_data) = get_symbolic_constraints_and_bus_data_values::(&table); let mut vars_counter = Counter::new(); - let mut cache: HashMap = HashMap::new(); + let mut cache: HashMap, String> = HashMap::new(); let mut res = format!( "def evaluate_air_constraints_table_{}({}, air_alpha_powers, bus_beta, logup_alphas_eq_poly):\n", @@ -434,7 +434,7 @@ where fn eval_air_constraint( expr: SymbolicExpression, dest: Option<&str>, - cache: &mut HashMap, + cache: &mut HashMap, String>, res: &mut String, ctr: &mut Counter, ) -> String { @@ -445,14 +445,14 @@ fn eval_air_constraint( v } SymbolicExpression::Variable(v) => format!("{} + DIM * {}", AIR_INNER_VALUES_VAR, v.index), - SymbolicExpression::Operation(idx) => { - if let Some(v) = cache.get(&idx) { + SymbolicExpression::Operation(handle) => { + if let Some(v) = cache.get(&handle) { if let Some(d) = dest { res.push_str(&format!("\n copy_5({}, {})", v, d)); } return v.clone(); } - let node = get_node::(idx); + let node = get_node::(handle); let v = match node.op { SymbolicOperation::Neg => { let a = eval_air_constraint(node.lhs, None, cache, res, ctr); @@ -462,13 +462,12 @@ fn eval_air_constraint( } _ => eval_air_binop(node.op, node.lhs, node.rhs, dest, cache, res, ctr), }; - // If dest was requested but the result landed elsewhere, copy it if let Some(d) = dest && v != d { res.push_str(&format!("\n copy_5({}, {})", v, d)); } - cache.insert(idx, v.clone()); + cache.insert(handle, v.clone()); v } } @@ -481,7 +480,7 @@ fn eval_air_binop( lhs: SymbolicExpression, rhs: SymbolicExpression, dest: Option<&str>, - cache: &mut HashMap, + cache: &mut HashMap, String>, res: &mut String, ctr: &mut Counter, ) -> String {