Skip to content

Commit

Permalink
feat: Constant-folding of integer and logic operations (#1009)
Browse files Browse the repository at this point in the history
Add constant folders for all integer and logical operations, and add
tests. Includes a few small refactors and fixes.

One test (which I'd intended to expand into a larger one involving many
operations) is marked "should panic" because of #996 .

Closes #773 .
  • Loading branch information
cqc-alec authored May 9, 2024
1 parent 3370d88 commit b0eb9d3
Show file tree
Hide file tree
Showing 13 changed files with 2,682 additions and 87 deletions.
74 changes: 58 additions & 16 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,15 @@ mod test {

use super::*;
use crate::extension::prelude::{sum_with_error, BOOL_T};
use crate::extension::PRELUDE;
use crate::ops::UnpackTuple;
use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::ops::{OpType, UnpackTuple};
use crate::std_extensions::arithmetic;
use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
use crate::std_extensions::arithmetic::float_ops::FloatOps;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES};
use crate::std_extensions::logic::{self, NaryLogic};
use crate::std_extensions::logic::{self, NaryLogic, NotOp};
use crate::utils::test::assert_fully_folded;

use rstest::rstest;

Expand Down Expand Up @@ -274,7 +275,7 @@ mod test {
.add_dataflow_op(FloatOps::fsub, unpack.outputs())
.unwrap();
let to_int = build
.add_dataflow_op(ConvertOpDef::trunc_u.with_width(5), sub.outputs())
.add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs())
.unwrap();

let reg = ExtensionRegistry::try_new([
Expand Down Expand Up @@ -362,19 +363,60 @@ mod test {
Ok(())
}

fn assert_fully_folded(h: &Hugr, expected_value: &Value) {
// check the hugr just loads and returns a single const
let mut node_count = 0;
#[test]
fn test_fold_and() {
// pseudocode:
// x0, x1 := bool(true), bool(true)
// x2 := and(x0, x1)
// output x2 == true;
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();
let x0 = build.add_load_const(Value::true_val());
let x1 = build.add_load_const(Value::true_val());
let x2 = build
.add_dataflow_op(NaryLogic::And.with_n_inputs(2), [x0, x1])
.unwrap();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}

for node in h.children(h.root()) {
let op = h.get_optype(node);
match op {
OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
OpType::Const(c) if c.value() == expected_value => node_count += 1,
_ => panic!("unexpected op: {:?}", op),
}
}
#[test]
fn test_fold_or() {
// pseudocode:
// x0, x1 := bool(true), bool(false)
// x2 := or(x0, x1)
// output x2 == true;
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();
let x0 = build.add_load_const(Value::true_val());
let x1 = build.add_load_const(Value::false_val());
let x2 = build
.add_dataflow_op(NaryLogic::Or.with_n_inputs(2), [x0, x1])
.unwrap();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}

assert_eq!(node_count, 4);
#[test]
fn test_fold_not() {
// pseudocode:
// x0 := bool(true)
// x1 := not(x0)
// output x1 == false;
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();
let x0 = build.add_load_const(Value::true_val());
let x1 = build.add_dataflow_op(NotOp, [x0]).unwrap();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::false_val();
assert_fully_folded(&h, &expected);
}
}
2 changes: 1 addition & 1 deletion hugr/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ mod const_fold;
pub mod prelude;
pub mod simple_op;
pub mod validate;
pub use const_fold::{ConstFold, ConstFoldResult};
pub use const_fold::{ConstFold, ConstFoldResult, Folder};
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};

pub mod declarative;
Expand Down
16 changes: 16 additions & 0 deletions hugr/src/extension/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ use std::fmt::Formatter;

use std::fmt::Debug;

use crate::ops::Value;
use crate::types::TypeArg;

use crate::IncomingPort;
use crate::OutgoingPort;

use crate::ops;
Expand Down Expand Up @@ -45,3 +47,17 @@ where
self(consts)
}
}

type FoldFn = dyn Fn(&[TypeArg], &[(IncomingPort, Value)]) -> ConstFoldResult + Send + Sync;

/// Type holding a boxed const-folding function.
pub struct Folder {
/// Const-folding function.
pub folder: Box<FoldFn>,
}

impl ConstFold for Folder {
fn fold(&self, type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
(self.folder)(type_args, consts)
}
}
4 changes: 2 additions & 2 deletions hugr/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ mod test {
)?;
let [a] = inner.input_wires_arr();
let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_width(6), [a, c1])?;
let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
inner.finish_with_outputs(a1.outputs())?
};
let [a1] = inner.outputs_arr();

let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_width(6), [a1, b])?;
let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?;
let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs(), &reg)?;

// Sanity checks
Expand Down
5 changes: 4 additions & 1 deletion hugr/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use itertools::Itertools;
use smol_str::SmolStr;
use thiserror::Error;

