Skip to content

feat: Add CallGraph struct, and dead-function-removal pass #1796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3a0ee56
Add call_graph.rs, start writing docs
acl-cqc Dec 17, 2024
fcd5321
roots == Some(empty) meaningless => make non-Opt; pub CallGraphEdge; …
acl-cqc Dec 17, 2024
c497e4d
Remove remove_polyfuncs
acl-cqc Dec 17, 2024
c3dd939
Warn on missing docs
acl-cqc Dec 17, 2024
3bc33bc
Reinstate remove_polyfuncs but deprecate: guess next version number, …
acl-cqc Dec 17, 2024
1e95bc6
Test module entry_points
acl-cqc Dec 17, 2024
9061dc9
Move reachable_funcs outside of CallGraph
acl-cqc Dec 17, 2024
e29ffa2
Rename entry_points<->roots, use extend + assert
acl-cqc Dec 17, 2024
5f89cac
Merge branch 'main' into acl/remove_dead_funcs
acl-cqc Dec 17, 2024
4ee87aa
Merge 'origin/main' into acl/remove_dead_funcs, deprecation msgs
acl-cqc Dec 18, 2024
220bf67
Add RemoveDeadFuncsPass. TODO make remove_dead_funcs use ValidationLe…
acl-cqc Dec 18, 2024
466123d
enclosing{=>_func}, switch order, comment
acl-cqc Dec 18, 2024
f8008d9
Use Pass in tests
acl-cqc Dec 18, 2024
7ba818d
Add CallGraphNode enum and accessors
acl-cqc Dec 20, 2024
03cac78
Move remove_dead_funcs stuff into separate file
acl-cqc Dec 20, 2024
e39c279
Add (rather useless atm) error type
acl-cqc Dec 20, 2024
3f1caa8
switch from Bfs to Dfs
acl-cqc Dec 20, 2024
c47a99e
Don't auto-insert 'main'; error not panic on bad entry-point
acl-cqc Dec 20, 2024
4f36e56
Sneakily-without-tests remove FuncDecls too
acl-cqc Dec 20, 2024
eaca2e7
Use petgraph::visit::Walker rather than std::iter::from_fn
acl-cqc Dec 20, 2024
393a476
dead_func_removal -> dead_funcs
acl-cqc Dec 23, 2024
53389c7
Reinstate monomorphize calling remove_polyfuncs with note re. planned…
acl-cqc Dec 23, 2024
6b496f1
fmt
acl-cqc Dec 23, 2024
4a07dee
Also deprecate remove_polyfuncs_ref; fix docs
acl-cqc Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions hugr-passes/src/call_graph.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#![warn(missing_docs)]
//! Data structure for call graphs of a Hugr
use std::collections::HashMap;

use hugr_core::{ops::OpType, HugrView, Node};
use petgraph::{graph::NodeIndex, Graph};

/// Weight for an edge in a [CallGraph]
pub enum CallGraphEdge {
/// Edge corresponds to a [Call](OpType::Call) node (specified) in the Hugr
Call(Node),
/// Edge corresponds to a [LoadFunction](OpType::LoadFunction) node (specified) in the Hugr
LoadFunction(Node),
}

/// Weight for a petgraph-node in a [CallGraph]
pub enum CallGraphNode {
/// petgraph-node corresponds to a [FuncDecl](OpType::FuncDecl) node (specified) in the Hugr
FuncDecl(Node),
/// petgraph-node corresponds to a [FuncDefn](OpType::FuncDefn) node (specified) in the Hugr
FuncDefn(Node),
/// petgraph-node corresponds to the root node of the hugr, that is not
/// a [FuncDefn](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module)
/// either, as such a node could not have outgoing edges, so is not represented in the petgraph.
NonFuncRoot,
}

/// Details the [Call]s and [LoadFunction]s in a Hugr.
/// Each node in the `CallGraph` corresponds to a [FuncDefn] in the Hugr; each edge corresponds
/// to a [Call]/[LoadFunction] of the edge's target, contained in the edge's source.
///
/// For Hugrs whose root is neither a [Module](OpType::Module) nor a [FuncDefn], the call graph
/// will have an additional [CallGraphNode::NonFuncRoot] corresponding to the Hugr's root, with no incoming edges.
///
/// [Call]: OpType::Call
/// [FuncDefn]: OpType::FuncDefn
/// [LoadFunction]: OpType::LoadFunction
pub struct CallGraph {
g: Graph<CallGraphNode, CallGraphEdge>,
node_to_g: HashMap<Node, NodeIndex<u32>>,
}

