Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Constant-folding of integer and logic operations #1009

Merged
merged 31 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9c96a58
Rename function.
cqc-alec Apr 22, 2024
449b235
Move function and make public.
cqc-alec Apr 22, 2024
751d374
Define convenience method to get a pair of input values.
cqc-alec Apr 29, 2024
2f4899f
Fix signature of divmod operations (output typerow not tuple).
cqc-alec May 8, 2024
824e404
Define const-folders for iwiden_u and iwiden_s.
cqc-alec Apr 22, 2024
d5a57dc
Define const folders for inarrow_u and inarrow_s.
cqc-alec Apr 26, 2024
7996833
Define const folder for not.
cqc-alec Apr 26, 2024
4fb6a1f
Refactor IntOpType to store an arbitrary vector of log-widths.
cqc-alec Apr 26, 2024
57151d0
Fix signatures of itobool and ifrombool.
cqc-alec Apr 29, 2024
4a34abb
Define const folders for remaining integer operations and add some te…
cqc-alec Apr 26, 2024
fc6083f
Extend test (panics because of #996 ).
cqc-alec May 2, 2024
02f4592
Temporarily mark test as "should panic".
cqc-alec May 2, 2024
57fa958
Refactor.
cqc-alec May 6, 2024
cace68a
Fix inot folding.
cqc-alec May 8, 2024
e68a422
Move `Folder` into `extension::const_fold`.
cqc-alec May 7, 2024
daf78d6
Add tests for all integer const-folding functions.
cqc-alec May 7, 2024
4917d92
Remove a lot of unneeded `into()`.
cqc-alec May 8, 2024
f067549
Fix const folder for `not`.
cqc-alec May 8, 2024
9f3f1bd
Add test for const-folding `not`.
cqc-alec May 8, 2024
cd34aa9
Use a better name.
cqc-alec May 8, 2024
197aec5
Use better names.
cqc-alec May 8, 2024
befb826
Move assert_fully_folded() into crate::utils::test.
cqc-alec May 8, 2024
8a96f2e
Move integer const-folding tests to a more suitable place.
cqc-alec May 8, 2024
df16791
Improve test coverage for idivmod_s.
cqc-alec May 8, 2024
ec39e69
Merge branch 'main' into constfoldint
cqc-alec May 8, 2024
e43d84a
Rename file to avoid clippy warning.
cqc-alec May 8, 2024
c6bf5c7
Merge branch 'main' into constfoldint
cqc-alec May 8, 2024
3b326bd
Merge branch 'main' into constfoldint
cqc-alec May 9, 2024
bc41bfd
use new Into<TypeArg> impls
cqc-alec May 9, 2024
a6feddd
Remove redundant inner module.
cqc-alec May 9, 2024
0efe5a4
Rename file.
cqc-alec May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 58 additions & 16 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,15 @@ mod test {

use super::*;
use crate::extension::prelude::{sum_with_error, BOOL_T};
use crate::extension::PRELUDE;
use crate::ops::UnpackTuple;
use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::ops::{OpType, UnpackTuple};
use crate::std_extensions::arithmetic;
use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
use crate::std_extensions::arithmetic::float_ops::FloatOps;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};
use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES};
use crate::std_extensions::logic::{self, NaryLogic};
use crate::std_extensions::logic::{self, NaryLogic, NotOp};
use crate::utils::test::assert_fully_folded;

use rstest::rstest;

Expand Down Expand Up @@ -274,7 +275,7 @@ mod test {
.add_dataflow_op(FloatOps::fsub, unpack.outputs())
.unwrap();
let to_int = build
.add_dataflow_op(ConvertOpDef::trunc_u.with_width(5), sub.outputs())
.add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs())
.unwrap();

