Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the Verify mode #726

Merged
merged 17 commits into from
Feb 7, 2025
2 changes: 1 addition & 1 deletion crates/interpreter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,4 @@ pub use module::Module;
pub use syntax::{
FuncType, GlobalType, ImportDesc, Limits, Mut, RefType, ResultType, TableType, ValType,
};
pub use valid::validate;
pub use valid::prepare;
4 changes: 2 additions & 2 deletions crates/interpreter/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::parser::{SkipData, SkipElem};
use crate::side_table::*;
use crate::syntax::*;
use crate::toctou::*;
use crate::valid::validate;
use crate::valid::prepare;
use crate::*;

/// Valid module.
Expand Down Expand Up @@ -52,7 +52,7 @@ impl ImportDesc {
impl<'m> Module<'m> {
/// Validates a WASM module in binary format.
pub fn new(binary: &'m [u8]) -> Result<Self, Error> {
let side_table = validate(binary)?;
let side_table = prepare(binary)?;
let mut module = unsafe { Self::new_unchecked(binary) };
// TODO(dev/fast-interp): We should take a buffer as argument to write to.
module.side_table = Box::leak(Box::new(side_table));
Expand Down
134 changes: 114 additions & 20 deletions crates/interpreter/src/valid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,102 @@ use crate::toctou::*;
use crate::util::*;
use crate::*;

/// Checks whether a WASM module in binary format is valid.
pub fn validate(binary: &[u8]) -> Result<Vec<MetadataEntry>, Error> {
/// Checks whether a WASM module in binary format is valid, and returns the side table.
pub fn prepare(binary: &[u8]) -> Result<Vec<MetadataEntry>, Error> {
Context::<Prepare>::default().check_module(&mut Parser::new(binary))
}

pub trait ValidMode: Default {}
#[allow(dead_code)]
#[allow(unused_variables)]
/// Checks whether a WASM module with the side table in binary format is valid.
pub fn verify(binary: &[u8]) -> Result<(), Error> {
todo!()
}

trait ValidMode: Default {
/// List of source branches, when preparing.
///
/// When verifying, contains at most one _target_ branch. Source branches are eagerly patched to
/// their target branch using the branch table.
type Branches<'m>: Default + IntoIterator<Item = SideTableBranch<'m>>;

/// Branch table to prepare or verify.
type BranchTable<'m>;

/// Updates the branch table for source according to target, when preparing.
///
/// When verifying, makes sure source is target branch.
fn stitch_branch<'m>(
expr: &mut Expr<'_, 'm, Self>, source: SideTableBranch<'m>, target: SideTableBranch<'m>,
) -> CheckResult;

/// Pushes a source branch, when preparing.
///
/// When verifying, only push if there are no branches already. If there is one, verify that
/// it's the same.
fn push_branch<'m>(
branches: &mut Self::Branches<'m>, branch: SideTableBranch<'m>,
) -> CheckResult;

/// Does nothing, when preparing.
///
/// When verifying, patches a source branch to its target branch using the branch table.
fn patch_branch<'m>(
table: &Self::BranchTable<'m>, source: SideTableBranch<'m>,
) -> Result<SideTableBranch<'m>, Error>;
}

#[derive(Default)]
pub struct Prepare;
impl ValidMode for Prepare {}
struct Prepare;
impl ValidMode for Prepare {
type Branches<'m> = Vec<SideTableBranch<'m>>;
type BranchTable<'m> = BranchTable;

fn stitch_branch<'m>(
expr: &mut Expr<'_, 'm, Self>, source: SideTableBranch<'m>, target: SideTableBranch<'m>,
) -> CheckResult {
expr.branch_table.stitch(source, target)
}

fn push_branch<'m>(
branches: &mut Self::Branches<'m>, branch: SideTableBranch<'m>,
) -> CheckResult {
Ok(branches.push(branch))
}

fn patch_branch<'m>(
_: &Self::BranchTable<'m>, source: SideTableBranch<'m>,
) -> Result<SideTableBranch<'m>, Error> {
Ok(source)
}
}

#[derive(Default)]
struct Verify;
impl ValidMode for Verify {
type Branches<'m> = Option<SideTableBranch<'m>>;
type BranchTable<'m> = BranchTableView<'m>;

fn stitch_branch<'m>(
_: &mut Expr<'_, 'm, Self>, source: SideTableBranch<'m>, target: SideTableBranch<'m>,
) -> CheckResult {
check(source == target)
}

fn push_branch<'m>(
branches: &mut Self::Branches<'m>, branch: SideTableBranch<'m>,
) -> CheckResult {
check(branches.replace(branch).is_none_or(|x| x == branch))
}

