Skip to content

Commit

Permalink
remove integer folding
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Dec 22, 2023
1 parent 6fa7eb9 commit 7381432
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 98 deletions.
62 changes: 7 additions & 55 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,70 +220,22 @@ mod test {

use super::*;

/// int to constant
fn i2c(b: u64) -> Const {
Const::new(
ConstIntU::new(5, b).unwrap().into(),
INT_TYPES[5].to_owned(),
)
.unwrap()
}

/// float to constant
fn f2c(f: f64) -> Const {
ConstF64::new(f).into()
}

#[rstest]
#[case(0, 0, 0)]
#[case(0, 1, 1)]
#[case(23, 435, 458)]
#[case(0.0, 0.0, 0.0)]
#[case(0.0, 1.0, 1.0)]
#[case(23.5, 435.5, 459.0)]
// c = a + b
fn test_add(#[case] a: u64, #[case] b: u64, #[case] c: u64) {
let consts = vec![(0.into(), i2c(a)), (1.into(), i2c(b))];
let add_op: OpType = IntOpDef::iadd.with_width(5).into();
fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) {
let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))];
let add_op: OpType = FloatOps::fadd.into();
let out = fold_const(&add_op, &consts).unwrap();

assert_eq!(&out[..], &[(0.into(), i2c(c))]);
}

#[test]
fn test_fold() {
/*
Test hugr calculates
1 + 2 == 3
*/
let mut b = DFGBuilder::new(FunctionType::new(
type_row![],
vec![INT_TYPES[5].to_owned()],
))
.unwrap();

let one = b.add_load_const(i2c(1)).unwrap();
let two = b.add_load_const(i2c(2)).unwrap();

let add = b
.add_dataflow_op(IntOpDef::iadd.with_width(5), [one, two])
.unwrap();
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
arithmetic::int_types::EXTENSION.to_owned(),
arithmetic::int_ops::EXTENSION.to_owned(),
])
.unwrap();
let mut h = b.finish_hugr_with_outputs(add.outputs(), &reg).unwrap();
assert_eq!(h.node_count(), 8);

let (repl, removes) = find_consts(&h, h.nodes(), &reg).exactly_one().ok().unwrap();
let [remove_1, remove_2] = removes.try_into().unwrap();

h.apply_rewrite(repl).unwrap();
for rem in [remove_1, remove_2] {
let const_node = h.apply_rewrite(rem).unwrap();
h.apply_rewrite(RemoveConst(const_node)).unwrap();
}

assert_fully_folded(&h, &i2c(3));
assert_eq!(&out[..], &[(0.into(), f2c(c))]);
}

#[test]
Expand Down
5 changes: 0 additions & 5 deletions src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use lazy_static::lazy_static;
use smol_str::SmolStr;
use strum_macros::{EnumIter, EnumString, IntoStaticStr};

mod fold;
/// The extension identifier.
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int");

Expand Down Expand Up @@ -217,10 +216,6 @@ impl MakeOpDef for IntOpDef {
(rightmost bits replace leftmost bits)",
}.into()
}

fn post_opdef(&self, def: &mut OpDef) {
fold::set_fold(self, def)
}
}
fn int_polytype(
n_vars: usize,
Expand Down
38 changes: 0 additions & 38 deletions src/std_extensions/arithmetic/int_ops/fold.rs

This file was deleted.

0 comments on commit 7381432

Please sign in to comment.