diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index e260a89b9..1071de1ca 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,4 +9,5 @@ # The release PRs that trigger publication to crates.io or PyPI always modify the changelog. # We require those PRs to be approved by someone with release permissions. hugr/CHANGELOG.md @aborgna-q @ss2165 +hugr-passes/CHANGELOG.md @aborgna-q @ss2165 hugr-py/CHANGELOG.md @aborgna-q @ss2165 diff --git a/.github/change-filters.yml b/.github/change-filters.yml index 9fb256c3d..ebf6b45e7 100644 --- a/.github/change-filters.yml +++ b/.github/change-filters.yml @@ -3,6 +3,7 @@ rust: - "hugr/**" + - "hugr-passes/**" - "Cargo.toml" - "specification/schema/**" diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 120000 index 729dc652a..000000000 --- a/CHANGELOG.md +++ /dev/null @@ -1 +0,0 @@ -hugr/CHANGELOG.md \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index ad4eec3e5..b7adbd9b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,8 +3,8 @@ lto = "thin" [workspace] resolver = "2" -members = ["hugr"] -default-members = ["hugr"] +members = ["hugr", "hugr-passes"] +default-members = ["hugr", "hugr-passes"] [workspace.package] rust-version = "1.75" diff --git a/hugr-passes/CHANGELOG.md b/hugr-passes/CHANGELOG.md new file mode 100644 index 000000000..4818de914 --- /dev/null +++ b/hugr-passes/CHANGELOG.md @@ -0,0 +1,5 @@ +# Changelog + +## 0.1.0 (2024-05-23) + +Initial release, with functions ported from the `hugr::algorithms` module. diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml new file mode 100644 index 000000000..688d32e48 --- /dev/null +++ b/hugr-passes/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "hugr-passes" +version = "0.1.0" +edition = { workspace = true } +rust-version = { workspace = true } +license = { workspace = true } +readme = "README.md" +documentation = "https://docs.rs/hugr-passes/" +homepage = { workspace = true } +repository = { workspace = true } +description = "Compiler passes for Quantinuum's HUGR" +keywords = ["Quantum", "Quantinuum"] +categories = ["compilers"] + +[dependencies] +hugr = { path = "../hugr", version = "0.4.0" } +itertools = { workspace = true } +lazy_static = { workspace = true } +paste = { workspace = true } +thiserror = { workspace = true } + +[features] +extension_inference = ["hugr/extension_inference"] + +[dev-dependencies] +rstest = "0.19.0" diff --git a/hugr-passes/README.md b/hugr-passes/README.md new file mode 100644 index 000000000..bdb566181 --- /dev/null +++ b/hugr-passes/README.md @@ -0,0 +1,58 @@ +![](/hugr/assets/hugr_logo.svg) + +hugr-passes +=============== + +[![build_status][]](https://github.com/CQCL/hugr/actions) +[![crates][]](https://crates.io/crates/hugr-passes) +[![msrv][]](https://github.com/CQCL/hugr) +[![codecov][]](https://codecov.io/gh/CQCL/hugr) + +The Hierarchical Unified Graph Representation (HUGR, pronounced _hugger_) is the +common representation of quantum circuits and operations in the Quantinuum +ecosystem. + +It provides a high-fidelity representation of operations, that facilitates +compilation and encodes runnable programs. + +The HUGR specification is [here](https://github.com/CQCL/hugr/blob/main/specification/hugr.md). + +This crate provides compilation passes that act on HUGR programs. + +## Usage + +Add the dependency to your project: + +```bash +cargo add hugr-passes +``` + +Please read the [API documentation here][]. + +## Experimental Features + +- `extension_inference`: + Experimental feature which allows automatic inference of extension usages and + requirements in a HUGR and validation that extensions are correctly specified. + Not enabled by default. + +## Recent Changes + +See [CHANGELOG][] for a list of changes. The minimum supported rust +version will only change on major releases. + +## Development + +See [DEVELOPMENT.md](https://github.com/CQCL/hugr/blob/main/DEVELOPMENT.md) for instructions on setting up the development environment. + +## License + +This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). + + [API documentation here]: https://docs.rs/hugr-passes/ + [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main + [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [crates]: https://img.shields.io/crates/v/hugr-passes + [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov + [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE + [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-passes/CHANGELOG.md diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs new file mode 100644 index 000000000..4a124c1a4 --- /dev/null +++ b/hugr-passes/src/const_fold.rs @@ -0,0 +1,230 @@ +//! Constant folding routines. + +use std::collections::{BTreeSet, HashMap}; + +use itertools::Itertools; +use thiserror::Error; + +use hugr::hugr::{SimpleReplacementError, ValidationError}; +use hugr::types::SumType; +use hugr::Direction; +use hugr::{ + builder::{DFGBuilder, Dataflow, DataflowHugr}, + extension::{ConstFoldResult, ExtensionRegistry}, + hugr::{ + hugrmut::HugrMut, + rewrite::consts::{RemoveConst, RemoveLoadConstant}, + views::SiblingSubgraph, + }, + ops::{OpType, Value}, + type_row, + types::FunctionType, + utils::sorted_consts, + Hugr, HugrView, IncomingPort, Node, SimpleReplacement, +}; + +#[derive(Error, Debug)] +#[allow(missing_docs)] +pub enum ConstFoldError { + #[error("Failed to verify {label} HUGR: {err}")] + VerifyError { + label: String, + #[source] + err: ValidationError, + }, + #[error(transparent)] + SimpleReplaceError(#[from] SimpleReplacementError), +} + +/// Tag some output constants with [`OutgoingPort`] inferred from the ordering. +fn out_row(consts: impl IntoIterator) -> ConstFoldResult { + let vec = consts + .into_iter() + .enumerate() + .map(|(i, c)| (i.into(), c)) + .collect(); + Some(vec) +} + +/// For a given op and consts, attempt to evaluate the op. +pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldResult { + let fold_result = match op { + OpType::Noop { .. } => out_row([consts.first()?.1.clone()]), + OpType::MakeTuple { .. } => { + out_row([Value::tuple(sorted_consts(consts).into_iter().cloned())]) + } + OpType::UnpackTuple { .. } => { + let c = &consts.first()?.1; + let Value::Tuple { vs } = c else { + panic!("This op always takes a Tuple input."); + }; + out_row(vs.iter().cloned()) + } + + OpType::Tag(t) => out_row([Value::sum( + t.tag, + consts.iter().map(|(_, konst)| konst.clone()), + SumType::new(t.variants.clone()), + ) + .unwrap()]), + OpType::CustomOp(op) => { + let ext_op = op.as_extension_op()?; + ext_op.constant_fold(consts) + } + _ => None, + }; + debug_assert!(fold_result.as_ref().map_or(true, |x| x.len() + == op.value_port_count(Direction::Outgoing))); + fold_result +} + +/// Generate a graph that loads and outputs `consts` in order, validating +/// against `reg`. +fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { + let const_types = consts.iter().map(Value::get_type).collect_vec(); + let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); + + let outputs = consts + .into_iter() + .map(|c| b.add_load_const(c)) + .collect_vec(); + + b.finish_hugr_with_outputs(outputs, reg).unwrap() +} + +/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, +/// return an iterator of possible constant folding rewrites. The +/// [`SimpleReplacement`] replaces an operation with constants that result from +/// evaluating it, the extension registry `reg` is used to validate the +/// replacement HUGR. The vector of [`RemoveLoadConstant`] refer to the +/// LoadConstant nodes that could be removed - they are not automatically +/// removed as they may be used by other operations. +pub fn find_consts<'a, 'r: 'a>( + hugr: &'a impl HugrView, + candidate_nodes: impl IntoIterator + 'a, + reg: &'r ExtensionRegistry, +) -> impl Iterator)> + 'a { + // track nodes for operations that have already been considered for folding + let mut used_neighbours = BTreeSet::new(); + + candidate_nodes + .into_iter() + .filter_map(move |n| { + // only look at LoadConstant + hugr.get_optype(n).is_load_constant().then_some(())?; + + let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; + let neighbours = hugr + .linked_inputs(n, out_p) + .filter(|(n, _)| used_neighbours.insert(*n)) + .collect_vec(); + if neighbours.is_empty() { + // no uses of LoadConstant that haven't already been considered. + return None; + } + let fold_iter = neighbours + .into_iter() + .filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg)); + Some(fold_iter) + }) + .flatten() +} + +/// Attempt to evaluate and generate rewrites for the operation at `op_node` +fn fold_op( + hugr: &impl HugrView, + op_node: Node, + reg: &ExtensionRegistry, +) -> Option<(SimpleReplacement, Vec)> { + // only support leaf folding for now. + let neighbour_op = hugr.get_optype(op_node); + let (in_consts, removals): (Vec<_>, Vec<_>) = hugr + .node_inputs(op_node) + .filter_map(|in_p| { + let (con_op, load_n) = get_const(hugr, op_node, in_p)?; + Some(((in_p, con_op), RemoveLoadConstant(load_n))) + }) + .unzip(); + // attempt to evaluate op + let (nu_out, consts): (HashMap<_, _>, Vec<_>) = fold_leaf_op(neighbour_op, &in_consts)? + .into_iter() + .enumerate() + .filter_map(|(i, (op_out, konst))| { + // for each used port of the op give the nu_out entry and the + // corresponding Value + hugr.single_linked_input(op_node, op_out) + .map(|np| ((np, i.into()), konst)) + }) + .unzip(); + let replacement = const_graph(consts, reg); + let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) + .expect("Operation should form valid subgraph."); + + let simple_replace = SimpleReplacement::new( + sibling_graph, + replacement, + // no inputs to replacement + HashMap::new(), + nu_out, + ); + Some((simple_replace, removals)) +} + +/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant +/// and the LoadConstant node +fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Value, Node)> { + let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; + let load_op = hugr.get_optype(load_n).as_load_constant()?; + let const_node = hugr + .single_linked_output(load_n, load_op.constant_port())? + .0; + let const_op = hugr.get_optype(const_node).as_const()?; + + // TODO avoid const clone here + Some((const_op.as_ref().clone(), load_n)) +} + +/// Exhaustively apply constant folding to a HUGR. +pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { + #[cfg(test)] + let verify = |label, h: &H| { + h.validate_no_extensions(reg).unwrap_or_else(|err| { + panic!( + "constant_fold_pass: failed to verify {label} HUGR: {err}\n{}", + h.mermaid_string() + ) + }) + }; + #[cfg(test)] + verify("input", h); + loop { + // We can only safely apply a single replacement. Applying a + // replacement removes nodes and edges which may be referenced by + // further replacements returned by find_consts. Even worse, if we + // attempted to apply those replacements, expecting them to fail if + // the nodes and edges they reference had been deleted, they may + // succeed because new nodes and edges reused the ids. + // + // We could be a lot smarter here, keeping track of `LoadConstant` + // nodes and only looking at their out neighbours. + let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else { + break; + }; + h.apply_rewrite(replace).unwrap(); + for rem in removes { + // We are optimistically applying these [RemoveLoadConstant] and + // [RemoveConst] rewrites without checking whether the nodes + // they attempt to remove have remaining uses. If they do, then + // the rewrite fails and we move on. + if let Ok(const_node) = h.apply_rewrite(rem) { + // if the LoadConst was removed, try removing the Const too. + let _ = h.apply_rewrite(RemoveConst(const_node)); + } + } + } + #[cfg(test)] + verify("output", h); +} + +#[cfg(test)] +mod test; diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs similarity index 79% rename from hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs rename to hugr-passes/src/const_fold/test.rs index 959240a51..07b881bc7 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1,18 +1,339 @@ -use crate::algorithm::const_fold::constant_fold_pass; -use crate::builder::{DFGBuilder, Dataflow, DataflowHugr}; -use crate::extension::prelude::{sum_with_error, ConstError, ConstString, BOOL_T, STRING_TYPE}; -use crate::extension::{ExtensionRegistry, PRELUDE}; -use crate::ops::Value; -use crate::std_extensions::arithmetic; -use crate::std_extensions::arithmetic::int_ops::IntOpDef; -use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; -use crate::std_extensions::logic::{self, NaryLogic}; -use crate::type_row; -use crate::types::{FunctionType, Type, TypeRow}; -use crate::utils::test::assert_fully_folded; +use crate::const_fold::constant_fold_pass; +use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr}; +use hugr::extension::prelude::{sum_with_error, ConstError, ConstString, BOOL_T, STRING_TYPE}; +use hugr::extension::{ExtensionRegistry, PRELUDE}; +use hugr::ops::Value; +use hugr::std_extensions::arithmetic; +use hugr::std_extensions::arithmetic::int_ops::IntOpDef; +use hugr::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; +use hugr::std_extensions::logic::{self, NaryLogic, NotOp}; +use hugr::type_row; +use hugr::types::{FunctionType, Type, TypeRow}; use rstest::rstest; +use lazy_static::lazy_static; + +use super::*; +use hugr::builder::Container; +use hugr::ops::{OpType, UnpackTuple}; +use hugr::std_extensions::arithmetic::conversions::ConvertOpDef; +use hugr::std_extensions::arithmetic::float_ops::FloatOps; +use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; + +/// Check that a hugr just loads and returns a single expected constant. +pub fn assert_fully_folded(h: &Hugr, expected_value: &Value) { + assert_fully_folded_with(h, |v| v == expected_value) +} + +/// Check that a hugr just loads and returns a single constant, and validate +/// that constant using `check_value`. +/// +/// [CustomConst::equals_const] is not required to be implemented. Use this +/// function for Values containing such a `CustomConst`. +fn assert_fully_folded_with(h: &Hugr, check_value: impl Fn(&Value) -> bool) { + let mut node_count = 0; + + for node in h.children(h.root()) { + let op = h.get_optype(node); + match op { + OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, + OpType::Const(c) if check_value(c.value()) => node_count += 1, + _ => panic!("unexpected op: {:?}\n{}", op, h.mermaid_string()), + } + } + + assert_eq!(node_count, 4); +} + +/// int to constant +fn i2c(b: u64) -> Value { + Value::extension(ConstInt::new_u(5, b).unwrap()) +} + +/// float to constant +fn f2c(f: f64) -> Value { + ConstF64::new(f).into() +} + +#[rstest] +#[case(0.0, 0.0, 0.0)] +#[case(0.0, 1.0, 1.0)] +#[case(23.5, 435.5, 459.0)] +// c = a + b +fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { + let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; + let add_op: OpType = FloatOps::fadd.into(); + let outs = fold_leaf_op(&add_op, &consts) + .unwrap() + .into_iter() + .map(|(p, v)| (p, v.get_custom_value::().unwrap().value())) + .collect_vec(); + + assert_eq!(outs.as_slice(), &[(0.into(), c)]); +} +#[test] +fn test_big() { + /* + Test approximately calculates + let x = (5.6, 3.2); + int(x.0 - x.1) == 2 + */ + let sum_type = sum_with_error(INT_TYPES[5].to_owned()); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![sum_type.clone().into()], + )) + .unwrap(); + + let tup = build.add_load_const(Value::tuple([f2c(5.6), f2c(3.2)])); + + let unpack = build + .add_dataflow_op( + UnpackTuple::new(type_row![FLOAT64_TYPE, FLOAT64_TYPE]), + [tup], + ) + .unwrap(); + + let sub = build + .add_dataflow_op(FloatOps::fsub, unpack.outputs()) + .unwrap(); + let to_int = build + .add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs()) + .unwrap(); + + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + arithmetic::float_ops::EXTENSION.to_owned(), + arithmetic::conversions::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build + .finish_hugr_with_outputs(to_int.outputs(), ®) + .unwrap(); + assert_eq!(h.node_count(), 8); + + constant_fold_pass(&mut h, ®); + + let expected = Value::sum(0, [i2c(2).clone()], sum_type).unwrap(); + assert_fully_folded(&h, &expected); +} + +#[test] +#[cfg_attr( + feature = "extension_inference", + ignore = "inference fails for test graph, it shouldn't" +)] +fn test_list_ops() -> Result<(), Box> { + use hugr::std_extensions::collections::{self, ListOp, ListValue}; + + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + collections::EXTENSION.to_owned(), + ]) + .unwrap(); + let list: Value = ListValue::new(BOOL_T, [Value::unit_sum(0, 1).unwrap()]).into(); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![list.get_type().clone()], + )) + .unwrap(); + + let list_wire = build.add_load_const(list.clone()); + + let pop = build.add_dataflow_op( + ListOp::Pop.with_type(BOOL_T).to_extension_op(®).unwrap(), + [list_wire], + )?; + + let push = build.add_dataflow_op( + ListOp::Push + .with_type(BOOL_T) + .to_extension_op(®) + .unwrap(), + pop.outputs(), + )?; + let mut h = build.finish_hugr_with_outputs(push.outputs(), ®)?; + constant_fold_pass(&mut h, ®); + + assert_fully_folded(&h, &list); + Ok(()) +} + +#[test] +fn test_fold_and() { + // pseudocode: + // x0, x1 := bool(true), bool(true) + // x2 := and(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_load_const(Value::true_val()); + let x2 = build + .add_dataflow_op(NaryLogic::And.with_n_inputs(2), [x0, x1]) + .unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_or() { + // pseudocode: + // x0, x1 := bool(true), bool(false) + // x2 := or(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_load_const(Value::false_val()); + let x2 = build + .add_dataflow_op(NaryLogic::Or.with_n_inputs(2), [x0, x1]) + .unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_not() { + // pseudocode: + // x0 := bool(true) + // x1 := not(x0) + // output x1 == false; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_dataflow_op(NotOp, [x0]).unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::false_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn orphan_output() { + // pseudocode: + // x0 := bool(true) + // x1 := not(x0) + // x2 := or(x0,x1) + // output x2 == true; + // + // We arange things so that the `or` folds away first, leaving the not + // with no outputs. + use hugr::hugr::NodeType; + use hugr::ops::handle::NodeHandle; + + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let true_wire = build.add_load_value(Value::true_val()); + // this Not will be manually replaced + let orig_not = build.add_dataflow_op(NotOp, [true_wire]).unwrap(); + let r = build + .add_dataflow_op( + NaryLogic::Or.with_n_inputs(2), + [true_wire, orig_not.out_wire(0)], + ) + .unwrap(); + let or_node = r.node(); + let parent = build.container_node(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); + + // we delete the original Not and create a new One. This means it will be + // traversed by `constant_fold_pass` after the Or. + let new_not = h.add_node_with_parent(parent, NodeType::new_auto(NotOp)); + h.connect(true_wire.node(), true_wire.source(), new_not, 0); + h.disconnect(or_node, IncomingPort::from(1)); + h.connect(new_not, 0, or_node, 1); + h.remove_node(orig_not.node()); + constant_fold_pass(&mut h, ®); + assert_fully_folded(&h, &Value::true_val()) +} + +#[test] +fn test_folding_pass_issue_996() { + // pseudocode: + // + // x0 := 3.0 + // x1 := 4.0 + // x2 := fne(x0, x1); // true + // x3 := flt(x0, x1); // true + // x4 := and(x2, x3); // true + // x5 := -10.0 + // x6 := flt(x0, x5) // false + // x7 := or(x4, x6) // true + // output x7 + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0))); + let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); + let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap(); + let x4 = build + .add_dataflow_op( + NaryLogic::And.with_n_inputs(2), + x2.outputs().chain(x3.outputs()), + ) + .unwrap(); + let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0))); + let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap(); + let x7 = build + .add_dataflow_op( + NaryLogic::Or.with_n_inputs(2), + x4.outputs().chain(x6.outputs()), + ) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_const_fold_to_nonfinite() { + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::float_types::EXTENSION.to_owned(), + ]) + .unwrap(); + + // HUGR computing 1.0 / 1.0 + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); + let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h0, ®); + assert_fully_folded_with(&h0, |v| { + v.get_custom_value::().unwrap().value() == 1.0 + }); + assert_eq!(h0.node_count(), 5); + + // HUGR computing 1.0 / 0.0 + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); + let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); + let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); + let mut h1 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h1, ®); + assert_eq!(h1.node_count(), 8); +} + #[test] fn test_fold_iwiden_u() { // pseudocode: @@ -109,10 +430,17 @@ fn test_fold_inarrow, E: std::fmt::Debug>( .unwrap(); let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); constant_fold_pass(&mut h, ®); + lazy_static! { + static ref INARROW_ERROR_VALUE: Value = ConstError { + signal: 0, + message: "Integer too large to narrow".to_string(), + } + .into(); + } let expected = if succeeds { Value::sum(0, [mk_const(to_log_width, val).unwrap().into()], sum_type).unwrap() } else { - Value::sum(1, [super::INARROW_ERROR_VALUE.clone()], sum_type).unwrap() + Value::sum(1, [INARROW_ERROR_VALUE.clone()], sum_type).unwrap() }; assert_fully_folded(&h, &expected); } diff --git a/hugr/src/algorithm/half_node.rs b/hugr-passes/src/half_node.rs similarity index 96% rename from hugr/src/algorithm/half_node.rs rename to hugr-passes/src/half_node.rs index 4edd35d34..cb9c6e55e 100644 --- a/hugr/src/algorithm/half_node.rs +++ b/hugr-passes/src/half_node.rs @@ -2,12 +2,12 @@ use std::hash::Hash; use super::nest_cfgs::CfgNodeMap; -use crate::hugr::RootTagged; +use hugr::hugr::RootTagged; -use crate::ops::handle::CfgID; -use crate::ops::{OpTag, OpTrait}; +use hugr::ops::handle::CfgID; +use hugr::ops::{OpTag, OpTrait}; -use crate::{Direction, Node}; +use hugr::{Direction, Node}; /// We provide a view of a cfg where every node has at most one of /// (multiple predecessors, multiple successors). @@ -97,9 +97,9 @@ impl> CfgNodeMap for HalfNodeView mod test { use super::super::nest_cfgs::{test::*, EdgeClassifier}; use super::{HalfNode, HalfNodeView}; - use crate::builder::BuildError; - use crate::hugr::views::RootChecked; - use crate::ops::handle::NodeHandle; + use hugr::builder::BuildError; + use hugr::hugr::views::RootChecked; + use hugr::ops::handle::NodeHandle; use itertools::Itertools; use std::collections::HashSet; diff --git a/hugr/src/algorithm.rs b/hugr-passes/src/lib.rs similarity index 52% rename from hugr/src/algorithm.rs rename to hugr-passes/src/lib.rs index 585e25a01..803196144 100644 --- a/hugr/src/algorithm.rs +++ b/hugr-passes/src/lib.rs @@ -1,4 +1,4 @@ -//! Algorithms using the Hugr. +//! Compilation passes acting on the HUGR program representation. pub mod const_fold; mod half_node; diff --git a/hugr/src/algorithm/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs similarity index 90% rename from hugr/src/algorithm/merge_bbs.rs rename to hugr-passes/src/merge_bbs.rs index 06d3b3bfe..17adc4e57 100644 --- a/hugr/src/algorithm/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -2,15 +2,16 @@ //! and the target BB has no other predecessors. use std::collections::HashMap; +use hugr::hugr::hugrmut::HugrMut; use itertools::Itertools; -use crate::hugr::rewrite::inline_dfg::InlineDFG; -use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; -use crate::hugr::{HugrMut, RootTagged}; -use crate::ops::handle::CfgID; -use crate::ops::leaf::UnpackTuple; -use crate::ops::{DataflowBlock, DataflowParent, Input, Output, DFG}; -use crate::{Hugr, HugrView, Node}; +use hugr::hugr::rewrite::inline_dfg::InlineDFG; +use hugr::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; +use hugr::hugr::RootTagged; +use hugr::ops::handle::CfgID; +use hugr::ops::leaf::UnpackTuple; +use hugr::ops::{DataflowBlock, DataflowParent, Input, Output, DFG}; +use hugr::{Hugr, HugrView, Node}; /// Merge any basic blocks that are direct children of the specified CFG /// i.e. where a basic block B has a single successor B' whose only predecessor @@ -52,7 +53,10 @@ fn mk_rep( let pred_ty = cfg.get_optype(pred).as_dataflow_block().unwrap(); let succ_ty = cfg.get_optype(succ).as_dataflow_block().unwrap(); let succ_sig = succ_ty.inner_signature(); + + // Make a Hugr with just a single CFG root node having the same signature. let mut replacement: Hugr = Hugr::new(cfg.root_type().clone()); + let merged = replacement.add_node_with_parent(replacement.root(), { let mut merged_block = DataflowBlock { inputs: pred_ty.inputs.clone(), @@ -98,12 +102,7 @@ fn mk_rep( // At the junction, must unpack the first (tuple, branch predicate) output let tuple_elems = pred_ty.sum_rows.clone().into_iter().exactly_one().unwrap(); - let unp = replacement.add_node_with_parent( - merged, - UnpackTuple { - tys: tuple_elems.clone(), - }, - ); + let unp = replacement.add_node_with_parent(merged, UnpackTuple::new(tuple_elems.clone())); replacement.connect(dfg1, 0, unp, 0); let other_start = tuple_elems.len(); for (i, _) in tuple_elems.iter().enumerate() { @@ -162,15 +161,15 @@ mod test { use itertools::Itertools; use rstest::rstest; - use crate::builder::{CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; - use crate::extension::prelude::{ConstUsize, PRELUDE_ID, QB_T, USIZE_T}; - use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE, PRELUDE_REGISTRY}; - use crate::hugr::views::sibling::SiblingMut; - use crate::ops::constant::Value; - use crate::ops::handle::CfgID; - use crate::ops::{Lift, LoadConstant, Noop, OpTrait, OpType}; - use crate::types::{FunctionType, Type, TypeRow}; - use crate::{const_extension_ids, type_row, Extension, Hugr, HugrView, Wire}; + use hugr::builder::{CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; + use hugr::extension::prelude::{ConstUsize, PRELUDE_ID, QB_T, USIZE_T}; + use hugr::extension::{ExtensionRegistry, ExtensionSet, PRELUDE, PRELUDE_REGISTRY}; + use hugr::hugr::views::sibling::SiblingMut; + use hugr::ops::constant::Value; + use hugr::ops::handle::CfgID; + use hugr::ops::{Lift, LoadConstant, Noop, OpTrait, OpType}; + use hugr::types::{FunctionType, Type, TypeRow}; + use hugr::{const_extension_ids, type_row, Extension, Hugr, HugrView, Wire}; use super::merge_basic_blocks; @@ -198,13 +197,7 @@ mod test { fn lifted_unary_unit_sum + AsRef, T>(b: &mut DFGWrapper) -> Wire { let lc = b.add_load_value(Value::unary_unit_sum()); let lift = b - .add_dataflow_op( - Lift { - type_row: Type::new_unit_sum(1).into(), - new_extension: PRELUDE_ID, - }, - [lc], - ) + .add_dataflow_op(Lift::new(Type::new_unit_sum(1).into(), PRELUDE_ID), [lc]) .unwrap(); let [w] = lift.outputs_arr(); w @@ -235,7 +228,7 @@ mod test { .with_extension_delta(ExtensionSet::singleton(&PRELUDE_ID)), )?; let mut no_b1 = h.simple_entry_builder(loop_variants.clone(), 1, ExtensionSet::new())?; - let n = no_b1.add_dataflow_op(Noop { ty: QB_T }, no_b1.input_wires())?; + let n = no_b1.add_dataflow_op(Noop::new(QB_T), no_b1.input_wires())?; let br = lifted_unary_unit_sum(&mut no_b1); let no_b1 = no_b1.finish_with_outputs(br, n.outputs())?; let mut test_block = h.block_builder( @@ -254,7 +247,7 @@ mod test { no_b1 } else { let mut no_b2 = h.simple_block_builder(FunctionType::new_endo(loop_variants), 1)?; - let n = no_b2.add_dataflow_op(Noop { ty: QB_T }, no_b2.input_wires())?; + let n = no_b2.add_dataflow_op(Noop::new(QB_T), no_b2.input_wires())?; let br = lifted_unary_unit_sum(&mut no_b2); let nid = no_b2.finish_with_outputs(br, n.outputs())?; h.branch(&nid, 0, &no_b1)?; diff --git a/hugr/src/algorithm/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs similarity index 96% rename from hugr/src/algorithm/nest_cfgs.rs rename to hugr-passes/src/nest_cfgs.rs index feae6470b..4231e2583 100644 --- a/hugr/src/algorithm/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -44,14 +44,14 @@ use std::hash::Hash; use itertools::Itertools; use thiserror::Error; -use crate::hugr::rewrite::outline_cfg::OutlineCfg; -use crate::hugr::views::sibling::SiblingMut; -use crate::hugr::views::{HierarchyView, HugrView, SiblingGraph}; -use crate::hugr::{HugrMut, Rewrite, RootTagged}; -use crate::ops::handle::{BasicBlockID, CfgID}; -use crate::ops::OpTag; -use crate::ops::OpTrait; -use crate::{Direction, Hugr, Node}; +use hugr::hugr::rewrite::outline_cfg::OutlineCfg; +use hugr::hugr::views::sibling::SiblingMut; +use hugr::hugr::views::{HierarchyView, HugrView, SiblingGraph}; +use hugr::hugr::{hugrmut::HugrMut, Rewrite, RootTagged}; +use hugr::ops::handle::{BasicBlockID, CfgID}; +use hugr::ops::OpTag; +use hugr::ops::OpTrait; +use hugr::{Direction, Hugr, Node}; /// A "view" of a CFG in a Hugr which allows basic blocks in the underlying CFG to be split into /// multiple blocks in the view (or merged together). @@ -574,15 +574,17 @@ impl EdgeClassifier { #[cfg(test)] pub(crate) mod test { use super::*; - use crate::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder}; - use crate::extension::PRELUDE_REGISTRY; - use crate::extension::{prelude::USIZE_T, ExtensionSet}; - - use crate::hugr::views::RootChecked; - use crate::ops::handle::{ConstID, NodeHandle}; - use crate::ops::Value; - use crate::type_row; - use crate::types::{FunctionType, Type}; + use hugr::builder::{BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder}; + use hugr::extension::PRELUDE_REGISTRY; + use hugr::extension::{prelude::USIZE_T, ExtensionSet}; + + use hugr::hugr::rewrite::insert_identity::{IdentityInsertion, IdentityInsertionError}; + use hugr::hugr::views::RootChecked; + use hugr::ops::handle::{ConstID, NodeHandle}; + use hugr::ops::Value; + use hugr::type_row; + use hugr::types::{EdgeKind, FunctionType, Type}; + use hugr::utils::depth; const NAT: Type = USIZE_T; pub fn group_by(h: HashMap) -> HashSet> { @@ -814,6 +816,25 @@ pub(crate) mod test { Ok(()) } + #[test] + fn incorrect_insertion() { + let (mut h, _, tail) = build_conditional_in_loop_cfg(false).unwrap(); + + let final_node = tail.node(); + + let final_node_input = h.node_inputs(final_node).next().unwrap(); + + let rw = IdentityInsertion::new(final_node, final_node_input); + + let apply_result = h.apply_rewrite(rw); + assert_eq!( + apply_result, + Err(IdentityInsertionError::InvalidPortKind(Some( + EdgeKind::ControlFlow + ))) + ); + } + fn n_identity( mut dataflow_builder: T, pred_const: &ConstID, @@ -936,11 +957,4 @@ pub(crate) mod test { Ok((head, tail)) } - - pub fn depth(h: &Hugr, n: Node) -> u32 { - match h.get_parent(n) { - Some(p) => 1 + depth(h, p), - None => 0, - } - } } diff --git a/hugr/README.md b/hugr/README.md index 96c823016..88a61035b 100644 --- a/hugr/README.md +++ b/hugr/README.md @@ -53,4 +53,4 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [crates]: https://img.shields.io/crates/v/hugr [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE - [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/CHANGELOG.md + [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr/CHANGELOG.md diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs deleted file mode 100644 index 5d4cffed3..000000000 --- a/hugr/src/algorithm/const_fold.rs +++ /dev/null @@ -1,551 +0,0 @@ -//! Constant folding routines. - -use std::collections::{BTreeSet, HashMap}; - -use itertools::Itertools; -use thiserror::Error; - -use crate::hugr::{SimpleReplacementError, ValidationError}; -use crate::types::SumType; -use crate::Direction; -use crate::{ - builder::{DFGBuilder, Dataflow, DataflowHugr}, - extension::{ConstFoldResult, ExtensionRegistry}, - hugr::{ - rewrite::consts::{RemoveConst, RemoveLoadConstant}, - views::SiblingSubgraph, - HugrMut, - }, - ops::{OpType, Value}, - type_row, - types::FunctionType, - Hugr, HugrView, IncomingPort, Node, SimpleReplacement, -}; - -#[derive(Error, Debug)] -#[allow(missing_docs)] -pub enum ConstFoldError { - #[error("Failed to verify {label} HUGR: {err}")] - VerifyError { - label: String, - #[source] - err: ValidationError, - }, - #[error(transparent)] - SimpleReplaceError(#[from] SimpleReplacementError), -} - -/// Tag some output constants with [`OutgoingPort`] inferred from the ordering. -fn out_row(consts: impl IntoIterator) -> ConstFoldResult { - let vec = consts - .into_iter() - .enumerate() - .map(|(i, c)| (i.into(), c)) - .collect(); - Some(vec) -} - -/// Sort folding inputs with [`IncomingPort`] as key -fn sort_by_in_port(consts: &[(IncomingPort, Value)]) -> Vec<&(IncomingPort, Value)> { - let mut v: Vec<_> = consts.iter().collect(); - v.sort_by_key(|(i, _)| i); - v -} - -/// Sort some input constants by port and just return the constants. -pub(crate) fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> { - sort_by_in_port(consts) - .into_iter() - .map(|(_, c)| c) - .collect() -} - -/// For a given op and consts, attempt to evaluate the op. -pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldResult { - let fold_result = match op { - OpType::Noop { .. } => out_row([consts.first()?.1.clone()]), - OpType::MakeTuple { .. } => { - out_row([Value::tuple(sorted_consts(consts).into_iter().cloned())]) - } - OpType::UnpackTuple { .. } => { - let c = &consts.first()?.1; - let Value::Tuple { vs } = c else { - panic!("This op always takes a Tuple input."); - }; - out_row(vs.iter().cloned()) - } - - OpType::Tag(t) => out_row([Value::sum( - t.tag, - consts.iter().map(|(_, konst)| konst.clone()), - SumType::new(t.variants.clone()), - ) - .unwrap()]), - OpType::CustomOp(op) => { - let ext_op = op.as_extension_op()?; - ext_op.constant_fold(consts) - } - _ => None, - }; - debug_assert!(fold_result.as_ref().map_or(true, |x| x.len() - == op.value_port_count(Direction::Outgoing))); - fold_result -} - -/// Generate a graph that loads and outputs `consts` in order, validating -/// against `reg`. -fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { - let const_types = consts.iter().map(Value::get_type).collect_vec(); - let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); - - let outputs = consts - .into_iter() - .map(|c| b.add_load_const(c)) - .collect_vec(); - - b.finish_hugr_with_outputs(outputs, reg).unwrap() -} - -/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, -/// return an iterator of possible constant folding rewrites. The -/// [`SimpleReplacement`] replaces an operation with constants that result from -/// evaluating it, the extension registry `reg` is used to validate the -/// replacement HUGR. The vector of [`RemoveLoadConstant`] refer to the -/// LoadConstant nodes that could be removed - they are not automatically -/// removed as they may be used by other operations. -pub fn find_consts<'a, 'r: 'a>( - hugr: &'a impl HugrView, - candidate_nodes: impl IntoIterator + 'a, - reg: &'r ExtensionRegistry, -) -> impl Iterator)> + 'a { - // track nodes for operations that have already been considered for folding - let mut used_neighbours = BTreeSet::new(); - - candidate_nodes - .into_iter() - .filter_map(move |n| { - // only look at LoadConstant - hugr.get_optype(n).is_load_constant().then_some(())?; - - let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; - let neighbours = hugr - .linked_inputs(n, out_p) - .filter(|(n, _)| used_neighbours.insert(*n)) - .collect_vec(); - if neighbours.is_empty() { - // no uses of LoadConstant that haven't already been considered. - return None; - } - let fold_iter = neighbours - .into_iter() - .filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg)); - Some(fold_iter) - }) - .flatten() -} - -/// Attempt to evaluate and generate rewrites for the operation at `op_node` -fn fold_op( - hugr: &impl HugrView, - op_node: Node, - reg: &ExtensionRegistry, -) -> Option<(SimpleReplacement, Vec)> { - // only support leaf folding for now. - let neighbour_op = hugr.get_optype(op_node); - let (in_consts, removals): (Vec<_>, Vec<_>) = hugr - .node_inputs(op_node) - .filter_map(|in_p| { - let (con_op, load_n) = get_const(hugr, op_node, in_p)?; - Some(((in_p, con_op), RemoveLoadConstant(load_n))) - }) - .unzip(); - // attempt to evaluate op - let (nu_out, consts): (HashMap<_, _>, Vec<_>) = fold_leaf_op(neighbour_op, &in_consts)? - .into_iter() - .enumerate() - .filter_map(|(i, (op_out, konst))| { - // for each used port of the op give the nu_out entry and the - // corresponding Value - hugr.single_linked_input(op_node, op_out) - .map(|np| ((np, i.into()), konst)) - }) - .unzip(); - let replacement = const_graph(consts, reg); - let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) - .expect("Operation should form valid subgraph."); - - let simple_replace = SimpleReplacement::new( - sibling_graph, - replacement, - // no inputs to replacement - HashMap::new(), - nu_out, - ); - Some((simple_replace, removals)) -} - -/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant -/// and the LoadConstant node -fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Value, Node)> { - let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; - let load_op = hugr.get_optype(load_n).as_load_constant()?; - let const_node = hugr - .single_linked_output(load_n, load_op.constant_port())? - .0; - let const_op = hugr.get_optype(const_node).as_const()?; - - // TODO avoid const clone here - Some((const_op.as_ref().clone(), load_n)) -} - -/// Exhaustively apply constant folding to a HUGR. -pub fn constant_fold_pass(h: &mut H, reg: &ExtensionRegistry) { - #[cfg(test)] - let verify = |label, h: &H| { - h.validate_no_extensions(reg).unwrap_or_else(|err| { - panic!( - "constant_fold_pass: failed to verify {label} HUGR: {err}\n{}", - h.mermaid_string() - ) - }) - }; - #[cfg(test)] - verify("input", h); - loop { - // We can only safely apply a single replacement. Applying a - // replacement removes nodes and edges which may be referenced by - // further replacements returned by find_consts. Even worse, if we - // attempted to apply those replacements, expecting them to fail if - // the nodes and edges they reference had been deleted, they may - // succeed because new nodes and edges reused the ids. - // - // We could be a lot smarter here, keeping track of `LoadConstant` - // nodes and only looking at their out neighbours. - let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else { - break; - }; - h.apply_rewrite(replace).unwrap(); - for rem in removes { - // We are optimistically applying these [RemoveLoadConstant] and - // [RemoveConst] rewrites without checking whether the nodes - // they attempt to remove have remaining uses. If they do, then - // the rewrite fails and we move on. - if let Ok(const_node) = h.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - let _ = h.apply_rewrite(RemoveConst(const_node)); - } - } - } - #[cfg(test)] - verify("output", h); -} - -#[cfg(test)] -mod test { - - use super::*; - use crate::extension::prelude::{sum_with_error, BOOL_T}; - use crate::extension::{ExtensionRegistry, PRELUDE}; - use crate::ops::{OpType, UnpackTuple}; - use crate::std_extensions::arithmetic; - use crate::std_extensions::arithmetic::conversions::ConvertOpDef; - use crate::std_extensions::arithmetic::float_ops::FloatOps; - use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; - use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; - use crate::std_extensions::logic::{self, NaryLogic, NotOp}; - use crate::utils::test::{assert_fully_folded, assert_fully_folded_with}; - - use rstest::rstest; - - /// int to constant - fn i2c(b: u64) -> Value { - Value::extension(ConstInt::new_u(5, b).unwrap()) - } - - /// float to constant - fn f2c(f: f64) -> Value { - ConstF64::new(f).into() - } - - #[rstest] - #[case(0.0, 0.0, 0.0)] - #[case(0.0, 1.0, 1.0)] - #[case(23.5, 435.5, 459.0)] - // c = a + b - fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { - let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; - let add_op: OpType = FloatOps::fadd.into(); - let outs = fold_leaf_op(&add_op, &consts) - .unwrap() - .into_iter() - .map(|(p, v)| (p, v.get_custom_value::().unwrap().value())) - .collect_vec(); - - assert_eq!(outs.as_slice(), &[(0.into(), c)]); - } - #[test] - fn test_big() { - /* - Test approximately calculates - let x = (5.6, 3.2); - int(x.0 - x.1) == 2 - */ - let sum_type = sum_with_error(INT_TYPES[5].to_owned()); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![sum_type.clone().into()], - )) - .unwrap(); - - let tup = build.add_load_const(Value::tuple([f2c(5.6), f2c(3.2)])); - - let unpack = build - .add_dataflow_op( - UnpackTuple { - tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE], - }, - [tup], - ) - .unwrap(); - - let sub = build - .add_dataflow_op(FloatOps::fsub, unpack.outputs()) - .unwrap(); - let to_int = build - .add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs()) - .unwrap(); - - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - arithmetic::float_ops::EXTENSION.to_owned(), - arithmetic::conversions::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build - .finish_hugr_with_outputs(to_int.outputs(), ®) - .unwrap(); - assert_eq!(h.node_count(), 8); - - constant_fold_pass(&mut h, ®); - - let expected = Value::sum(0, [i2c(2).clone()], sum_type).unwrap(); - assert_fully_folded(&h, &expected); - } - - #[test] - #[cfg_attr( - feature = "extension_inference", - ignore = "inference fails for test graph, it shouldn't" - )] - fn test_list_ops() -> Result<(), Box> { - use crate::std_extensions::collections::{self, ListOp, ListValue}; - - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - collections::EXTENSION.to_owned(), - ]) - .unwrap(); - let list: Value = ListValue::new(BOOL_T, [Value::unit_sum(0, 1).unwrap()]).into(); - let mut build = DFGBuilder::new(FunctionType::new( - type_row![], - vec![list.get_type().clone()], - )) - .unwrap(); - - let list_wire = build.add_load_const(list.clone()); - - let pop = build.add_dataflow_op( - ListOp::Pop.with_type(BOOL_T).to_extension_op(®).unwrap(), - [list_wire], - )?; - - let push = build.add_dataflow_op( - ListOp::Push - .with_type(BOOL_T) - .to_extension_op(®) - .unwrap(), - pop.outputs(), - )?; - let mut h = build.finish_hugr_with_outputs(push.outputs(), ®)?; - constant_fold_pass(&mut h, ®); - - assert_fully_folded(&h, &list); - Ok(()) - } - - #[test] - fn test_fold_and() { - // pseudocode: - // x0, x1 := bool(true), bool(true) - // x2 := and(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_load_const(Value::true_val()); - let x2 = build - .add_dataflow_op(NaryLogic::And.with_n_inputs(2), [x0, x1]) - .unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); - } - - #[test] - fn test_fold_or() { - // pseudocode: - // x0, x1 := bool(true), bool(false) - // x2 := or(x0, x1) - // output x2 == true; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_load_const(Value::false_val()); - let x2 = build - .add_dataflow_op(NaryLogic::Or.with_n_inputs(2), [x0, x1]) - .unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); - } - - #[test] - fn test_fold_not() { - // pseudocode: - // x0 := bool(true) - // x1 := not(x0) - // output x1 == false; - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_dataflow_op(NotOp, [x0]).unwrap(); - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::false_val(); - assert_fully_folded(&h, &expected); - } - - #[test] - fn orphan_output() { - // pseudocode: - // x0 := bool(true) - // x1 := not(x0) - // x2 := or(x0,x1) - // output x2 == true; - // - // We arange things so that the `or` folds away first, leaving the not - // with no outputs. - use crate::hugr::NodeType; - use crate::ops::handle::NodeHandle; - - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let true_wire = build.add_load_value(Value::true_val()); - // this Not will be manually replaced - let orig_not = build.add_dataflow_op(NotOp, [true_wire]).unwrap(); - let r = build - .add_dataflow_op( - NaryLogic::Or.with_n_inputs(2), - [true_wire, orig_not.out_wire(0)], - ) - .unwrap(); - let or_node = r.node(); - let parent = build.dfg_node; - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(r.outputs(), ®).unwrap(); - - // we delete the original Not and create a new One. This means it will be - // traversed by `constant_fold_pass` after the Or. - let new_not = h.add_node_with_parent(parent, NodeType::new_auto(NotOp)); - h.connect(true_wire.node(), true_wire.source(), new_not, 0); - h.disconnect(or_node, IncomingPort::from(1)); - h.connect(new_not, 0, or_node, 1); - h.remove_node(orig_not.node()); - constant_fold_pass(&mut h, ®); - assert_fully_folded(&h, &Value::true_val()) - } - - #[test] - fn test_folding_pass_issue_996() { - // pseudocode: - // - // x0 := 3.0 - // x1 := 4.0 - // x2 := fne(x0, x1); // true - // x3 := flt(x0, x1); // true - // x4 := and(x2, x3); // true - // x5 := -10.0 - // x6 := flt(x0, x5) // false - // x7 := or(x4, x6) // true - // output x7 - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0))); - let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0))); - let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); - let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap(); - let x4 = build - .add_dataflow_op( - NaryLogic::And.with_n_inputs(2), - x2.outputs().chain(x3.outputs()), - ) - .unwrap(); - let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0))); - let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap(); - let x7 = build - .add_dataflow_op( - NaryLogic::Or.with_n_inputs(2), - x4.outputs().chain(x6.outputs()), - ) - .unwrap(); - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - logic::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); - constant_fold_pass(&mut h, ®); - let expected = Value::true_val(); - assert_fully_folded(&h, &expected); - } - - #[test] - fn test_const_fold_to_nonfinite() { - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - ]) - .unwrap(); - - // HUGR computing 1.0 / 1.0 - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); - let x1 = build.add_load_const(Value::extension(ConstF64::new(1.0))); - let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h0 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h0, ®); - assert_fully_folded_with(&h0, |v| { - v.get_custom_value::().unwrap().value() == 1.0 - }); - assert_eq!(h0.node_count(), 5); - - // HUGR computing 1.0 / 0.0 - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], vec![FLOAT64_TYPE])).unwrap(); - let x0 = build.add_load_const(Value::extension(ConstF64::new(1.0))); - let x1 = build.add_load_const(Value::extension(ConstF64::new(0.0))); - let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); - let mut h1 = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); - constant_fold_pass(&mut h1, ®); - assert_eq!(h1.node_count(), 8); - } -} diff --git a/hugr/src/hugr.rs b/hugr/src/hugr.rs index 4d25807b3..f53dfe482 100644 --- a/hugr/src/hugr.rs +++ b/hugr/src/hugr.rs @@ -192,6 +192,11 @@ pub type NodeMetadataMap = serde_json::Map; /// Public API for HUGRs. impl Hugr { + /// Create a new Hugr, with a single root node. + pub fn new(root_node: NodeType) -> Self { + Self::with_capacity(root_node, 0, 0) + } + /// Resolve extension ops, infer extensions used, and pass the closure into validation pub fn update_validate( &mut self, @@ -237,11 +242,6 @@ impl Hugr { /// Internal API for HUGRs, not intended for use by users. impl Hugr { - /// Create a new Hugr, with a single root node. - pub(crate) fn new(root_node: NodeType) -> Self { - Self::with_capacity(root_node, 0, 0) - } - /// Create a new Hugr, with a single root node and preallocated capacity. // TODO: Make this take a NodeType pub(crate) fn with_capacity(root_node: NodeType, nodes: usize, ports: usize) -> Self { diff --git a/hugr/src/hugr/rewrite/insert_identity.rs b/hugr/src/hugr/rewrite/insert_identity.rs index 43cb35142..7dbde932b 100644 --- a/hugr/src/hugr/rewrite/insert_identity.rs +++ b/hugr/src/hugr/rewrite/insert_identity.rs @@ -100,9 +100,7 @@ mod tests { use super::super::simple_replace::test::dfg_hugr; use super::*; use crate::{ - algorithm::nest_cfgs::test::build_conditional_in_loop_cfg, extension::{prelude::QB_T, PRELUDE_REGISTRY}, - ops::handle::NodeHandle, Hugr, }; @@ -131,23 +129,4 @@ mod tests { h.update_validate(&PRELUDE_REGISTRY).unwrap(); } - - #[test] - fn incorrect_insertion() { - let (mut h, _, tail) = build_conditional_in_loop_cfg(false).unwrap(); - - let final_node = tail.node(); - - let final_node_input = h.node_inputs(final_node).next().unwrap(); - - let rw = IdentityInsertion::new(final_node, final_node_input); - - let apply_result = h.apply_rewrite(rw); - assert_eq!( - apply_result, - Err(IdentityInsertionError::InvalidPortKind(Some( - EdgeKind::ControlFlow - ))) - ); - } } diff --git a/hugr/src/hugr/rewrite/replace.rs b/hugr/src/hugr/rewrite/replace.rs index b06fcfe62..5b2ed2948 100644 --- a/hugr/src/hugr/rewrite/replace.rs +++ b/hugr/src/hugr/rewrite/replace.rs @@ -445,7 +445,6 @@ mod test { use cool_asserts::assert_matches; use itertools::Itertools; - use crate::algorithm::nest_cfgs::test::depth; use crate::builder::{ BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, SubContainer, @@ -463,6 +462,7 @@ mod test { use crate::ops::{self, Case, DataflowBlock, OpTag, OpType, DFG}; use crate::std_extensions::collections; use crate::types::{FunctionType, Type, TypeArg, TypeRow}; + use crate::utils::depth; use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort}; use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement}; diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index 7c1cdca5f..c093ba2aa 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -140,7 +140,6 @@ // https://github.com/proptest-rs/proptest/issues/447 #![cfg_attr(test, allow(non_local_definitions))] -pub mod algorithm; pub mod builder; pub mod core; pub mod extension; @@ -149,7 +148,7 @@ pub mod macros; pub mod ops; pub mod std_extensions; pub mod types; -mod utils; +pub mod utils; pub use crate::core::{ CircuitUnit, Direction, IncomingPort, Node, NodeIndex, OutgoingPort, Port, PortIndex, Wire, diff --git a/hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs b/hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs index a97a2d1c7..974dbe9b6 100644 --- a/hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs +++ b/hugr/src/std_extensions/arithmetic/float_ops/const_fold.rs @@ -1,8 +1,8 @@ use crate::{ - algorithm::const_fold::sorted_consts, extension::{prelude::ConstString, ConstFold, ConstFoldResult, OpDef}, ops, std_extensions::arithmetic::float_types::ConstF64, + utils::sorted_consts, IncomingPort, }; diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs index 8738e1872..4c520963d 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs @@ -1196,6 +1196,3 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { }, }); } - -#[cfg(test)] -mod test; diff --git a/hugr/src/std_extensions/collections.rs b/hugr/src/std_extensions/collections.rs index 747841a6b..392417dff 100644 --- a/hugr/src/std_extensions/collections.rs +++ b/hugr/src/std_extensions/collections.rs @@ -8,7 +8,6 @@ use crate::ops::constant::ValueName; use crate::ops::{OpName, Value}; use crate::types::TypeName; use crate::{ - algorithm::const_fold::sorted_consts, extension::{ simple_op::{MakeExtensionOp, OpLoadError}, ConstFold, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, TypeDef, @@ -20,6 +19,7 @@ use crate::{ type_param::{TypeArg, TypeParam}, CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound, }, + utils::sorted_consts, Extension, }; diff --git a/hugr/src/std_extensions/logic.rs b/hugr/src/std_extensions/logic.rs index f747428c8..d6e51811f 100644 --- a/hugr/src/std_extensions/logic.rs +++ b/hugr/src/std_extensions/logic.rs @@ -6,7 +6,6 @@ use crate::extension::{ConstFold, ConstFoldResult}; use crate::ops::constant::ValueName; use crate::ops::{OpName, Value}; use crate::{ - algorithm::const_fold::sorted_consts, extension::{ prelude::BOOL_T, simple_op::{try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}, @@ -18,6 +17,7 @@ use crate::{ type_param::{TypeArg, TypeParam}, FunctionType, }, + utils::sorted_consts, Extension, IncomingPort, }; use lazy_static::lazy_static; diff --git a/hugr/src/utils.rs b/hugr/src/utils.rs index d7ce0e101..863cfa7a9 100644 --- a/hugr/src/utils.rs +++ b/hugr/src/utils.rs @@ -1,7 +1,11 @@ +//! General utilities. + use std::fmt::{self, Debug, Display}; use itertools::Itertools; +use crate::{ops::Value, Hugr, HugrView, IncomingPort, Node}; + /// Write a comma separated list of of some types. /// Like debug_list, but using the Display instance rather than Debug, /// and not adding surrounding square brackets. @@ -205,6 +209,29 @@ pub(crate) mod test_quantum_extension { } } +/// Sort folding inputs with [`IncomingPort`] as key +fn sort_by_in_port(consts: &[(IncomingPort, Value)]) -> Vec<&(IncomingPort, Value)> { + let mut v: Vec<_> = consts.iter().collect(); + v.sort_by_key(|(i, _)| i); + v +} + +/// Sort some input constants by port and just return the constants. +pub fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> { + sort_by_in_port(consts) + .into_iter() + .map(|(_, c)| c) + .collect() +} + +/// Calculate the depth of a node in the hierarchy. +pub fn depth(h: &Hugr, n: Node) -> u32 { + match h.get_parent(n) { + Some(p) => 1 + depth(h, p), + None => 0, + } +} + #[allow(dead_code)] // Test only utils #[cfg(test)] diff --git a/release-plz.toml b/release-plz.toml index f6cfe39de..c17abd90a 100644 --- a/release-plz.toml +++ b/release-plz.toml @@ -14,8 +14,16 @@ changelog_update = false # (This would normally only be enabled once there are multiple packages in the workspace) git_tag_name = "{{ package }}-v{{ version }}" +git_release_name = "{{ package }}: v{{ version }}" + [[package]] name = "hugr" # Enable the changelog for this package changelog_update = true + +[[package]] +name = "hugr-passes" + +# Enable the changelog for this package +changelog_update = true