Skip to content

Commit

Permalink
feat: constant folding implemented for core and float extension (#758)
Browse files Browse the repository at this point in the history
Closes #711 
CI will pass after updating MSRV to 1.75 (from end of year)

based on #757

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Alan Lawrence <[email protected]>
Co-authored-by: Luca Mondada <[email protected]>
Co-authored-by: Luca Mondada <[email protected]>
  • Loading branch information
5 people authored Jan 2, 2024
1 parent 26bc5ff commit 664fe89
Show file tree
Hide file tree
Showing 16 changed files with 592 additions and 20 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
rust: ['1.70', stable, beta, nightly]
rust: ['1.75', stable, beta, nightly]
# workaround to ignore non-stable tests when running the merge queue checks
# see: https://github.community/t/how-to-conditionally-include-exclude-items-in-matrix-eg-based-on-branch/16853/6
isMerge:
- ${{ github.event_name == 'merge_group' }}
exclude:
- rust: '1.70'
- rust: '1.75'
isMerge: true
- rust: beta
isMerge: true
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ description = "Hierarchical Unified Graph Representation"
#categories = [] # TODO

edition = "2021"
rust-version = "1.70"
rust-version = "1.75"

[lib]
# Using different names for the lib and for the package is supported, but may be confusing.
Expand Down Expand Up @@ -46,7 +46,7 @@ lazy_static = "1.4.0"
petgraph = { version = "0.6.3", default-features = false }
context-iterators = "0.2.0"
serde_json = "1.0.97"
delegate = "0.11.0"
delegate = "0.12.0"
rustversion = "1.0.14"
paste = "1.0"
strum = "0.25.0"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up the developm
This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0).

[build_status]: https://github.com/CQCL/hugr/workflows/Continuous%20integration/badge.svg?branch=main
[msrv]: https://img.shields.io/badge/rust-1.70.0%2B-blue.svg
[msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg
[codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov
[LICENSE]: LICENCE
1 change: 1 addition & 0 deletions src/algorithm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Algorithms using the Hugr.
pub mod const_fold;
mod half_node;
pub mod nest_cfgs;
302 changes: 302 additions & 0 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
//! Constant folding routines.
use std::collections::{BTreeSet, HashMap};

use itertools::Itertools;

use crate::{
builder::{DFGBuilder, Dataflow, DataflowHugr},
extension::{ConstFoldResult, ExtensionRegistry},
hugr::{
rewrite::consts::{RemoveConst, RemoveConstIgnore},
views::SiblingSubgraph,
HugrMut,
},
ops::{Const, LeafOp, OpType},
type_row,
types::{FunctionType, Type, TypeEnum},
values::Value,
Hugr, HugrView, IncomingPort, Node, SimpleReplacement,
};

/// Tag some output constants with [`OutgoingPort`] inferred from the ordering.
fn out_row(consts: impl IntoIterator<Item = Const>) -> ConstFoldResult {
let vec = consts
.into_iter()
.enumerate()
.map(|(i, c)| (i.into(), c))
.collect();
Some(vec)
}

/// Sort folding inputs with [`IncomingPort`] as key
fn sort_by_in_port(consts: &[(IncomingPort, Const)]) -> Vec<&(IncomingPort, Const)> {
let mut v: Vec<_> = consts.iter().collect();
v.sort_by_key(|(i, _)| i);
v
}

/// Sort some input constants by port and just return the constants.
pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> {
sort_by_in_port(consts)
.into_iter()
.map(|(_, c)| c)
.collect()
}
/// For a given op and consts, attempt to evaluate the op.
pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult {
let op = op.as_leaf_op()?;

match op {
LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]),
LeafOp::MakeTuple { .. } => {
out_row([Const::new_tuple(sorted_consts(consts).into_iter().cloned())])
}
LeafOp::UnpackTuple { .. } => {
let c = &consts.first()?.1;

if let Value::Tuple { vs } = c.value() {
if let TypeEnum::Tuple(tys) = c.const_type().as_type_enum() {
return out_row(tys.iter().zip(vs.iter()).map(|(t, v)| {
Const::new(v.clone(), t.clone())
.expect("types should already have been checked")
}));
}
}
None // could panic
}

LeafOp::Tag { tag, variants } => out_row([Const::new(
Value::sum(*tag, consts.first()?.1.value().clone()),
Type::new_sum(variants.clone()),
)
.unwrap()]),
LeafOp::CustomOp(_) => {
let ext_op = op.as_extension_op()?;

ext_op.constant_fold(consts)
}
_ => None,
}
}

/// Generate a graph that loads and outputs `consts` in order, validating
/// against `reg`.
fn const_graph(consts: Vec<Const>, reg: &ExtensionRegistry) -> Hugr {
let const_types = consts.iter().map(Const::const_type).cloned().collect_vec();
let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap();

let outputs = consts
.into_iter()
.map(|c| b.add_load_const(c).unwrap())
.collect_vec();

b.finish_hugr_with_outputs(outputs, reg).unwrap()
}

