Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add CallGraph struct, and dead-function-removal pass #1796

Merged
merged 24 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3a0ee56
Add call_graph.rs, start writing docs
acl-cqc Dec 17, 2024
fcd5321
roots == Some(empty) meaningless => make non-Opt; pub CallGraphEdge; …
acl-cqc Dec 17, 2024
c497e4d
Remove remove_polyfuncs
acl-cqc Dec 17, 2024
c3dd939
Warn on missing docs
acl-cqc Dec 17, 2024
3bc33bc
Reinstate remove_polyfuncs but deprecate: guess next version number, …
acl-cqc Dec 17, 2024
1e95bc6
Test module entry_points
acl-cqc Dec 17, 2024
9061dc9
Move reachable_funcs outside of CallGraph
acl-cqc Dec 17, 2024
e29ffa2
Rename entry_points<->roots, use extend + assert
acl-cqc Dec 17, 2024
5f89cac
Merge branch 'main' into acl/remove_dead_funcs
acl-cqc Dec 17, 2024
4ee87aa
Merge 'origin/main' into acl/remove_dead_funcs, deprecation msgs
acl-cqc Dec 18, 2024
220bf67
Add RemoveDeadFuncsPass. TODO make remove_dead_funcs use ValidationLe…
acl-cqc Dec 18, 2024
466123d
enclosing{=>_func}, switch order, comment
acl-cqc Dec 18, 2024
f8008d9
Use Pass in tests
acl-cqc Dec 18, 2024
7ba818d
Add CallGraphNode enum and accessors
acl-cqc Dec 20, 2024
03cac78
Move remove_dead_funcs stuff into separate file
acl-cqc Dec 20, 2024
e39c279
Add (rather useless atm) error type
acl-cqc Dec 20, 2024
3f1caa8
switch from Bfs to Dfs
acl-cqc Dec 20, 2024
c47a99e
Don't auto-insert 'main'; error not panic on bad entry-point
acl-cqc Dec 20, 2024
4f36e56
Sneakily-without-tests remove FuncDecls too
acl-cqc Dec 20, 2024
eaca2e7
Use petgraph::visit::Walker rather than std::iter::from_fn
acl-cqc Dec 20, 2024
393a476
dead_func_removal -> dead_funcs
acl-cqc Dec 23, 2024
53389c7
Reinstate monomorphize calling remove_polyfuncs with note re. planned…
acl-cqc Dec 23, 2024
6b496f1
fmt
acl-cqc Dec 23, 2024
4a07dee
Also deprecate remove_polyfuncs_ref; fix docs
acl-cqc Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions hugr-passes/src/call_graph.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
#![warn(missing_docs)]
//! Data structure for call graphs of a Hugr, and some transformations using them.
use std::collections::{HashMap, HashSet};

use hugr_core::{
hugr::hugrmut::HugrMut,
ops::{OpTag, OpTrait, OpType},
HugrView, Node,
};
use itertools::Itertools;
use petgraph::{graph::NodeIndex, visit::Bfs, Graph};

/// Weight for an edge in a [CallGraph]
pub enum CallGraphEdge {
/// Edge corresponds to a [Call](OpType::Call) node (specified) in the Hugr
Call(Node),
/// Edge corresponds to a [LoadFunction](OpType::LoadFunction) node (specified) in the Hugr
LoadFunction(Node),
}

/// Details the [Call]s and [LoadFunction]s in a Hugr.
/// Each node in the `CallGraph` corresponds to a [FuncDefn] in the Hugr; each edge corresponds
/// to a [Call]/[LoadFunction] of the edge's target, contained in the edge's source.
///
/// For Hugrs whose root is neither a [Module](OpType::Module) nor a [FuncDefn], the call graph
/// will have an additional node corresponding to the Hugr's root, with no incoming edges.
///
/// [Call]: OpType::Call
/// [FuncDefn]: OpType::FuncDefn
/// [LoadFunction]: OpType::LoadFunction
pub struct CallGraph {
g: Graph<Node, CallGraphEdge>,
node_to_g: HashMap<Node, NodeIndex<u32>>,
}

impl CallGraph {
/// Makes a new CallGraph for a specified (subview) of a Hugr.
/// Calls to functions outside the view will be dropped.
pub fn new(hugr: &impl HugrView) -> Self {
let mut g = Graph::default();
// For non-Module-rooted Hugrs, make sure we include the root
let root = (!hugr.get_optype(hugr.root()).is_module()).then_some(hugr.root());
let node_to_g = hugr
.nodes()
.filter(|&n| Some(n) == root || OpTag::Function.is_superset(hugr.get_optype(n).tag()))
.map(|n| (n, g.add_node(n)))
.collect::<HashMap<_, _>>();
for (func, cg_node) in node_to_g.iter() {
traverse(hugr, *func, *cg_node, &mut g, &node_to_g)
}
fn traverse(
h: &impl HugrView,
node: Node,
enclosing: NodeIndex<u32>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for clarification - is enclosing the node in the call graph corresponding to the node node in the hugr or something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the (callgraph representation of) the nearest enclosing FuncDefn, so maybe the node's parent or an ancestor. I'll rename...

g: &mut Graph<Node, CallGraphEdge>,
node_to_g: &HashMap<Node, NodeIndex<u32>>,
) {
for ch in h.children(node) {
if h.get_optype(ch).is_func_defn() {
continue;
};
traverse(h, ch, enclosing, g, node_to_g);
let weight = match h.get_optype(ch) {
OpType::Call(_) => CallGraphEdge::Call(ch),
OpType::LoadFunction(_) => CallGraphEdge::LoadFunction(ch),
_ => continue,
};
if let Some(target) = h.static_source(ch) {
g.add_edge(enclosing, *node_to_g.get(&target).unwrap(), weight);
}
}
}
CallGraph { g, node_to_g }
}
}

