Skip to content

Commit

Permalink
feat: add verification to constant folding (#1030)
Browse files Browse the repository at this point in the history
Fixes #996. Fixes `test_fold_inarrow` tests which were shown to be wrong
by verification, and `test_fold_int_ops` test which was causing constant
folding to panic.
  • Loading branch information
doug-q authored May 14, 2024
1 parent 3e16e42 commit 81c4465
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 94 deletions.
179 changes: 148 additions & 31 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
use std::collections::{BTreeSet, HashMap};

use itertools::Itertools;
use thiserror::Error;

use crate::hugr::{SimpleReplacementError, ValidationError};
use crate::types::SumType;
use crate::Direction;
use crate::{
builder::{DFGBuilder, Dataflow, DataflowHugr},
extension::{ConstFoldResult, ExtensionRegistry},
Expand All @@ -19,6 +22,19 @@ use crate::{
Hugr, HugrView, IncomingPort, Node, SimpleReplacement,
};

#[derive(Error, Debug)]
#[allow(missing_docs)]
pub enum ConstFoldError {
#[error("Failed to verify {label} HUGR: {err}")]
VerifyError {
label: String,
#[source]
err: ValidationError,
},
#[error(transparent)]
SimpleReplaceError(#[from] SimpleReplacementError),
}

/// Tag some output constants with [`OutgoingPort`] inferred from the ordering.
fn out_row(consts: impl IntoIterator<Item = Value>) -> ConstFoldResult {
let vec = consts
Expand All @@ -43,9 +59,10 @@ pub(crate) fn sorted_consts(consts: &[(IncomingPort, Value)]) -> Vec<&Value> {
.map(|(_, c)| c)
.collect()
}

/// For a given op and consts, attempt to evaluate the op.
pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
match op {
let fold_result = match op {
OpType::Noop { .. } => out_row([consts.first()?.1.clone()]),
OpType::MakeTuple { .. } => {
out_row([Value::tuple(sorted_consts(consts).into_iter().cloned())])
Expand All @@ -69,7 +86,10 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR
ext_op.constant_fold(consts)
}
_ => None,
}
};
debug_assert!(fold_result.as_ref().map_or(true, |x| x.len()
== op.value_port_count(Direction::Outgoing)));
fold_result
}

/// Generate a graph that loads and outputs `consts` in order, validating
Expand Down Expand Up @@ -140,18 +160,16 @@ fn fold_op(
})
.unzip();
// attempt to evaluate op
let folded = fold_leaf_op(neighbour_op, &in_consts)?;
let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip();
let nu_out = op_outs
let (nu_out, consts): (HashMap<_, _>, Vec<_>) = fold_leaf_op(neighbour_op, &in_consts)?
.into_iter()
.enumerate()
.filter_map(|(i, out)| {
// map from the ports the op was linked to, to the output ports of
// the replacement.
hugr.single_linked_input(op_node, out)
.map(|np| (np, i.into()))
.filter_map(|(i, (op_out, konst))| {
// for each used port of the op give the nu_out entry and the
// corresponding Value
hugr.single_linked_input(op_node, op_out)
.map(|np| ((np, i.into()), konst))
})
.collect();
.unzip();
let replacement = const_graph(consts, reg);
let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr)
.expect("Operation should form valid subgraph.");
Expand All @@ -172,39 +190,54 @@ fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<
let (load_n, _) = hugr.single_linked_output(op_node, in_p)?;
let load_op = hugr.get_optype(load_n).as_load_constant()?;
let const_node = hugr
.linked_outputs(load_n, load_op.constant_port())
.exactly_one()
.ok()?
.single_linked_output(load_n, load_op.constant_port())?
.0;

let const_op = hugr.get_optype(const_node).as_const()?;

// TODO avoid const clone here
Some((const_op.as_ref().clone(), load_n))
}

/// Exhaustively apply constant folding to a HUGR.
pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) {
pub fn constant_fold_pass<H: HugrMut>(h: &mut H, reg: &ExtensionRegistry) {
#[cfg(test)]
let verify = |label, h: &H| {
h.validate_no_extensions(reg).unwrap_or_else(|err| {
panic!(
"constant_fold_pass: failed to verify {label} HUGR: {err}\n{}",
h.mermaid_string()
)
})
};
#[cfg(test)]
verify("input", h);
loop {
// would be preferable if the candidates were updated to be just the
// neighbouring nodes of those added.
let rewrites = find_consts(h, h.nodes(), reg).collect_vec();
if rewrites.is_empty() {
// We can only safely apply a single replacement. Applying a
// replacement removes nodes and edges which may be referenced by
// further replacements returned by find_consts. Even worse, if we
// attempted to apply those replacements, expecting them to fail if
// the nodes and edges they reference had been deleted, they may
// succeed because new nodes and edges reused the ids.
//
// We could be a lot smarter here, keeping track of `LoadConstant`
// nodes and only looking at their out neighbours.
let Some((replace, removes)) = find_consts(h, h.nodes(), reg).next() else {
break;
}
for (replace, removes) in rewrites {
h.apply_rewrite(replace).unwrap();
for rem in removes {
if let Ok(const_node) = h.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
if h.apply_rewrite(RemoveConst(const_node)).is_err() {
// const cannot be removed - no problem
continue;
}
}
};
h.apply_rewrite(replace).unwrap();
for rem in removes {
// We are optimistically applying these [RemoveLoadConstant] and
// [RemoveConst] rewrites without checking whether the nodes
// they attempt to remove have remaining uses. If they do, then
// the rewrite fails and we move on.
if let Ok(const_node) = h.apply_rewrite(rem) {
// if the LoadConst was removed, try removing the Const too.
let _ = h.apply_rewrite(RemoveConst(const_node));
}
}
}
#[cfg(test)]
verify("output", h);
}

#[cfg(test)]
Expand Down Expand Up @@ -395,4 +428,88 @@ mod test {
let expected = Value::false_val();
assert_fully_folded(&h, &expected);
}

#[test]
fn orphan_output() {
// pseudocode:
// x0 := bool(true)
// x1 := not(x0)
// x2 := or(x0,x1)
// output x2 == true;
//
// We arange things so that the `or` folds away first, leaving the not
// with no outputs.
use crate::hugr::NodeType;
use crate::ops::handle::NodeHandle;

let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();
let true_wire = build.add_load_value(Value::true_val());
// this Not will be manually replaced
let orig_not = build.add_dataflow_op(NotOp, [true_wire]).unwrap();
let r = build
.add_dataflow_op(
NaryLogic::Or.with_n_inputs(2),
[true_wire, orig_not.out_wire(0)],
)
.unwrap();
let or_node = r.node();
let parent = build.dfg_node;
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap();
let mut h = build.finish_hugr_with_outputs(r.outputs(), &reg).unwrap();

// we delete the original Not and create a new One. This means it will be
// traversed by `constant_fold_pass` after the Or.
let new_not = h.add_node_with_parent(parent, NodeType::new_auto(NotOp));
h.connect(true_wire.node(), true_wire.source(), new_not, 0);
h.disconnect(or_node, IncomingPort::from(1));
h.connect(new_not, 0, or_node, 1);
h.remove_node(orig_not.node());
constant_fold_pass(&mut h, &reg);
assert_fully_folded(&h, &Value::true_val())
}

#[test]
fn test_folding_pass_issue_996() {
// pseudocode:
//
// x0 := 3.0
// x1 := 4.0
// x2 := fne(x0, x1); // true
// x3 := flt(x0, x1); // true
// x4 := and(x2, x3); // true
// x5 := -10.0
// x6 := flt(x0, x5) // false
// x7 := or(x4, x6) // true
// output x7
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();
let x0 = build.add_load_const(Value::extension(ConstF64::new(3.0)));
let x1 = build.add_load_const(Value::extension(ConstF64::new(4.0)));
let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap();
let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap();
let x4 = build
.add_dataflow_op(
NaryLogic::And.with_n_inputs(2),
x2.outputs().chain(x3.outputs()),
)
.unwrap();
let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0)));
let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap();
let x7 = build
.add_dataflow_op(
NaryLogic::Or.with_n_inputs(2),
x4.outputs().chain(x6.outputs()),
)
.unwrap();
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
logic::EXTENSION.to_owned(),
arithmetic::float_types::EXTENSION.to_owned(),
])
.unwrap();
let mut h = build.finish_hugr_with_outputs(x7.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}
}
17 changes: 16 additions & 1 deletion hugr/src/hugr/views.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use itertools::{Itertools, MapInto};
use portgraph::render::{DotFormat, MermaidFormat};
use portgraph::{multiportgraph, LinkView, MultiPortGraph, PortView};

