Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: constant folding implemented for core and float extension #758

Merged
merged 40 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
bffed99
wip: constant folding
ss2165 Nov 13, 2023
1a27d54
start moving folding to op_def
ss2165 Nov 20, 2023
b84766b
thread through folding methods
ss2165 Nov 23, 2023
8ee49da
integer addition tests passing
ss2165 Nov 23, 2023
520de7c
remove FoldOutput
ss2165 Nov 24, 2023
1d656d6
Merge branch 'main' into feat/const-fold2
ss2165 Dec 18, 2023
9398d9d
refactor int folding to separate repo
ss2165 Dec 18, 2023
7b955a9
add tuple and sum constant folding
ss2165 Dec 18, 2023
6cb3c62
simplify test code
ss2165 Dec 18, 2023
0500624
wip: fold finder
ss2165 Dec 20, 2023
8f554e0
chore(deps): bump actions/upload-artifact from 3 to 4 (#751)
dependabot[bot] Dec 20, 2023
215eb40
chore(deps): bump dawidd6/action-download-artifact from 2 to 3 (#752)
dependabot[bot] Dec 20, 2023
ff26546
fix: case node should not have an external signature (#749)
ss2165 Dec 20, 2023
64b9199
refactor: move hugr equality check out for reuse
ss2165 Dec 20, 2023
6d7d440
feat: implement RemoveConst and RemoveConstIgnore
ss2165 Dec 21, 2023
cdde503
use remove rewrites while folding
ss2165 Dec 21, 2023
114524c
alllow candidate node specification in find_consts
ss2165 Dec 21, 2023
a087fbc
add exhaustive fold pass
ss2165 Dec 21, 2023
07768b2
refactor!: use enum op traits for floats + conversions
ss2165 Dec 21, 2023
9a81260
Merge branch 'refactor/fops-enum' into feat/const-fold2
ss2165 Dec 21, 2023
658adf4
add folding definitions for float ops
ss2165 Dec 21, 2023
2c0e75b
refactor: ERROR_CUSTOM_TYPE
ss2165 Dec 21, 2023
dc7ff13
refactor: const ConstF64::new
ss2165 Dec 21, 2023
aa73ab2
feat: implement folding for conversion ops
ss2165 Dec 21, 2023
a519f34
fixup! refactor: ERROR_CUSTOM_TYPE
ss2165 Dec 21, 2023
a7a4088
Merge branch 'main' into feat/const-fold2
ss2165 Dec 21, 2023
46075c2
implement bigger tests and fix unearthed bugs
ss2165 Dec 21, 2023
df854e8
Revert "refactor: move hugr equality check out for reuse"
ss2165 Dec 22, 2023
1ed42e9
feat: Custom const for ERROR_TYPE (#756)
ss2165 Dec 22, 2023
09ce1c9
remove conversion foldin
ss2165 Dec 22, 2023
5a372c7
Merge branch 'main' into feat/const-fold-floats
ss2165 Dec 22, 2023
b513ace
Merge branch 'feat/const-rewrites' into feat/const-fold-floats
ss2165 Dec 22, 2023
5a71f75
docs: add public method docstrings
ss2165 Dec 22, 2023
6fa7eb9
add some docstrings and comments
ss2165 Dec 22, 2023
7381432
remove integer folding
ss2165 Dec 22, 2023
0e0411f
remove unused imports
ss2165 Dec 22, 2023
8e88f3e
add docstrings and simplify
ss2165 Dec 22, 2023
4607d64
chore(deps): update delegate requirement from 0.11.0 to 0.12.0 (#760)
dependabot[bot] Jan 2, 2024
0edee65
chore!: hike MSRV to 1.75 (#761)
lmondada Jan 2, 2024
0c060fb
Merge branch 'main' into feat/const-fold-floats
lmondada Jan 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
rust: ['1.70', stable, beta, nightly]
rust: ['1.75', stable, beta, nightly]
# workaround to ignore non-stable tests when running the merge queue checks
# see: https://github.community/t/how-to-conditionally-include-exclude-items-in-matrix-eg-based-on-branch/16853/6
isMerge:
- ${{ github.event_name == 'merge_group' }}
exclude:
- rust: '1.70'
- rust: '1.75'
isMerge: true
- rust: beta
isMerge: true
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ description = "Hierarchical Unified Graph Representation"
#categories = [] # TODO

edition = "2021"
rust-version = "1.70"
rust-version = "1.75"

[lib]
# Using different names for the lib and for the package is supported, but may be confusing.
Expand Down Expand Up @@ -46,7 +46,7 @@ lazy_static = "1.4.0"
petgraph = { version = "0.6.3", default-features = false }
context-iterators = "0.2.0"
serde_json = "1.0.97"
delegate = "0.11.0"
delegate = "0.12.0"
rustversion = "1.0.14"
paste = "1.0"
strum = "0.25.0"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up the developm
This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0).

[build_status]: https://github.com/CQCL/hugr/workflows/Continuous%20integration/badge.svg?branch=main
[msrv]: https://img.shields.io/badge/rust-1.70.0%2B-blue.svg
[msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg
[codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov
[LICENSE]: LICENCE
1 change: 1 addition & 0 deletions src/algorithm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Algorithms using the Hugr.

pub mod const_fold;
mod half_node;
pub mod nest_cfgs;
302 changes: 302 additions & 0 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
//! Constant folding routines.

use std::collections::{BTreeSet, HashMap};

use itertools::Itertools;

use crate::{
builder::{DFGBuilder, Dataflow, DataflowHugr},
extension::{ConstFoldResult, ExtensionRegistry},
hugr::{
rewrite::consts::{RemoveConst, RemoveConstIgnore},
views::SiblingSubgraph,
HugrMut,
},
ops::{Const, LeafOp, OpType},
type_row,
types::{FunctionType, Type, TypeEnum},
values::Value,
Hugr, HugrView, IncomingPort, Node, SimpleReplacement,
};

/// Tag some output constants with [`OutgoingPort`] inferred from the ordering.
fn out_row(consts: impl IntoIterator<Item = Const>) -> ConstFoldResult {
let vec = consts
.into_iter()
.enumerate()
.map(|(i, c)| (i.into(), c))
.collect();
Some(vec)
}

/// Sort folding inputs with [`IncomingPort`] as key
fn sort_by_in_port(consts: &[(IncomingPort, Const)]) -> Vec<&(IncomingPort, Const)> {
let mut v: Vec<_> = consts.iter().collect();
v.sort_by_key(|(i, _)| i);
v
}

/// Sort some input constants by port and just return the constants.
pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> {
sort_by_in_port(consts)
.into_iter()
.map(|(_, c)| c)
.collect()
}
/// For a given op and consts, attempt to evaluate the op.
pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult {
let op = op.as_leaf_op()?;

match op {
LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]),
LeafOp::MakeTuple { .. } => {
out_row([Const::new_tuple(sorted_consts(consts).into_iter().cloned())])
}
LeafOp::UnpackTuple { .. } => {
let c = &consts.first()?.1;

if let Value::Tuple { vs } = c.value() {
if let TypeEnum::Tuple(tys) = c.const_type().as_type_enum() {
return out_row(tys.iter().zip(vs.iter()).map(|(t, v)| {
Const::new(v.clone(), t.clone())
.expect("types should already have been checked")
}));
}
}
None // could panic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure the comment is correct, i.e. what would panic and how returning None would help. I suppose you mean that the let if clauses might fail?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i mean that None shouldn't ever be returned here, so it might be clearer to panic

}

LeafOp::Tag { tag, variants } => out_row([Const::new(
Value::sum(*tag, consts.first()?.1.value().clone()),
Type::new_sum(variants.clone()),
)
.unwrap()]),
LeafOp::CustomOp(_) => {
let ext_op = op.as_extension_op()?;

ext_op.constant_fold(consts)
}
_ => None,
}
}

/// Generate a graph that loads and outputs `consts` in order, validating
/// against `reg`.
fn const_graph(consts: Vec<Const>, reg: &ExtensionRegistry) -> Hugr {
let const_types = consts.iter().map(Const::const_type).cloned().collect_vec();
let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap();

let outputs = consts
.into_iter()
.map(|c| b.add_load_const(c).unwrap())
.collect_vec();

b.finish_hugr_with_outputs(outputs, reg).unwrap()
}

/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`,
/// return an iterator of possible constant folding rewrites. The
/// [`SimpleReplacement`] replaces an operation with constants that result from
/// evaluating it, the extension registry `reg` is used to validate the
/// replacement HUGR. The vector of [`RemoveConstIgnore`] refer to the
/// LoadConstant nodes that could be removed.
pub fn find_consts<'a, 'r: 'a>(
hugr: &'a impl HugrView,
candidate_nodes: impl IntoIterator<Item = Node> + 'a,
reg: &'r ExtensionRegistry,
) -> impl Iterator<Item = (SimpleReplacement, Vec<RemoveConstIgnore>)> + 'a {
// track nodes for operations that have already been considered for folding
let mut used_neighbours = BTreeSet::new();

candidate_nodes
.into_iter()
.filter_map(move |n| {
// only look at LoadConstant
hugr.get_optype(n).is_load_constant().then_some(())?;

let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?;
let neighbours = hugr
.linked_inputs(n, out_p)
.filter(|(n, _)| used_neighbours.insert(*n))
.collect_vec();
if neighbours.is_empty() {
// no uses of LoadConstant that haven't already been considered.
return None;
}
let fold_iter = neighbours
.into_iter()
.filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg));
Some(fold_iter)
})
.flatten()
}

/// Attempt to evaluate and generate rewrites for the operation at `op_node`
fn fold_op(
hugr: &impl HugrView,
op_node: Node,
reg: &ExtensionRegistry,
) -> Option<(SimpleReplacement, Vec<RemoveConstIgnore>)> {
let (in_consts, removals): (Vec<_>, Vec<_>) = hugr
.node_inputs(op_node)
.filter_map(|in_p| {
let (con_op, load_n) = get_const(hugr, op_node, in_p)?;
Some(((in_p, con_op), RemoveConstIgnore(load_n)))
})
.unzip();
let neighbour_op = hugr.get_optype(op_node);
// attempt to evaluate op
let folded = fold_const(neighbour_op, &in_consts)?;
let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip();
let nu_out = op_outs
.into_iter()
.enumerate()
.filter_map(|(i, out)| {
// map from the ports the op was linked to, to the output ports of
// the replacement.
hugr.single_linked_input(op_node, out)
.map(|np| (np, i.into()))
})
.collect();
let replacement = const_graph(consts, reg);
let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr)
.expect("Operation should form valid subgraph.");

let simple_replace = SimpleReplacement::new(
sibling_graph,
replacement,
// no inputs to replacement
HashMap::new(),
nu_out,
);
Some((simple_replace, removals))
}

/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant
/// and the LoadConstant node
fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Const, Node)> {
let (load_n, _) = hugr.single_linked_output(op_node, in_p)?;
let load_op = hugr.get_optype(load_n).as_load_constant()?;
let const_node = hugr
.linked_outputs(load_n, load_op.constant_port())
.exactly_one()
.ok()?
.0;

let const_op = hugr.get_optype(const_node).as_const()?;

// TODO avoid const clone here
Some((const_op.clone(), load_n))
}

/// Exhaustively apply constant folding to a HUGR.
pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) {
loop {
// would be preferable if the candidates were updated to be just the
// neighbouring nodes of those added.
let rewrites = find_consts(h, h.nodes(), reg).collect_vec();
if rewrites.is_empty() {
break;
}
for (replace, removes) in rewrites {
h.apply_rewrite(replace).unwrap();
for rem in removes {
if let Ok(const_node) = h.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
if h.apply_rewrite(RemoveConst(const_node)).is_err() {
// const cannot be removed - no problem
continue;
Comment on lines +205 to +208
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it surprising that this is not removed by the rewrite itself. Maybe add a comment about this in the docstring of find_consts?

}
}
}
}
}
}

#[cfg(test)]
mod test {

use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::std_extensions::arithmetic;

use crate::std_extensions::arithmetic::float_ops::FloatOps;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};

use rstest::rstest;

use super::*;

/// float to constant
fn f2c(f: f64) -> Const {
ConstF64::new(f).into()
}

#[rstest]
#[case(0.0, 0.0, 0.0)]
#[case(0.0, 1.0, 1.0)]
#[case(23.5, 435.5, 459.0)]
// c = a + b
fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) {
let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))];
let add_op: OpType = FloatOps::fadd.into();
let out = fold_const(&add_op, &consts).unwrap();

assert_eq!(&out[..], &[(0.into(), f2c(c))]);
}

#[test]
fn test_big() {
/*
Test hugr approximately calculates
let x = (5.5, 3.25);
x.0 - x.1 == 2.25
*/
let mut build =
DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap();

let tup = build
.add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)]))
.unwrap();

let unpack = build
.add_dataflow_op(
LeafOp::UnpackTuple {
tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE],
},
[tup],
)
.unwrap();

let sub = build
.add_dataflow_op(FloatOps::fsub, unpack.outputs())
.unwrap();

let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
arithmetic::float_types::EXTENSION.to_owned(),
arithmetic::float_ops::EXTENSION.to_owned(),
])
.unwrap();
let mut h = build.finish_hugr_with_outputs(sub.outputs(), &reg).unwrap();
assert_eq!(h.node_count(), 7);

constant_fold_pass(&mut h, &reg);

assert_fully_folded(&h, &f2c(2.25));
}
fn assert_fully_folded(h: &Hugr, expected_const: &Const) {
// check the hugr just loads and returns a single const
let mut node_count = 0;

for node in h.children(h.root()) {
let op = h.get_optype(node);
match op {
OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
OpType::Const(c) if c == expected_const => node_count += 1,
_ => panic!("unexpected op: {:?}", op),
}
}

assert_eq!(node_count, 4);
}
}
2 changes: 2 additions & 0 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ pub use op_def::{
};
mod type_def;
pub use type_def::{TypeDef, TypeDefBound};
mod const_fold;
pub mod prelude;
pub mod simple_op;
pub mod validate;
pub use const_fold::{ConstFold, ConstFoldResult};
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};

/// Extension Registries store extensions to be looked up e.g. during validation.
Expand Down
Loading