fn reachable_funcs<'a>(
cg: &'a CallGraph,
h: &impl HugrView,
entry_points: impl IntoIterator<Item = Node>,
) -> impl Iterator<Item = Node> + 'a {
let mut roots = entry_points.into_iter().collect_vec();
let mut b = if h.get_optype(h.root()).is_module() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love single letter variable names but I think this deserves a few more letters

if roots.is_empty() {
roots.extend(h.children(h.root()).filter(|n| {
h.get_optype(*n)
.as_func_defn()
.is_some_and(|fd| fd.name == "main")
}));
assert_eq!(roots.len(), 1, "No entry_points for Module and no `main`");
}
let mut roots = roots.into_iter().map(|i| cg.node_to_g.get(&i).unwrap());
let mut b = Bfs::new(&cg.g, *roots.next().unwrap());
b.stack.extend(roots);
b
} else {
assert!(roots.is_empty());
Bfs::new(&cg.g, *cg.node_to_g.get(&h.root()).unwrap())
};
std::iter::from_fn(move || b.next(&cg.g)).map(|i| *cg.g.node_weight(i).unwrap())
}

/// Delete from the Hugr any functions that are not used by either [Call](OpType::Call) or
/// [LoadFunction](OpType::LoadFunction) nodes in reachable parts.
///
/// For [Module](OpType::Module)-rooted Hugrs, `roots` may provide a list of entry points;
/// these are expected to be children of the root although this is not enforced. If `roots`
/// is empty, then the root must have exactly one child being a function called `main`,
/// which is used as sole entry point.
///
/// For non-Module-rooted Hugrs, `entry_points` must be empty; the root node is used.
///
/// # Panics
/// * If the Hugr is non-Module-rooted and `entry_points` is non-empty
/// * If the Hugr is Module-rooted, but does not declare `main`, and `entry_points` is empty
/// * If the Hugr is Module-rooted, and `entry_points` is non-empty but contains nodes that
/// are not [FuncDefn](OpType::FuncDefn)s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, I don't think the second should panic.

I think the interface would be cleaner if entry_points, maybe with a different name: must be FuncDefn or FuncDecl nodes that are immediate children of the root.

Now the first panic goes away, and the third would be an error with the offending nodes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to error, indeed. But FuncDefns inside a non-Module are invisible from outside (so unless you're gonna add new stuff inside the root - which you can do, but that's not linking, that's....arbitrary editing), so I've not allowed those as entry_points. I could be persuaded, it'd be easier from a code perspective not to check, but it feels wrong :-!

pub fn remove_dead_funcs(h: &mut impl HugrMut, entry_points: impl IntoIterator<Item = Node>) {
let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points).collect::<HashSet<_>>();
let unreachable = h
.nodes()
.filter(|n| h.get_optype(*n).is_func_defn() && !reachable.contains(n))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove dead FuncDecls too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done (no tests I admit)

.collect::<Vec<_>>();
for n in unreachable {
h.remove_subtree(n);
}
}

