Skip to content

Commit

Permalink
fix: NaryLogicOp constant folding
Browse files Browse the repository at this point in the history
  • Loading branch information
doug-q committed May 13, 2024
1 parent b9c3ee4 commit 21b5ad6
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 43 deletions.
24 changes: 0 additions & 24 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,30 +297,6 @@ mod test {
assert_fully_folded(&h, &expected);
}

#[rstest]
#[case(NaryLogic::And, [true, true, true], true)]
#[case(NaryLogic::And, [true, false, true], false)]
#[case(NaryLogic::Or, [false, false, true], true)]
#[case(NaryLogic::Or, [false, false, false], false)]
fn test_logic_and(
#[case] op: NaryLogic,
#[case] ins: [bool; 3],
#[case] out: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();

let ins = ins.map(|b| build.add_load_const(Value::from_bool(b)));
let logic_op = build.add_dataflow_op(op.with_n_inputs(ins.len() as u64), ins)?;

let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap();
let mut h = build.finish_hugr_with_outputs(logic_op.outputs(), &reg)?;
constant_fold_pass(&mut h, &reg);

assert_fully_folded(&h, &Value::from_bool(out));
Ok(())
}

#[test]
#[cfg_attr(
feature = "extension_inference",
Expand Down
2 changes: 0 additions & 2 deletions hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1235,8 +1235,6 @@ fn test_fold_itostring_s() {
}

#[test]
#[should_panic]
// FIXME: https://github.com/CQCL/hugr/issues/996
fn test_fold_int_ops() {
// pseudocode:
//
Expand Down
104 changes: 88 additions & 16 deletions hugr/src/std_extensions/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
use strum_macros::{EnumIter, EnumString, IntoStaticStr};

use crate::extension::{ConstFold, ConstFoldResult};
use crate::ops::constant::ValueName;
use crate::ops::OpName;
use crate::ops::{OpName, Value};
use crate::{
algorithm::const_fold::sorted_consts,
extension::{
Expand All @@ -25,6 +26,29 @@ pub const FALSE_NAME: ValueName = ValueName::new_inline("FALSE");
/// Name of extension true value.
pub const TRUE_NAME: ValueName = ValueName::new_inline("TRUE");

impl ConstFold for NaryLogic {
fn fold(&self, type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
let [TypeArg::BoundedNat { n: num_args }] = *type_args else {
panic!("impossible by validation");
};
match self {
Self::And => {
let inps = read_inputs(consts)?;
let res = inps.iter().all(|x| *x);
// We can only fold to true if we have a const for all our inputs.
(!res || inps.len() as u64 == num_args)
.then_some(vec![(0.into(), ops::Value::from_bool(res))])
}
Self::Or => {
let inps = read_inputs(consts)?;
let res = inps.iter().any(|x| *x);
// We can only fold to false if we have a const for all our inputs
(res || inps.len() as u64 == num_args)
.then_some(vec![(0.into(), ops::Value::from_bool(res))])
}
}
}
}
/// Logic extension operation definitions.
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
#[allow(missing_docs)]
Expand Down Expand Up @@ -52,18 +76,7 @@ impl MakeOpDef for NaryLogic {
}

fn post_opdef(&self, def: &mut OpDef) {
def.set_constant_folder(match self {
NaryLogic::And => |consts: &_| {
let inps = read_inputs(consts)?;
let res = inps.into_iter().all(|x| x);
Some(vec![(0.into(), ops::Value::from_bool(res))])
},
NaryLogic::Or => |consts: &_| {
let inps = read_inputs(consts)?;
let res = inps.into_iter().any(|x| x);
Some(vec![(0.into(), ops::Value::from_bool(res))])
},
})
def.set_constant_folder(*self);
}
}

Expand Down Expand Up @@ -224,14 +237,22 @@ fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option<Vec<bool>> {
pub(crate) mod test {
use super::{extension, ConcreteLogicOp, NaryLogic, NotOp, FALSE_NAME, TRUE_NAME};
use crate::{
algorithm::const_fold::constant_fold_pass,
builder::{handle::Outputs, DFGBuilder, Dataflow, DataflowHugr},
extension::{
prelude::BOOL_T,
prelude::{BOOL_T, STRING_TYPE},
simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp},
ExtensionRegistry, PRELUDE,
},
ops::NamedOp,
Extension,
ops::{handle::NodeHandle, NamedOp, Value},
type_row,
types::FunctionType,
utils::test::assert_fully_folded,
Extension, Hugr, HugrView, Node, Wire,
};

use lazy_static::lazy_static;
use rstest::rstest;
use strum::IntoEnumIterator;

#[test]
Expand Down Expand Up @@ -280,4 +301,55 @@ pub(crate) mod test {
pub(crate) fn or_op() -> ConcreteLogicOp {
NaryLogic::Or.with_n_inputs(2)
}

#[rstest]
#[case(NaryLogic::And, [], true)]
#[case(NaryLogic::And, [true, true, true], true)]
#[case(NaryLogic::And, [true, false, true], false)]
#[case(NaryLogic::Or, [], false)]
#[case(NaryLogic::Or, [false, false, true], true)]
#[case(NaryLogic::Or, [false, false, false], false)]
fn nary_const_fold(
#[case] op: NaryLogic,
#[case] ins: impl IntoIterator<Item = bool>,
#[case] out: bool,
) {
use itertools::Itertools;

use crate::extension::ConstFold;
let in_vals = ins
.into_iter()
.enumerate()
.map(|(i, b)| (i.into(), Value::from_bool(b)))
.collect_vec();
assert_eq!(
Some(vec![(0.into(), Value::from_bool(out))]),
op.fold(&[(in_vals.len() as u64).into()], &in_vals)
);
}

#[rstest]
#[case(NaryLogic::And, [Some(true), None], None)]
#[case(NaryLogic::And, [Some(false), None], Some(false))]
#[case(NaryLogic::Or, [None, Some(false)], None)]
#[case(NaryLogic::Or, [None, Some(true)], Some(true))]
fn nary_partial_const_fold(
#[case] op: NaryLogic,
#[case] ins: impl IntoIterator<Item = Option<bool>>,
#[case] mb_out: Option<bool>,
) {
use itertools::Itertools;

use crate::extension::ConstFold;
let in_vals0 = ins.into_iter().enumerate().collect_vec();
let num_args = in_vals0.len() as u64;
let in_vals = in_vals0
.into_iter()
.filter_map(|(i, mb_b)| mb_b.map(|b| (i.into(), Value::from_bool(b))))
.collect_vec();
assert_eq!(
mb_out.map(|out| vec![(0.into(), Value::from_bool(out))]),
op.fold(&[num_args.into()], &in_vals)
);
}
}
2 changes: 1 addition & 1 deletion hugr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ pub(crate) mod test {
match op {
OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
OpType::Const(c) if c.value() == expected_value => node_count += 1,
_ => panic!("unexpected op: {:?}", op),
_ => panic!("unexpected op: {:?}\n{}", op, h.mermaid_string()),
}
}

Expand Down

0 comments on commit 21b5ad6

Please sign in to comment.