let reg = ExtensionRegistry::try_new([
Expand Down Expand Up @@ -362,19 +363,60 @@ mod test {
Ok(())
}

fn assert_fully_folded(h: &Hugr, expected_value: &Value) {
// check the hugr just loads and returns a single const
let mut node_count = 0;
#[test]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all these tests should be moved to a new file src/std_extensions/arithmetic/int_ops/const_fold/test.rs`

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have a redundant mod test inside the new file. See my PR

fn test_fold_and() {
// pseudocode:
// x0, x1 := bool(true), bool(true)
// x2 := and(x0, x1)
// output x2 == true;
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();
let x0 = build.add_load_const(Value::true_val());
let x1 = build.add_load_const(Value::true_val());
let x2 = build
.add_dataflow_op(NaryLogic::And.with_n_inputs(2), [x0, x1])
.unwrap();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}

for node in h.children(h.root()) {
let op = h.get_optype(node);
match op {
OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1,
OpType::Const(c) if c.value() == expected_value => node_count += 1,
_ => panic!("unexpected op: {:?}", op),
}
}
#[test]
fn test_fold_or() {
// pseudocode:
// x0, x1 := bool(true), bool(false)
// x2 := or(x0, x1)
// output x2 == true;
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();
let x0 = build.add_load_const(Value::true_val());
let x1 = build.add_load_const(Value::false_val());
let x2 = build
.add_dataflow_op(NaryLogic::Or.with_n_inputs(2), [x0, x1])
.unwrap();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap();
let mut h = build.finish_hugr_with_outputs(x2.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::true_val();
assert_fully_folded(&h, &expected);
}

assert_eq!(node_count, 4);
#[test]
fn test_fold_not() {
// pseudocode:
// x0 := bool(true)
// x1 := not(x0)
// output x1 == false;
let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap();
let x0 = build.add_load_const(Value::true_val());
let x1 = build.add_dataflow_op(NotOp, [x0]).unwrap();
let reg =
ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap();
let mut h = build.finish_hugr_with_outputs(x1.outputs(), &reg).unwrap();
constant_fold_pass(&mut h, &reg);
let expected = Value::false_val();
assert_fully_folded(&h, &expected);
}
}
2 changes: 1 addition & 1 deletion hugr/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ mod const_fold;
pub mod prelude;
pub mod simple_op;
pub mod validate;
pub use const_fold::{ConstFold, ConstFoldResult};
pub use const_fold::{ConstFold, ConstFoldResult, Folder};
pub use prelude::{PRELUDE, PRELUDE_REGISTRY};

pub mod declarative;
Expand Down
16 changes: 16 additions & 0 deletions hugr/src/extension/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ use std::fmt::Formatter;

use std::fmt::Debug;

use crate::ops::Value;
use crate::types::TypeArg;

use crate::IncomingPort;
use crate::OutgoingPort;

use crate::ops;
Expand Down Expand Up @@ -45,3 +47,17 @@ where
self(consts)
}
}

type FoldFn = dyn Fn(&[TypeArg], &[(IncomingPort, Value)]) -> ConstFoldResult + Send + Sync;

/// Type holding a boxed const-folding function.
pub struct Folder {
/// Const-folding function.
pub folder: Box<FoldFn>,
}

impl ConstFold for Folder {
fn fold(&self, type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult {
(self.folder)(type_args, consts)
}
}
4 changes: 2 additions & 2 deletions hugr/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ mod test {
)?;
let [a] = inner.input_wires_arr();
let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_width(6), [a, c1])?;
let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
inner.finish_with_outputs(a1.outputs())?
};
let [a1] = inner.outputs_arr();

let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_width(6), [a1, b])?;
let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?;
let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs(), &reg)?;

// Sanity checks
Expand Down
5 changes: 4 additions & 1 deletion hugr/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use itertools::Itertools;
use smol_str::SmolStr;
use thiserror::Error;

pub use custom::{downcast_equal_consts, CustomConst, CustomSerialized};
pub use custom::{
downcast_equal_consts, get_pair_of_input_values, get_single_input_value, CustomConst,
CustomSerialized,
};

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// An operation returning a constant value.
Expand Down
21 changes: 21 additions & 0 deletions hugr/src/ops/constant/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ use crate::extension::ExtensionSet;
use crate::macros::impl_box_clone;

use crate::types::{CustomCheckFailure, Type};
use crate::IncomingPort;

use super::Value;

use super::ValueName;

Expand Down Expand Up @@ -118,3 +121,21 @@ impl PartialEq for dyn CustomConst {
(*self).equal_consts(other)
}
}

