diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs index 884002307..b32d54a8c 100644 --- a/hugr/src/algorithm/const_fold.rs +++ b/hugr/src/algorithm/const_fold.rs @@ -297,30 +297,6 @@ mod test { assert_fully_folded(&h, &expected); } - #[rstest] - #[case(NaryLogic::And, [true, true, true], true)] - #[case(NaryLogic::And, [true, false, true], false)] - #[case(NaryLogic::Or, [false, false, true], true)] - #[case(NaryLogic::Or, [false, false, false], false)] - fn test_logic_and( - #[case] op: NaryLogic, - #[case] ins: [bool; 3], - #[case] out: bool, - ) -> Result<(), Box> { - let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); - - let ins = ins.map(|b| build.add_load_const(Value::from_bool(b))); - let logic_op = build.add_dataflow_op(op.with_n_inputs(ins.len() as u64), ins)?; - - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); - let mut h = build.finish_hugr_with_outputs(logic_op.outputs(), ®)?; - constant_fold_pass(&mut h, ®); - - assert_fully_folded(&h, &Value::from_bool(out)); - Ok(()) - } - #[test] #[cfg_attr( feature = "extension_inference", diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs index e7daf7000..5241bf230 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs @@ -1235,8 +1235,6 @@ fn test_fold_itostring_s() { } #[test] -#[should_panic] -// FIXME: https://github.com/CQCL/hugr/issues/996 fn test_fold_int_ops() { // pseudocode: // diff --git a/hugr/src/std_extensions/logic.rs b/hugr/src/std_extensions/logic.rs index 50e6921ce..01b4bec0c 100644 --- a/hugr/src/std_extensions/logic.rs +++ b/hugr/src/std_extensions/logic.rs @@ -2,8 +2,9 @@ use strum_macros::{EnumIter, EnumString, IntoStaticStr}; +use crate::extension::{ConstFold, ConstFoldResult}; use crate::ops::constant::ValueName; -use crate::ops::OpName; +use crate::ops::{OpName, Value}; use crate::{ algorithm::const_fold::sorted_consts, extension::{ @@ -25,6 +26,29 @@ pub const FALSE_NAME: ValueName = ValueName::new_inline("FALSE"); /// Name of extension true value. pub const TRUE_NAME: ValueName = ValueName::new_inline("TRUE"); +impl ConstFold for NaryLogic { + fn fold(&self, type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult { + let [TypeArg::BoundedNat { n: num_args }] = *type_args else { + panic!("impossible by validation"); + }; + match self { + Self::And => { + let inps = read_inputs(consts)?; + let res = inps.iter().all(|x| *x); + // We can only fold to true if we have a const for all our inputs. + (!res || inps.len() as u64 == num_args) + .then_some(vec![(0.into(), ops::Value::from_bool(res))]) + } + Self::Or => { + let inps = read_inputs(consts)?; + let res = inps.iter().any(|x| *x); + // We can only fold to false if we have a const for all our inputs + (res || inps.len() as u64 == num_args) + .then_some(vec![(0.into(), ops::Value::from_bool(res))]) + } + } + } +} /// Logic extension operation definitions. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] #[allow(missing_docs)] @@ -52,18 +76,7 @@ impl MakeOpDef for NaryLogic { } fn post_opdef(&self, def: &mut OpDef) { - def.set_constant_folder(match self { - NaryLogic::And => |consts: &_| { - let inps = read_inputs(consts)?; - let res = inps.into_iter().all(|x| x); - Some(vec![(0.into(), ops::Value::from_bool(res))]) - }, - NaryLogic::Or => |consts: &_| { - let inps = read_inputs(consts)?; - let res = inps.into_iter().any(|x| x); - Some(vec![(0.into(), ops::Value::from_bool(res))]) - }, - }) + def.set_constant_folder(*self); } } @@ -224,14 +237,22 @@ fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option> { pub(crate) mod test { use super::{extension, ConcreteLogicOp, NaryLogic, NotOp, FALSE_NAME, TRUE_NAME}; use crate::{ + algorithm::const_fold::constant_fold_pass, + builder::{handle::Outputs, DFGBuilder, Dataflow, DataflowHugr}, extension::{ - prelude::BOOL_T, + prelude::{BOOL_T, STRING_TYPE}, simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp}, + ExtensionRegistry, PRELUDE, }, - ops::NamedOp, - Extension, + ops::{handle::NodeHandle, NamedOp, Value}, + type_row, + types::FunctionType, + utils::test::assert_fully_folded, + Extension, Hugr, HugrView, Node, Wire, }; + use lazy_static::lazy_static; + use rstest::rstest; use strum::IntoEnumIterator; #[test] @@ -280,4 +301,55 @@ pub(crate) mod test { pub(crate) fn or_op() -> ConcreteLogicOp { NaryLogic::Or.with_n_inputs(2) } + + #[rstest] + #[case(NaryLogic::And, [], true)] + #[case(NaryLogic::And, [true, true, true], true)] + #[case(NaryLogic::And, [true, false, true], false)] + #[case(NaryLogic::Or, [], false)] + #[case(NaryLogic::Or, [false, false, true], true)] + #[case(NaryLogic::Or, [false, false, false], false)] + fn nary_const_fold( + #[case] op: NaryLogic, + #[case] ins: impl IntoIterator, + #[case] out: bool, + ) { + use itertools::Itertools; + + use crate::extension::ConstFold; + let in_vals = ins + .into_iter() + .enumerate() + .map(|(i, b)| (i.into(), Value::from_bool(b))) + .collect_vec(); + assert_eq!( + Some(vec![(0.into(), Value::from_bool(out))]), + op.fold(&[(in_vals.len() as u64).into()], &in_vals) + ); + } + + #[rstest] + #[case(NaryLogic::And, [Some(true), None], None)] + #[case(NaryLogic::And, [Some(false), None], Some(false))] + #[case(NaryLogic::Or, [None, Some(false)], None)] + #[case(NaryLogic::Or, [None, Some(true)], Some(true))] + fn nary_partial_const_fold( + #[case] op: NaryLogic, + #[case] ins: impl IntoIterator>, + #[case] mb_out: Option, + ) { + use itertools::Itertools; + + use crate::extension::ConstFold; + let in_vals0 = ins.into_iter().enumerate().collect_vec(); + let num_args = in_vals0.len() as u64; + let in_vals = in_vals0 + .into_iter() + .filter_map(|(i, mb_b)| mb_b.map(|b| (i.into(), Value::from_bool(b)))) + .collect_vec(); + assert_eq!( + mb_out.map(|out| vec![(0.into(), Value::from_bool(out))]), + op.fold(&[num_args.into()], &in_vals) + ); + } } diff --git a/hugr/src/utils.rs b/hugr/src/utils.rs index 15c2a0386..072f1105e 100644 --- a/hugr/src/utils.rs +++ b/hugr/src/utils.rs @@ -241,7 +241,7 @@ pub(crate) mod test { match op { OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, OpType::Const(c) if c.value() == expected_value => node_count += 1, - _ => panic!("unexpected op: {:?}", op), + _ => panic!("unexpected op: {:?}\n{}", op, h.mermaid_string()), } }