Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions crates/backend/air/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ edition.workspace = true
[dependencies]
field = { path = "../field", package = "mt-field" }
poly = { path = "../poly", package = "mt-poly" }

[dev-dependencies]
koala-bear = { path = "../koala-bear", package = "mt-koala-bear" }
261 changes: 246 additions & 15 deletions crates/backend/air/src/symbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::iter::{Product, Sum};
use core::marker::PhantomData;
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use std::cell::RefCell;
use std::sync::atomic::{AtomicU32, Ordering};

use field::{Algebra, Field, InjectiveMonomial, PrimeCharacteristicRing};

Expand Down Expand Up @@ -73,37 +74,142 @@ pub struct SymbolicNode<F: Copy> {
pub rhs: SymbolicExpression<F>, // dummy (ZERO) for Neg
}

// We use an arena as a trick to allow SymbolicExpression to be Copy
// (ugly trick but fine in practice since SymbolicExpression is only used once at the start of the program)
/// Opaque handle into the thread-local symbolic arena.
///
/// Handles are scoped to a specific arena (thread) and generation (clear cycle).
/// Using a handle from a different thread or after the arena has been cleared will
/// produce a deterministic error instead of undefined behaviour.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct SymbolicNodeRef<F> {
arena_id: u32,
generation: u32,
offset: u32,
_phantom: PhantomData<fn() -> F>,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum SymbolicNodeAccessError {
WrongArena,
StaleGeneration,
OutOfBounds,
}

impl core::fmt::Display for SymbolicNodeAccessError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::WrongArena => {
write!(f, "symbolic node handle belongs to a different thread's arena")
}
Self::StaleGeneration => {
write!(f, "symbolic node handle is stale (arena was cleared)")
}
Self::OutOfBounds => write!(f, "symbolic node handle offset is out of bounds"),
}
}
}

impl std::error::Error for SymbolicNodeAccessError {}

#[derive(Debug)]
struct ArenaState {
arena_id: u32,
generation: u32,
bytes: Vec<u8>,
}

impl ArenaState {
fn new() -> Self {
Self {
arena_id: NEXT_ARENA_ID.fetch_add(1, Ordering::Relaxed),
generation: 0,
bytes: Vec::new(),
}
}
}

static NEXT_ARENA_ID: AtomicU32 = AtomicU32::new(1);

// We use an arena as a trick to allow SymbolicExpression to be Copy.
// Handles carry arena_id + generation so that stale or cross-thread use
// is caught deterministically instead of reading garbage bytes.
thread_local! {
static ARENA: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
static ARENA: RefCell<ArenaState> = RefCell::new(ArenaState::new());
}