use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, DEFAULT_NODETYPE};
use super::{
Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, ValidationError, DEFAULT_NODETYPE,
};
use crate::extension::ExtensionRegistry;
use crate::ops::handle::NodeHandle;
use crate::ops::{OpParent, OpTag, OpTrait, OpType};

Expand Down Expand Up @@ -460,6 +463,18 @@ pub trait HugrView: sealed::HugrInternals {
self.value_types(node, Direction::Outgoing)
.map(|(p, t)| (p.as_outgoing().unwrap(), t))
}

/// Check the validity of the underlying HUGR.
fn validate(&self, reg: &ExtensionRegistry) -> Result<(), ValidationError> {
self.base_hugr().validate(reg)
}

/// Check the validity of the underlying HUGR, but don't check consistency
/// of extension requirements between connected nodes or between parents and
/// children.
fn validate_no_extensions(&self, reg: &ExtensionRegistry) -> Result<(), ValidationError> {
self.base_hugr().validate_no_extensions(reg)
}
}

/// Wraps an iterator over [Port]s that are known to be [OutgoingPort]s
Expand Down
49 changes: 23 additions & 26 deletions hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ use crate::{

use super::IntOpDef;

use lazy_static::lazy_static;

lazy_static! {
static ref INARROW_ERROR_VALUE: Value = ConstError {
signal: 0,
message: "Integer too large to narrow".to_string(),
}
.into();
}

fn bitmask_from_width(width: u64) -> u64 {
debug_assert!(width <= 64);
if width == 64 {
Expand Down Expand Up @@ -111,28 +121,22 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
let logwidth0: u8 = get_log_width(arg0).ok()?;
let logwidth1: u8 = get_log_width(arg1).ok()?;
let n0: &ConstInt = get_single_input_value(consts)?;
(logwidth0 >= logwidth1 && n0.log_width() == logwidth0).then_some(())?;

let int_out_type = INT_TYPES[logwidth1 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
signal: 0,
message: "Integer too large to narrow".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())

let mk_out_const = |i, mb_v: Result<Value, _>| {
mb_v.and_then(|v| Value::sum(i, [v], sum_type))
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
};
let n0val: u64 = n0.value_u();
let out_const: Value = if n0val >> (1 << logwidth1) != 0 {
err_value()
mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone()))
} else {
Value::extension(ConstInt::new_u(logwidth1, n0val).unwrap())
mk_out_const(0, ConstInt::new_u(logwidth1, n0val).map(Into::into))
};
if logwidth0 < logwidth1 || n0.log_width() != logwidth0 {
None
} else {
Some(vec![(0.into(), out_const)])
}
Some(vec![(0.into(), out_const)])
},
),
},
Expand All @@ -145,29 +149,22 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) {
let logwidth0: u8 = get_log_width(arg0).ok()?;
let logwidth1: u8 = get_log_width(arg1).ok()?;
let n0: &ConstInt = get_single_input_value(consts)?;
(logwidth0 >= logwidth1 && n0.log_width() == logwidth0).then_some(())?;

let int_out_type = INT_TYPES[logwidth1 as usize].to_owned();
let sum_type = sum_with_error(int_out_type.clone());
let err_value = || {
let err_val = ConstError {
signal: 0,
message: "Integer too large to narrow".to_string(),
};
Value::sum(1, [err_val.into()], sum_type.clone())
let mk_out_const = |i, mb_v: Result<Value, _>| {
mb_v.and_then(|v| Value::sum(i, [v], sum_type))
.unwrap_or_else(|e| panic!("Invalid computed sum, {}", e))
};
let n0val: i64 = n0.value_s();
let ub = 1i64 << ((1 << logwidth1) - 1);
let out_const: Value = if n0val >= ub || n0val < -ub {
err_value()
mk_out_const(1, Ok(INARROW_ERROR_VALUE.clone()))
} else {
Value::extension(ConstInt::new_s(logwidth1, n0val).unwrap())
mk_out_const(0, ConstInt::new_s(logwidth1, n0val).map(Into::into))
};
if logwidth0 < logwidth1 || n0.log_width() != logwidth0 {
None
} else {
Some(vec![(0.into(), out_const)])
}
Some(vec![(0.into(), out_const)])
},
),
},
Expand Down
Loading

0 comments on commit 81c4465

Please sign in to comment.