Skip to content

Commit

Permalink
feat: constant folding for arithmetic conversion operations (#720)
Browse files Browse the repository at this point in the history
Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Alan Lawrence <[email protected]>
Co-authored-by: Luca Mondada <[email protected]>
Co-authored-by: Alec Edgington <[email protected]>
Co-authored-by: Alan Lawrence <[email protected]>
Co-authored-by: Agustín Borgna <[email protected]>
Co-authored-by: Luca Mondada <[email protected]>
  • Loading branch information
8 people authored Jan 3, 2024
1 parent cf69e01 commit 968c8b0
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 12 deletions.
45 changes: 33 additions & 12 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,24 @@ pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) {
#[cfg(test)]
mod test {

use super::*;
use crate::extension::prelude::sum_with_error;
use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::std_extensions::arithmetic;

use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
use crate::std_extensions::arithmetic::float_ops::FloatOps;
use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE};

use crate::std_extensions::arithmetic::int_types::{ConstIntU, INT_TYPES};
use rstest::rstest;

use super::*;
/// int to constant
fn i2c(b: u64) -> Const {
Const::new(
ConstIntU::new(5, b).unwrap().into(),
INT_TYPES[5].to_owned(),
)
.unwrap()
}

/// float to constant
fn f2c(f: f64) -> Const {
Expand All @@ -244,19 +253,19 @@ mod test {

assert_eq!(&out[..], &[(0.into(), f2c(c))]);
}

#[test]
fn test_big() {
/*
Test hugr approximately calculates
let x = (5.5, 3.25);
x.0 - x.1 == 2.25
Test approximately calculates
let x = (5.6, 3.2);
int(x.0 - x.1) == 2
*/
let sum_type = sum_with_error(INT_TYPES[5].to_owned());
let mut build =
DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap();
DFGBuilder::new(FunctionType::new(type_row![], vec![sum_type.clone()])).unwrap();

let tup = build
.add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)]))
.add_load_const(Const::new_tuple([f2c(5.6), f2c(3.2)]))
.unwrap();

let unpack = build
Expand All @@ -271,19 +280,31 @@ mod test {
let sub = build
.add_dataflow_op(FloatOps::fsub, unpack.outputs())
.unwrap();
let to_int = build
.add_dataflow_op(ConvertOpDef::trunc_u.with_width(5), sub.outputs())
.unwrap();

let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
arithmetic::int_types::EXTENSION.to_owned(),
arithmetic::float_types::EXTENSION.to_owned(),
arithmetic::float_ops::EXTENSION.to_owned(),
arithmetic::conversions::EXTENSION.to_owned(),
])
.unwrap();
let mut h = build.finish_hugr_with_outputs(sub.outputs(), &reg).unwrap();
assert_eq!(h.node_count(), 7);
let mut h = build
.finish_hugr_with_outputs(to_int.outputs(), &reg)
.unwrap();
assert_eq!(h.node_count(), 8);

constant_fold_pass(&mut h, &reg);

assert_fully_folded(&h, &f2c(2.25));
let expected = Value::Sum {
tag: 0,
value: Box::new(i2c(2).value().clone()),
};
let expected = Const::new(expected, sum_type).unwrap();
assert_fully_folded(&h, &expected);
}
fn assert_fully_folded(h: &Hugr, expected_const: &Const) {
// check the hugr just loads and returns a single const
Expand Down
14 changes: 14 additions & 0 deletions src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::{
use super::int_types::int_tv;
use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};
use lazy_static::lazy_static;
mod const_fold;
/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");

Expand Down Expand Up @@ -63,8 +64,21 @@ impl MakeOpDef for ConvertOpDef {
}
.to_string()
}

fn post_opdef(&self, def: &mut OpDef) {
const_fold::set_fold(self, def)
}
}

impl ConvertOpDef {
/// Initialise a conversion op with an integer log width type argument.
pub fn with_width(self, log_width: u8) -> ConvertOpType {
ConvertOpType {
def: self,
log_width: log_width as u64,
}
}
}
/// Concrete convert operation with integer width set.
#[derive(Debug, Clone, PartialEq)]
pub struct ConvertOpType {
Expand Down
134 changes: 134 additions & 0 deletions src/std_extensions/arithmetic/conversions/const_fold.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use crate::{
extension::{
prelude::{sum_with_error, ConstError},
ConstFold, ConstFoldResult, OpDef,
},
ops,
std_extensions::arithmetic::{
float_types::ConstF64,
int_types::{get_log_width, ConstIntS, ConstIntU, INT_TYPES},
},
types::ConstTypeError,
values::{CustomConst, Value},
IncomingPort,
};

use super::ConvertOpDef;

pub(super) fn set_fold(op: &ConvertOpDef, def: &mut OpDef) {
use ConvertOpDef::*;

match op {
trunc_u => def.set_constant_folder(TruncU),
trunc_s => def.set_constant_folder(TruncS),
convert_u => def.set_constant_folder(ConvertU),
convert_s => def.set_constant_folder(ConvertS),
}
}

fn get_input<T: CustomConst>(consts: &[(IncomingPort, ops::Const)]) -> Option<&T> {
let [(_, c)] = consts else {
return None;
};
c.get_custom_value()
}

fn fold_trunc(
type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Const)],
convert: impl Fn(f64, u8) -> Result<Value, ConstTypeError>,
) -> ConstFoldResult {
let f: &ConstF64 = get_input(consts)?;
let f = f.value();
let [arg] = type_args else {
return None;
};
let log_width = get_log_width(arg).ok()?;
let int_type = INT_TYPES[log_width as usize].to_owned();
let sum_type = sum_with_error(int_type.clone());
let err_value = || {
let err_val = ConstError {
signal: 0,
message: "Can't truncate non-finite float".to_string(),
};
let sum_val = Value::Sum {
tag: 1,
value: Box::new(err_val.into()),
};

ops::Const::new(sum_val, sum_type.clone()).unwrap()
};
let out_const: ops::Const = if !f.is_finite() {
err_value()
} else {
let cv = convert(f, log_width);
if let Ok(cv) = cv {
let sum_val = Value::Sum {
tag: 0,
value: Box::new(cv),
};

ops::Const::new(sum_val, sum_type).unwrap()
} else {
err_value()
}
};

Some(vec![(0.into(), out_const)])
}

struct TruncU;

impl ConstFold for TruncU {
fn fold(
&self,
type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Const)],
) -> ConstFoldResult {
fold_trunc(type_args, consts, |f, log_width| {
ConstIntU::new(log_width, f.trunc() as u64).map(Into::into)
})
}
}

struct TruncS;

impl ConstFold for TruncS {
fn fold(
&self,
type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Const)],
) -> ConstFoldResult {
fold_trunc(type_args, consts, |f, log_width| {
ConstIntS::new(log_width, f.trunc() as i64).map(Into::into)
})
}
}

struct ConvertU;

impl ConstFold for ConvertU {
fn fold(
&self,
_type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Const)],
) -> ConstFoldResult {
let u: &ConstIntU = get_input(consts)?;
let f = u.value() as f64;
Some(vec![(0.into(), ConstF64::new(f).into())])
}
}

struct ConvertS;

impl ConstFold for ConvertS {
fn fold(
&self,
_type_args: &[crate::types::TypeArg],
consts: &[(IncomingPort, ops::Const)],
) -> ConstFoldResult {
let u: &ConstIntS = get_input(consts)?;
let f = u.value() as f64;
Some(vec![(0.into(), ConstF64::new(f).into())])
}
}

0 comments on commit 968c8b0

Please sign in to comment.