/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`,
/// return an iterator of possible constant folding rewrites. The
/// [`SimpleReplacement`] replaces an operation with constants that result from
/// evaluating it, the extension registry `reg` is used to validate the
/// replacement HUGR. The vector of [`RemoveConstIgnore`] refer to the
/// LoadConstant nodes that could be removed.
pub fn find_consts<'a, 'r: 'a>(
hugr: &'a impl HugrView,
candidate_nodes: impl IntoIterator<Item = Node> + 'a,
reg: &'r ExtensionRegistry,
) -> impl Iterator<Item = (SimpleReplacement, Vec<RemoveConstIgnore>)> + 'a {
// track nodes for operations that have already been considered for folding
let mut used_neighbours = BTreeSet::new();

candidate_nodes
.into_iter()
.filter_map(move |n| {
// only look at LoadConstant
hugr.get_optype(n).is_load_constant().then_some(())?;

let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?;
let neighbours = hugr
.linked_inputs(n, out_p)
.filter(|(n, _)| used_neighbours.insert(*n))
.collect_vec();
if neighbours.is_empty() {
// no uses of LoadConstant that haven't already been considered.
return None;
}
let fold_iter = neighbours
.into_iter()
.filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg));
Some(fold_iter)
})
.flatten()
}

/// Attempt to evaluate and generate rewrites for the operation at `op_node`
fn fold_op(
hugr: &impl HugrView,
op_node: Node,
reg: &ExtensionRegistry,
) -> Option<(SimpleReplacement, Vec<RemoveConstIgnore>)> {
let (in_consts, removals): (Vec<_>, Vec<_>) = hugr
.node_inputs(op_node)
.filter_map(|in_p| {
let (con_op, load_n) = get_const(hugr, op_node, in_p)?;
Some(((in_p, con_op), RemoveConstIgnore(load_n)))
})
.unzip();
let neighbour_op = hugr.get_optype(op_node);
// attempt to evaluate op
let folded = fold_const(neighbour_op, &in_consts)?;
let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip();
let nu_out = op_outs
.into_iter()
.enumerate()
.filter_map(|(i, out)| {
// map from the ports the op was linked to, to the output ports of
// the replacement.
hugr.single_linked_input(op_node, out)
.map(|np| (np, i.into()))
})
.collect();
let replacement = const_graph(consts, reg);
let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr)
.expect("Operation should form valid subgraph.");

let simple_replace = SimpleReplacement::new(
sibling_graph,
replacement,
// no inputs to replacement
HashMap::new(),
nu_out,
);
Some((simple_replace, removals))
}

/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant
/// and the LoadConstant node
fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Const, Node)> {
let (load_n, _) = hugr.single_linked_output(op_node, in_p)?;
let load_op = hugr.get_optype(load_n).as_load_constant()?;
let const_node = hugr
.linked_outputs(load_n, load_op.constant_port())
.exactly_one()
.ok()?
.0;

let const_op = hugr.get_optype(const_node).as_const()?;

// TODO avoid const clone here
Some((const_op.clone(), load_n))
}

/// Exhaustively apply constant folding to a HUGR.
pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) {
loop {
// would be preferable if the candidates were updated to be just the
// neighbouring nodes of those added.
let rewrites = find_consts(h, h.nodes(), reg).collect_vec();
if rewrites.is_empty() {
break;
}
for (replace, removes) in rewrites {
h.apply_rewrite(replace).unwrap();
for rem in removes {
if let Ok(const_node) = h.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
if h.apply_rewrite(RemoveConst(const_node)).is_err() {
// const cannot be removed - no problem
continue;
}
}
}
}
}
}

#[cfg(test)]
mod test {

use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::std_extensions::arithmetic;

use crate::std_extensions::arithmetic::float_ops::FloatOps;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};

use rstest::rstest;

use super::*;

/// float to constant
fn f2c(f: f64) -> Const {
ConstF64::new(f).into()
}

#[rstest]
#[case(0.0, 0.0, 0.0)]
#[case(0.0, 1.0, 1.0)]
#[case(23.5, 435.5, 459.0)]
// c = a + b
fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) {
let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))];
let add_op: OpType = FloatOps::fadd.into();
let out = fold_const(&add_op, &consts).unwrap();

assert_eq!(&out[..], &[(0.into(), f2c(c))]);
}

#[test]
fn test_big() {
/*
Test hugr approximately calculates
let x = (5.5, 3.25);
x.0 - x.1 == 2.25
*/
let mut build =
DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap();

let tup = build
.add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)]))
.unwrap();

let unpack = build
.add_dataflow_op(
LeafOp::UnpackTuple {
tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE],
},
[tup],
)
.unwrap();

let sub = build
.add_dataflow_op(FloatOps::fsub, unpack.outputs())
.unwrap();

let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
arithmetic::float_types::EXTENSION.to_owned(),
arithmetic::float_ops::EXTENSION.to_owned(),
])
.unwrap();
let mut h = build.finish_hugr_with_outputs(sub.outputs(), &reg).unwrap();
assert_eq!(h.node_count(), 7);

constant_fold_pass(&mut h, &reg);

assert_fully_folded(&h, &f2c(2.25));
}
fn assert_fully_folded(h: &Hugr, expected_const: &Const) {
// check the hugr just loads and returns a single const
let mut node_count = 0;

for node in h.children(h.root()) {
let op = h.get_optype(node);
match op {
OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
OpType::Const(c) if c == expected_const => node_count += 1,
_ => panic!("unexpected op: {:?}", op),
}
}

assert_eq!(node_count, 4);
}
}
2 changes: 2 additions & 0 deletions src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ pub use op_def::{
};
mod type_def;
pub use type_def::{TypeDef, TypeDefBound};
mod const_fold;
pub mod prelude;
pub mod simple_op;
pub mod validate;
pub use const_fold::{ConstFold, ConstFoldResult};
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};

/// Extension Registries store extensions to be looked up e.g. during validation.
Expand Down
Loading

0 comments on commit 664fe89

Please sign in to comment.