Skip to content

ashiven/eqsat

Repository files navigation

eqsat

Equality Saturation in MimIR

License: MIT GitHub Release GitHub Issues or Pull Requests GitHub Issues or Pull Requests

About

Equality Saturation is a compiler optimization technique that utilizes E-Graphs to represent every possible way of equivalently rewriting a program and find the most optimal version of it, thereby solving the Phase-Ordering Problem. This repository contains Equality Saturation implementations in egg and slotted-egraphs as a plugin for the DSL compiler intermediate representation MimIR.

Usage

You may use this plugin through the MimIR C++ API or its textual representation Mim. Consider the following lightweight examples to get you started.

Option 1: C++ API

#include <fstream>
#include <mim/driver.h>
#include <mim/ast/parser.h>
#include <mim/pass/optimize.h>
#include <mim/util/sys.h>
#include <mim/plug/eqsat/eqsat.h>

using namespace mim;
using namespace mim::plug;

int main(int, char**) {
    try {
        Driver driver;
        auto& w = driver.world();
        driver.log().set(&std::cerr).set(Log::Level::Debug);
        ast::load_plugins(w, View<std::string>{"compile", "core", "opt", "eqsat"});

        // rule foo (x: Nat) = %core.nat.add (x, 0) => x;
        auto foo = w.mut_rule(w.type_nat())->set("foo");
        foo->var()->set("x");
        auto lhs = w.call(core::nat::add, w.tuple(foo->var(), lit_nat(0)))
        auto rhs = foo->var();
        foo.set_lhs(lhs);
        foo.set_rhs(rhs);
        foo.set_guard(w.lit_tt());

        // Use the provided methods to quickly define config functions
        eqsat_impl(w, eqsat::slotted);
        eqsat_cost_fun(w, eqsat::AstSize);
        eqsat_rulesets(w, {eqsat::standard});
        eqsat_rules(w, {foo});

        // fun extern main(x: Nat): Nat = return %core.nat.add (x, 0);
        auto main   = w.mut_fun({w.type_nat()}, {w.type_nat()})->set("main");
        auto x = main->var(2, 0)->set("x");
        auto ret               = main->var(2, 1);
        main->app(false, ret, x);
        main->externalize();

        // Equality saturation is performed as part of optimization
        optimize(w);
        std::ofstream ofs("eqsat.ll");
        driver.backend("ll")(w, ofs);
        ofs.close();

        sys::system("clang eqsat.ll -o eqsat -Wno-override-module");
        outln("exit code: {}", sys::system("./eqsat"));
    } catch (const std::exception& e) {
        errln("{}", e.what());
        return EXIT_FAILURE;
    } catch (...) {
        errln("error: unknown exception");
        return EXIT_FAILURE;
    }

    return EXIT_SUCCESS;
}

Option 2: Mim

plugin core;
plugin eqsat;

// Here you can specify whether the plugin should use its `egg` or `slotted-egraphs` backend.
// The default implementation when nothing gets specified is `egg`.
fun extern _impl(): %eqsat.Impl =
    return %eqsat.slotted;

// To define the cost function that should be used for term extraction,
// simply provide the following config function.
//
// Config values:
// Egg:       AstSize (default), AstDepth
// Slotted:   AstSize (default)
fun extern _cost_fun(): %eqsat.CostFun =
    return %eqsat.AstSize;

// To use a set of rules directly implemented in `egg` or `slotted-egraphs`, define
// the following config function.
//
// To see the existing rulesets, have a look at `src\mim_[egg|slotted]\rulesets`.
// To implement and use your own ruleset, follow the instructions under **Adding rulesets**.
fun extern _rulesets(): %eqsat.Ruleset =
    return %eqsat.rulesets ( %eqsat.standard );

// You can also define your own syntactic rewrite-rules in `MimIR`.
//
// To differentiate between slots: "(var $x)" and patterns: "?x" you should
// prefix variables with "slot_" or "pat_" when using the slotted implementation.
rule foo (slot_x: Nat) = %core.nat.add (slot_x, 0) => slotx;

// And then tell the eqsat plugin to use them for term rewriting.
fun extern _rules(): %eqsat.Rules =
    return %eqsat.rules ( foo );


// Using your rewrite-rule 'foo', this will be rewritten to:
//
//    fun extern main(x: Nat): Nat =
//        return x;
//
fun extern main(x: Nat): Nat =
    return %core.nat.add (x, 0);

Installation

To install this plugin simply follow the instructions below:

1. Clone the mimir repository if you haven't already

git clone --recursive https://github.com/mimir/mimir.git

2. Clone the eqsat repository into mimir/extra

cd mimir/extra
git clone https://github.com/ashiven/eqsat.git
cd ..

3. Ensure that Rust and Cargo are installed

curl https://sh.rustup.rs -sSf | sh

4. Build the project according to the instructions

cmake -S . -B build -DBUILD_TESTING=ON -DMIM_BUILD_EXAMPLES=ON
cmake --build build -j$(nproc)

Rulesets

You may want to define a set of rewrite-rules that are more complex than the syntactic rewrite-rules that can be defined in MimIR. In this case, you should follow this implementation guide on adding a set of rules directly in egg or slotted-egraphs. (The example defines a ruleset for egg)

1. First you should fork and then clone this repository

# After forking:
git clone https://github.com/your-username/eqsat.git

2. Define a set of rules in src/mim_egg/rulesets/myrules.rs

