Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add MonomorphizePass and deprecate monomorphize #1809

Merged
merged 7 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion hugr-passes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@ mod half_node;
pub mod lower;
pub mod merge_bbs;
mod monomorphize;
pub use monomorphize::{monomorphize, remove_polyfuncs};
// TODO: Deprecated re-export. Remove on a breaking release.
#[deprecated(
since = "0.14.1",
note = "Use `hugr::algorithms::MonomorphizePass` instead."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is hugr::algorithms ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hugr/hugr/src/lib.rs

Lines 134 to 135 in 5dc24c1

pub use hugr_passes as algorithms;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😮

)]
#[allow(deprecated)]
pub use monomorphize::monomorphize;
pub use monomorphize::{remove_polyfuncs, MonomorphizePass};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why remove_polyfuncs is treated any differently to monomorphize here.
Please export MonomorphizeError

pub mod nest_cfgs;
pub mod non_local;
pub mod validation;
Expand Down
128 changes: 100 additions & 28 deletions hugr-passes/src/monomorphize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ use hugr_core::{
Node,
};

use hugr_core::hugr::{hugrmut::HugrMut, internal::HugrMutInternals, Hugr, HugrView, OpType};
use hugr_core::hugr::{hugrmut::HugrMut, Hugr, HugrView, OpType};
use itertools::Itertools as _;
use thiserror::Error;

/// Replaces calls to polymorphic functions with calls to new monomorphic
/// instantiations of the polymorphic ones.
Expand All @@ -28,26 +29,33 @@ use itertools::Itertools as _;
/// children of the root node. We make best effort to ensure that names (derived
/// from parent function names and concrete type args) of new functions are unique
/// whenever the names of their parents are unique, but this is not guaranteed.
#[deprecated(
since = "0.14.1",
note = "Use `hugr::algorithms::MonomorphizePass` instead."
)]
// TODO: Deprecated. Remove on a breaking release.
pub fn monomorphize(mut h: Hugr) -> Hugr {
let validate = |h: &Hugr| h.validate().unwrap_or_else(|e| panic!("{e}"));

// We clone the extension registry because we will need a reference to
// create our mutable substitutions. This is cannot cause a problem because
// we will not be adding any new types or extension ops to the HUGR.
#[cfg(debug_assertions)]
validate(&h);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[cfg(debug_assertions)]
validate(&h);

kill these?


monomorphize_ref(&mut h);

#[cfg(debug_assertions)]
validate(&h);
h
}

fn monomorphize_ref(h: &mut impl HugrMut) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like the name. Is the plan to rename this in the next breaking release, and to add a deprecated alias monomorphize_ref? If so, could you add a TODO comment?
Similar for remove_polyfuncs_ref

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a private function so we can rename to monomorphize when we remove the public one, will add comment

let root = h.root();
// If the root is a polymorphic function, then there are no external calls, so nothing to do
if !is_polymorphic_funcdefn(h.get_optype(root)) {
mono_scan(&mut h, root, None, &mut HashMap::new());
mono_scan(h, root, None, &mut HashMap::new());
if !h.get_optype(root).is_module() {
return remove_polyfuncs(h);
remove_polyfuncs_ref(h);
}
}
#[cfg(debug_assertions)]
validate(&h);
h
}

/// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have
Expand All @@ -57,6 +65,11 @@ pub fn monomorphize(mut h: Hugr) -> Hugr {
/// TODO replace this with a more general remove-unused-functions pass
/// <https://github.com/CQCL/hugr/issues/1753>
pub fn remove_polyfuncs(mut h: Hugr) -> Hugr {
remove_polyfuncs_ref(&mut h);
h
}

fn remove_polyfuncs_ref(h: &mut impl HugrMut) {
let mut pfs_to_delete = Vec::new();
let mut to_scan = Vec::from_iter(h.children(h.root()));
while let Some(n) = to_scan.pop() {
Expand All @@ -69,7 +82,6 @@ pub fn remove_polyfuncs(mut h: Hugr) -> Hugr {
for n in pfs_to_delete {
h.remove_subtree(n);
}
h
}

fn is_polymorphic(fd: &FuncDefn) -> bool {
Expand All @@ -93,7 +105,7 @@ type Instantiations = HashMap<Node, HashMap<Vec<TypeArg>, Node>>;
/// Optionally copies the subtree into a new location whilst applying a substitution.
/// The subtree should be monomorphic after the substitution (if provided) has been applied.
fn mono_scan(
h: &mut Hugr,
h: &mut impl HugrMut,
parent: Node,
mut subst_into: Option<&mut Instantiating>,
cache: &mut Instantiations,
Expand Down Expand Up @@ -161,7 +173,7 @@ fn mono_scan(
}

fn instantiate(
h: &mut Hugr,
h: &mut impl HugrMut,
poly_func: Node,
type_args: Vec<TypeArg>,
mono_sig: Signature,
Expand Down Expand Up @@ -218,20 +230,20 @@ fn instantiate(
// 'ext' edges by copying every node before recursing on any of them,
// 'dom' edges would *also* require recursing in dominator-tree preorder.
for (&old_ch, &new_ch) in node_map.iter() {
for inport in h.node_inputs(old_ch).collect::<Vec<_>>() {
for in_port in h.node_inputs(old_ch).collect::<Vec<_>>() {
// Edges from monomorphized functions to their calls already added during mono_scan()
// as these depend not just on the original FuncDefn but also the TypeArgs
if h.linked_outputs(new_ch, inport).next().is_some() {
if h.linked_outputs(new_ch, in_port).next().is_some() {
continue;
};
let srcs = h.linked_outputs(old_ch, inport).collect::<Vec<_>>();
let srcs = h.linked_outputs(old_ch, in_port).collect::<Vec<_>>();
for (src, outport) in srcs {
// Sources could be a mixture of within this polymorphic FuncDefn, and Static edges from outside
h.connect(
node_map.get(&src).copied().unwrap_or(src),
outport,
new_ch,
inport,
in_port,
);
}
}
Expand All @@ -240,6 +252,57 @@ fn instantiate(
mono_tgt
}

use crate::validation::{ValidatePassError, ValidationLevel};

/// Replaces calls to polymorphic functions with calls to new monomorphic
/// instantiations of the polymorphic ones.
///
/// If the Hugr is [Module](OpType::Module)-rooted,
/// * then the original polymorphic [FuncDefn]s are left untouched (including Calls inside them)
/// - call [remove_polyfuncs] when no other Hugr will be linked in that might instantiate these
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
/// * else, the originals are removed (they are invisible from outside the Hugr).
///
/// If the Hugr is [FuncDefn](OpType::FuncDefn)-rooted with polymorphic
/// signature then the HUGR will not be modified.
///
/// Monomorphic copies of polymorphic functions will be added to the HUGR as
/// children of the root node. We make best effort to ensure that names (derived
/// from parent function names and concrete type args) of new functions are unique
/// whenever the names of their parents are unique, but this is not guaranteed.
#[derive(Debug, Clone, Default)]
pub struct MonomorphizePass {
validation: ValidationLevel,
}

#[derive(Debug, Error)]
#[non_exhaustive]
/// Errors produced by [MonomorphizePass].
pub enum MonomorphizeError {
#[error(transparent)]
#[allow(missing_docs)]
ValidationError(#[from] ValidatePassError),
}

impl MonomorphizePass {
/// Sets the validation level used before and after the pass is run.
pub fn validation_level(mut self, level: ValidationLevel) -> Self {
self.validation = level;
self
}

/// Run the Monomorphization pass.
fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), MonomorphizeError> {
monomorphize_ref(hugr);
Ok(())
}

/// Run the pass using specified configuration.
pub fn run<H: HugrMut>(&self, hugr: &mut H) -> Result<(), MonomorphizeError> {
self.validation
.run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr))
}
}

struct TypeArgsList<'a>(&'a [TypeArg]);

impl std::fmt::Display for TypeArgsList<'_> {
Expand Down Expand Up @@ -322,7 +385,9 @@ mod test {
use hugr_core::{Hugr, HugrView, Node};
use rstest::rstest;

use super::{is_polymorphic, mangle_inner_func, mangle_name, monomorphize, remove_polyfuncs};
use crate::monomorphize::{remove_polyfuncs_ref, MonomorphizePass};

use super::{is_polymorphic, mangle_inner_func, mangle_name, remove_polyfuncs};

fn pair_type(ty: Type) -> Type {
Type::new_tuple(vec![ty.clone(), ty])
Expand All @@ -342,7 +407,8 @@ mod test {
DFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap();
let [i1] = dfg_builder.input_wires_arr();
let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap();
let hugr2 = monomorphize(hugr.clone());
let mut hugr2 = hugr.clone();
MonomorphizePass::default().run(&mut hugr2).unwrap();
assert_eq!(hugr, hugr2);
}

Expand Down Expand Up @@ -397,14 +463,15 @@ mod test {
let [res2] = fb.call(tr.handle(), &[pty], pair.outputs())?.outputs_arr();
fb.finish_with_outputs([res1, res2])?;
}
let hugr = mb.finish_hugr()?;
let mut hugr = mb.finish_hugr()?;
assert_eq!(
hugr.nodes()
.filter(|n| hugr.get_optype(*n).is_func_defn())
.count(),
3
);
let mono = monomorphize(hugr);
MonomorphizePass::default().run(&mut hugr)?;
let mono = hugr;
mono.validate()?;

let mut funcs = list_funcs(&mono);
Expand All @@ -423,8 +490,10 @@ mod test {
funcs.into_keys().sorted().collect_vec(),
["double", "main", "triple"]
);
let mut mono2 = mono.clone();
MonomorphizePass::default().run(&mut mono2)?;

assert_eq!(monomorphize(mono.clone()), mono); // Idempotent
assert_eq!(mono2, mono); // Idempotent

let nopoly = remove_polyfuncs(mono);
let mut funcs = list_funcs(&nopoly);
Expand Down Expand Up @@ -527,9 +596,10 @@ mod test {
.call(pf1.handle(), &[sa(n - 1)], [ar2_unwrapped])
.unwrap()
.outputs_arr();
let hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap();
let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap();

let mono_hugr = monomorphize(hugr);
MonomorphizePass::default().run(&mut hugr).unwrap();
let mono_hugr = hugr;
mono_hugr.validate().unwrap();
let funcs = list_funcs(&mono_hugr);
let pf2_name = mangle_inner_func("pf1", "pf2");
Expand Down Expand Up @@ -588,8 +658,9 @@ mod test {
.outputs_arr();
let mono = mono.finish_with_outputs([a, b]).unwrap();
let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap();
let hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap();
let mono_hugr = monomorphize(hugr);
let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap();
MonomorphizePass::default().run(&mut hugr)?;
let mono_hugr = hugr;

let mut funcs = list_funcs(&mono_hugr);
assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));
Expand All @@ -606,7 +677,7 @@ mod test {

#[test]
fn load_function() {
let hugr = {
let mut hugr = {
let mut module_builder = ModuleBuilder::new();
let foo = {
let builder = module_builder
Expand Down Expand Up @@ -645,9 +716,10 @@ mod test {
module_builder.finish_hugr().unwrap()
};

let mono_hugr = remove_polyfuncs(monomorphize(hugr));
MonomorphizePass::default().run(&mut hugr).unwrap();
remove_polyfuncs_ref(&mut hugr);

let funcs = list_funcs(&mono_hugr);
let funcs = list_funcs(&hugr);
assert!(funcs.values().all(|(_, fd)| !is_polymorphic(fd)));
}

Expand Down
Loading