fn patch_branch<'m>(
table: &Self::BranchTable<'m>, source: SideTableBranch<'m>,
) -> Result<SideTableBranch<'m>, Error> {
// source.parser += delta_ip;
// source.branch_table += delta_stp;
todo!()
}
}

type Parser<'m> = parser::Parser<'m, Check>;
type CheckResult = MResult<(), Check>;
Expand All @@ -49,7 +135,6 @@ struct Context<'m, M: ValidMode> {
globals: Vec<GlobalType>,
elems: Vec<RefType>,
datas: Option<usize>,
#[allow(dead_code)]
mode: PhantomData<M>,
}

Expand Down Expand Up @@ -427,8 +512,8 @@ struct Expr<'a, 'm, M: ValidMode> {
is_const: Result<&'a mut [bool], &'a [bool]>,
is_body: bool,
locals: Vec<ValType>,
labels: Vec<Label<'m>>,
branch_table: BranchTable,
labels: Vec<Label<'m, M>>,
branch_table: M::BranchTable<'m>,
}

#[derive(Default)]
Expand Down Expand Up @@ -499,7 +584,7 @@ impl BranchTable {
}
}

#[derive(Debug, Copy, Clone)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
struct SideTableBranch<'m> {
parser: &'m [u8],
branch_table: usize,
Expand All @@ -508,14 +593,14 @@ struct SideTableBranch<'m> {
}

#[derive(Debug, Default)]
struct Label<'m> {
struct Label<'m, M: ValidMode> {
type_: FuncType<'m>,
/// Whether an `else` is possible before `end`.
kind: LabelKind<'m>,
/// Whether the bottom of the stack is polymorphic.
polymorphic: bool,
stack: Vec<OpdType>,
branches: Vec<SideTableBranch<'m>>,
branches: M::Branches<'m>,
/// Total stack length of the labels in this function up to this label.
prev_stack: usize,
}
Expand Down Expand Up @@ -620,7 +705,7 @@ impl<'a, 'm, M: ValidMode> Expr<'a, 'm, M> {
let result = self.label().type_.results.len();
let mut target = self.branch_target(result);
target.branch_table += 1;
self.branch_table.stitch(source, target)?
M::stitch_branch(self, source, target)?;
}
_ => Err(invalid())?,
}
Expand Down Expand Up @@ -787,11 +872,11 @@ impl<'a, 'm, M: ValidMode> Expr<'a, 'm, M> {
self.locals.get(x as usize).cloned().ok_or_else(invalid)
}

fn label(&mut self) -> &mut Label<'m> {
fn label(&mut self) -> &mut Label<'m, M> {
self.labels.last_mut().unwrap()
}

fn immutable_label(&self) -> &Label<'m> {
fn immutable_label(&self) -> &Label<'m, M> {
self.labels.last().unwrap()
}

Expand Down Expand Up @@ -866,7 +951,14 @@ impl<'a, 'm, M: ValidMode> Expr<'a, 'm, M> {
let stack = type_.params.iter().cloned().map(OpdType::from).collect();
let prev_label = self.immutable_label();
let prev_stack = prev_label.prev_stack + prev_label.stack.len();
let label = Label { type_, kind, polymorphic: false, stack, branches: vec![], prev_stack };
let label = Label {
type_,
kind,
polymorphic: false,
stack,
branches: Default::default(),
prev_stack,
};
self.labels.push(label);
Ok(())
}
Expand All @@ -875,14 +967,14 @@ impl<'a, 'm, M: ValidMode> Expr<'a, 'm, M> {
let results_len = self.label().type_.results.len();
let mut target = self.branch_target(results_len);
for source in core::mem::take(&mut self.label().branches) {
self.branch_table.stitch(source, target)?;
M::stitch_branch(self, source, target)?;
}
let label = self.label();
if let LabelKind::If(source) = label.kind {
check(label.type_.params == label.type_.results)?;
// SAFETY: This function is only called after parsing an End instruction.
target.parser = offset_front(target.parser, -1);
self.branch_table.stitch(source, target)?;
M::stitch_branch(self, source, target)?;
}
let results = self.label().type_.results;
self.pops(results)?;
Expand All @@ -898,15 +990,17 @@ impl<'a, 'm, M: ValidMode> Expr<'a, 'm, M> {
let n = self.labels.len();
check(l < n)?;
let source = self.branch_source();
let source = M::patch_branch(&self.branch_table, source)?;
let label = &mut self.labels[n - l - 1];
Ok(match label.kind {
LabelKind::Block | LabelKind::If(_) => {
label.branches.push(source);
M::push_branch(&mut label.branches, source)?;
label.type_.results
}
LabelKind::Loop(target) => {
self.branch_table.stitch(source, target)?;
label.type_.params
let params = label.type_.params;
M::stitch_branch(self, source, target)?;
params
}
})
}
Expand Down
Loading