From 45ad37d35ac3975acc65f365746aba5e4321ede1 Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Wed, 2 Oct 2024 10:29:40 +0200 Subject: [PATCH 1/3] wip: datalog foreign function interface prototype MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This allows using external functions in datalog. This makes it easy to provide custom logic without extending the spec for every use-case, at the expense of portability: behaviour is no longer guaranteed to be consistent cross languages, and some languages won’t be able to support it at all (for instance JS as of now). Todo: - stricter conversions from datalog - feature-gating if possible Open questions: - enum index for the FFI variants (contiguous or not?) - how to provide functions (right now, function pointers: prevent mutability and closing over arguments) - how to provide arguments (right now, datalog::Term, so symbols have to be resolved, and functions returning strings have to register new symbols) --- biscuit-auth/src/capi.rs | 2 + biscuit-auth/src/datalog/expression.rs | 191 ++++++++++++++---- biscuit-auth/src/datalog/mod.rs | 62 ++++-- biscuit-auth/src/error.rs | 6 + biscuit-auth/src/format/convert.rs | 20 ++ biscuit-auth/src/format/schema.proto | 4 + biscuit-auth/src/format/schema.rs | 6 + biscuit-auth/src/parser.rs | 12 +- biscuit-auth/src/token/authorizer.rs | 41 +++- biscuit-auth/src/token/authorizer/snapshot.rs | 1 + biscuit-auth/src/token/builder.rs | 2 + biscuit-auth/tests/macros.rs | 8 + biscuit-parser/src/builder.rs | 4 + biscuit-parser/src/parser.rs | 38 +++- 14 files changed, 333 insertions(+), 64 deletions(-) diff --git a/biscuit-auth/src/capi.rs b/biscuit-auth/src/capi.rs index 3a80cbba..43b376eb 100644 --- a/biscuit-auth/src/capi.rs +++ b/biscuit-auth/src/capi.rs @@ -79,6 +79,7 @@ pub enum ErrorKind { FormatPublicKeyTableOverlap, FormatUnknownExternalKey, FormatUnknownSymbol, + FormatMissingFfiName, AppendOnSealed, LogicInvalidBlockRule, LogicUnauthorized, @@ -158,6 +159,7 @@ pub extern "C" fn error_kind() -> ErrorKind { ErrorKind::FormatUnknownExternalKey } Token::Format(Format::UnknownSymbol(_)) => ErrorKind::FormatUnknownSymbol, + Token::Format(Format::MissingFfiName) => ErrorKind::FormatMissingFfiName, Token::AppendOnSealed => ErrorKind::AppendOnSealed, Token::AlreadySealed => ErrorKind::AlreadySealed, Token::Language(_) => ErrorKind::LanguageError, diff --git a/biscuit-auth/src/datalog/expression.rs b/biscuit-auth/src/datalog/expression.rs index c07e42e4..fc5c6d2d 100644 --- a/biscuit-auth/src/datalog/expression.rs +++ b/biscuit-auth/src/datalog/expression.rs @@ -5,6 +5,16 @@ use super::{SymbolTable, TemporarySymbolTable}; use regex::Regex; use std::collections::{HashMap, HashSet}; +type ExternBinary = fn(&mut TemporarySymbolTable, &Term, &Term) -> Result; + +type ExternUnary = fn(&mut TemporarySymbolTable, &Term) -> Result; + +#[derive(Debug, Clone)] +pub enum ExternFunc { + Unary(ExternUnary), + Binary(ExternBinary), +} + #[derive(Debug, Clone, PartialEq, Hash, Eq)] pub struct Expression { pub ops: Vec, @@ -24,13 +34,15 @@ pub enum Unary { Negate, Parens, Length, + Ffi(String), } impl Unary { fn evaluate( &self, value: Term, - symbols: &TemporarySymbolTable, + symbols: &mut TemporarySymbolTable, + extern_funcs: &HashMap, ) -> Result { match (self, value) { (Unary::Negate, Term::Bool(b)) => Ok(Term::Bool(!b)), @@ -41,6 +53,15 @@ impl Unary { .ok_or(error::Expression::UnknownSymbol(i)), (Unary::Length, Term::Bytes(s)) => Ok(Term::Integer(s.len() as i64)), (Unary::Length, Term::Set(s)) => Ok(Term::Integer(s.len() as i64)), + (Unary::Ffi(name), i) => { + let fun = extern_funcs + .get(name) + .ok_or(error::Expression::UndefinedExtern(name.to_owned()))?; + match fun { + ExternFunc::Unary(fun) => fun(symbols, &i), + ExternFunc::Binary(_) => Err(error::Expression::IncorrectArityExtern), + } + } _ => { //println!("unexpected value type on the stack"); Err(error::Expression::InvalidType) @@ -53,6 +74,7 @@ impl Unary { Unary::Negate => format!("!{}", value), Unary::Parens => format!("({})", value), Unary::Length => format!("{}.length()", value), + Unary::Ffi(name) => format!("{value}.extern::{name}()"), } } } @@ -87,6 +109,7 @@ pub enum Binary { LazyOr, All, Any, + Ffi(String), } impl Binary { @@ -97,23 +120,24 @@ impl Binary { params: &[u32], values: &mut HashMap, symbols: &mut TemporarySymbolTable, + extern_func: &HashMap, ) -> Result { match (self, left, params) { (Binary::LazyOr, Term::Bool(true), []) => Ok(Term::Bool(true)), (Binary::LazyOr, Term::Bool(false), []) => { let e = Expression { ops: right.clone() }; - e.evaluate(values, symbols) + e.evaluate(values, symbols, extern_func) } (Binary::LazyAnd, Term::Bool(false), []) => Ok(Term::Bool(false)), (Binary::LazyAnd, Term::Bool(true), []) => { let e = Expression { ops: right.clone() }; - e.evaluate(values, symbols) + e.evaluate(values, symbols, extern_func) } (Binary::All, Term::Set(set_values), [param]) => { for value in set_values.iter() { values.insert(*param, value.clone()); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(true) => {} @@ -127,7 +151,7 @@ impl Binary { for value in set_values.iter() { values.insert(*param, value.clone()); let e = Expression { ops: right.clone() }; - let result = e.evaluate(values, symbols); + let result = e.evaluate(values, symbols, extern_func); values.remove(param); match result? { Term::Bool(false) => {} @@ -145,6 +169,7 @@ impl Binary { left: Term, right: Term, symbols: &mut TemporarySymbolTable, + extern_funcs: &HashMap, ) -> Result { match (self, left, right) { // integer @@ -303,9 +328,21 @@ impl Binary { (Binary::HeterogeneousNotEqual, Term::Null, _) => Ok(Term::Bool(true)), (Binary::HeterogeneousNotEqual, _, Term::Null) => Ok(Term::Bool(true)), + // heterogeneous equals catch all (Binary::HeterogeneousEqual, _, _) => Ok(Term::Bool(false)), (Binary::HeterogeneousNotEqual, _, _) => Ok(Term::Bool(true)), + // FFI + (Binary::Ffi(name), left, right) => { + let fun = extern_funcs + .get(name) + .ok_or(error::Expression::UndefinedExtern(name.to_owned()))?; + match fun { + ExternFunc::Binary(fun) => fun(symbols, &left, &right), + ExternFunc::Unary(_) => Err(error::Expression::IncorrectArityExtern), + } + } + _ => { //println!("unexpected value type on the stack"); Err(error::Expression::InvalidType) @@ -342,6 +379,7 @@ impl Binary { Binary::LazyOr => format!("{left} || {right}"), Binary::All => format!("{left}.all({right})"), Binary::Any => format!("{left}.any({right})"), + Binary::Ffi(name) => format!("{left}.extern::{name}({right})"), } } } @@ -357,6 +395,7 @@ impl Expression { &self, values: &HashMap, symbols: &mut TemporarySymbolTable, + extern_funcs: &HashMap, ) -> Result { let mut stack: Vec = Vec::new(); @@ -372,19 +411,24 @@ impl Expression { } }, Op::Value(term) => stack.push(StackElem::Term(term.clone())), - Op::Unary(unary) => match stack.pop() { - Some(StackElem::Term(term)) => { - stack.push(StackElem::Term(unary.evaluate(term, symbols)?)) - } - _ => { - return Err(error::Expression::InvalidStack); + Op::Unary(unary) => { + match stack.pop() { + Some(StackElem::Term(term)) => stack.push(StackElem::Term( + unary.evaluate(term, symbols, extern_funcs)?, + )), + _ => { + return Err(error::Expression::InvalidStack); + } } - }, + } Op::Binary(binary) => match (stack.pop(), stack.pop()) { (Some(StackElem::Term(right_term)), Some(StackElem::Term(left_term))) => stack - .push(StackElem::Term( - binary.evaluate(left_term, right_term, symbols)?, - )), + .push(StackElem::Term(binary.evaluate( + left_term, + right_term, + symbols, + extern_funcs, + )?)), ( Some(StackElem::Closure(params, right_ops)), Some(StackElem::Term(left_term)), @@ -405,6 +449,7 @@ impl Expression { ¶ms, &mut values, symbols, + extern_funcs, )?)) } @@ -502,7 +547,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); } @@ -532,7 +577,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&HashMap::new(), &mut tmp_symbols); + let res = e.evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Integer(expected))); } } @@ -549,7 +594,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Err(error::Expression::DivideByZero)); let ops = vec![ @@ -560,7 +605,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Err(error::Expression::Overflow)); let ops = vec![ @@ -571,7 +616,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Err(error::Expression::Overflow)); let ops = vec![ @@ -582,7 +627,7 @@ mod tests { let values = HashMap::new(); let e = Expression { ops }; - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Err(error::Expression::Overflow)); } @@ -649,7 +694,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(true))); } } @@ -673,7 +718,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(false))); } } @@ -697,7 +742,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(result))); } } @@ -741,7 +786,7 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - let res = e.evaluate(&values, &mut tmp_symbols); + let res = e.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res, Ok(Term::Bool(*result))); } } @@ -780,7 +825,8 @@ mod tests { let e = Expression { ops }; println!("print: {}", e.print(&symbols).unwrap()); - e.evaluate(&values, &mut tmp_symbols).unwrap_err(); + e.evaluate(&values, &mut tmp_symbols, &Default::default()) + .unwrap_err(); } } } @@ -805,7 +851,9 @@ mod tests { ]; let e2 = Expression { ops: ops1 }; - let res2 = e2.evaluate(&HashMap::new(), &mut symbols).unwrap(); + let res2 = e2 + .evaluate(&HashMap::new(), &mut symbols, &Default::default()) + .unwrap(); assert_eq!(res2, Term::Bool(true)); } @@ -823,7 +871,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); let ops2 = vec![ @@ -841,7 +891,9 @@ mod tests { let e2 = Expression { ops: ops2 }; println!("{:?}", e2.print(&symbols)); - let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res2 = e2 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res2, Term::Bool(false)); let ops3 = vec![ @@ -852,7 +904,9 @@ mod tests { let e3 = Expression { ops: ops3 }; println!("{:?}", e3.print(&symbols)); - let err3 = e3.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap_err(); + let err3 = e3 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap_err(); assert_eq!(err3, error::Expression::InvalidType); } @@ -877,7 +931,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{:?}", e1.print(&symbols)); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); let ops2 = vec![ @@ -895,7 +951,9 @@ mod tests { let e2 = Expression { ops: ops2 }; println!("{:?}", e2.print(&symbols)); - let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res2 = e2 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res2, Term::Bool(false)); let ops3 = vec![ @@ -906,7 +964,9 @@ mod tests { let e3 = Expression { ops: ops3 }; println!("{:?}", e3.print(&symbols)); - let err3 = e3.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap_err(); + let err3 = e3 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap_err(); assert_eq!(err3, error::Expression::InvalidType); } @@ -952,7 +1012,9 @@ mod tests { let e1 = Expression { ops: ops1 }; println!("{}", e1.print(&symbols).unwrap()); - let res1 = e1.evaluate(&HashMap::new(), &mut tmp_symbols).unwrap(); + let res1 = e1 + .evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()) + .unwrap(); assert_eq!(res1, Term::Bool(true)); } @@ -979,7 +1041,7 @@ mod tests { let mut values = HashMap::new(); values.insert(p, Term::Null); - let res1 = e1.evaluate(&values, &mut tmp_symbols); + let res1 = e1.evaluate(&values, &mut tmp_symbols, &Default::default()); assert_eq!(res1, Err(error::Expression::ShadowedVariable)); let mut symbols = SymbolTable::new(); @@ -1021,7 +1083,64 @@ mod tests { let e2 = Expression { ops: ops2 }; println!("{}", e2.print(&symbols).unwrap()); - let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols); + let res2 = e2.evaluate(&HashMap::new(), &mut tmp_symbols, &Default::default()); assert_eq!(res2, Err(error::Expression::ShadowedVariable)); } + + #[test] + fn ffi() { + let mut symbols = SymbolTable::new(); + let i = symbols.insert("test"); + let j = symbols.insert("TeSt"); + let mut tmp_symbols = TemporarySymbolTable::new(&symbols); + let ops = vec![ + Op::Value(Term::Integer(60)), + Op::Value(Term::Integer(0)), + Op::Binary(Binary::Ffi("test_bin".to_owned())), + Op::Value(Term::Str(i)), + Op::Value(Term::Str(j)), + Op::Binary(Binary::Ffi("test_bin".to_owned())), + Op::Binary(Binary::And), + Op::Value(Term::Integer(42)), + Op::Unary(Unary::Ffi("test_un".to_owned())), + Op::Binary(Binary::And), + ]; + + let values = HashMap::new(); + let e = Expression { ops }; + let mut extern_funcs: HashMap = Default::default(); + extern_funcs.insert( + "test_bin".to_owned(), + ExternFunc::Binary(|sym, left, right| match (left, right) { + (Term::Integer(left), Term::Integer(right)) => { + println!("{left} {right}"); + Ok(Term::Bool((left % 60) == (right % 60))) + } + (Term::Str(left), Term::Str(right)) => { + let left = sym + .get_symbol(*left) + .ok_or(error::Expression::UnknownSymbol(*left))?; + let right = sym + .get_symbol(*right) + .ok_or(error::Expression::UnknownSymbol(*right))?; + + println!("{left} {right}"); + Ok(Term::Bool(left.to_lowercase() == right.to_lowercase())) + } + _ => Err(error::Expression::InvalidType), + }), + ); + extern_funcs.insert( + "test_un".to_owned(), + ExternFunc::Unary(|_, value| match value { + Term::Integer(value) => Ok(Term::Bool(*value == 42)), + _ => { + println!("{value:?}"); + Err(error::Expression::InvalidType) + } + }), + ); + let res = e.evaluate(&values, &mut tmp_symbols, &extern_funcs); + assert_eq!(res, Ok(Term::Bool(true))); + } } diff --git a/biscuit-auth/src/datalog/mod.rs b/biscuit-auth/src/datalog/mod.rs index 94671047..9e0e572a 100644 --- a/biscuit-auth/src/datalog/mod.rs +++ b/biscuit-auth/src/datalog/mod.rs @@ -128,6 +128,7 @@ impl Rule { facts: IT, rule_origin: usize, symbols: &'a SymbolTable, + extern_funcs: &'a HashMap, ) -> impl Iterator> + 'a where IT: Iterator + Clone + 'a, @@ -139,7 +140,7 @@ impl Rule { .map(move |(origin, variables)| { let mut temporary_symbols = TemporarySymbolTable::new(symbols); for e in self.expressions.iter() { - match e.evaluate(&variables, &mut temporary_symbols) { + match e.evaluate(&variables, &mut temporary_symbols, extern_funcs) { Ok(Term::Bool(true)) => {} Ok(Term::Bool(false)) => return Ok((origin, variables, false)), Ok(_) => return Err(error::Expression::InvalidType), @@ -184,9 +185,10 @@ impl Rule { origin: usize, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { let fact_it = facts.iterator(scope); - let mut it = self.apply(fact_it, origin, symbols); + let mut it = self.apply(fact_it, origin, symbols, extern_funcs); let next = it.next(); match next { @@ -201,6 +203,7 @@ impl Rule { facts: &FactSet, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { let fact_it = facts.iterator(scope); let variables = MatchedVariables::new(self.variables_set()); @@ -211,7 +214,7 @@ impl Rule { let mut temporary_symbols = TemporarySymbolTable::new(symbols); for e in self.expressions.iter() { - match e.evaluate(&variables, &mut temporary_symbols) { + match e.evaluate(&variables, &mut temporary_symbols, extern_funcs) { Ok(Term::Bool(true)) => {} Ok(Term::Bool(false)) => { //println!("expr returned {:?}", res); @@ -607,7 +610,7 @@ impl World { for (scope, rules) in self.rules.inner.iter() { let it = self.facts.iterator(scope); for (origin, rule) in rules { - for res in rule.apply(it.clone(), *origin, symbols) { + for res in rule.apply(it.clone(), *origin, symbols, &limits.extern_funcs) { match res { Ok((origin, fact)) => { new_facts.insert(&origin, fact); @@ -678,11 +681,12 @@ impl World { origin: usize, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { let mut new_facts = FactSet::default(); let it = self.facts.iterator(scope); //new_facts.extend(rule.apply(it, origin, symbols)); - for res in rule.apply(it.clone(), origin, symbols) { + for res in rule.apply(it.clone(), origin, symbols, extern_funcs) { match res { Ok((origin, fact)) => { new_facts.insert(&origin, fact); @@ -702,8 +706,9 @@ impl World { origin: usize, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { - rule.find_match(&self.facts, origin, scope, symbols) + rule.find_match(&self.facts, origin, scope, symbols, extern_funcs) } pub fn query_match_all( @@ -711,8 +716,9 @@ impl World { rule: Rule, scope: &TrustedOrigins, symbols: &SymbolTable, + extern_funcs: &HashMap, ) -> Result { - rule.check_match_all(&self.facts, scope, symbols) + rule.check_match_all(&self.facts, scope, symbols, extern_funcs) } } @@ -725,6 +731,8 @@ pub struct RunLimits { pub max_iterations: u64, /// maximum execution time pub max_time: Duration, + + pub extern_funcs: HashMap, } impl std::default::Default for RunLimits { @@ -733,6 +741,7 @@ impl std::default::Default for RunLimits { max_facts: 1000, max_iterations: 100, max_time: Duration::from_millis(1), + extern_funcs: Default::default(), } } } @@ -1034,7 +1043,8 @@ mod tests { println!("symbols: {:?}", syms); println!("testing r1: {}", syms.print_rule(&r1)); - let query_rule_result = w.query_rule(r1, 0, &[0].iter().collect(), &syms); + let query_rule_result = + w.query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()); println!("grandparents query_rules: {:?}", query_rule_result); println!("current facts: {:?}", w.facts); @@ -1079,6 +1089,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1096,7 +1107,8 @@ mod tests { ), 0, &[0].iter().collect(), - &syms + &syms, + &Default::default() ) ); println!( @@ -1112,7 +1124,8 @@ mod tests { ), 0, &[0].iter().collect(), - &syms + &syms, + &Default::default() ) ); w.add_fact(&[0].iter().collect(), fact(parent, &[&c, &e])); @@ -1130,6 +1143,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); println!("grandparents after inserting parent(C, E): {:?}", res); @@ -1205,6 +1219,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1254,6 +1269,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1340,6 +1356,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap() .iter_all() @@ -1420,7 +1437,9 @@ mod tests { ); println!("testing r1: {}", syms.print_rule(&r1)); - let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); for (_, fact) in res.iter_all() { println!("\t{}", syms.print_fact(fact)); } @@ -1458,7 +1477,9 @@ mod tests { ); println!("testing r2: {}", syms.print_rule(&r2)); - let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r2, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); for (_, fact) in res.iter_all() { println!("\t{}", syms.print_fact(fact)); } @@ -1521,6 +1542,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1572,6 +1594,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1617,6 +1640,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1662,6 +1686,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1685,6 +1710,7 @@ mod tests { 0, &[0].iter().collect(), &syms, + &Default::default(), ) .unwrap(); @@ -1727,7 +1753,9 @@ mod tests { println!("world:\n{}\n", syms.print_world(&w)); println!("\ntesting r1: {}\n", syms.print_rule(&r1)); - let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); for (_, fact) in res.iter_all() { println!("\t{}", syms.print_fact(fact)); } @@ -1766,7 +1794,9 @@ mod tests { ); println!("world:\n{}\n", syms.print_world(&w)); println!("\ntesting r1: {}\n", syms.print_rule(&r1)); - let res = w.query_rule(r1, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r1, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); println!("generated facts:"); for (_, fact) in res.iter_all() { @@ -1782,7 +1812,9 @@ mod tests { let r2 = rule(check, &[&read], &[pred(operation, &[&read])]); println!("world:\n{}\n", syms.print_world(&w)); println!("\ntesting r2: {}\n", syms.print_rule(&r2)); - let res = w.query_rule(r2, 0, &[0].iter().collect(), &syms).unwrap(); + let res = w + .query_rule(r2, 0, &[0].iter().collect(), &syms, &Default::default()) + .unwrap(); println!("generated facts:"); for (_, fact) in res.iter_all() { diff --git a/biscuit-auth/src/error.rs b/biscuit-auth/src/error.rs index 984369c5..5209703b 100644 --- a/biscuit-auth/src/error.rs +++ b/biscuit-auth/src/error.rs @@ -150,6 +150,8 @@ pub enum Format { UnknownExternalKey, #[error("the symbol id was not in the table")] UnknownSymbol(u64), + #[error("missing FFI name field")] + MissingFfiName, } /// Signature errors @@ -250,6 +252,10 @@ pub enum Expression { InvalidStack, #[error("Shadowed variable")] ShadowedVariable, + #[error("Incorrect arity for extern func")] + IncorrectArityExtern, + #[error("Undefined extern func: {0}")] + UndefinedExtern(String), } /// runtime limits errors diff --git a/biscuit-auth/src/format/convert.rs b/biscuit-auth/src/format/convert.rs index 5f1b4c31..061e84bd 100644 --- a/biscuit-auth/src/format/convert.rs +++ b/biscuit-auth/src/format/convert.rs @@ -608,7 +608,12 @@ pub mod v2 { Unary::Negate => Kind::Negate, Unary::Parens => Kind::Parens, Unary::Length => Kind::Length, + Unary::Ffi(_) => Kind::Ffi, } as i32, + ffi_name: match u { + Unary::Ffi(name) => Some(name.to_owned()), + _ => None, + }, }) } Op::Binary(b) => { @@ -643,7 +648,12 @@ pub mod v2 { Binary::LazyOr => Kind::LazyOr, Binary::All => Kind::All, Binary::Any => Kind::Any, + Binary::Ffi(_) => Kind::Ffi, } as i32, + ffi_name: match b { + Binary::Ffi(name) => Some(name.to_owned()), + _ => None, + }, }) } Op::Closure(params, ops) => schema::op::Content::Closure(schema::OpClosure { @@ -671,6 +681,11 @@ pub mod v2 { Some(op_unary::Kind::Negate) => Op::Unary(Unary::Negate), Some(op_unary::Kind::Parens) => Op::Unary(Unary::Parens), Some(op_unary::Kind::Length) => Op::Unary(Unary::Length), + Some(op_unary::Kind::Ffi) => match u.ffi_name.as_ref() { + // todo clementd error if ffi name is defined with another kind + Some(n) => Op::Unary(Unary::Ffi(n.to_owned())), + None => return Err(error::Format::MissingFfiName), + }, None => { return Err(error::Format::DeserializationError( "deserialization error: unary operation is empty".to_string(), @@ -707,6 +722,11 @@ pub mod v2 { Some(op_binary::Kind::LazyOr) => Op::Binary(Binary::LazyOr), Some(op_binary::Kind::All) => Op::Binary(Binary::All), Some(op_binary::Kind::Any) => Op::Binary(Binary::Any), + Some(op_binary::Kind::Ffi) => match b.ffi_name.as_ref() { + // todo clementd error if ffi name is defined with another kind + Some(n) => Op::Binary(Binary::Ffi(n.to_owned())), + None => return Err(error::Format::MissingFfiName), + }, None => { return Err(error::Format::DeserializationError( "deserialization error: binary operation is empty".to_string(), diff --git a/biscuit-auth/src/format/schema.proto b/biscuit-auth/src/format/schema.proto index 349bfb41..13465c19 100644 --- a/biscuit-auth/src/format/schema.proto +++ b/biscuit-auth/src/format/schema.proto @@ -124,9 +124,11 @@ message OpUnary { Negate = 0; Parens = 1; Length = 2; + Ffi = 1024; } required Kind kind = 1; + optional string ffiName = 2; } message OpBinary { @@ -158,9 +160,11 @@ message OpBinary { LazyOr = 24; All = 25; Any = 26; + Ffi = 1024; } required Kind kind = 1; + optional string ffiName = 2; } message OpClosure { diff --git a/biscuit-auth/src/format/schema.rs b/biscuit-auth/src/format/schema.rs index 58e7769a..ffde4661 100644 --- a/biscuit-auth/src/format/schema.rs +++ b/biscuit-auth/src/format/schema.rs @@ -197,6 +197,8 @@ pub mod op { pub struct OpUnary { #[prost(enumeration="op_unary::Kind", required, tag="1")] pub kind: i32, + #[prost(string, optional, tag="2")] + pub ffi_name: ::core::option::Option<::prost::alloc::string::String>, } /// Nested message and enum types in `OpUnary`. pub mod op_unary { @@ -206,12 +208,15 @@ pub mod op_unary { Negate = 0, Parens = 1, Length = 2, + Ffi = 1024, } } #[derive(Clone, PartialEq, ::prost::Message)] pub struct OpBinary { #[prost(enumeration="op_binary::Kind", required, tag="1")] pub kind: i32, + #[prost(string, optional, tag="2")] + pub ffi_name: ::core::option::Option<::prost::alloc::string::String>, } /// Nested message and enum types in `OpBinary`. pub mod op_binary { @@ -245,6 +250,7 @@ pub mod op_binary { LazyOr = 24, All = 25, Any = 26, + Ffi = 1024, } } #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/biscuit-auth/src/parser.rs b/biscuit-auth/src/parser.rs index 59510afb..77f0c3ef 100644 --- a/biscuit-auth/src/parser.rs +++ b/biscuit-auth/src/parser.rs @@ -383,7 +383,11 @@ mod tests { println!("print: {}", e.print(&syms).unwrap()); let h = HashMap::new(); let result = e - .evaluate(&h, &mut TemporarySymbolTable::new(&syms)) + .evaluate( + &h, + &mut TemporarySymbolTable::new(&syms), + &Default::default(), + ) .unwrap(); println!("evaluates to: {:?}", result); @@ -414,7 +418,11 @@ mod tests { println!("print: {}", e.print(&syms).unwrap()); let h = HashMap::new(); let result = e - .evaluate(&h, &mut TemporarySymbolTable::new(&syms)) + .evaluate( + &h, + &mut TemporarySymbolTable::new(&syms), + &Default::default(), + ) .unwrap(); println!("evaluates to: {:?}", result); diff --git a/biscuit-auth/src/token/authorizer.rs b/biscuit-auth/src/token/authorizer.rs index 5d9a2e99..00080b16 100644 --- a/biscuit-auth/src/token/authorizer.rs +++ b/biscuit-auth/src/token/authorizer.rs @@ -469,10 +469,15 @@ impl Authorizer { &self.public_key_to_block_id, ); + let extern_binary = limits.extern_funcs.clone(); self.world.run_with_limits(&self.symbols, limits)?; - let res = self - .world - .query_rule(rule, usize::MAX, &rule_trusted_origins, &self.symbols)?; + let res = self.world.query_rule( + rule, + usize::MAX, + &rule_trusted_origins, + &self.symbols, + &extern_binary, + )?; res.inner .into_iter() @@ -552,6 +557,7 @@ impl Authorizer { rule: datalog::Rule, limits: AuthorizerLimits, ) -> Result, error::Token> { + let extern_binary = limits.extern_funcs.clone(); self.world.run_with_limits(&self.symbols, limits)?; let rule_trusted_origins = if rule.scopes.is_empty() { @@ -568,9 +574,13 @@ impl Authorizer { ) }; - let res = self - .world - .query_rule(rule, 0, &rule_trusted_origins, &self.symbols)?; + let res = self.world.query_rule( + rule, + 0, + &rule_trusted_origins, + &self.symbols, + &extern_binary, + )?; let r: HashSet<_> = res.into_iter().map(|(_, fact)| fact).collect(); @@ -741,16 +751,20 @@ impl Authorizer { usize::MAX, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, + )?, + CheckKind::All => self.world.query_match_all( + query, + &rule_trusted_origins, + &self.symbols, + &limits.extern_funcs, )?, - CheckKind::All => { - self.world - .query_match_all(query, &rule_trusted_origins, &self.symbols)? - } CheckKind::Reject => !self.world.query_match( query, usize::MAX, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, }; @@ -799,17 +813,20 @@ impl Authorizer { 0, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, CheckKind::All => self.world.query_match_all( query.clone(), &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, CheckKind::Reject => !self.world.query_match( query.clone(), 0, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, }; @@ -849,6 +866,7 @@ impl Authorizer { usize::MAX, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?; let now = Instant::now(); @@ -898,17 +916,20 @@ impl Authorizer { i + 1, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, CheckKind::All => self.world.query_match_all( query.clone(), &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, CheckKind::Reject => !self.world.query_match( query.clone(), i + 1, &rule_trusted_origins, &self.symbols, + &limits.extern_funcs, )?, }; diff --git a/biscuit-auth/src/token/authorizer/snapshot.rs b/biscuit-auth/src/token/authorizer/snapshot.rs index 373aff9f..247dd890 100644 --- a/biscuit-auth/src/token/authorizer/snapshot.rs +++ b/biscuit-auth/src/token/authorizer/snapshot.rs @@ -31,6 +31,7 @@ impl super::Authorizer { max_facts: limits.max_facts, max_iterations: limits.max_iterations, max_time: Duration::from_nanos(limits.max_time), + extern_funcs: Default::default(), }; let execution_time = Duration::from_nanos(execution_time); diff --git a/biscuit-auth/src/token/builder.rs b/biscuit-auth/src/token/builder.rs index 5fcb010d..944c3ad4 100644 --- a/biscuit-auth/src/token/builder.rs +++ b/biscuit-auth/src/token/builder.rs @@ -999,6 +999,7 @@ impl From for Unary { biscuit_parser::builder::Unary::Negate => Unary::Negate, biscuit_parser::builder::Unary::Parens => Unary::Parens, biscuit_parser::builder::Unary::Length => Unary::Length, + biscuit_parser::builder::Unary::Ffi(name) => Unary::Ffi(name), } } } @@ -1033,6 +1034,7 @@ impl From for Binary { biscuit_parser::builder::Binary::LazyOr => Binary::LazyOr, biscuit_parser::builder::Binary::All => Binary::All, biscuit_parser::builder::Binary::Any => Binary::Any, + biscuit_parser::builder::Binary::Ffi(name) => Binary::Ffi(name), } } } diff --git a/biscuit-auth/tests/macros.rs b/biscuit-auth/tests/macros.rs index f38a14ea..e1ce7048 100644 --- a/biscuit-auth/tests/macros.rs +++ b/biscuit-auth/tests/macros.rs @@ -30,6 +30,14 @@ check if "my_value".starts_with("my"); check if [false, true].any($p -> true); "#, ); + + let b = block!(r#"check if "test".extern::toto() && "test".extern::test("test");"#); + + assert_eq!( + b.to_string(), + r#"check if "test".extern::toto() && "test".extern::test("test"); +"# + ); } #[test] diff --git a/biscuit-parser/src/builder.rs b/biscuit-parser/src/builder.rs index c8c13ee5..0c56fab8 100644 --- a/biscuit-parser/src/builder.rs +++ b/biscuit-parser/src/builder.rs @@ -209,6 +209,7 @@ pub enum Unary { Negate, Parens, Length, + Ffi(String), } #[derive(Debug, Clone, PartialEq, Eq)] @@ -240,6 +241,7 @@ pub enum Binary { LazyOr, All, Any, + Ffi(String), } #[cfg(feature = "datalog-macro")] @@ -266,6 +268,7 @@ impl ToTokens for Unary { Unary::Negate => quote! {::biscuit_auth::datalog::Unary::Negate }, Unary::Parens => quote! {::biscuit_auth::datalog::Unary::Parens }, Unary::Length => quote! {::biscuit_auth::datalog::Unary::Length }, + Unary::Ffi(name) => quote! {::biscuit_auth::datalog::Unary::Ffi(#name.to_string()) }, }); } } @@ -305,6 +308,7 @@ impl ToTokens for Binary { Binary::LazyOr => quote! { ::biscuit_auth::datalog::Binary::LazyOr }, Binary::All => quote! { ::biscuit_auth::datalog::Binary::All }, Binary::Any => quote! { ::biscuit_auth::datalog::Binary::Any }, + Binary::Ffi(name) => quote! {::biscuit_auth::datalog::Binary::Ffi(#name.to_string()) }, }); } } diff --git a/biscuit-parser/src/parser.rs b/biscuit-parser/src/parser.rs index ec7b62ab..39fdea06 100644 --- a/biscuit-parser/src/parser.rs +++ b/biscuit-parser/src/parser.rs @@ -494,6 +494,16 @@ fn binary_op_7(i: &str) -> IResult<&str, builder::Binary, Error> { alt((value(Binary::Mul, tag("*")), value(Binary::Div, tag("/"))))(i) } +fn extern_un(i: &str) -> IResult<&str, builder::Unary, Error> { + let (i, func) = preceded(tag("extern::"), name)(i)?; + Ok((i, builder::Unary::Ffi(func.to_string()))) +} + +fn extern_bin(i: &str) -> IResult<&str, builder::Binary, Error> { + let (i, func) = preceded(tag("extern::"), name)(i)?; + Ok((i, builder::Binary::Ffi(func.to_string()))) +} + fn binary_op_8(i: &str) -> IResult<&str, builder::Binary, Error> { use builder::Binary; @@ -506,6 +516,7 @@ fn binary_op_8(i: &str) -> IResult<&str, builder::Binary, Error> { value(Binary::Union, tag("union")), value(Binary::All, tag("all")), value(Binary::Any, tag("any")), + extern_bin, ))(i) } @@ -713,7 +724,7 @@ fn binary_method(i: &str) -> IResult<&str, (builder::Binary, Option> fn unary_method(i: &str) -> IResult<&str, builder::Unary, Error> { use builder::Unary; - let (i, op) = value(Unary::Length, tag("length"))(i)?; + let (i, op) = alt((value(Unary::Length, tag("length")), extern_un))(i)?; let (i, _) = char('(')(i)?; let (i, _) = space0(i)?; @@ -2496,4 +2507,29 @@ mod tests { )) ); } + + #[test] + fn extern_funcs() { + use builder::{int, Binary, Op}; + + assert_eq!( + super::expr("2.extern::toto()").map(|(i, o)| (i, o.opcodes())), + Ok(( + "", + vec![Op::Value(int(2)), Op::Unary(Unary::Ffi("toto".to_string()))], + )) + ); + + assert_eq!( + super::expr("2.extern::toto(3)").map(|(i, o)| (i, o.opcodes())), + Ok(( + "", + vec![ + Op::Value(int(2)), + Op::Value(int(3)), + Op::Binary(Binary::Ffi("toto".to_string())), + ], + )) + ); + } } From 3ed6b47704f9fda9d2a316304685f9b2e8207e6a Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Tue, 22 Oct 2024 11:36:52 +0200 Subject: [PATCH 2/3] wip: allow closures in external funcs, use builder terms - using boxed functions instead of function pointers allow capturing the environment - using builder terms instead of datalog terms remove the need for manual symbol management --- biscuit-auth/src/datalog/expression.rs | 84 +++++++++++++++----------- biscuit-auth/src/error.rs | 4 +- biscuit-auth/src/token/builder.rs | 52 +++++++++++++++- 3 files changed, 103 insertions(+), 37 deletions(-) diff --git a/biscuit-auth/src/datalog/expression.rs b/biscuit-auth/src/datalog/expression.rs index fc5c6d2d..c98ba073 100644 --- a/biscuit-auth/src/datalog/expression.rs +++ b/biscuit-auth/src/datalog/expression.rs @@ -1,18 +1,45 @@ -use crate::error; +use crate::{builder, error}; use super::Term; use super::{SymbolTable, TemporarySymbolTable}; use regex::Regex; use std::collections::{HashMap, HashSet}; +use std::rc::Rc; -type ExternBinary = fn(&mut TemporarySymbolTable, &Term, &Term) -> Result; +#[derive(Clone)] +pub struct ExternFunc( + pub Rc) -> Result>, +); -type ExternUnary = fn(&mut TemporarySymbolTable, &Term) -> Result; +impl std::fmt::Debug for ExternFunc { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "") + } +} + +impl ExternFunc { + pub fn new( + f: Rc) -> Result>, + ) -> Self { + Self(f) + } -#[derive(Debug, Clone)] -pub enum ExternFunc { - Unary(ExternUnary), - Binary(ExternBinary), + pub fn call( + &self, + symbols: &mut TemporarySymbolTable, + name: &str, + left: Term, + right: Option, + ) -> Result { + let left = builder::Term::from_datalog(left, symbols)?; + let right = right + .map(|right| builder::Term::from_datalog(right, symbols)) + .transpose()?; + match self.0(left, right) { + Ok(t) => Ok(t.to_datalog(symbols)), + Err(e) => Err(error::Expression::ExternEvalError(name.to_string(), e)), + } + } } #[derive(Debug, Clone, PartialEq, Hash, Eq)] @@ -57,10 +84,7 @@ impl Unary { let fun = extern_funcs .get(name) .ok_or(error::Expression::UndefinedExtern(name.to_owned()))?; - match fun { - ExternFunc::Unary(fun) => fun(symbols, &i), - ExternFunc::Binary(_) => Err(error::Expression::IncorrectArityExtern), - } + fun.call(symbols, name, i, None) } _ => { //println!("unexpected value type on the stack"); @@ -337,10 +361,7 @@ impl Binary { let fun = extern_funcs .get(name) .ok_or(error::Expression::UndefinedExtern(name.to_owned()))?; - match fun { - ExternFunc::Binary(fun) => fun(symbols, &left, &right), - ExternFunc::Unary(_) => Err(error::Expression::IncorrectArityExtern), - } + fun.call(symbols, name, left, Some(right)) } _ => { @@ -1111,34 +1132,29 @@ mod tests { let mut extern_funcs: HashMap = Default::default(); extern_funcs.insert( "test_bin".to_owned(), - ExternFunc::Binary(|sym, left, right| match (left, right) { - (Term::Integer(left), Term::Integer(right)) => { + ExternFunc::new(Rc::new(|left, right| match (left, right) { + (builder::Term::Integer(left), Some(builder::Term::Integer(right))) => { println!("{left} {right}"); - Ok(Term::Bool((left % 60) == (right % 60))) + Ok(builder::Term::Bool((left % 60) == (right % 60))) } - (Term::Str(left), Term::Str(right)) => { - let left = sym - .get_symbol(*left) - .ok_or(error::Expression::UnknownSymbol(*left))?; - let right = sym - .get_symbol(*right) - .ok_or(error::Expression::UnknownSymbol(*right))?; - + (builder::Term::Str(left), Some(builder::Term::Str(right))) => { println!("{left} {right}"); - Ok(Term::Bool(left.to_lowercase() == right.to_lowercase())) + Ok(builder::Term::Bool( + left.to_lowercase() == right.to_lowercase(), + )) } - _ => Err(error::Expression::InvalidType), - }), + _ => Err("Expected two strings or two integers".to_string()), + })), ); extern_funcs.insert( "test_un".to_owned(), - ExternFunc::Unary(|_, value| match value { - Term::Integer(value) => Ok(Term::Bool(*value == 42)), + ExternFunc::new(Rc::new(|left, right| match (&left, &right) { + (builder::Term::Integer(left), None) => Ok(builder::boolean(*left == 42)), _ => { - println!("{value:?}"); - Err(error::Expression::InvalidType) + println!("{left:?}, {right:?}"); + Err("expecting a single integer".to_string()) } - }), + })), ); let res = e.evaluate(&values, &mut tmp_symbols, &extern_funcs); assert_eq!(res, Ok(Term::Bool(true))); diff --git a/biscuit-auth/src/error.rs b/biscuit-auth/src/error.rs index 5209703b..d5570ca9 100644 --- a/biscuit-auth/src/error.rs +++ b/biscuit-auth/src/error.rs @@ -252,10 +252,10 @@ pub enum Expression { InvalidStack, #[error("Shadowed variable")] ShadowedVariable, - #[error("Incorrect arity for extern func")] - IncorrectArityExtern, #[error("Undefined extern func: {0}")] UndefinedExtern(String), + #[error("Error while evaluating extern func {0}: {1}")] + ExternEvalError(String, String), } /// runtime limits errors diff --git a/biscuit-auth/src/token/builder.rs b/biscuit-auth/src/token/builder.rs index 944c3ad4..60d37f1a 100644 --- a/biscuit-auth/src/token/builder.rs +++ b/biscuit-auth/src/token/builder.rs @@ -1,7 +1,7 @@ //! helper functions and structure to create tokens and blocks use super::{default_symbol_table, Biscuit, Block}; use crate::crypto::{KeyPair, PublicKey}; -use crate::datalog::{self, get_schema_version, SymbolTable}; +use crate::datalog::{self, get_schema_version, SymbolTable, TemporarySymbolTable}; use crate::error; use crate::token::builder_ext::BuilderExt; use biscuit_parser::parser::parse_block_source; @@ -432,6 +432,56 @@ pub enum Term { Null, } +impl Term { + pub fn to_datalog(self, symbols: &mut TemporarySymbolTable) -> datalog::Term { + match self { + Term::Variable(s) => datalog::Term::Variable(symbols.insert(&s) as u32), + Term::Integer(i) => datalog::Term::Integer(i), + Term::Str(s) => datalog::Term::Str(symbols.insert(&s)), + Term::Date(d) => datalog::Term::Date(d), + Term::Bytes(s) => datalog::Term::Bytes(s), + Term::Bool(b) => datalog::Term::Bool(b), + Term::Set(s) => { + datalog::Term::Set(s.into_iter().map(|i| i.to_datalog(symbols)).collect()) + } + Term::Null => datalog::Term::Null, + // The error is caught in the `add_xxx` functions, so this should + // not happen™ + Term::Parameter(s) => panic!("Remaining parameter {}", &s), + } + } + + pub fn from_datalog( + term: datalog::Term, + symbols: &TemporarySymbolTable, + ) -> Result { + Ok(match term { + datalog::Term::Variable(s) => Term::Variable( + symbols + .get_symbol(s as u64) + .ok_or(error::Expression::UnknownVariable(s))? + .to_string(), + ), + datalog::Term::Integer(i) => Term::Integer(i), + datalog::Term::Str(s) => Term::Str( + symbols + .get_symbol(s) + .ok_or(error::Expression::UnknownSymbol(s))? + .to_string(), + ), + datalog::Term::Date(d) => Term::Date(d), + datalog::Term::Bytes(s) => Term::Bytes(s), + datalog::Term::Bool(b) => Term::Bool(b), + datalog::Term::Set(s) => Term::Set( + s.into_iter() + .map(|i| Self::from_datalog(i, symbols)) + .collect::>()?, + ), + datalog::Term::Null => Term::Null, + }) + } +} + impl Convert for Term { fn convert(&self, symbols: &mut SymbolTable) -> datalog::Term { match self { From 806b7f3f9a7f08e10d37e9ec2b2ea186e1e506cb Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Tue, 22 Oct 2024 11:55:20 +0200 Subject: [PATCH 3/3] bump datalog version when external calls are used --- biscuit-auth/src/datalog/mod.rs | 2 ++ biscuit-auth/src/token/mod.rs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/biscuit-auth/src/datalog/mod.rs b/biscuit-auth/src/datalog/mod.rs index 9e0e572a..888374da 100644 --- a/biscuit-auth/src/datalog/mod.rs +++ b/biscuit-auth/src/datalog/mod.rs @@ -978,6 +978,7 @@ fn contains_v3_3_op(expressions: &[Expression]) -> bool { expression.ops.iter().any(|op| match op { Op::Value(term) => contains_v3_3_term(term), Op::Closure(_, _) => true, + Op::Unary(Unary::Ffi(_)) => true, Op::Binary(binary) => matches!( binary, Binary::HeterogeneousEqual @@ -986,6 +987,7 @@ fn contains_v3_3_op(expressions: &[Expression]) -> bool { | Binary::LazyOr | Binary::All | Binary::Any + | Binary::Ffi(_) ), _ => false, }) diff --git a/biscuit-auth/src/token/mod.rs b/biscuit-auth/src/token/mod.rs index 057e81b0..c76fb5ab 100644 --- a/biscuit-auth/src/token/mod.rs +++ b/biscuit-auth/src/token/mod.rs @@ -36,7 +36,7 @@ pub const MAX_SCHEMA_VERSION: u32 = 6; pub const DATALOG_3_1: u32 = 4; /// starting version for 3rd party blocks (datalog 3.2) pub const DATALOG_3_2: u32 = 5; -/// starting version for datalog 3.3 features (reject if, closures, array/map, null, …) +/// starting version for datalog 3.3 features (reject if, closures, array/map, null, external functions, …) pub const DATALOG_3_3: u32 = 6; /// some symbols are predefined and available in every implementation, to avoid