#[cfg(test)]
mod test {
use std::collections::HashMap;

use itertools::Itertools;
use rstest::rstest;

use hugr_core::builder::{
Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView};

use super::remove_dead_funcs;

#[rstest]
#[case([], vec!["from_main", "main"])]
#[case(["main"], vec!["from_main", "main"])]
#[case(["from_main"], vec!["from_main"])]
#[case(["other1"], vec!["other1", "other2"])]
#[case(["other2"], vec!["other2"])]
#[case(["other1", "other2"], vec!["other1", "other2"])]
fn remove_dead_funcs_entry_points(
#[case] entry_points: impl IntoIterator<Item = &'static str>,
#[case] retained_funcs: Vec<&'static str>,
) -> Result<(), Box<dyn std::error::Error>> {
let mut hb = ModuleBuilder::new();
let o2 = hb.define_function("other2", Signature::new_endo(usize_t()))?;
let o2inp = o2.input_wires();
let o2 = o2.finish_with_outputs(o2inp)?;
let mut o1 = hb.define_function("other1", Signature::new_endo(usize_t()))?;

let o1c = o1.call(o2.handle(), &[], o1.input_wires())?;
o1.finish_with_outputs(o1c.outputs())?;

let fm = hb.define_function("from_main", Signature::new_endo(usize_t()))?;
let f_inp = fm.input_wires();
let fm = fm.finish_with_outputs(f_inp)?;
let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?;
let mc = m.call(fm.handle(), &[], m.input_wires())?;
m.finish_with_outputs(mc.outputs())?;

let mut hugr = hb.finish_hugr()?;

let avail_funcs = hugr
.nodes()
.filter_map(|n| {
hugr.get_optype(n)
.as_func_defn()
.map(|fd| (fd.name.clone(), n))
})
.collect::<HashMap<_, _>>();

remove_dead_funcs(
&mut hugr,
entry_points
.into_iter()
.map(|name| *avail_funcs.get(name).unwrap())
.collect::<Vec<_>>(),
);
let remaining_funcs = hugr
.nodes()
.filter_map(|n| hugr.get_optype(n).as_func_defn().map(|fd| fd.name.as_str()))
.sorted()
.collect_vec();
assert_eq!(remaining_funcs, retained_funcs);
Ok(())
}
}
5 changes: 4 additions & 1 deletion hugr-passes/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
//! Compilation passes acting on the HUGR program representation.

pub mod call_graph;
pub mod const_fold;
pub mod dataflow;
pub mod force_order;
mod half_node;
pub mod lower;
pub mod merge_bbs;
mod monomorphize;
pub use monomorphize::{monomorphize, remove_polyfuncs};
pub use monomorphize::monomorphize;
#[allow(deprecated)]
pub use monomorphize::remove_polyfuncs;
pub mod nest_cfgs;
pub mod non_local;
pub mod validation;
Expand Down
23 changes: 16 additions & 7 deletions hugr-passes/src/monomorphize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ use hugr_core::{
use hugr_core::hugr::{hugrmut::HugrMut, internal::HugrMutInternals, Hugr, HugrView, OpType};
use itertools::Itertools as _;

use crate::call_graph::remove_dead_funcs;

/// Replaces calls to polymorphic functions with calls to new monomorphic
/// instantiations of the polymorphic ones.
///
/// If the Hugr is [Module](OpType::Module)-rooted,
/// * then the original polymorphic [FuncDefn]s are left untouched (including Calls inside them)
/// - call [remove_polyfuncs] when no other Hugr will be linked in that might instantiate these
/// - [remove_dead_funcs] can be used when no other Hugr will be linked in that might instantiate these
/// * else, the originals are removed (they are invisible from outside the Hugr).
///
/// If the Hugr is [FuncDefn](OpType::FuncDefn)-rooted with polymorphic
Expand All @@ -42,7 +44,7 @@ pub fn monomorphize(mut h: Hugr) -> Hugr {
if !is_polymorphic_funcdefn(h.get_optype(root)) {
mono_scan(&mut h, root, None, &mut HashMap::new());
if !h.get_optype(root).is_module() {
return remove_polyfuncs(h);
remove_dead_funcs(&mut h, []);
}
}
#[cfg(debug_assertions)]
Expand All @@ -54,8 +56,11 @@ pub fn monomorphize(mut h: Hugr) -> Hugr {
/// calls from *monomorphic* code, this will make the Hugr invalid (call [monomorphize]
/// first).
///
/// TODO replace this with a more general remove-unused-functions pass
/// <https://github.com/CQCL/hugr/issues/1753>
/// Deprecated: use [remove_dead_funcs] instead.
#[deprecated(
since = "0.14.1",
note = "Use hugr_passes::call_graph::remove_dead_funcs instead"
)]
pub fn remove_polyfuncs(mut h: Hugr) -> Hugr {
let mut pfs_to_delete = Vec::new();
let mut to_scan = Vec::from_iter(h.children(h.root()));
Expand Down Expand Up @@ -322,7 +327,9 @@ mod test {
use hugr_core::{Hugr, HugrView, Node};
use rstest::rstest;

use super::{is_polymorphic, mangle_inner_func, mangle_name, monomorphize, remove_polyfuncs};
use crate::call_graph::remove_dead_funcs;

use super::{is_polymorphic, mangle_inner_func, mangle_name, monomorphize};

fn pair_type(ty: Type) -> Type {
Type::new_tuple(vec![ty.clone(), ty])
Expand Down Expand Up @@ -426,7 +433,8 @@ mod test {

assert_eq!(monomorphize(mono.clone()), mono); // Idempotent

let nopoly = remove_polyfuncs(mono);
let mut nopoly = mono;
remove_dead_funcs(&mut nopoly, []);
let mut funcs = list_funcs(&nopoly);

assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));
Expand Down Expand Up @@ -645,7 +653,8 @@ mod test {
module_builder.finish_hugr().unwrap()
};

let mono_hugr = remove_polyfuncs(monomorphize(hugr));
let mut mono_hugr = monomorphize(hugr);
remove_dead_funcs(&mut mono_hugr, []);

let funcs = list_funcs(&mono_hugr);
assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));
Expand Down
Loading