Skip to content

Commit

Permalink
wip: allow closures in external funcs, use builder terms
Browse files Browse the repository at this point in the history
- using boxed functions instead of function pointers allow capturing the environment
- using builder terms instead of datalog terms remove the need for manual symbol management
  • Loading branch information
divarvel committed Oct 24, 2024
1 parent 642cf85 commit 3175834
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 37 deletions.
84 changes: 50 additions & 34 deletions biscuit-auth/src/datalog/expression.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,45 @@
use crate::error;
use crate::{builder, error};

use super::Term;
use super::{SymbolTable, TemporarySymbolTable};
use regex::Regex;
use std::collections::{HashMap, HashSet};
use std::rc::Rc;

type ExternBinary = fn(&mut TemporarySymbolTable, &Term, &Term) -> Result<Term, error::Expression>;
#[derive(Clone)]
pub struct ExternFunc(
pub Rc<dyn Fn(builder::Term, Option<builder::Term>) -> Result<builder::Term, String>>,
);

type ExternUnary = fn(&mut TemporarySymbolTable, &Term) -> Result<Term, error::Expression>;
impl std::fmt::Debug for ExternFunc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<function>")

Check warning on line 16 in biscuit-auth/src/datalog/expression.rs

View check run for this annotation

Codecov / codecov/patch

biscuit-auth/src/datalog/expression.rs#L15-L16

Added lines #L15 - L16 were not covered by tests
}
}