fn alloc_node<F: Field>(node: SymbolicNode<F>) -> u32 {
fn clear_arena() {
ARENA.with(|arena| {
let mut bytes = arena.borrow_mut();
let mut state = arena.borrow_mut();
state.generation = state
.generation
.checked_add(1)
.expect("symbolic arena generation overflow");
state.bytes.clear();
});
}

fn alloc_node<F: Field>(node: SymbolicNode<F>) -> SymbolicNodeRef<F> {
ARENA.with(|arena| {
let mut state = arena.borrow_mut();
let node_size = std::mem::size_of::<SymbolicNode<F>>();
let idx = bytes.len();
bytes.resize(idx + node_size, 0);
let offset = state.bytes.len();
let offset_u32 = u32::try_from(offset).expect("symbolic arena exceeded u32::MAX bytes");
let end = offset
.checked_add(node_size)
.expect("symbolic arena allocation overflow");
state.bytes.resize(end, 0);
// SAFETY: We just resized the buffer to `end` bytes, so `offset..end` is valid.
unsafe {
std::ptr::write_unaligned(bytes.as_mut_ptr().add(idx) as *mut SymbolicNode<F>, node);
std::ptr::write_unaligned(
state.bytes.as_mut_ptr().add(offset).cast::<SymbolicNode<F>>(),
node,
);
}
SymbolicNodeRef {
arena_id: state.arena_id,
generation: state.generation,
offset: offset_u32,
_phantom: PhantomData,
}
idx as u32
})
}

pub fn get_node<F: Field>(idx: u32) -> SymbolicNode<F> {
pub fn try_get_node<F: Field>(
handle: SymbolicNodeRef<F>,
) -> Result<SymbolicNode<F>, SymbolicNodeAccessError> {
ARENA.with(|arena| {
let bytes = arena.borrow();
unsafe { std::ptr::read_unaligned(bytes.as_ptr().add(idx as usize) as *const SymbolicNode<F>) }
let state = arena.borrow();
if state.arena_id != handle.arena_id {
return Err(SymbolicNodeAccessError::WrongArena);
}
if state.generation != handle.generation {
return Err(SymbolicNodeAccessError::StaleGeneration);
}
let offset = handle.offset as usize;
let node_size = std::mem::size_of::<SymbolicNode<F>>();
let end = offset
.checked_add(node_size)
.ok_or(SymbolicNodeAccessError::OutOfBounds)?;
if end > state.bytes.len() {
return Err(SymbolicNodeAccessError::OutOfBounds);
}
// SAFETY: We verified that `offset..end` is within the arena buffer.
Ok(unsafe {
std::ptr::read_unaligned(
state.bytes.as_ptr().add(offset).cast::<SymbolicNode<F>>(),
)
})
})
}

pub fn get_node<F: Field>(handle: SymbolicNodeRef<F>) -> SymbolicNode<F> {
try_get_node(handle).expect("invalid or stale symbolic node handle")
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum SymbolicExpression<F: Copy> {
Variable(SymbolicVariable<F>),
Constant(F),
Operation(u32), // index into thread-local arena
Operation(SymbolicNodeRef<F>),
}

impl<F: Field> Default for SymbolicExpression<F> {
Expand Down Expand Up @@ -325,8 +431,7 @@ pub fn get_symbolic_constraints_and_bus_data_values<F: Field, A: Air>(
where
A::ExtraData: Default,
{
// Clear the arena before building constraints
ARENA.with(|arena| arena.borrow_mut().clear());
clear_arena();

let mut builder = SymbolicAirBuilder::<F>::new(air.n_columns(), air.n_down_columns());
air.eval(&mut builder, &Default::default());
Expand All @@ -336,3 +441,129 @@ where
builder.bus_data_values.unwrap(),
)
}

#[cfg(test)]
mod tests {
use super::*;
use koala_bear::KoalaBear;

type F = KoalaBear;

const _: () = {
const fn assert_copy<T: Copy>() {}
assert_copy::<SymbolicExpression<F>>();
assert_copy::<SymbolicNodeRef<F>>();
};

#[test]
fn roundtrip_alloc_get() {
clear_arena();
let a = SymbolicExpression::<F>::Constant(F::ONE);
let b = SymbolicExpression::<F>::Constant(F::TWO);
let handle = alloc_node(SymbolicNode {
op: SymbolicOperation::Add,
lhs: a,
rhs: b,
});
let node = get_node::<F>(handle);
assert_eq!(node.op, SymbolicOperation::Add);
assert_eq!(node.lhs, a);
assert_eq!(node.rhs, b);
}

#[test]
fn stale_handle_rejected_after_clear() {
clear_arena();
let handle = alloc_node(SymbolicNode {
op: SymbolicOperation::Mul,
lhs: SymbolicExpression::<F>::ONE,
rhs: SymbolicExpression::<F>::TWO,
});
assert!(try_get_node::<F>(handle).is_ok());
clear_arena();
assert!(
matches!(try_get_node::<F>(handle), Err(SymbolicNodeAccessError::StaleGeneration))
);
}

#[test]
fn old_handle_cannot_read_new_generation_bytes() {
clear_arena();
let old_handle = alloc_node(SymbolicNode {
op: SymbolicOperation::Add,
lhs: SymbolicExpression::<F>::ONE,
rhs: SymbolicExpression::<F>::TWO,
});
clear_arena();
let _new_handle = alloc_node(SymbolicNode {
op: SymbolicOperation::Sub,
lhs: SymbolicExpression::<F>::ZERO,
rhs: SymbolicExpression::<F>::ONE,
});
assert!(
matches!(try_get_node::<F>(old_handle), Err(SymbolicNodeAccessError::StaleGeneration))
);
}

#[test]
fn wrong_thread_handle_rejected() {
clear_arena();
let handle = alloc_node(SymbolicNode {
op: SymbolicOperation::Neg,
lhs: SymbolicExpression::<F>::ONE,
rhs: SymbolicExpression::<F>::ZERO,
});
let result = std::thread::spawn(move || try_get_node::<F>(handle))
.join()
.unwrap();
assert!(matches!(result, Err(SymbolicNodeAccessError::WrongArena)));
}

#[test]
fn out_of_bounds_handle_rejected() {
clear_arena();
let bogus = SymbolicNodeRef::<F> {
arena_id: ARENA.with(|a| a.borrow().arena_id),
generation: ARENA.with(|a| a.borrow().generation),
offset: 999_999,
_phantom: PhantomData,
};
assert!(matches!(
try_get_node::<F>(bogus),
Err(SymbolicNodeAccessError::OutOfBounds)
));
}

#[test]
fn offset_truncation_detected() {
fn checked_offset(len: usize) -> u32 {
u32::try_from(len).expect("symbolic arena exceeded u32::MAX bytes")
}
assert!(std::panic::catch_unwind(|| checked_offset(u32::MAX as usize + 1)).is_err());
}

#[test]
fn arithmetic_produces_valid_handles() {
clear_arena();
let var = SymbolicExpression::<F>::Variable(SymbolicVariable::new(0));
let c = SymbolicExpression::<F>::Constant(F::TWO);
let sum = var + c;
if let SymbolicExpression::Operation(handle) = sum {
let node = get_node::<F>(handle);
assert_eq!(node.op, SymbolicOperation::Add);
assert_eq!(node.lhs, var);
assert_eq!(node.rhs, c);
} else {
panic!("expected Operation variant from variable + constant");
}

let neg = -var;
if let SymbolicExpression::Operation(handle) = neg {
let node = get_node::<F>(handle);
assert_eq!(node.op, SymbolicOperation::Neg);
assert_eq!(node.lhs, var);
} else {
panic!("expected Operation variant from neg(variable)");
}
}
}
15 changes: 7 additions & 8 deletions crates/rec_aggregation/src/compilation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ where
{
let (constraints, bus_flag, bus_data) = get_symbolic_constraints_and_bus_data_values::<F, _>(&table);
let mut vars_counter = Counter::new();
let mut cache: HashMap<u32, String> = HashMap::new();
let mut cache: HashMap<SymbolicNodeRef<F>, String> = HashMap::new();

let mut res = format!(
"def evaluate_air_constraints_table_{}({}, air_alpha_powers, bus_beta, logup_alphas_eq_poly):\n",
Expand Down Expand Up @@ -434,7 +434,7 @@ where
fn eval_air_constraint(
expr: SymbolicExpression<F>,
dest: Option<&str>,
cache: &mut HashMap<u32, String>,
cache: &mut HashMap<SymbolicNodeRef<F>, String>,
res: &mut String,
ctr: &mut Counter,
) -> String {
Expand All @@ -445,14 +445,14 @@ fn eval_air_constraint(
v
}
SymbolicExpression::Variable(v) => format!("{} + DIM * {}", AIR_INNER_VALUES_VAR, v.index),
SymbolicExpression::Operation(idx) => {
if let Some(v) = cache.get(&idx) {
SymbolicExpression::Operation(handle) => {
if let Some(v) = cache.get(&handle) {
if let Some(d) = dest {
res.push_str(&format!("\n copy_5({}, {})", v, d));
}
return v.clone();
}
let node = get_node::<F>(idx);
let node = get_node::<F>(handle);
let v = match node.op {
SymbolicOperation::Neg => {
let a = eval_air_constraint(node.lhs, None, cache, res, ctr);
Expand All @@ -462,13 +462,12 @@ fn eval_air_constraint(
}
_ => eval_air_binop(node.op, node.lhs, node.rhs, dest, cache, res, ctr),
};
// If dest was requested but the result landed elsewhere, copy it
if let Some(d) = dest
&& v != d
{
res.push_str(&format!("\n copy_5({}, {})", v, d));
}
cache.insert(idx, v.clone());
cache.insert(handle, v.clone());
v
}
}
Expand All @@ -481,7 +480,7 @@ fn eval_air_binop(
lhs: SymbolicExpression<F>,
rhs: SymbolicExpression<F>,
dest: Option<&str>,
cache: &mut HashMap<u32, String>,
cache: &mut HashMap<SymbolicNodeRef<F>, String>,
res: &mut String,
ctr: &mut Counter,
) -> String {
Expand Down