pub use custom::{downcast_equal_consts, CustomConst, CustomSerialized};
pub use custom::{
downcast_equal_consts, get_pair_of_input_values, get_single_input_value, CustomConst,
CustomSerialized,
};

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// An operation returning a constant value.
Expand Down
21 changes: 21 additions & 0 deletions hugr/src/ops/constant/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ use crate::extension::ExtensionSet;
use crate::macros::impl_box_clone;

use crate::types::{CustomCheckFailure, Type};
use crate::IncomingPort;

use super::Value;

use super::ValueName;

Expand Down Expand Up @@ -118,3 +121,21 @@ impl PartialEq for dyn CustomConst {
(*self).equal_consts(other)
}
}

/// Given a singleton list of constant operations, return the value.
pub fn get_single_input_value<T: CustomConst>(consts: &[(IncomingPort, Value)]) -> Option<&T> {
let [(_, c)] = consts else {
return None;
};
c.get_custom_value()
}

/// Given a list of two constant operations, return the values.
pub fn get_pair_of_input_values<T: CustomConst>(
consts: &[(IncomingPort, Value)],
) -> Option<(&T, &T)> {
let [(_, c0), (_, c1)] = consts else {
return None;
};
Some((c0.get_custom_value()?, c1.get_custom_value()?))
}
16 changes: 9 additions & 7 deletions hugr/src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ impl MakeOpDef for ConvertOpDef {

impl ConvertOpDef {
/// Initialise a conversion op with an integer log width type argument.
pub fn with_width(self, log_width: u8) -> ConvertOpType {
pub fn with_log_width(self, log_width: u8) -> ConvertOpType {
ConvertOpType {
def: self,
log_width: log_width as u64,
log_width,
}
}
}
/// Concrete convert operation with integer width set.
/// Concrete convert operation with integer log width set.
#[derive(Debug, Clone, PartialEq)]
pub struct ConvertOpType {
def: ConvertOpDef,
log_width: u64,
log_width: u8,
}

impl NamedOp for ConvertOpType {
Expand All @@ -99,18 +99,20 @@ impl NamedOp for ConvertOpType {
impl MakeExtensionOp for ConvertOpType {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
let def = ConvertOpDef::from_def(ext_op.def())?;
let width = match *ext_op.args() {
let log_width: u64 = match *ext_op.args() {
[TypeArg::BoundedNat { n }] => n,
_ => return Err(SignatureError::InvalidTypeArgs.into()),
};
Ok(Self {
def,
log_width: width,
log_width: u8::try_from(log_width).unwrap(),
})
}

fn type_args(&self) -> Vec<crate::types::TypeArg> {
vec![TypeArg::BoundedNat { n: self.log_width }]
vec![TypeArg::BoundedNat {
n: self.log_width as u64,
}]
}
}

Expand Down
18 changes: 6 additions & 12 deletions hugr/src/std_extensions/arithmetic/conversions/const_fold.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use crate::ops::constant::get_single_input_value;
use crate::ops::Value;
use crate::std_extensions::arithmetic::int_types::INT_TYPES;
use crate::{
extension::{
prelude::{sum_with_error, ConstError},
ConstFold, ConstFoldResult, OpDef,
},
ops,
ops::constant::CustomConst,
std_extensions::arithmetic::{
float_types::ConstF64,
int_types::{get_log_width, ConstInt, INT_TYPES},
int_types::{get_log_width, ConstInt},
},
types::ConstTypeError,
IncomingPort,
Expand All @@ -27,19 +28,12 @@ pub(super) fn set_fold(op: &ConvertOpDef, def: &mut OpDef) {
}
}

fn get_input<T: CustomConst>(consts: &[(IncomingPort, ops::Value)]) -> Option<&T> {
let [(_, c)] = consts else {
return None;
};
c.get_custom_value()
}

fn fold_trunc(
type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, Value)],
convert: impl Fn(f64, u8) -> Result<Value, ConstTypeError>,
) -> ConstFoldResult {
let f: &ConstF64 = get_input(consts)?;
let f: &ConstF64 = get_single_input_value(consts)?;
let f = f.value();
let [arg] = type_args else {
return None;
Expand Down Expand Up @@ -105,7 +99,7 @@ impl ConstFold for ConvertU {
_type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
let u: &ConstInt = get_input(consts)?;
let u: &ConstInt = crate::ops::constant::get_single_input_value(consts)?;
let f = u.value_u() as f64;
Some(vec![(0.into(), ConstF64::new(f).into())])
}
Expand All @@ -119,7 +113,7 @@ impl ConstFold for ConvertS {
_type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
let u: &ConstInt = get_input(consts)?;
let u: &ConstInt = get_single_input_value(consts)?;
let f = u.value_s() as f64;
Some(vec![(0.into(), ConstF64::new(f).into())])
}
Expand Down
Loading

0 comments on commit b0eb9d3

Please sign in to comment.