impl ExternFunc {
pub fn new(
f: Rc<dyn Fn(builder::Term, Option<builder::Term>) -> Result<builder::Term, String>>,
) -> Self {
Self(f)
}

#[derive(Debug, Clone)]
pub enum ExternFunc {
Unary(ExternUnary),
Binary(ExternBinary),
pub fn call(
&self,
symbols: &mut TemporarySymbolTable,
name: &str,
left: Term,
right: Option<Term>,
) -> Result<Term, error::Expression> {
let left = builder::Term::from_datalog(left, symbols)?;
let right = right
.map(|right| builder::Term::from_datalog(right, symbols))
.transpose()?;
match self.0(left, right) {
Ok(t) => Ok(t.to_datalog(symbols)),
Err(e) => Err(error::Expression::ExternEvalError(name.to_string(), e)),

Check warning on line 40 in biscuit-auth/src/datalog/expression.rs

View check run for this annotation

Codecov / codecov/patch

biscuit-auth/src/datalog/expression.rs#L40

Added line #L40 was not covered by tests
}
}
}

#[derive(Debug, Clone, PartialEq, Hash, Eq)]
Expand Down Expand Up @@ -72,10 +99,7 @@ impl Unary {
let fun = extern_funcs
.get(name)
.ok_or(error::Expression::UndefinedExtern(name.to_owned()))?;
match fun {
ExternFunc::Unary(fun) => fun(symbols, &i),
ExternFunc::Binary(_) => Err(error::Expression::IncorrectArityExtern),
}
fun.call(symbols, name, i, None)
}
_ => {
//println!("unexpected value type on the stack");
Expand Down Expand Up @@ -353,10 +377,7 @@ impl Binary {
let fun = extern_funcs
.get(name)
.ok_or(error::Expression::UndefinedExtern(name.to_owned()))?;
match fun {
ExternFunc::Binary(fun) => fun(symbols, &left, &right),
ExternFunc::Unary(_) => Err(error::Expression::IncorrectArityExtern),
}
fun.call(symbols, name, left, Some(right))
}

_ => {
Expand Down Expand Up @@ -1127,34 +1148,29 @@ mod tests {
let mut extern_funcs: HashMap<String, ExternFunc> = Default::default();
extern_funcs.insert(
"test_bin".to_owned(),
ExternFunc::Binary(|sym, left, right| match (left, right) {
(Term::Integer(left), Term::Integer(right)) => {
ExternFunc::new(Rc::new(|left, right| match (left, right) {
(builder::Term::Integer(left), Some(builder::Term::Integer(right))) => {
println!("{left} {right}");
Ok(Term::Bool((left % 60) == (right % 60)))
Ok(builder::Term::Bool((left % 60) == (right % 60)))
}
(Term::Str(left), Term::Str(right)) => {
let left = sym
.get_symbol(*left)
.ok_or(error::Expression::UnknownSymbol(*left))?;
let right = sym
.get_symbol(*right)
.ok_or(error::Expression::UnknownSymbol(*right))?;

(builder::Term::Str(left), Some(builder::Term::Str(right))) => {
println!("{left} {right}");
Ok(Term::Bool(left.to_lowercase() == right.to_lowercase()))
Ok(builder::Term::Bool(
left.to_lowercase() == right.to_lowercase(),
))
}
_ => Err(error::Expression::InvalidType),
}),
_ => Err("Expected two strings or two integers".to_string()),
})),
);
extern_funcs.insert(
"test_un".to_owned(),
ExternFunc::Unary(|_, value| match value {
Term::Integer(value) => Ok(Term::Bool(*value == 42)),
ExternFunc::new(Rc::new(|left, right| match (&left, &right) {
(builder::Term::Integer(left), None) => Ok(builder::boolean(*left == 42)),
_ => {
println!("{value:?}");
Err(error::Expression::InvalidType)
println!("{left:?}, {right:?}");
Err("expecting a single integer".to_string())
}
}),
})),
);
let res = e.evaluate(&values, &mut tmp_symbols, &extern_funcs);
assert_eq!(res, Ok(Term::Bool(true)));
Expand Down
4 changes: 2 additions & 2 deletions biscuit-auth/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ pub enum Expression {
InvalidStack,
#[error("Shadowed variable")]
ShadowedVariable,
#[error("Incorrect arity for extern func")]
IncorrectArityExtern,
#[error("Undefined extern func: {0}")]
UndefinedExtern(String),
#[error("Error while evaluating extern func {0}: {1}")]
ExternEvalError(String, String),
}

/// runtime limits errors
Expand Down
52 changes: 51 additions & 1 deletion biscuit-auth/src/token/builder.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! helper functions and structure to create tokens and blocks
use super::{default_symbol_table, Biscuit, Block};
use crate::crypto::{KeyPair, PublicKey};
use crate::datalog::{self, get_schema_version, SymbolTable};
use crate::datalog::{self, get_schema_version, SymbolTable, TemporarySymbolTable};
use crate::error;
use crate::token::builder_ext::BuilderExt;
use biscuit_parser::parser::parse_block_source;
Expand Down Expand Up @@ -432,6 +432,56 @@ pub enum Term {
Null,
}

impl Term {
pub fn to_datalog(self, symbols: &mut TemporarySymbolTable) -> datalog::Term {
match self {
Term::Variable(s) => datalog::Term::Variable(symbols.insert(&s) as u32),
Term::Integer(i) => datalog::Term::Integer(i),
Term::Str(s) => datalog::Term::Str(symbols.insert(&s)),
Term::Date(d) => datalog::Term::Date(d),
Term::Bytes(s) => datalog::Term::Bytes(s),

Check warning on line 442 in biscuit-auth/src/token/builder.rs

View check run for this annotation

Codecov / codecov/patch

biscuit-auth/src/token/builder.rs#L438-L442

Added lines #L438 - L442 were not covered by tests
Term::Bool(b) => datalog::Term::Bool(b),
Term::Set(s) => {
datalog::Term::Set(s.into_iter().map(|i| i.to_datalog(symbols)).collect())

Check warning on line 445 in biscuit-auth/src/token/builder.rs

View check run for this annotation

Codecov / codecov/patch

biscuit-auth/src/token/builder.rs#L444-L445

Added lines #L444 - L445 were not covered by tests
}
Term::Null => datalog::Term::Null,

Check warning on line 447 in biscuit-auth/src/token/builder.rs

View check run for this annotation

Codecov / codecov/patch

biscuit-auth/src/token/builder.rs#L447

Added line #L447 was not covered by tests
// The error is caught in the `add_xxx` functions, so this should
// not happen™
Term::Parameter(s) => panic!("Remaining parameter {}", &s),

Check warning on line 450 in biscuit-auth/src/token/builder.rs

View check run for this annotation

Codecov / codecov/patch

biscuit-auth/src/token/builder.rs#L450

Added line #L450 was not covered by tests
}
}

pub fn from_datalog(
term: datalog::Term,
symbols: &TemporarySymbolTable,
) -> Result<Self, error::Expression> {
Ok(match term {
datalog::Term::Variable(s) => Term::Variable(
symbols
.get_symbol(s as u64)
.ok_or(error::Expression::UnknownVariable(s))?

Check warning on line 462 in biscuit-auth/src/token/builder.rs

View check run for this annotation

Codecov / codecov/patch

biscuit-auth/src/token/builder.rs#L459-L462

Added lines #L459 - L462 were not covered by tests
.to_string(),
),
datalog::Term::Integer(i) => Term::Integer(i),
datalog::Term::Str(s) => Term::Str(
symbols
.get_symbol(s)
.ok_or(error::Expression::UnknownSymbol(s))?
.to_string(),
),
datalog::Term::Date(d) => Term::Date(d),
datalog::Term::Bytes(s) => Term::Bytes(s),
datalog::Term::Bool(b) => Term::Bool(b),
datalog::Term::Set(s) => Term::Set(
s.into_iter()
.map(|i| Self::from_datalog(i, symbols))
.collect::<Result<_, _>>()?,

Check warning on line 478 in biscuit-auth/src/token/builder.rs

View check run for this annotation

Codecov / codecov/patch

biscuit-auth/src/token/builder.rs#L472-L478

Added lines #L472 - L478 were not covered by tests
),
datalog::Term::Null => Term::Null,

Check warning on line 480 in biscuit-auth/src/token/builder.rs

View check run for this annotation

Codecov / codecov/patch

biscuit-auth/src/token/builder.rs#L480

Added line #L480 was not covered by tests
})
}
}

impl Convert<datalog::Term> for Term {
fn convert(&self, symbols: &mut SymbolTable) -> datalog::Term {
match self {
Expand Down

0 comments on commit 3175834

Please sign in to comment.