From 898ca2a7aeb3a603e4800abfea8244778b64b552 Mon Sep 17 00:00:00 2001 From: Tarek Date: Tue, 12 Aug 2025 11:27:43 -0700 Subject: [PATCH] feat: decouple the solver into a standalone crate --- Cargo.lock | 14 + Cargo.toml | 1 + solver/Cargo.toml | 27 ++ solver/README.md | 151 ++++++++++ solver/src/backend/mod.rs | 3 + solver/src/backend/z3/mod.rs | 58 ++++ solver/src/backend/z3/node.rs | 186 +++++++++++++ solver/src/backend/z3/serdes.rs | 193 +++++++++++++ solver/src/backend/z3/solve.rs | 263 ++++++++++++++++++ solver/src/bin/main.rs | 286 +++++++++++++++++++ solver/src/lib.rs | 85 ++++++ solver/test_data/hello_world_simple.jsonl | 1 + solver/test_data/is_sorted_complex.jsonl | 2 + solver/tests/basic.rs | 133 +++++++++ solver/tests/binary_integration.rs | 104 +++++++ solver/tests/samples.rs | 324 ++++++++++++++++++++++ 16 files changed, 1831 insertions(+) create mode 100644 solver/Cargo.toml create mode 100644 solver/README.md create mode 100644 solver/src/backend/mod.rs create mode 100644 solver/src/backend/z3/mod.rs create mode 100644 solver/src/backend/z3/node.rs create mode 100644 solver/src/backend/z3/serdes.rs create mode 100644 solver/src/backend/z3/solve.rs create mode 100644 solver/src/bin/main.rs create mode 100644 solver/src/lib.rs create mode 100644 solver/test_data/hello_world_simple.jsonl create mode 100644 solver/test_data/is_sorted_complex.jsonl create mode 100644 solver/tests/basic.rs create mode 100644 solver/tests/binary_integration.rs create mode 100644 solver/tests/samples.rs diff --git a/Cargo.lock b/Cargo.lock index 4a7209ec..bf452f9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1366,6 +1366,20 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "leafsolver" +version = "0.2.0" +dependencies = [ + "clap", + "common", + "delegate", + "derive_more", + "serde", + "serde_json", + "z3", + "z3-sys", +] + [[package]] name = "libafl" version = "0.15.2" diff --git a/Cargo.toml b/Cargo.toml index 1be31d1e..5599b038 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "integration/libafl/fuzzers/pure_concolic", "macros", "orchestrator", + "solver", ] exclude = [ "runtime/shim", diff --git a/solver/Cargo.toml b/solver/Cargo.toml new file mode 100644 index 00000000..4a94f785 --- /dev/null +++ b/solver/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "leafsolver" +version = { workspace = true } +edition = "2021" + +[lib] +name = "leafsolver" +path = "src/lib.rs" + +[[bin]] +name = "leafsolver" +path = "src/bin/main.rs" + + +[features] +default = ["serde"] +serde = ["dep:serde", "common/serde"] + +[dependencies] +serde = { workspace = true, optional = true } +serde_json = { workspace = true } +clap = { workspace = true } +delegate = { workspace = true } +z3 = { workspace = true } +z3-sys = { workspace = true } +derive_more = { workspace = true } +common = { workspace = true, features = ["z3", "logging", "serde"] } diff --git a/solver/README.md b/solver/README.md new file mode 100644 index 00000000..515dd99b --- /dev/null +++ b/solver/README.md @@ -0,0 +1,151 @@ +# Leaf Solver + +A standalone SMT solver crate built on Z3 + +## Example Usage + +```rust +use leafsolver::{ + AstAndVars, AstNode, BVNode, Constraint, ConstraintKind, Z3Solver, + Config, Context, SolveResult, ast, Ast +}; + +// Create a Z3 context and solver +let context = Context::new(&Config::new()); +let solver: Z3Solver<'_, i32> = Z3Solver::new(&context); + +// Create symbolic variables: a + b == 15, a == 10 +let a = ast::BV::new_const(&context, "a", 32); +let b = ast::BV::new_const(&context, "b", 32); +let ten = ast::BV::from_i64(&context, 10, 32); +let fifteen = ast::BV::from_i64(&context, 15, 32); + +let sum = a.bvadd(&b); +let eq_constraint = sum._eq(&fifteen); // a + b == 15 +let a_eq_ten = a._eq(&ten); // a == 10 + +let variables = vec![ + (1, AstNode::BitVector(BVNode::new(a, true))), + (2, AstNode::BitVector(BVNode::new(b, true))), +]; + +let constraints = vec![ + Constraint { + discr: AstAndVars { value: AstNode::Bool(eq_constraint), variables: variables.clone() }, + kind: ConstraintKind::True, + }, + Constraint { + discr: AstAndVars { value: AstNode::Bool(a_eq_ten), variables }, + kind: ConstraintKind::True, + }, +]; + +// Solve +let result = solver.check(constraints.into_iter()); +match result { + SolveResult::Sat(model) => { + // Solver found: a = 10, b = 5 + println!("Solution found!"); + } + SolveResult::Unsat => println!("No solution exists"), + SolveResult::Unknown => println!("Could not determine"), +} +``` + +> See the [`tests/`](tests/) directory for more examples + +## CLI + +The solver provides a standalone binary for solving constraints from JSONL files + +### Installation + +From the root of the repository, run: + +```bash +cargo install --path solver --bin leafsolver +``` + +### Usage + +```bash +leafsolver [OPTIONS] +``` + +**Options:** +- `-i, --input ` - Input JSONL file with constraints (default: `sym_decisions.jsonl`) +- `-o, --output ` - Output JSON file with results (default: `solver_result.json`) +- `--format ` - Output format for the model: `standard` or `bytes` (default: `standard`) + +**Output Formats:** +- `standard`: Full SmtLibExpr format with SMT-LIB representation +- `bytes`: Raw byte values only (u8) + +> The binary is currently tested using constraint files generated from the `samples/` directory. The `test_data/` folder contains constraint files produced by running Leaf on sample programs. These files are used in the test suite. + +Example usage with the `is_sorted` test case: +```bash +# Standard format (default) +leafsolver -i test_data/is_sorted_complex.jsonl -o result.json + +# Bytes format for direct byte values +leafsolver -i test_data/hello_world_simple.jsonl -o result.json --format bytes +``` + +### Input Format + +The binary expects JSONL format where each line contains a constraint entry: + +```json +{ + "step": { + "value": "0:4:2", + "index": 3 + }, + "constraint": { + "discr": { + "decls": { + "1": { + "name": "k!1", + "sort": {"BitVector": {"is_signed": false}}, + "smtlib_rep": "(declare-fun k!1 () (_ BitVec 8))" + } + }, + "sort": "Bool", + "smtlib_rep": "(bvult k!1 #x05)" + }, + "kind": "False" + } +} +``` + +### Output Format + +#### Standard Format (default) + +```json +{ + "result": "sat", + "model": { + "1": { + "decls": {}, + "sort": {"BitVector": {"is_signed": false}}, + "smtlib_rep": "#x08" + } + } +} +``` + +#### Bytes Format + +```json +{ + "result": "sat", + "model": { + "1": 8 + } +} +``` +> The bytes output format only supports 8-bit `BitVector` variables. Other sorts (Bool, Int, etc.) will cause an error. +> +> Byte values must be in the range 0-255 (u8). Larger BitVectors are not supported in bytes format diff --git a/solver/src/backend/mod.rs b/solver/src/backend/mod.rs new file mode 100644 index 00000000..977803f0 --- /dev/null +++ b/solver/src/backend/mod.rs @@ -0,0 +1,3 @@ +pub mod z3; + +pub use z3::Z3Solver; diff --git a/solver/src/backend/z3/mod.rs b/solver/src/backend/z3/mod.rs new file mode 100644 index 00000000..47e55f55 --- /dev/null +++ b/solver/src/backend/z3/mod.rs @@ -0,0 +1,58 @@ +mod node; +#[cfg(feature = "serde")] +pub mod serdes; +mod solve; + +use std::{collections::HashMap, hash::Hash}; + +use z3::ast::{self, Ast}; + +use crate::{solver::SolveResult, solver::Solver}; +use common::types::trace::Constraint; + +pub use node::*; +pub use solve::{WrappedSolver, set_global_params}; + +pub trait BVExt { + fn as_u128(&self) -> Option; +} + +impl<'ctx> BVExt for ast::BV<'ctx> { + fn as_u128(&self) -> Option { + if self.get_size() <= 128 { + unsafe { + use std::ffi::CStr; + Some(z3_sys::Z3_get_numeral_string( + self.get_ctx().get_z3_context(), + self.get_z3_ast(), + )) + .filter(|x| !x.is_null()) + .map(|x| CStr::from_ptr(x)) + .and_then(|s| s.to_str().ok()) + .and_then(|s| u128::from_str_radix(s, 10).ok()) + } + } else { + None + } + } +} + +/// Z3-based solver implementation +pub type Z3Solver<'ctx, I> = WrappedSolver<'ctx, I>; + +impl<'a, 'ctx: 'a, I> Solver for Z3Solver<'ctx, I> +where + I: Eq + Hash + Clone, + Self: 'ctx, +{ + type Value = AstAndVars<'ctx, I>; + type Case = AstNode<'ctx>; + type Model = HashMap>; + + fn check( + &mut self, + constraints: impl Iterator>, + ) -> SolveResult { + Z3Solver::check(self, constraints) + } +} diff --git a/solver/src/backend/z3/node.rs b/solver/src/backend/z3/node.rs new file mode 100644 index 00000000..a3749089 --- /dev/null +++ b/solver/src/backend/z3/node.rs @@ -0,0 +1,186 @@ +use std::prelude::rust_2021::*; + +use derive_more as dm; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; +use z3::ast::{self, Ast}; + +/* NOTE: Why not using `Dynamic`? + * In this way we have a little more freedom to include our information such + * as whether the bit vector is signed or not. + */ +#[derive(Debug, Clone, PartialEq, Eq, dm::Display)] +#[display("{_0}")] +pub enum AstNode<'ctx> { + Bool(ast::Bool<'ctx>), + BitVector(BVNode<'ctx>), + Array(ArrayNode<'ctx>), +} + +impl<'ctx> From> for AstNode<'ctx> { + fn from(node: BVNode<'ctx>) -> Self { + Self::BitVector(node) + } +} + +impl<'ctx> From> for AstNode<'ctx> { + fn from(node: ArrayNode<'ctx>) -> Self { + Self::Array(node) + } +} + +#[derive(Debug, Clone, dm::Display, PartialEq, Eq)] +#[display("{_0}")] +pub struct BVNode<'ctx>(pub ast::BV<'ctx>, pub BVSort); + +impl<'ctx> BVNode<'ctx> { + pub fn new(ast: ast::BV<'ctx>, is_signed: bool) -> Self { + Self(ast, BVSort { is_signed }) + } + + #[inline] + pub fn map(&self, f: F) -> Self + where + F: FnOnce(&ast::BV<'ctx>) -> ast::BV<'ctx>, + { + Self(f(&self.0), self.1) + } + + #[inline(always)] + pub fn is_signed(&self) -> bool { + self.1.is_signed + } + + #[inline(always)] + pub fn size(&self) -> u32 { + self.0.get_size() + } +} + +#[derive(Debug, Clone, PartialEq, Eq, dm::Display)] +#[display("{_0}")] +pub struct ArrayNode<'ctx>(pub ast::Array<'ctx>, pub ArraySort); + +#[derive(Debug, Clone, PartialEq, Eq, dm::From)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum AstNodeSort { + Bool, + BitVector(BVSort), + Array(ArraySort), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct BVSort { + pub is_signed: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq, dm::From)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct ArraySort { + pub range: Box, +} + +impl<'ctx> From> for AstNode<'ctx> { + fn from(ast: ast::Bool<'ctx>) -> Self { + Self::Bool(ast) + } +} + +impl<'ctx> AstNode<'ctx> { + pub fn from_ubv(ast: ast::BV<'ctx>) -> Self { + BVNode::new(ast, false).into() + } + + pub fn from_ast(ast: ast::Dynamic<'ctx>, sort: &AstNodeSort) -> Self { + match sort { + AstNodeSort::Bool => ast.as_bool().map(Self::Bool), + AstNodeSort::BitVector(sort) => { + ast.as_bv().map(|ast| Self::BitVector(BVNode(ast, *sort))) + } + AstNodeSort::Array(sort) => ast + .as_array() + .map(|ast| Self::Array(ArrayNode(ast, sort.clone()))), + } + .unwrap_or_else(|| { + panic!( + "Sort of ${:?} is not compatible with the expected one.", + ast + ) + }) + } +} + +impl<'ctx> AstNode<'ctx> { + pub fn as_bool(&self) -> &ast::Bool<'ctx> { + match self { + Self::Bool(ast) => ast, + _ => panic!("Expected the value to be a boolean expression."), + } + } + + pub fn as_bit_vector(&self) -> &ast::BV<'ctx> { + match self { + Self::BitVector(BVNode(ast, _)) => ast, + _ => panic!("Expected the value to be a bit vector: {:?}", self), + } + } + + pub fn unwrap_as_bit_vector(self) -> ast::BV<'ctx> { + match self { + Self::BitVector(BVNode(ast, _)) => ast, + _ => panic!("Expected the value to be a bit vector: {:?}", self), + } + } +} + +impl<'ctx> AstNode<'ctx> { + pub fn ast(&self) -> &dyn ast::Ast<'ctx> { + match self { + Self::Bool(ast) => ast, + Self::BitVector(BVNode(ast, _)) => ast, + Self::Array(ArrayNode(ast, _)) => ast, + } + } + + pub fn dyn_ast(&self) -> ast::Dynamic<'ctx> { + ast::Dynamic::from_ast(self.ast()) + } + + pub fn sort(&self) -> AstNodeSort { + match self { + Self::Bool(_) => AstNodeSort::Bool, + Self::BitVector(BVNode(_, sort)) => AstNodeSort::BitVector(*sort), + Self::Array(ArrayNode(_, sort)) => AstNodeSort::Array(sort.clone()), + } + } + + pub fn z3_sort(&self) -> z3::Sort<'ctx> { + match self { + Self::Bool(ast) => ast.get_sort(), + Self::BitVector(BVNode(ast, _)) => ast.get_sort(), + Self::Array(ArrayNode(ast, _)) => ast.get_sort(), + } + } + + pub fn to_smtlib2(&self) -> String { + macro_rules! to_smt_string { + ($ast:expr) => { + $ast.simplify().to_string() + }; + } + match self { + Self::Bool(ast) => to_smt_string!(ast), + Self::BitVector(BVNode(ast, _)) => to_smt_string!(ast), + Self::Array(ArrayNode(ast, _)) => to_smt_string!(ast), + } + } +} + +#[derive(Debug, Clone, dm::Deref, dm::Display)] +#[display("{value}")] +pub struct AstAndVars<'ctx, I> { + #[deref] + pub value: AstNode<'ctx>, + pub variables: Vec<(I, AstNode<'ctx>)>, +} diff --git a/solver/src/backend/z3/serdes.rs b/solver/src/backend/z3/serdes.rs new file mode 100644 index 00000000..375891cb --- /dev/null +++ b/solver/src/backend/z3/serdes.rs @@ -0,0 +1,193 @@ +use core::hash::Hash; +use core::{iter, str::FromStr}; +use std::prelude::rust_2021::*; +use std::{collections::HashMap, ffi, format}; + +use derive_more as dm; +use serde::{Deserialize, Serialize}; +use z3::{Context, ast, ast::Ast}; +use z3_sys::{ + Z3_ast_vector_get, Z3_ast_vector_size, Z3_get_app_decl, Z3_get_decl_name, + Z3_parse_smtlib2_string, Z3_to_app, +}; + +use super::node::{AstAndVars, AstNode, AstNodeSort}; + +#[derive(Debug, Serialize, Deserialize)] +struct VarDecl { + name: String, + sort: AstNodeSort, + smtlib_rep: String, +} + +#[derive(Debug, Serialize, Deserialize, dm::Display)] +#[display("{smtlib_rep}")] +pub struct Expr { + pub sort: AstNodeSort, + pub smtlib_rep: String, +} + +#[derive(Debug, Serialize, Deserialize, dm::Display)] +#[display("{expr}")] +pub struct SmtLibExpr { + #[serde(default)] + decls: HashMap, + #[serde(flatten)] + expr: Expr, +} + +impl From for SmtLibExpr { + fn from(value: Expr) -> Self { + Self { + decls: Default::default(), + expr: value, + } + } +} + +impl<'ctx, I: ToString + FromStr> AstAndVars<'ctx, I> { + pub fn serializable(&self) -> impl Serialize { + SmtLibExpr { + expr: Expr { + sort: self.value.sort(), + smtlib_rep: self.value.to_smtlib2(), + }, + decls: self + .variables + .iter() + .map(|(id, node)| { + (id.to_string(), { + let decl = node + .ast() + .safe_decl() + .expect("Variable is expected to have a declaration"); + VarDecl { + name: decl.name(), + sort: node.sort(), + smtlib_rep: decl.to_string(), + } + }) + }) + .collect(), + } + } + + pub fn parse(context: &'ctx Context, smtlib: &SmtLibExpr) -> Self { + let variables = smtlib + .decls + .iter() + .map(|(id, decl)| { + ( + I::from_str(id).unwrap_or_else(|_| panic!("Invalid id: {id}")), + parse_var_decl(context, decl), + ) + }) + .collect::>(); + let value = parse_expr( + context, + &smtlib.expr, + variables.iter().map(|(_, node)| node), + ); + Self { variables, value } + } +} + +impl SmtLibExpr { + pub fn parse<'ctx, I: FromStr + Eq + Hash + Clone>( + &self, + context: &'ctx Context, + vars: &mut HashMap>, + ) -> AstAndVars<'ctx, I> { + let variables = self + .decls + .iter() + .map(|(id, decl)| { + let id = I::from_str(id).unwrap_or_else(|_| panic!("Invalid id: {id}")); + let decl = vars + .entry(id.clone()) + .or_insert_with(|| parse_var_decl(context, decl)) + .clone(); + (id, decl) + }) + .collect::>(); + let value = parse_expr(context, &self.expr, variables.iter().map(|(_, node)| node)); + AstAndVars { variables, value } + } + + pub fn parse_as_const<'ctx>(&self, context: &'ctx Context) -> Option> { + self.decls + .is_empty() + .then(|| parse_expr(context, &self.expr, iter::empty())) + } +} + +impl<'ctx> AstNode<'ctx> { + pub fn serializable(&self) -> impl Serialize { + Expr { + sort: self.sort(), + smtlib_rep: self.to_smtlib2(), + } + } +} + +fn parse_var_decl<'ctx>(context: &'ctx Context, decl: &VarDecl) -> AstNode<'ctx> { + let smtlib = [ + decl.smtlib_rep.as_str(), + dummy_assertion(&decl.name).as_str(), + ] + .concat(); + let dummy_ast = unsafe { parse_single_expr(context, smtlib, iter::empty()) }; + extract_expr_from_dummy(dummy_ast, &decl.sort) +} + +fn parse_expr<'ctx: 'a, 'a>( + context: &'ctx Context, + expr: &Expr, + decls: impl Iterator>, +) -> AstNode<'ctx> { + let smtlib = dummy_assertion(&expr.smtlib_rep); + let dummy_ast = unsafe { parse_single_expr(context, smtlib, decls.map(|d| d.ast())) }; + extract_expr_from_dummy(dummy_ast, &expr.sort) +} + +fn dummy_assertion(expr: &str) -> String { + format!("(assert (= {expr} {expr}))") +} + +fn extract_expr_from_dummy<'ctx>(ast: ast::Dynamic<'ctx>, sort: &AstNodeSort) -> AstNode<'ctx> { + assert_eq!(ast.num_children(), 2); + AstNode::from_ast( + ast.nth_child(0).expect("Unexpected structure").simplify(), + sort, + ) +} + +unsafe fn parse_single_expr<'ctx: 'a, 'a, S: Into>>( + context: &'ctx Context, + smtlib: S, + decls: impl Iterator + 'a)>, +) -> ast::Dynamic<'ctx> { + let c = context.get_z3_context(); + let decls = decls + .map(|d| { + let app = Z3_to_app(d.get_ctx().get_z3_context(), d.get_z3_ast()); + Z3_get_app_decl(context.get_z3_context(), app) + }) + .collect::>(); + let decl_names = decls + .iter() + .map(|d| Z3_get_decl_name(context.get_z3_context(), *d)) + .collect::>(); + let vec = Z3_parse_smtlib2_string( + c, + ffi::CString::new(smtlib).unwrap().as_ptr(), + 0, + core::ptr::null(), + core::ptr::null(), + decls.len() as u32, + decl_names.as_ptr(), + decls.as_ptr(), + ); + assert_eq!(Z3_ast_vector_size(c, vec), 1); + ast::Dynamic::wrap(context, Z3_ast_vector_get(c, vec, 0)) +} diff --git a/solver/src/backend/z3/solve.rs b/solver/src/backend/z3/solve.rs new file mode 100644 index 00000000..b2fc70cf --- /dev/null +++ b/solver/src/backend/z3/solve.rs @@ -0,0 +1,263 @@ +use std::prelude::rust_2021::*; +use std::{collections::HashMap, hash::Hash}; + +use common::{log_debug, utils}; +use delegate::delegate; +use z3::{ + self, Context, Model, Optimize, SatResult, Solver, + ast::{self, Ast}, +}; + +use super::node::*; +use common::types::trace::{Constraint, ConstraintKind}; + +enum SolverImpl<'ctx> { + Solver(Solver<'ctx>), + Optimize(Optimize<'ctx>), +} + +trait Z3Solver<'ctx> { + fn push(&self); + fn pop(&self); + + fn assert(&self, ast: &ast::Bool<'ctx>); + fn check(&self) -> SatResult; + fn get_model(&self) -> Option>; +} + +impl<'ctx> Z3Solver<'ctx> for Solver<'ctx> { + delegate! { + to self { + fn push(&self); + + fn assert(&self, ast: &ast::Bool<'ctx>); + fn check(&self) -> SatResult; + fn get_model(&self) -> Option>; + } + } + + fn pop(&self) { + self.pop(1); + } +} + +impl<'ctx> Z3Solver<'ctx> for Optimize<'ctx> { + delegate! { + to self { + fn push(&self); + fn pop(&self); + + fn assert(&self, ast: &ast::Bool<'ctx>); + fn get_model(&self) -> Option>; + } + } + + fn check(&self) -> SatResult { + self.check(&[]) + } +} + +impl<'ctx> Z3Solver<'ctx> for SolverImpl<'ctx> { + delegate! { + to match self { + Self::Solver(solver) => solver, + Self::Optimize(optimize) => optimize, + } { + #[through(Z3Solver)] + fn push(&self); + #[through(Z3Solver)] + fn pop(&self); + + #[through(Z3Solver)] + fn assert(&self, ast: &ast::Bool<'ctx>); + #[through(Z3Solver)] + fn check(&self) -> SatResult; + #[through(Z3Solver)] + fn get_model(&self) -> Option>; + } + } +} + +pub struct WrappedSolver<'ctx, I> { + context: &'ctx Context, + solver: SolverImpl<'ctx>, + _phantom: core::marker::PhantomData<(I,)>, +} + +impl<'ctx, I> WrappedSolver<'ctx, I> { + pub fn new_in_global_context() -> Self { + Self::new(context::get_context_for_thread()) + } + + pub fn new(context: &'ctx Context) -> Self { + Self { + context, + solver: SolverImpl::Solver(Solver::new(context)), + _phantom: Default::default(), + } + } + + pub fn context(&self) -> &'ctx Context { + self.context + } +} + +impl Default for WrappedSolver<'_, I> { + fn default() -> Self { + Self::new_in_global_context() + } +} + +impl<'ctx, I> Clone for WrappedSolver<'ctx, I> { + fn clone(&self) -> Self { + // Prevent cloning the assumptions in the solver + Self::new(self.context) + } +} + +impl<'ctx, I> WrappedSolver<'ctx, I> +where + I: Eq + Hash, +{ + pub fn check( + &self, + constraints: impl Iterator, AstNode<'ctx>>>, + ) -> crate::solver::SolveResult>> { + let mut all_vars = HashMap::::new(); + let asts = constraints + .map(|constraint| { + let Constraint { discr, kind } = constraint; + use ConstraintKind::*; + let (kind, negated) = match kind { + True => (True, false), + False => (True, true), + OneOf(options) => (OneOf(options), false), + NoneOf(options) => (OneOf(options), true), + }; + + let ast = match kind { + True => discr.value.as_bool().clone(), + OneOf(cases) => { + let value_ast = ast::Dynamic::from_ast(discr.value.ast()); + cases + .iter() + .map(|c| ast::Dynamic::from_ast(c.ast())) + .map(|c| value_ast._eq(&c)) + .reduce(|all, m| all.xor(&m)) + .unwrap() + } + _ => unreachable!(), + }; + all_vars.extend(discr.variables.into_iter()); + if negated { ast.not() } else { ast } + }) + .collect::>(); + + self.check_using(&self.solver, &asts, all_vars) + } + + fn check_using( + &self, + solver: &(impl Z3Solver<'ctx> + ?Sized), + constraints: &[ast::Bool<'ctx>], + vars: HashMap>, + ) -> crate::solver::SolveResult>> { + log_debug!("Sending constraints to Z3: {:#?}", constraints); + + solver.push(); + + for constraint in constraints { + solver.assert(constraint); + } + + let result = match solver.check() { + SatResult::Sat => { + let model = solver.get_model().unwrap(); + let mut values = HashMap::new(); + for (id, node) in vars { + let value = match node { + AstNode::Bool(ast) => AstNode::Bool(model.eval(&ast, true).unwrap()), + AstNode::BitVector(BVNode(ast, is_signed)) => { + AstNode::BitVector(BVNode(model.eval(&ast, true).unwrap(), is_signed)) + } + AstNode::Array(ArrayNode(ast, sort)) => { + AstNode::Array(ArrayNode(model.eval(&ast, true).unwrap(), sort)) + } + }; + values.insert(id, value.into()); + } + crate::solver::SolveResult::Sat(values) + } + SatResult::Unsat => crate::solver::SolveResult::Unsat, + SatResult::Unknown => crate::solver::SolveResult::Unknown, + }; + + solver.pop(); + result + } +} + +impl<'ctx, I> WrappedSolver<'ctx, I> +where + I: Eq + Hash, +{ + pub fn consider_possible_answer(&mut self, var: AstNode<'ctx>, answer: AstNode<'ctx>) { + if let SolverImpl::Solver(..) = self.solver { + self.solver = SolverImpl::Optimize(Optimize::new(self.context)); + } + let SolverImpl::Optimize(optimize) = &mut self.solver else { + unreachable!(); + }; + + optimize.assert_soft(&var.dyn_ast()._eq(&answer.dyn_ast()), 1, None); + } +} + +mod context { + use std::{ + collections::HashMap, + sync::{Mutex, OnceLock}, + thread::ThreadId, + }; + + use z3::Config; + + use super::*; + use utils::{UnsafeSend, UnsafeSync}; + + static CONTEXTS: OnceLock>>> = OnceLock::new(); + static THREAD_MAP: OnceLock>> = OnceLock::new(); + + pub fn set_global_params, V: AsRef>(params: impl Iterator) { + for (k, v) in params { + log_debug!("Setting global param: {} = {}", k.as_ref(), v.as_ref()); + z3::set_global_param(k.as_ref(), v.as_ref()); + } + } + + fn init_contexts() -> Vec>> { + // Statically allocate some in advance. + const TOTAL_CONTEXTS: usize = 1; + + let mut list = Vec::with_capacity(TOTAL_CONTEXTS); + for _ in 0..TOTAL_CONTEXTS { + list.push(UnsafeSync::new(UnsafeSend::new(Context::new( + &Config::new(), + )))); + } + list + } + + pub(super) fn get_context_for_thread() -> &'static Context { + let contexts = CONTEXTS.get_or_init(init_contexts); + let thread_id = std::thread::current().id(); + let mut thread_map = THREAD_MAP.get_or_init(Default::default).lock().unwrap(); + let accessor_count = thread_map.len(); + let index = *thread_map.entry(thread_id).or_insert(accessor_count); + let context = &contexts + .get(index) + .expect("Unexpected number of threads to access Z3 context"); + *context + } +} +pub use context::set_global_params; diff --git a/solver/src/bin/main.rs b/solver/src/bin/main.rs new file mode 100644 index 00000000..d51b864f --- /dev/null +++ b/solver/src/bin/main.rs @@ -0,0 +1,286 @@ +use std::{ + collections::HashMap, + fs::File, + io::{BufRead, BufReader, Write}, + path::{Path, PathBuf}, + process::ExitCode, +}; + +use clap::Parser; + +use leafsolver::{ + Config, Constraint, Context, SolveResult, + backend::z3::{AstAndVars, AstNode, Z3Solver, serdes::SmtLibExpr}, + format::{ModelFormat, OutputFormat, SolverOutput}, +}; + +#[derive(Parser, Debug)] +#[command(name = "leafsolver")] +#[command(version = "0.2.0")] +#[command(about = "Leaf SMT Solver - solves constraints from JSONL files")] +struct Args { + /// Input JSONL file with constraints + #[arg(short, long, default_value = "sym_decisions.jsonl")] + input: PathBuf, + /// Output JSON file with results + #[arg(short, long, default_value = "solver_result.json")] + output: PathBuf, + /// Output format for the model + #[arg(long, default_value = "standard")] + format: OutputFormat, +} + +fn main() -> ExitCode { + let args = Args::parse(); + + match run(args) { + Ok(()) => ExitCode::SUCCESS, + Err(e) => { + eprintln!("Error: {}", e); + ExitCode::FAILURE + } + } +} + +fn run(args: Args) -> Result<(), Box> { + println!("Reading constraints from: {}", args.input.display()); + println!("Writing results to: {}", args.output.display()); + + let constraint_entries = read_constraint_entries(&args.input)?; + let constraint_count = constraint_entries.len(); + println!("Loaded {} constraint entries", constraint_count); + + if constraint_entries.is_empty() { + let output = SolverOutput { + result: "unsat".to_string(), + model: None, + }; + write_result(&args.output, &output)?; + println!("No constraints found - wrote UNSAT result"); + return Ok(()); + } + + let constraints = constraint_entries; + let output = solve_constraints(constraints, args.format)?; + + write_result(&args.output, &output)?; + + match output.result.as_str() { + "sat" => println!( + "✓ SAT - Solution found and written to {}", + args.output.display() + ), + "unsat" => println!("✗ UNSAT - No solution exists"), + "unknown" => println!("? UNKNOWN - Could not determine satisfiability"), + _ => unreachable!("Unexpected result: {}", output.result), + } + + Ok(()) +} + +fn read_constraint_entries( + filename: &Path, +) -> Result>, Box> { + if !filename.exists() { + return Err(format!("Input file '{}' not found", filename.display()).into()); + } + + let file = File::open(filename)?; + let reader = BufReader::new(file); + let mut entries = Vec::new(); + + for (line_num, line) in reader.lines().enumerate() { + let line = line?; + if line.trim().is_empty() { + continue; + } + + // Parse JSON and extract only the constraint field (ignore step info) + let mut json_value: serde_json::Value = match serde_json::from_str(&line) { + Ok(value) => value, + Err(e) => { + return Err(format!("Error parsing JSON on line {}: {}", line_num + 1, e).into()); + } + }; + + let constraint_value = json_value + .as_object_mut() + .and_then(|obj| obj.remove("constraint")) + .ok_or_else(|| format!("Missing 'constraint' field on line {}", line_num + 1))?; + + match serde_json::from_value::>(constraint_value) { + Ok(constraint) => entries.push(constraint), + Err(e) => { + return Err( + format!("Error parsing constraint on line {}: {}", line_num + 1, e).into(), + ); + } + } + } + + Ok(entries) +} + +fn solve_constraints( + constraints: Vec< + Constraint< + leafsolver::backend::z3::serdes::SmtLibExpr, + leafsolver::backend::z3::serdes::SmtLibExpr, + >, + >, + format: OutputFormat, +) -> Result> { + let context = Context::new(&Config::new()); + let solver: Z3Solver<'_, String> = Z3Solver::new(&context); + + // Convert serializable constraints to internal format + let mut variable_map: HashMap = HashMap::new(); + let internal_constraints: Vec, AstNode<'_>>> = constraints + .into_iter() + .map(|c| convert_constraint(&context, c, &mut variable_map)) + .collect::, _>>()?; + + println!( + "Solving {} constraints with {} variables", + internal_constraints.len(), + variable_map.len() + ); + + let result = solver.check(internal_constraints.into_iter()); + match result { + SolveResult::Sat(model) => { + let model_format = match format { + OutputFormat::Standard => { + let serializable_model: HashMap< + String, + leafsolver::backend::z3::serdes::SmtLibExpr, + > = model + .into_iter() + .map(|(id, ast_node)| { + let expr = leafsolver::backend::z3::serdes::Expr { + sort: ast_node.sort(), + smtlib_rep: ast_node.to_smtlib2(), + }; + (id, expr.into()) + }) + .collect(); + ModelFormat::Standard(serializable_model) + } + OutputFormat::Bytes => { + let bytes_model: Result, Box> = + model + .into_iter() + .map(|(id, ast_node)| { + let id_clone = id.clone(); + convert_ast_node_to_byte(&ast_node) + .map(|byte_val| (id, byte_val)) + .map_err(|e| { + format!("Error converting variable {}: {}", id_clone, e) + .into() + }) + }) + .collect(); + ModelFormat::Bytes(bytes_model?) + } + }; + + Ok(SolverOutput { + result: "sat".to_string(), + model: Some(model_format), + }) + } + SolveResult::Unsat => Ok(SolverOutput { + result: "unsat".to_string(), + model: None, + }), + SolveResult::Unknown => Ok(SolverOutput { + result: "unknown".to_string(), + model: None, + }), + } +} + +/// Converts serialized constraints from JSONL format to live Z3 constraint objects. +/// +/// This conversion is needed because: +/// - `SmtLibExpr` contains serialized constraint data (strings, metadata) that can be stored in files +/// - `AstAndVars`/`AstNode` are live Z3 AST objects bound to a specific Z3 Context +/// - Z3Solver.check() requires live Z3 objects, not serialized strings +fn convert_constraint<'ctx>( + context: &'ctx Context, + constraint: Constraint< + leafsolver::backend::z3::serdes::SmtLibExpr, + leafsolver::backend::z3::serdes::SmtLibExpr, + >, + variable_map: &mut HashMap>, +) -> Result, AstNode<'ctx>>, Box> { + // Parse the discriminant (the expression being constrained) + let discr = constraint.discr.parse(context, variable_map); + + // Convert the constraint kind + let kind = match constraint.kind { + leafsolver::ConstraintKind::True => leafsolver::ConstraintKind::True, + leafsolver::ConstraintKind::False => leafsolver::ConstraintKind::False, + leafsolver::ConstraintKind::OneOf(cases) => { + let parsed_cases: Result, _> = cases + .into_iter() + .map(|case| { + case.parse_as_const(context) + .ok_or_else(|| "Case contains variables - not supported as constant") + }) + .collect(); + leafsolver::ConstraintKind::OneOf(parsed_cases?) + } + leafsolver::ConstraintKind::NoneOf(cases) => { + let parsed_cases: Result, _> = cases + .into_iter() + .map(|case| { + case.parse_as_const(context) + .ok_or_else(|| "Case contains variables - not supported as constant") + }) + .collect(); + leafsolver::ConstraintKind::NoneOf(parsed_cases?) + } + }; + + Ok(Constraint { discr, kind }) +} + +/// Converts a Z3 AST node to a byte value +/// Only supports u8 bit vectors - errors for other types +fn convert_ast_node_to_byte(ast_node: &AstNode) -> Result> { + match ast_node { + AstNode::BitVector(bv_node) => { + let bv = &bv_node.0; + let size = bv.get_size(); + + if size != 8 { + return Err(format!( + "Expected 8-bit value for byte conversion, got {}-bit value", + size + ) + .into()); + } + + match bv.as_u64() { + Some(value) if value <= u8::MAX as u64 => Ok(value as u8), + Some(value) => { + Err(format!("Value {} is too large for u8 (max: {})", value, u8::MAX).into()) + } + None => Err("Could not convert bit vector to integer".into()), + } + } + _ => Err(format!( + "Cannot convert {:?} to byte - only BitVector supported", + ast_node.sort() + ) + .into()), + } +} + +fn write_result(filename: &Path, result: &SolverOutput) -> Result<(), Box> { + let json = serde_json::to_string_pretty(result)?; + let mut file = File::create(filename)?; + file.write_all(json.as_bytes())?; + Ok(()) +} diff --git a/solver/src/lib.rs b/solver/src/lib.rs new file mode 100644 index 00000000..f59dd5e1 --- /dev/null +++ b/solver/src/lib.rs @@ -0,0 +1,85 @@ +pub mod backend; + +pub use crate::backend::z3::{AstAndVars, AstNode, BVNode, Z3Solver, set_global_params}; +pub use crate::solver::{Model, SolveResult, Solver}; +pub use common::types::trace::{Constraint, ConstraintKind}; + +#[cfg(feature = "serde")] +pub use crate::format::{ModelFormat, OutputFormat}; + +// Re-export essential Z3 types for standalone usage +pub use z3::ast::{self, Ast}; +pub use z3::{Config, Context}; + +pub mod solver { + use common::types::trace::Constraint; + use std::collections::HashMap; + + /// Core solver trait that all backend implementations must provide + pub trait Solver { + type Value; + type Case; + type Model; + + /// Check satisfiability of the given constraints + fn check( + &mut self, + constraints: impl Iterator>, + ) -> SolveResult; + } + + pub type Model = HashMap; + + /// The result of the checking performed by [`Solver`] + pub enum SolveResult { + Sat(M), + Unsat, + Unknown, + } +} + +// Serialization support for binary interface +#[cfg(feature = "serde")] +pub mod format { + use crate::backend::z3::serdes::SmtLibExpr; + use common::types::trace::Constraint; + use serde::{Deserialize, Serialize}; + + /// Input format for JSONL constraint files + #[derive(Debug, Serialize, Deserialize)] + pub struct ConstraintEntry { + pub step: StepInfo, + pub constraint: Constraint, + } + + /// Step information from the execution trace + #[derive(Debug, Serialize, Deserialize)] + pub struct StepInfo { + pub value: String, + pub index: u32, + } + + /// Output format configuration + #[derive(Debug, Clone, Copy, clap::ValueEnum)] + pub enum OutputFormat { + /// Standard format with SmtLibExpr values + Standard, + /// Raw bytes format - variables mapped to byte values + Bytes, + } + + /// Model representation based on output format + #[derive(Debug, Serialize, Deserialize)] + #[serde(untagged)] + pub enum ModelFormat { + Standard(std::collections::HashMap), + Bytes(std::collections::HashMap), + } + + /// Output format for solver results + #[derive(Debug, Serialize, Deserialize)] + pub struct SolverOutput { + pub result: String, // "sat", "unsat", or "unknown" + pub model: Option, + } +} diff --git a/solver/test_data/hello_world_simple.jsonl b/solver/test_data/hello_world_simple.jsonl new file mode 100644 index 00000000..09895735 --- /dev/null +++ b/solver/test_data/hello_world_simple.jsonl @@ -0,0 +1 @@ +{"step":{"value":"0:4:2","index":3},"constraint":{"discr":{"decls":{"1":{"name":"k!1","sort":{"BitVector":{"is_signed":false}},"smtlib_rep":"(declare-fun k!1 () (_ BitVec 8))"}},"sort":"Bool","smtlib_rep":"(not (bvule #x05 k!1))"},"kind":"False"}} diff --git a/solver/test_data/is_sorted_complex.jsonl b/solver/test_data/is_sorted_complex.jsonl new file mode 100644 index 00000000..6bc4cd92 --- /dev/null +++ b/solver/test_data/is_sorted_complex.jsonl @@ -0,0 +1,2 @@ +{"step":{"value":"0:10:5","index":5},"constraint":{"discr":{"decls":{"2":{"name":"k!2","sort":{"BitVector":{"is_signed":false}},"smtlib_rep":"(declare-fun k!2 () (_ BitVec 8))"},"1":{"name":"k!1","sort":{"BitVector":{"is_signed":false}},"smtlib_rep":"(declare-fun k!1 () (_ BitVec 8))"}},"sort":"Bool","smtlib_rep":"(not (bvule k!1 k!2))"},"kind":"True"}} +{"step":{"value":"0:5:5","index":16},"constraint":{"discr":{"decls":{"3":{"name":"k!3","sort":{"BitVector":{"is_signed":false}},"smtlib_rep":"(declare-fun k!3 () (_ BitVec 8))"},"2":{"name":"k!2","sort":{"BitVector":{"is_signed":false}},"smtlib_rep":"(declare-fun k!2 () (_ BitVec 8))"},"1":{"name":"k!1","sort":{"BitVector":{"is_signed":false}},"smtlib_rep":"(declare-fun k!1 () (_ BitVec 8))"},"4":{"name":"k!4","sort":{"BitVector":{"is_signed":false}},"smtlib_rep":"(declare-fun k!4 () (_ BitVec 8))"}},"sort":"Bool","smtlib_rep":"(and (bvule k!1 k!2) (bvule k!2 k!3) (bvule k!3 k!4))"},"kind":"False"}} diff --git a/solver/tests/basic.rs b/solver/tests/basic.rs new file mode 100644 index 00000000..d62a26d4 --- /dev/null +++ b/solver/tests/basic.rs @@ -0,0 +1,133 @@ +use std::collections::HashMap; + +use leafsolver::{ + Ast, AstAndVars, AstNode, BVNode, Config, Constraint, ConstraintKind, Context, SolveResult, + Z3Solver, ast, +}; + +#[test] +fn test_constraint_creation() { + let constraint = Constraint::equality("x", 5); + + assert_eq!(constraint.discr, "x"); + match constraint.kind { + ConstraintKind::OneOf(ref values) => { + assert_eq!(values.len(), 1); + assert_eq!(values[0], 5); + } + _ => panic!("Expected OneOf constraint"), + } +} + +#[test] +fn test_constraint_negation() { + let constraint = Constraint::equality("x", 5); + let negated = constraint.not(); + + match negated.kind { + ConstraintKind::NoneOf(ref values) => { + assert_eq!(values.len(), 1); + assert_eq!(values[0], 5); + } + _ => panic!("Expected NoneOf constraint after negation"), + } +} + +#[test] +fn test_constraint_kind_operations() { + assert_eq!( + ConstraintKind::::True.not(), + ConstraintKind::::False + ); + assert_eq!( + ConstraintKind::::False.not(), + ConstraintKind::::True + ); + + let one_of = ConstraintKind::OneOf(vec![1, 2, 3]); + let negated = one_of.not(); + match negated { + ConstraintKind::NoneOf(values) => { + assert_eq!(values, vec![1, 2, 3]); + } + _ => panic!("Expected NoneOf after negating OneOf"), + } +} + +#[test] +fn test_solve_result_enum() { + let sat_result: SolveResult> = SolveResult::Sat(HashMap::new()); + let unsat_result: SolveResult> = SolveResult::Unsat; + let unknown_result: SolveResult> = SolveResult::Unknown; + + match sat_result { + SolveResult::Sat(_) => (), + _ => panic!("Expected Sat"), + } + + match unsat_result { + SolveResult::Unsat => (), + _ => panic!("Expected Unsat"), + } + + match unknown_result { + SolveResult::Unknown => (), + _ => panic!("Expected Unknown"), + } +} + +#[test] +fn test_z3_solver_basic() { + let context = Context::new(&Config::new()); + let solver: Z3Solver<'_, i32> = Z3Solver::new(&context); + + // Simple boolean constraint: true + let true_ast = ast::Bool::from_bool(&context, true); + let constraint = Constraint { + discr: AstAndVars { + value: AstNode::Bool(true_ast), + variables: vec![], + }, + kind: ConstraintKind::True, + }; + + let result = solver.check(std::iter::once(constraint)); + + match result { + SolveResult::Sat(model) => { + assert!(model.is_empty()); + } + _ => panic!("Expected SAT result for true constraint"), + } +} + +#[test] +fn test_bitvector_constraint() { + let context = Context::new(&Config::new()); + let solver: Z3Solver<'_, i32> = Z3Solver::new(&context); + + let x_var = ast::BV::new_const(&context, "x", 8); + let eq_ast = x_var._eq(&x_var); // x == x (always true) + + let constraint = Constraint { + discr: AstAndVars { + value: AstNode::Bool(eq_ast), + variables: vec![(1, AstNode::BitVector(BVNode::new(x_var, false)))], + }, + kind: ConstraintKind::True, + }; + + let result = solver.check(std::iter::once(constraint)); + + match result { + SolveResult::Sat(model) => { + assert!(model.contains_key(&1)); + if let Some(AstNode::BitVector(_)) = model.get(&1) { + // The only correct case + } else { + panic!("Expected BitVector solution for variable 1"); + } + } + _ => panic!("Expected SAT result for x == x constraint"), + } +} diff --git a/solver/tests/binary_integration.rs b/solver/tests/binary_integration.rs new file mode 100644 index 00000000..d37af001 --- /dev/null +++ b/solver/tests/binary_integration.rs @@ -0,0 +1,104 @@ +use std::{fs, path::Path, process::Command}; + +#[test] +fn test_hello_world_simple_constraint() { + // Test case from samples/hello_world.rs: + // A simple program that checks if x < 5 where x is marked symbolic. + // Generated constraint: NOT (5 <= k!1) should be False + // This means we want 5 <= k!1 to be True (i.e., k!1 >= 5) + let binary_path = env!("CARGO_BIN_EXE_leafsolver"); + let test_data_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("test_data"); + + let input_file = test_data_dir.join("hello_world_simple.jsonl"); + let output_file = test_data_dir.join("hello_world_result.json"); + + assert!(input_file.exists(), "Hello world test input should exist"); + + let output = Command::new(binary_path) + .args([ + "-i", + input_file.to_str().unwrap(), + "-o", + output_file.to_str().unwrap(), + ]) + .output() + .expect("Failed to run leafsolver binary"); + + assert!( + output.status.success(), + "Binary should handle simple constraint from hello_world.rs" + ); + assert!(output_file.exists(), "Output file should be created"); + + let result_json = fs::read_to_string(&output_file).expect("Should read result file"); + let result: serde_json::Value = serde_json::from_str(&result_json).expect("Should parse JSON"); + + assert_eq!(result["result"], "sat"); + assert!(result["model"].is_object()); + + // Verify the solution: k!1 should be >= 5 (to make "NOT (5 <= k!1)" false) + let model = result["model"].as_object().unwrap(); + let k1_value = model["1"]["smtlib_rep"].as_str().unwrap(); + // Should be a hex value >= #x05 + assert!(k1_value.starts_with("#x")); + + fs::remove_file(output_file).ok(); +} + +#[test] +fn test_is_sorted_complex_constraint() { + // Test case from samples/basic/is_sorted/main.rs: + // A program that tests array sorting with 4 symbolic u8 elements. + // Tests both regular and non-short-circuiting sorting implementations. + // Generated constraints involve multiple comparison operations between array elements. + let binary_path = env!("CARGO_BIN_EXE_leafsolver"); + let test_data_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("test_data"); + + let input_file = test_data_dir.join("is_sorted_complex.jsonl"); + let output_file = test_data_dir.join("is_sorted_result.json"); + + assert!(input_file.exists(), "Is sorted test input should exist"); + + let output = Command::new(binary_path) + .args([ + "-i", + input_file.to_str().unwrap(), + "-o", + output_file.to_str().unwrap(), + ]) + .output() + .expect("Failed to run leafsolver binary"); + + assert!( + output.status.success(), + "Binary should handle complex multi-variable constraints from is_sorted.rs" + ); + assert!(output_file.exists(), "Output file should be created"); + + let result_json = fs::read_to_string(&output_file).expect("Should read result file"); + let result: serde_json::Value = serde_json::from_str(&result_json).expect("Should parse JSON"); + + assert_eq!(result["result"], "sat"); + assert!(result["model"].is_object()); + + // Verify we have multiple variables in the model (4 array elements) + let model = result["model"].as_object().unwrap(); + assert!( + model.len() >= 4, + "Should have at least 4 variables (array elements), got {}", + model.len() + ); + + // All model values should be valid u8 bit vectors + for (var_id, value) in model { + let smtlib_value = value["smtlib_rep"].as_str().unwrap(); + assert!( + smtlib_value.starts_with("#x"), + "Variable {} should have hex bit vector value, got {}", + var_id, + smtlib_value + ); + } + + fs::remove_file(output_file).ok(); +} diff --git a/solver/tests/samples.rs b/solver/tests/samples.rs new file mode 100644 index 00000000..7e856eed --- /dev/null +++ b/solver/tests/samples.rs @@ -0,0 +1,324 @@ +use leafsolver::{ + Ast, AstAndVars, AstNode, BVNode, Config, Constraint, ConstraintKind, Context, SolveResult, + Z3Solver, ast, +}; + +/// Tests based on real constraint patterns from `samples/` directory + +#[test] +fn test_arithmetic_operations() { + // Based on samples/assignment/bin_op/main.rs + // Test: a + b == 15 where a = 10, should find b = 5 + let context = Context::new(&Config::new()); + let solver: Z3Solver<'_, i32> = Z3Solver::new(&context); + + let a = ast::BV::new_const(&context, "a", 32); + let b = ast::BV::new_const(&context, "b", 32); + let ten = ast::BV::from_i64(&context, 10, 32); + let fifteen = ast::BV::from_i64(&context, 15, 32); + + let sum = a.bvadd(&b); + let eq_constraint = sum._eq(&fifteen); + let a_eq_ten = a._eq(&ten); + + let variables = vec![ + (1, AstNode::BitVector(BVNode::new(a, true))), + (2, AstNode::BitVector(BVNode::new(b, true))), + ]; + + let constraints = vec![ + Constraint { + discr: AstAndVars { + value: AstNode::Bool(eq_constraint), + variables: variables.clone(), + }, + kind: ConstraintKind::True, + }, + Constraint { + discr: AstAndVars { + value: AstNode::Bool(a_eq_ten), + variables, + }, + kind: ConstraintKind::True, + }, + ]; + + let result = solver.check(constraints.into_iter()); + + match result { + SolveResult::Sat(model) => { + let a_val = model + .get(&1) + .and_then(|v| match v { + AstNode::BitVector(bv) => bv.0.as_i64(), + _ => None, + }) + .unwrap_or(0); + let b_val = model + .get(&2) + .and_then(|v| match v { + AstNode::BitVector(bv) => bv.0.as_i64(), + _ => None, + }) + .unwrap_or(0); + + assert_eq!(a_val, 10); + assert_eq!(b_val, 5); + assert_eq!(a_val + b_val, 15); + } + _ => panic!("Expected SAT result for arithmetic constraint"), + } +} + +#[test] +fn test_modulo_branching() { + // Based on samples/branching/match_basic/main.rs + // Test: x % 3 == 2 + let context = Context::new(&Config::new()); + let solver: Z3Solver<'_, i32> = Z3Solver::new(&context); + + let x = ast::BV::new_const(&context, "x", 32); + let three = ast::BV::from_i64(&context, 3, 32); + let two = ast::BV::from_i64(&context, 2, 32); + + let mod_result = x.bvsrem(&three); + let eq_two = mod_result._eq(&two); + + let constraint = Constraint { + discr: AstAndVars { + value: AstNode::Bool(eq_two), + variables: vec![(1, AstNode::BitVector(BVNode::new(x, true)))], + }, + kind: ConstraintKind::True, + }; + + let result = solver.check(std::iter::once(constraint)); + + match result { + SolveResult::Sat(model) => { + let x_val = model + .get(&1) + .and_then(|v| match v { + AstNode::BitVector(bv) => bv.0.as_i64(), + _ => None, + }) + .unwrap_or(0); + + assert_eq!(x_val % 3, 2, "Expected x % 3 == 2, got x = {}", x_val); + } + _ => panic!("Expected SAT result for modulo constraint"), + } +} + +#[test] +fn test_range_constraints() { + // Based on samples/branching/match_range/main.rs + // Test: x in range [1, 3] using OneOf constraint + let context = Context::new(&Config::new()); + let solver: Z3Solver<'_, i32> = Z3Solver::new(&context); + + let x = ast::BV::new_const(&context, "x", 8); + let one = ast::BV::from_u64(&context, 1, 8); + let two = ast::BV::from_u64(&context, 2, 8); + let three = ast::BV::from_u64(&context, 3, 8); + + let constraint = Constraint { + discr: AstAndVars { + value: AstNode::BitVector(BVNode::new(x.clone(), false)), + variables: vec![(1, AstNode::BitVector(BVNode::new(x, false)))], + }, + kind: ConstraintKind::OneOf(vec![ + AstNode::BitVector(BVNode::new(one, false)), + AstNode::BitVector(BVNode::new(two, false)), + AstNode::BitVector(BVNode::new(three, false)), + ]), + }; + + let result = solver.check(std::iter::once(constraint)); + + match result { + SolveResult::Sat(model) => { + let x_val = model + .get(&1) + .and_then(|v| match v { + AstNode::BitVector(bv) => bv.0.as_u64(), + _ => None, + }) + .unwrap_or(0); + + assert!( + x_val >= 1 && x_val <= 3, + "Expected x ∈ [1, 3], got {}", + x_val + ); + } + _ => panic!("Expected SAT result for range constraint"), + } +} + +#[test] +fn test_array_bounds_check() { + // Based on samples/branching/assert/boundscheck.rs + // Test: index < array_length for array access + let context = Context::new(&Config::new()); + let solver: Z3Solver<'_, i32> = Z3Solver::new(&context); + + let index = ast::BV::new_const(&context, "index", 8); + let array_len = ast::BV::from_u64(&context, 5, 8); + + let valid_constraint = index.bvult(&array_len); + + let constraint = Constraint { + discr: AstAndVars { + value: AstNode::Bool(valid_constraint), + variables: vec![(1, AstNode::BitVector(BVNode::new(index, false)))], + }, + kind: ConstraintKind::True, + }; + + let result = solver.check(std::iter::once(constraint)); + + match result { + SolveResult::Sat(model) => { + let index_val = model + .get(&1) + .and_then(|v| match v { + AstNode::BitVector(bv) => bv.0.as_u64(), + _ => None, + }) + .unwrap_or(0); + + assert!(index_val < 5, "Expected valid index < 5, got {}", index_val); + } + _ => panic!("Expected SAT result for bounds check constraint"), + } +} + +#[test] +fn test_overflow_detection() { + // Based on samples/assignment/bin_op/with_overflow_ops/ + // Test: detect when a + b overflows u8 + let context = Context::new(&Config::new()); + let solver: Z3Solver<'_, i32> = Z3Solver::new(&context); + + let a = ast::BV::new_const(&context, "a", 8); + let b = ast::BV::new_const(&context, "b", 8); + + let sum = a.bvadd(&b); + // Overflow occurred if sum < a (wraparound happened) + let overflow_condition = sum.bvult(&a); + + let variables = vec![ + (1, AstNode::BitVector(BVNode::new(a, false))), + (2, AstNode::BitVector(BVNode::new(b, false))), + ]; + + let constraint = Constraint { + discr: AstAndVars { + value: AstNode::Bool(overflow_condition), + variables, + }, + kind: ConstraintKind::True, + }; + + let result = solver.check(std::iter::once(constraint)); + + match result { + SolveResult::Sat(model) => { + let a_val = model + .get(&1) + .and_then(|v| match v { + AstNode::BitVector(bv) => bv.0.as_u64(), + _ => None, + }) + .unwrap_or(0); + let b_val = model + .get(&2) + .and_then(|v| match v { + AstNode::BitVector(bv) => bv.0.as_u64(), + _ => None, + }) + .unwrap_or(0); + + // Verify that adding these would indeed overflow u8 + assert!( + a_val + b_val > 255, + "Expected overflow: {} + {} = {} > 255", + a_val, + b_val, + a_val + b_val + ); + } + _ => panic!("Expected SAT result for overflow constraint"), + } +} + +#[test] +fn test_multiple_variables() { + // Based on samples/misc/multiple_vars/main.rs + // Test: complex interactions between multiple symbolic variables + let context = Context::new(&Config::new()); + let solver: Z3Solver<'_, i32> = Z3Solver::new(&context); + + let x = ast::BV::new_const(&context, "x", 32); + let y = ast::BV::new_const(&context, "y", 32); + + // Constraint: (x + 5) == 15 AND (y + 3) == 25 + // Therefore: x == 10 AND y == 22 + let five = ast::BV::from_u64(&context, 5, 32); + let three = ast::BV::from_u64(&context, 3, 32); + let fifteen = ast::BV::from_u64(&context, 15, 32); + let twenty_five = ast::BV::from_u64(&context, 25, 32); + + let calc_x = x.bvadd(&five); + let calc_y = y.bvadd(&three); + let eq1 = calc_x._eq(&fifteen); + let eq2 = calc_y._eq(&twenty_five); + + let variables = vec![ + (1, AstNode::BitVector(BVNode::new(x, false))), + (2, AstNode::BitVector(BVNode::new(y, false))), + ]; + + let constraints = vec![ + Constraint { + discr: AstAndVars { + value: AstNode::Bool(eq1), + variables: variables.clone(), + }, + kind: ConstraintKind::True, + }, + Constraint { + discr: AstAndVars { + value: AstNode::Bool(eq2), + variables, + }, + kind: ConstraintKind::True, + }, + ]; + + let result = solver.check(constraints.into_iter()); + + match result { + SolveResult::Sat(model) => { + let x_val = model + .get(&1) + .and_then(|v| match v { + AstNode::BitVector(bv) => bv.0.as_u64(), + _ => None, + }) + .unwrap_or(0); + let y_val = model + .get(&2) + .and_then(|v| match v { + AstNode::BitVector(bv) => bv.0.as_u64(), + _ => None, + }) + .unwrap_or(0); + + assert_eq!(x_val, 10, "Expected x = 10, got {}", x_val); + assert_eq!(y_val, 22, "Expected y = 22, got {}", y_val); + } + _ => panic!("Expected SAT result for multiple variable constraints"), + } +}