impl CallGraph {
/// Makes a new CallGraph for a specified (subview) of a Hugr.
/// Calls to functions outside the view will be dropped.
pub fn new(hugr: &impl HugrView) -> Self {
let mut g = Graph::default();
let non_func_root = (!hugr.get_optype(hugr.root()).is_module()).then_some(hugr.root());
let node_to_g = hugr
.nodes()
.filter_map(|n| {
let weight = match hugr.get_optype(n) {
OpType::FuncDecl(_) => CallGraphNode::FuncDecl(n),
OpType::FuncDefn(_) => CallGraphNode::FuncDefn(n),
_ => (Some(n) == non_func_root).then_some(CallGraphNode::NonFuncRoot)?,
};
Some((n, g.add_node(weight)))
})
.collect::<HashMap<_, _>>();
for (func, cg_node) in node_to_g.iter() {
traverse(hugr, *cg_node, *func, &mut g, &node_to_g)
}
fn traverse(
h: &impl HugrView,
enclosing_func: NodeIndex<u32>,
node: Node, // Nonstrict-descendant of `enclosing_func``
g: &mut Graph<CallGraphNode, CallGraphEdge>,
node_to_g: &HashMap<Node, NodeIndex<u32>>,
) {
for ch in h.children(node) {
if h.get_optype(ch).is_func_defn() {
continue;
};
traverse(h, enclosing_func, ch, g, node_to_g);
let weight = match h.get_optype(ch) {
OpType::Call(_) => CallGraphEdge::Call(ch),
OpType::LoadFunction(_) => CallGraphEdge::LoadFunction(ch),
_ => continue,
};
if let Some(target) = h.static_source(ch) {
g.add_edge(enclosing_func, *node_to_g.get(&target).unwrap(), weight);
}
}
}
CallGraph { g, node_to_g }
}

/// Allows access to the petgraph
pub fn graph(&self) -> &Graph<CallGraphNode, CallGraphEdge> {
&self.g
}

/// Convert a Hugr [Node] into a petgraph node index.
/// Result will be `None` if `n` is not a [FuncDefn](OpType::FuncDefn),
/// [FuncDecl](OpType::FuncDecl) or the hugr root.
pub fn node_index(&self, n: Node) -> Option<NodeIndex<u32>> {
self.node_to_g.get(&n).copied()
}
}
197 changes: 197 additions & 0 deletions hugr-passes/src/dead_funcs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#![warn(missing_docs)]
//! Pass for removing statically-unreachable functions from a Hugr

use std::collections::HashSet;

use hugr_core::{
hugr::hugrmut::HugrMut,
ops::{OpTag, OpTrait},
HugrView, Node,
};
use petgraph::visit::{Dfs, Walker};

use crate::validation::{ValidatePassError, ValidationLevel};

use super::call_graph::{CallGraph, CallGraphNode};

#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
/// Errors produced by [ConstantFoldPass].
pub enum RemoveDeadFuncsError {
#[error("Node {0} was not a FuncDefn child of the Module root")]
InvalidEntryPoint(Node),
#[error(transparent)]
#[allow(missing_docs)]
ValidationError(#[from] ValidatePassError),
}

fn reachable_funcs<'a>(
cg: &'a CallGraph,
h: &'a impl HugrView,
entry_points: impl IntoIterator<Item = Node>,
) -> Result<impl Iterator<Item = Node> + 'a, RemoveDeadFuncsError> {
let g = cg.graph();
let mut entry_points = entry_points.into_iter();
let searcher = if h.get_optype(h.root()).is_module() {
let mut d = Dfs::new(g, 0.into());
d.stack.clear();
for n in entry_points {
if !h.get_optype(n).is_func_defn() || h.get_parent(n) != Some(h.root()) {
return Err(RemoveDeadFuncsError::InvalidEntryPoint(n));
}
d.stack.push(cg.node_index(n).unwrap())
}
d
} else {
if let Some(n) = entry_points.next() {
// Can't be a child of the module root as there isn't a module root!
return Err(RemoveDeadFuncsError::InvalidEntryPoint(n));
}
Dfs::new(g, cg.node_index(h.root()).unwrap())
};
Ok(searcher.iter(g).map(|i| match g.node_weight(i).unwrap() {
CallGraphNode::FuncDefn(n) | CallGraphNode::FuncDecl(n) => *n,
CallGraphNode::NonFuncRoot => h.root(),
}))
}

#[derive(Debug, Clone, Default)]
/// A configuration for the Dead Function Removal pass.
pub struct RemoveDeadFuncsPass {
validation: ValidationLevel,
entry_points: Vec<Node>,
}

