Equality Saturation in MimIR
Equality Saturation is a compiler optimization technique that utilizes E-Graphs to represent every possible way of equivalently rewriting a program and find the most optimal version of it, thereby solving the Phase-Ordering Problem. This repository contains Equality Saturation implementations in egg and slotted-egraphs as a plugin for the DSL compiler intermediate representation MimIR.
You may use this plugin through the MimIR C++ API or its textual representation Mim. Consider the following lightweight examples to get you started.
#include <fstream>
#include <mim/driver.h>
#include <mim/ast/parser.h>
#include <mim/pass/optimize.h>
#include <mim/util/sys.h>
#include <mim/plug/eqsat/eqsat.h>
using namespace mim;
using namespace mim::plug;
int main(int, char**) {
try {
Driver driver;
auto& w = driver.world();
driver.log().set(&std::cerr).set(Log::Level::Debug);
ast::load_plugins(w, View<std::string>{"compile", "core", "opt", "eqsat"});
// rule foo (x: Nat) = %core.nat.add (x, 0) => x;
auto foo = w.mut_rule(w.type_nat())->set("foo");
foo->var()->set("x");
auto lhs = w.call(core::nat::add, w.tuple(foo->var(), lit_nat(0)))
auto rhs = foo->var();
foo.set_lhs(lhs);
foo.set_rhs(rhs);
foo.set_guard(w.lit_tt());
// Use the provided methods to quickly define config functions
eqsat_impl(w, eqsat::slotted);
eqsat_cost_fun(w, eqsat::AstSize);
eqsat_rulesets(w, {eqsat::standard});
eqsat_rules(w, {foo});
// fun extern main(x: Nat): Nat = return %core.nat.add (x, 0);
auto main = w.mut_fun({w.type_nat()}, {w.type_nat()})->set("main");
auto x = main->var(2, 0)->set("x");
auto ret = main->var(2, 1);
main->app(false, ret, x);
main->externalize();
// Equality saturation is performed as part of optimization
optimize(w);
std::ofstream ofs("eqsat.ll");
driver.backend("ll")(w, ofs);
ofs.close();
sys::system("clang eqsat.ll -o eqsat -Wno-override-module");
outln("exit code: {}", sys::system("./eqsat"));
} catch (const std::exception& e) {
errln("{}", e.what());
return EXIT_FAILURE;
} catch (...) {
errln("error: unknown exception");
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}plugin core;
plugin eqsat;
// Here you can specify whether the plugin should use its `egg` or `slotted-egraphs` backend.
// The default implementation when nothing gets specified is `egg`.
fun extern _impl(): %eqsat.Impl =
return %eqsat.slotted;
// To define the cost function that should be used for term extraction,
// simply provide the following config function.
//
// Config values:
// Egg: AstSize (default), AstDepth
// Slotted: AstSize (default)
fun extern _cost_fun(): %eqsat.CostFun =
return %eqsat.AstSize;
// To use a set of rules directly implemented in `egg` or `slotted-egraphs`, define
// the following config function.
//
// To see the existing rulesets, have a look at `src\mim_[egg|slotted]\rulesets`.
// To implement and use your own ruleset, follow the instructions under **Adding rulesets**.
fun extern _rulesets(): %eqsat.Ruleset =
return %eqsat.rulesets ( %eqsat.standard );
// You can also define your own syntactic rewrite-rules in `MimIR`.
//
// To differentiate between slots: "(var $x)" and patterns: "?x" you should
// prefix variables with "slot_" or "pat_" when using the slotted implementation.
rule foo (slot_x: Nat) = %core.nat.add (slot_x, 0) => slotx;
// And then tell the eqsat plugin to use them for term rewriting.
fun extern _rules(): %eqsat.Rules =
return %eqsat.rules ( foo );
// Using your rewrite-rule 'foo', this will be rewritten to:
//
// fun extern main(x: Nat): Nat =
// return x;
//
fun extern main(x: Nat): Nat =
return %core.nat.add (x, 0);
To install this plugin simply follow the instructions below:
1. Clone the mimir repository if you haven't already
git clone --recursive https://github.com/mimir/mimir.git2. Clone the eqsat repository into mimir/extra
cd mimir/extra
git clone https://github.com/ashiven/eqsat.git
cd ..3. Ensure that Rust and Cargo are installed
curl https://sh.rustup.rs -sSf | sh4. Build the project according to the instructions
cmake -S . -B build -DBUILD_TESTING=ON -DMIM_BUILD_EXAMPLES=ON
cmake --build build -j$(nproc)You may want to define a set of rewrite-rules that are more complex than the syntactic rewrite-rules that can be defined in MimIR. In this case, you should follow this implementation guide on adding a set of rules directly in egg or slotted-egraphs. (The example defines a ruleset for egg)
1. First you should fork and then clone this repository
# After forking:
git clone https://github.com/your-username/eqsat.git2. Define a set of rules in src/mim_egg/rulesets/myrules.rs
use crate::mim_egg::Mim;
use crate::mim_egg::analysis::MimAnalysis;
use egg::{Rewrite, Pattern};
pub fn rules() -> Vec<Rewrite<Mim, MimAnalysis>> {
let rules = vec![
my_rule(),
];
rules
}
fn my_rule() -> Rewrite<Mim, MimAnalysis> {
let pat: Pattern<Mim> = "(app %foo.bar ?baz)".parse().unwrap();
let outpat: Pattern<Mim> = "?baz".parse().unwrap();
Rewrite::new("my_rule", pat, outpat).unwrap()
}3. Add your ruleset to the RuleSet enum in src/ffi.rs
// ...
#[cxx::bridge]
pub mod bridge {
#[derive(Debug)]
enum RuleSet {
// Egg
Core,
MyRules,
// Slotted
Standard,
}
// ...4. Ensure that your ruleset is registered in src/mim_egg/rulesets/mod.rs
use crate::RuleSet;
use crate::mim_egg::Mim;
use crate::mim_egg::analysis::MimAnalysis;
use egg::Rewrite;
pub mod core;
// Add the module:
pub mod myrules;
pub fn get_rules(rulesets: Vec<RuleSet>) -> Vec<Rewrite<Mim, MimAnalysis>> {
let mut rules = Vec::new();
for ruleset in rulesets {
match ruleset {
RuleSet::Core => rules.extend(core::rules()),
// Add the ruleset:
RuleSet::MyRules => rules.extend(myrules::rules()),
_ => (),
}
}
rules
}5. Add your ruleset as a new axiom to eqsat.mim
/// ...
/// ## Rulesets
///
/// ### Egg
///
axm %eqsat.core: %eqsat.Ruleset;
axm %eqsat.myrules: %eqsat.Ruleset;
///
/// ### Slotted
///
axm %eqsat.standard: %eqsat.Ruleset;
/// ...
6. Patch the rewrite phase in plug/phase/rewrite_egg.cpp
std::pair<rust::Vec<RuleSet>, CostFn> RewriteEgg::import_config() {
// ...
rust::Vec<RuleSet> rulesets;
CostFn cost_fn = CostFn::AstSize;
for (auto lam : lams) {
auto body = lam->as<Lam>()->body();
if (auto body_app = body->isa<App>()) {
if (auto ruleset_config = Axm::isa<eqsat::rulesets>(body_app->arg())) {
for (auto ruleset : ruleset_config->args())
if (Axm::isa<eqsat::core>(ruleset))
rulesets.push_back(RuleSet::Core);
// Add this:
else if (Axm::isa<eqsat::myrules>(ruleset))
rulesets.push_back(RuleSet::MyRules);
else
assert(false && "Provided ruleset does not exist for egg");
// ...
}There are two separate implementations in egg and slotted-egraphs that expose the following methods:
/**
* Rewrites an sexpr in `egg` format
*
* sexpr: a symbolic expr in `egg` format (emitted by the `mim` compiler via `--output-sexpr`)
* rulesets: provides a list of identifiers to rulesets that should be used for rewriting (see src/mim_egg/rulesets)
* cost_fn: provides a cost function that should be used for extraction (currently only AstSize and AstDepth)
*/
rust::Vec<RecExprFFI> eqsat_egg(std::string sexpr, rust::Vec<RuleSet> rulesets, CostFn cost_fn);/**
* Rewrites an sexpr in `slotted-egraphs` format
*
* sexpr: a symbolic expr in `slotted-egraphs` format (emitted by the `mim` compiler via `--slotted --output-sexpr`)
* rulesets: provides a list of identifiers to rulesets that should be used for rewriting (see src/mim_slotted/rulesets)
* cost_fn: provides a cost function that should be used for extraction (currently only AstSize)
*/
rust::Vec<RecExprFFI> eqsat_slotted(std::string sexpr, rust::Vec<RuleSet> rulesets, CostFn cost_fn);/**
* Pretty-prints an sexpr in `egg` format
*
* sexpr: a symbolic expr in `egg` format (emitted by the `mim` compiler via `--output-sexpr`)
* line_len: the maximal line length after which the sexpr continues on a new line
*/
rust::String pretty_egg(std::string sexpr, size_t line_len);/**
* Pretty-prints an sexpr in `slotted-egraphs` format
*
* sexpr: a symbolic expr in `slotted-egraphs` format (emitted by the `mim` compiler via `--slotted --output-sexpr`)
* line_len: the maximal line length after which the sexpr continues on a new line
*/
rust::String pretty_slotted(std::string sexpr, size_t line_len);/**
* Pretty-prints an sexpr represented by a Vec<RecExprFFI>
*
* sexprs: a vector of symbolic expressions in RecExprFFI format (the result of equality saturation)
* line_len: the maximal line length after which the sexpr continues on a new line
*/
rust::String pretty_ffi(rust::Vec<RecExprFFI> sexprs, size_t line_len);