-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: constant folding implemented for core and float extension #758
Merged
Merged
Changes from all commits
Commits
Show all changes
40 commits
Select commit
Hold shift + click to select a range
bffed99
wip: constant folding
ss2165 1a27d54
start moving folding to op_def
ss2165 b84766b
thread through folding methods
ss2165 8ee49da
integer addition tests passing
ss2165 520de7c
remove FoldOutput
ss2165 1d656d6
Merge branch 'main' into feat/const-fold2
ss2165 9398d9d
refactor int folding to separate repo
ss2165 7b955a9
add tuple and sum constant folding
ss2165 6cb3c62
simplify test code
ss2165 0500624
wip: fold finder
ss2165 8f554e0
chore(deps): bump actions/upload-artifact from 3 to 4 (#751)
dependabot[bot] 215eb40
chore(deps): bump dawidd6/action-download-artifact from 2 to 3 (#752)
dependabot[bot] ff26546
fix: case node should not have an external signature (#749)
ss2165 64b9199
refactor: move hugr equality check out for reuse
ss2165 6d7d440
feat: implement RemoveConst and RemoveConstIgnore
ss2165 cdde503
use remove rewrites while folding
ss2165 114524c
alllow candidate node specification in find_consts
ss2165 a087fbc
add exhaustive fold pass
ss2165 07768b2
refactor!: use enum op traits for floats + conversions
ss2165 9a81260
Merge branch 'refactor/fops-enum' into feat/const-fold2
ss2165 658adf4
add folding definitions for float ops
ss2165 2c0e75b
refactor: ERROR_CUSTOM_TYPE
ss2165 dc7ff13
refactor: const ConstF64::new
ss2165 aa73ab2
feat: implement folding for conversion ops
ss2165 a519f34
fixup! refactor: ERROR_CUSTOM_TYPE
ss2165 a7a4088
Merge branch 'main' into feat/const-fold2
ss2165 46075c2
implement bigger tests and fix unearthed bugs
ss2165 df854e8
Revert "refactor: move hugr equality check out for reuse"
ss2165 1ed42e9
feat: Custom const for ERROR_TYPE (#756)
ss2165 09ce1c9
remove conversion foldin
ss2165 5a372c7
Merge branch 'main' into feat/const-fold-floats
ss2165 b513ace
Merge branch 'feat/const-rewrites' into feat/const-fold-floats
ss2165 5a71f75
docs: add public method docstrings
ss2165 6fa7eb9
add some docstrings and comments
ss2165 7381432
remove integer folding
ss2165 0e0411f
remove unused imports
ss2165 8e88f3e
add docstrings and simplify
ss2165 4607d64
chore(deps): update delegate requirement from 0.11.0 to 0.12.0 (#760)
dependabot[bot] 0edee65
chore!: hike MSRV to 1.75 (#761)
lmondada 0c060fb
Merge branch 'main' into feat/const-fold-floats
lmondada File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
//! Algorithms using the Hugr. | ||
|
||
pub mod const_fold; | ||
mod half_node; | ||
pub mod nest_cfgs; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,302 @@ | ||
//! Constant folding routines. | ||
|
||
use std::collections::{BTreeSet, HashMap}; | ||
|
||
use itertools::Itertools; | ||
|
||
use crate::{ | ||
builder::{DFGBuilder, Dataflow, DataflowHugr}, | ||
extension::{ConstFoldResult, ExtensionRegistry}, | ||
hugr::{ | ||
rewrite::consts::{RemoveConst, RemoveConstIgnore}, | ||
views::SiblingSubgraph, | ||
HugrMut, | ||
}, | ||
ops::{Const, LeafOp, OpType}, | ||
type_row, | ||
types::{FunctionType, Type, TypeEnum}, | ||
values::Value, | ||
Hugr, HugrView, IncomingPort, Node, SimpleReplacement, | ||
}; | ||
|
||
/// Tag some output constants with [`OutgoingPort`] inferred from the ordering. | ||
fn out_row(consts: impl IntoIterator<Item = Const>) -> ConstFoldResult { | ||
let vec = consts | ||
.into_iter() | ||
.enumerate() | ||
.map(|(i, c)| (i.into(), c)) | ||
.collect(); | ||
Some(vec) | ||
} | ||
|
||
/// Sort folding inputs with [`IncomingPort`] as key | ||
fn sort_by_in_port(consts: &[(IncomingPort, Const)]) -> Vec<&(IncomingPort, Const)> { | ||
let mut v: Vec<_> = consts.iter().collect(); | ||
v.sort_by_key(|(i, _)| i); | ||
v | ||
} | ||
|
||
/// Sort some input constants by port and just return the constants. | ||
pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> { | ||
sort_by_in_port(consts) | ||
.into_iter() | ||
.map(|(_, c)| c) | ||
.collect() | ||
} | ||
/// For a given op and consts, attempt to evaluate the op. | ||
pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult { | ||
let op = op.as_leaf_op()?; | ||
|
||
match op { | ||
LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]), | ||
LeafOp::MakeTuple { .. } => { | ||
out_row([Const::new_tuple(sorted_consts(consts).into_iter().cloned())]) | ||
} | ||
LeafOp::UnpackTuple { .. } => { | ||
let c = &consts.first()?.1; | ||
|
||
if let Value::Tuple { vs } = c.value() { | ||
if let TypeEnum::Tuple(tys) = c.const_type().as_type_enum() { | ||
return out_row(tys.iter().zip(vs.iter()).map(|(t, v)| { | ||
Const::new(v.clone(), t.clone()) | ||
.expect("types should already have been checked") | ||
})); | ||
} | ||
} | ||
None // could panic | ||
} | ||
|
||
LeafOp::Tag { tag, variants } => out_row([Const::new( | ||
Value::sum(*tag, consts.first()?.1.value().clone()), | ||
Type::new_sum(variants.clone()), | ||
) | ||
.unwrap()]), | ||
LeafOp::CustomOp(_) => { | ||
let ext_op = op.as_extension_op()?; | ||
|
||
ext_op.constant_fold(consts) | ||
} | ||
_ => None, | ||
} | ||
} | ||
|
||
/// Generate a graph that loads and outputs `consts` in order, validating | ||
/// against `reg`. | ||
fn const_graph(consts: Vec<Const>, reg: &ExtensionRegistry) -> Hugr { | ||
let const_types = consts.iter().map(Const::const_type).cloned().collect_vec(); | ||
let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); | ||
|
||
let outputs = consts | ||
.into_iter() | ||
.map(|c| b.add_load_const(c).unwrap()) | ||
.collect_vec(); | ||
|
||
b.finish_hugr_with_outputs(outputs, reg).unwrap() | ||
} | ||
|
||
/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, | ||
/// return an iterator of possible constant folding rewrites. The | ||
/// [`SimpleReplacement`] replaces an operation with constants that result from | ||
/// evaluating it, the extension registry `reg` is used to validate the | ||
/// replacement HUGR. The vector of [`RemoveConstIgnore`] refer to the | ||
/// LoadConstant nodes that could be removed. | ||
pub fn find_consts<'a, 'r: 'a>( | ||
hugr: &'a impl HugrView, | ||
candidate_nodes: impl IntoIterator<Item = Node> + 'a, | ||
reg: &'r ExtensionRegistry, | ||
) -> impl Iterator<Item = (SimpleReplacement, Vec<RemoveConstIgnore>)> + 'a { | ||
// track nodes for operations that have already been considered for folding | ||
let mut used_neighbours = BTreeSet::new(); | ||
|
||
candidate_nodes | ||
.into_iter() | ||
.filter_map(move |n| { | ||
// only look at LoadConstant | ||
hugr.get_optype(n).is_load_constant().then_some(())?; | ||
|
||
let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; | ||
let neighbours = hugr | ||
.linked_inputs(n, out_p) | ||
.filter(|(n, _)| used_neighbours.insert(*n)) | ||
.collect_vec(); | ||
if neighbours.is_empty() { | ||
// no uses of LoadConstant that haven't already been considered. | ||
return None; | ||
} | ||
let fold_iter = neighbours | ||
.into_iter() | ||
.filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg)); | ||
Some(fold_iter) | ||
}) | ||
.flatten() | ||
} | ||
|
||
/// Attempt to evaluate and generate rewrites for the operation at `op_node` | ||
fn fold_op( | ||
hugr: &impl HugrView, | ||
op_node: Node, | ||
reg: &ExtensionRegistry, | ||
) -> Option<(SimpleReplacement, Vec<RemoveConstIgnore>)> { | ||
let (in_consts, removals): (Vec<_>, Vec<_>) = hugr | ||
.node_inputs(op_node) | ||
.filter_map(|in_p| { | ||
let (con_op, load_n) = get_const(hugr, op_node, in_p)?; | ||
Some(((in_p, con_op), RemoveConstIgnore(load_n))) | ||
}) | ||
.unzip(); | ||
let neighbour_op = hugr.get_optype(op_node); | ||
// attempt to evaluate op | ||
let folded = fold_const(neighbour_op, &in_consts)?; | ||
let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip(); | ||
let nu_out = op_outs | ||
.into_iter() | ||
.enumerate() | ||
.filter_map(|(i, out)| { | ||
// map from the ports the op was linked to, to the output ports of | ||
// the replacement. | ||
hugr.single_linked_input(op_node, out) | ||
.map(|np| (np, i.into())) | ||
}) | ||
.collect(); | ||
let replacement = const_graph(consts, reg); | ||
let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) | ||
.expect("Operation should form valid subgraph."); | ||
|
||
let simple_replace = SimpleReplacement::new( | ||
sibling_graph, | ||
replacement, | ||
// no inputs to replacement | ||
HashMap::new(), | ||
nu_out, | ||
); | ||
Some((simple_replace, removals)) | ||
} | ||
|
||
/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant | ||
/// and the LoadConstant node | ||
fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Const, Node)> { | ||
let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; | ||
let load_op = hugr.get_optype(load_n).as_load_constant()?; | ||
let const_node = hugr | ||
.linked_outputs(load_n, load_op.constant_port()) | ||
.exactly_one() | ||
.ok()? | ||
.0; | ||
|
||
let const_op = hugr.get_optype(const_node).as_const()?; | ||
|
||
// TODO avoid const clone here | ||
Some((const_op.clone(), load_n)) | ||
} | ||
|
||
/// Exhaustively apply constant folding to a HUGR. | ||
pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { | ||
loop { | ||
// would be preferable if the candidates were updated to be just the | ||
// neighbouring nodes of those added. | ||
let rewrites = find_consts(h, h.nodes(), reg).collect_vec(); | ||
if rewrites.is_empty() { | ||
break; | ||
} | ||
for (replace, removes) in rewrites { | ||
h.apply_rewrite(replace).unwrap(); | ||
for rem in removes { | ||
if let Ok(const_node) = h.apply_rewrite(rem) { | ||
// if the LoadConst was removed, try removing the Const too. | ||
if h.apply_rewrite(RemoveConst(const_node)).is_err() { | ||
// const cannot be removed - no problem | ||
continue; | ||
Comment on lines
+205
to
+208
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find it surprising that this is not removed by the rewrite itself. Maybe add a comment about this in the docstring of |
||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
|
||
use crate::extension::{ExtensionRegistry, PRELUDE}; | ||
use crate::std_extensions::arithmetic; | ||
|
||
use crate::std_extensions::arithmetic::float_ops::FloatOps; | ||
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; | ||
|
||
use rstest::rstest; | ||
|
||
use super::*; | ||
|
||
/// float to constant | ||
fn f2c(f: f64) -> Const { | ||
ConstF64::new(f).into() | ||
} | ||
|
||
#[rstest] | ||
#[case(0.0, 0.0, 0.0)] | ||
#[case(0.0, 1.0, 1.0)] | ||
#[case(23.5, 435.5, 459.0)] | ||
// c = a + b | ||
fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { | ||
let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; | ||
let add_op: OpType = FloatOps::fadd.into(); | ||
let out = fold_const(&add_op, &consts).unwrap(); | ||
|
||
assert_eq!(&out[..], &[(0.into(), f2c(c))]); | ||
} | ||
|
||
#[test] | ||
fn test_big() { | ||
/* | ||
Test hugr approximately calculates | ||
let x = (5.5, 3.25); | ||
x.0 - x.1 == 2.25 | ||
*/ | ||
let mut build = | ||
DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap(); | ||
|
||
let tup = build | ||
.add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)])) | ||
.unwrap(); | ||
|
||
let unpack = build | ||
.add_dataflow_op( | ||
LeafOp::UnpackTuple { | ||
tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE], | ||
}, | ||
[tup], | ||
) | ||
.unwrap(); | ||
|
||
let sub = build | ||
.add_dataflow_op(FloatOps::fsub, unpack.outputs()) | ||
.unwrap(); | ||
|
||
let reg = ExtensionRegistry::try_new([ | ||
PRELUDE.to_owned(), | ||
arithmetic::float_types::EXTENSION.to_owned(), | ||
arithmetic::float_ops::EXTENSION.to_owned(), | ||
]) | ||
.unwrap(); | ||
let mut h = build.finish_hugr_with_outputs(sub.outputs(), ®).unwrap(); | ||
assert_eq!(h.node_count(), 7); | ||
|
||
constant_fold_pass(&mut h, ®); | ||
|
||
assert_fully_folded(&h, &f2c(2.25)); | ||
} | ||
fn assert_fully_folded(h: &Hugr, expected_const: &Const) { | ||
// check the hugr just loads and returns a single const | ||
let mut node_count = 0; | ||
|
||
for node in h.children(h.root()) { | ||
let op = h.get_optype(node); | ||
match op { | ||
OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, | ||
OpType::Const(c) if c == expected_const => node_count += 1, | ||
_ => panic!("unexpected op: {:?}", op), | ||
} | ||
} | ||
|
||
assert_eq!(node_count, 4); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure the comment is correct, i.e. what would
panic
and how returning None would help. I suppose you mean that thelet if
clauses might fail?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i mean that
None
shouldn't ever be returned here, so it might be clearer to panic