impl RemoveDeadFuncsPass {
/// Sets the validation level used before and after the pass is run
pub fn validation_level(mut self, level: ValidationLevel) -> Self {
self.validation = level;
self
}

/// Adds new entry points - these must be [FuncDefn] nodes
/// that are children of the [Module] at the root of the Hugr.
///
/// [FuncDefn]: hugr_core::ops::OpType::FuncDefn
/// [Module]: hugr_core::ops::OpType::Module
pub fn with_module_entry_points(
mut self,
entry_points: impl IntoIterator<Item = Node>,
) -> Self {
self.entry_points.extend(entry_points);
self
}

/// Runs the pass (see [remove_dead_funcs]) with this configuration
pub fn run<H: HugrMut>(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> {
self.validation.run_validated_pass(hugr, |hugr: &mut H, _| {
remove_dead_funcs(hugr, self.entry_points.iter().cloned())
})
}
}

/// Delete from the Hugr any functions that are not used by either [Call] or
/// [LoadFunction] nodes in reachable parts.
///
/// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points,
/// which must be children of the root. Note that if `entry_points` is empty, this will
/// result in all functions in the module being removed.
///
/// For non-[Module]-rooted Hugrs, `entry_points` must be empty; the root node is used.
///
/// # Errors
/// * If there are any `entry_points` but the root of the hugr is not a [Module]
/// * If any node in `entry_points` is
/// * not a [FuncDefn], or
/// * not a child of the root
///
/// [Call]: hugr_core::ops::OpType::Call
/// [FuncDefn]: hugr_core::ops::OpType::FuncDefn
/// [LoadFunction]: hugr_core::ops::OpType::LoadFunction
/// [Module]: hugr_core::ops::OpType::Module
pub fn remove_dead_funcs(
h: &mut impl HugrMut,
entry_points: impl IntoIterator<Item = Node>,
) -> Result<(), RemoveDeadFuncsError> {
let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points)?.collect::<HashSet<_>>();
let unreachable = h
.nodes()
.filter(|n| OpTag::Function.is_superset(h.get_optype(*n).tag()) && !reachable.contains(n))
.collect::<Vec<_>>();
for n in unreachable {
h.remove_subtree(n);
}
Ok(())
}

#[cfg(test)]
mod test {
use std::collections::HashMap;

use itertools::Itertools;
use rstest::rstest;

use hugr_core::builder::{
Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView};

use super::RemoveDeadFuncsPass;

#[rstest]
#[case([], vec![])] // No entry_points removes everything!
#[case(["main"], vec!["from_main", "main"])]
#[case(["from_main"], vec!["from_main"])]
#[case(["other1"], vec!["other1", "other2"])]
#[case(["other2"], vec!["other2"])]
#[case(["other1", "other2"], vec!["other1", "other2"])]
fn remove_dead_funcs_entry_points(
#[case] entry_points: impl IntoIterator<Item = &'static str>,
#[case] retained_funcs: Vec<&'static str>,
) -> Result<(), Box<dyn std::error::Error>> {
let mut hb = ModuleBuilder::new();
let o2 = hb.define_function("other2", Signature::new_endo(usize_t()))?;
let o2inp = o2.input_wires();
let o2 = o2.finish_with_outputs(o2inp)?;
let mut o1 = hb.define_function("other1", Signature::new_endo(usize_t()))?;

let o1c = o1.call(o2.handle(), &[], o1.input_wires())?;
o1.finish_with_outputs(o1c.outputs())?;

let fm = hb.define_function("from_main", Signature::new_endo(usize_t()))?;
let f_inp = fm.input_wires();
let fm = fm.finish_with_outputs(f_inp)?;
let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?;
let mc = m.call(fm.handle(), &[], m.input_wires())?;
m.finish_with_outputs(mc.outputs())?;

let mut hugr = hb.finish_hugr()?;

let avail_funcs = hugr
.nodes()
.filter_map(|n| {
hugr.get_optype(n)
.as_func_defn()
.map(|fd| (fd.name.clone(), n))
})
.collect::<HashMap<_, _>>();

RemoveDeadFuncsPass::default()
.with_module_entry_points(
entry_points
.into_iter()
.map(|name| *avail_funcs.get(name).unwrap())
.collect::<Vec<_>>(),
)
.run(&mut hugr)
.unwrap();

let remaining_funcs = hugr
.nodes()
.filter_map(|n| hugr.get_optype(n).as_func_defn().map(|fd| fd.name.as_str()))
.sorted()
.collect_vec();
assert_eq!(remaining_funcs, retained_funcs);
Ok(())
}
}
12 changes: 11 additions & 1 deletion hugr-passes/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
//! Compilation passes acting on the HUGR program representation.

pub mod call_graph;
pub mod const_fold;
pub mod dataflow;
mod dead_funcs;
pub use dead_funcs::{remove_dead_funcs, RemoveDeadFuncsPass};
pub mod force_order;
mod half_node;
pub mod lower;
pub mod merge_bbs;
mod monomorphize;
// TODO: Deprecated re-export. Remove on a breaking release.
#[deprecated(
since = "0.14.1",
note = "Use `hugr::algorithms::call_graph::RemoveDeadFuncsPass` instead."
)]
#[allow(deprecated)]
pub use monomorphize::remove_polyfuncs;
// TODO: Deprecated re-export. Remove on a breaking release.
#[deprecated(
since = "0.14.1",
note = "Use `hugr::algorithms::MonomorphizePass` instead."
)]
#[allow(deprecated)]
pub use monomorphize::monomorphize;
pub use monomorphize::{remove_polyfuncs, MonomorphizeError, MonomorphizePass};
pub use monomorphize::{MonomorphizeError, MonomorphizePass};
pub mod nest_cfgs;
pub mod non_local;
pub mod validation;
Expand Down
Loading
Loading