/// Given a singleton list of constant operations, return the value.
pub fn get_single_input_value<T: CustomConst>(consts: &[(IncomingPort, Value)]) -> Option<&T> {
let [(_, c)] = consts else {
return None;
};
c.get_custom_value()
}

/// Given a list of two constant operations, return the values.
pub fn get_pair_of_input_values<T: CustomConst>(
consts: &[(IncomingPort, Value)],
) -> Option<(&T, &T)> {
let [(_, c0), (_, c1)] = consts else {
return None;
};
Some((c0.get_custom_value()?, c1.get_custom_value()?))
}
16 changes: 9 additions & 7 deletions hugr/src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ impl MakeOpDef for ConvertOpDef {

impl ConvertOpDef {
/// Initialise a conversion op with an integer log width type argument.
pub fn with_width(self, log_width: u8) -> ConvertOpType {
pub fn with_log_width(self, log_width: u8) -> ConvertOpType {
ConvertOpType {
def: self,
log_width: log_width as u64,
log_width,
}
}
}
/// Concrete convert operation with integer width set.
/// Concrete convert operation with integer log width set.
#[derive(Debug, Clone, PartialEq)]
pub struct ConvertOpType {
def: ConvertOpDef,
log_width: u64,
log_width: u8,
}

impl NamedOp for ConvertOpType {
Expand All @@ -99,18 +99,20 @@ impl NamedOp for ConvertOpType {
impl MakeExtensionOp for ConvertOpType {
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
let def = ConvertOpDef::from_def(ext_op.def())?;
let width = match *ext_op.args() {
let log_width: u64 = match *ext_op.args() {
[TypeArg::BoundedNat { n }] => n,
_ => return Err(SignatureError::InvalidTypeArgs.into()),
};
Ok(Self {
def,
log_width: width,
log_width: u8::try_from(log_width).unwrap(),
})
}

fn type_args(&self) -> Vec<crate::types::TypeArg> {
vec![TypeArg::BoundedNat { n: self.log_width }]
vec![TypeArg::BoundedNat {
n: self.log_width as u64,
}]
}
}

Expand Down
18 changes: 6 additions & 12 deletions hugr/src/std_extensions/arithmetic/conversions/const_fold.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use crate::ops::constant::get_single_input_value;
use crate::ops::Value;
use crate::std_extensions::arithmetic::int_types::INT_TYPES;
use crate::{
extension::{
prelude::{sum_with_error, ConstError},
ConstFold, ConstFoldResult, OpDef,
},
ops,
ops::constant::CustomConst,
std_extensions::arithmetic::{
float_types::ConstF64,
int_types::{get_log_width, ConstInt, INT_TYPES},
int_types::{get_log_width, ConstInt},
},
types::ConstTypeError,
IncomingPort,
Expand All @@ -27,19 +28,12 @@ pub(super) fn set_fold(op: &ConvertOpDef, def: &mut OpDef) {
}
}

fn get_input<T: CustomConst>(consts: &[(IncomingPort, ops::Value)]) -> Option<&T> {
let [(_, c)] = consts else {
return None;
};
c.get_custom_value()
}

fn fold_trunc(
type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, Value)],
convert: impl Fn(f64, u8) -> Result<Value, ConstTypeError>,
) -> ConstFoldResult {
let f: &ConstF64 = get_input(consts)?;
let f: &ConstF64 = get_single_input_value(consts)?;
let f = f.value();
let [arg] = type_args else {
return None;
Expand Down Expand Up @@ -105,7 +99,7 @@ impl ConstFold for ConvertU {
_type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
let u: &ConstInt = get_input(consts)?;
let u: &ConstInt = crate::ops::constant::get_single_input_value(consts)?;
let f = u.value_u() as f64;
Some(vec![(0.into(), ConstF64::new(f).into())])
}
Expand All @@ -119,7 +113,7 @@ impl ConstFold for ConvertS {
_type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Value)],
) -> ConstFoldResult {
let u: &ConstInt = get_input(consts)?;
let u: &ConstInt = get_single_input_value(consts)?;
let f = u.value_s() as f64;
Some(vec![(0.into(), ConstF64::new(f).into())])
}
Expand Down
Loading