use crate::mim_egg::Mim;
use crate::mim_egg::analysis::MimAnalysis;
use egg::{Rewrite, Pattern};

pub fn rules() -> Vec<Rewrite<Mim, MimAnalysis>> {
    let rules = vec![
        my_rule(),
    ];
    rules
}

fn my_rule() -> Rewrite<Mim, MimAnalysis> {
    let pat: Pattern<Mim> = "(app %foo.bar ?baz)".parse().unwrap();
    let outpat: Pattern<Mim> = "?baz".parse().unwrap();
    Rewrite::new("my_rule", pat, outpat).unwrap()
}

3. Add your ruleset to the RuleSet enum in src/ffi.rs

// ...
#[cxx::bridge]
pub mod bridge {
    #[derive(Debug)]
    enum RuleSet {
        // Egg
        Core,
        MyRules,
        // Slotted
        Standard,
    }
// ...

4. Ensure that your ruleset is registered in src/mim_egg/rulesets/mod.rs

use crate::RuleSet;
use crate::mim_egg::Mim;
use crate::mim_egg::analysis::MimAnalysis;
use egg::Rewrite;

pub mod core;
// Add the module:
pub mod myrules;

pub fn get_rules(rulesets: Vec<RuleSet>) -> Vec<Rewrite<Mim, MimAnalysis>> {
    let mut rules = Vec::new();
    for ruleset in rulesets {
        match ruleset {
            RuleSet::Core => rules.extend(core::rules()),
            // Add the ruleset:
            RuleSet::MyRules => rules.extend(myrules::rules()),
            _ => (),
        }
    }
    rules
}

5. Add your ruleset as a new axiom to eqsat.mim

/// ...
/// ## Rulesets
///
/// ### Egg
///
axm %eqsat.core: %eqsat.Ruleset;
axm %eqsat.myrules: %eqsat.Ruleset;
///
/// ### Slotted
///
axm %eqsat.standard: %eqsat.Ruleset;
/// ...

6. Patch the rewrite phase in plug/phase/rewrite_egg.cpp

std::pair<rust::Vec<RuleSet>, CostFn> RewriteEgg::import_config() {
    // ...
    rust::Vec<RuleSet> rulesets;
    CostFn cost_fn = CostFn::AstSize;
    for (auto lam : lams) {
        auto body = lam->as<Lam>()->body();
        if (auto body_app = body->isa<App>()) {
            if (auto ruleset_config = Axm::isa<eqsat::rulesets>(body_app->arg())) {
                for (auto ruleset : ruleset_config->args())
                    if (Axm::isa<eqsat::core>(ruleset))
                        rulesets.push_back(RuleSet::Core);
                    // Add this:
                    else if (Axm::isa<eqsat::myrules>(ruleset))
                        rulesets.push_back(RuleSet::MyRules);
                    else
                        assert(false && "Provided ruleset does not exist for egg");
    // ...
}

Provided Methods

There are two separate implementations in egg and slotted-egraphs that expose the following methods:

Rewriting

/**
 *  Rewrites an sexpr in `egg` format
 *
 *  sexpr:     a symbolic expr in `egg` format (emitted by the `mim` compiler via `--output-sexpr`)
 *  rulesets:  provides a list of identifiers to rulesets that should be used for rewriting (see src/mim_egg/rulesets)
 *  cost_fn:   provides a cost function that should be used for extraction (currently only AstSize and AstDepth)
 */
rust::Vec<RecExprFFI> eqsat_egg(std::string sexpr, rust::Vec<RuleSet> rulesets, CostFn cost_fn);
/**
 *  Rewrites an sexpr in `slotted-egraphs` format
 *
 *  sexpr:     a symbolic expr in `slotted-egraphs` format (emitted by the `mim` compiler via `--slotted --output-sexpr`)
 *  rulesets:  provides a list of identifiers to rulesets that should be used for rewriting (see src/mim_slotted/rulesets)
 *  cost_fn:   provides a cost function that should be used for extraction (currently only AstSize)
 */
rust::Vec<RecExprFFI> eqsat_slotted(std::string sexpr, rust::Vec<RuleSet> rulesets, CostFn cost_fn);

Pretty-printing

/**
 *  Pretty-prints an sexpr in `egg` format
 *
 *  sexpr:     a symbolic expr in `egg` format (emitted by the `mim` compiler via `--output-sexpr`)
 *  line_len:  the maximal line length after which the sexpr continues on a new line
 */
rust::String pretty_egg(std::string sexpr, size_t line_len);
/**
 *  Pretty-prints an sexpr in `slotted-egraphs` format
 *
 *  sexpr:     a symbolic expr in `slotted-egraphs` format (emitted by the `mim` compiler via `--slotted --output-sexpr`)
 *  line_len:  the maximal line length after which the sexpr continues on a new line
 */
rust::String pretty_slotted(std::string sexpr, size_t line_len);
/**
 *  Pretty-prints an sexpr represented by a Vec<RecExprFFI>
 *
 *  sexprs:    a vector of symbolic expressions in RecExprFFI format (the result of equality saturation)
 *  line_len:  the maximal line length after which the sexpr continues on a new line
 */
rust::String pretty_ffi(rust::Vec<RecExprFFI> sexprs, size_t line_len);

About

Equality Saturation in MimIR

Resources

License

Stars

Watchers

Forks

Contributors