diff --git a/hugr/src/algorithm/const_fold.rs b/hugr/src/algorithm/const_fold.rs index 047665234..884002307 100644 --- a/hugr/src/algorithm/const_fold.rs +++ b/hugr/src/algorithm/const_fold.rs @@ -212,14 +212,15 @@ mod test { use super::*; use crate::extension::prelude::{sum_with_error, BOOL_T}; - use crate::extension::PRELUDE; - use crate::ops::UnpackTuple; + use crate::extension::{ExtensionRegistry, PRELUDE}; + use crate::ops::{OpType, UnpackTuple}; use crate::std_extensions::arithmetic; use crate::std_extensions::arithmetic::conversions::ConvertOpDef; use crate::std_extensions::arithmetic::float_ops::FloatOps; use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; - use crate::std_extensions::logic::{self, NaryLogic}; + use crate::std_extensions::logic::{self, NaryLogic, NotOp}; + use crate::utils::test::assert_fully_folded; use rstest::rstest; @@ -274,7 +275,7 @@ mod test { .add_dataflow_op(FloatOps::fsub, unpack.outputs()) .unwrap(); let to_int = build - .add_dataflow_op(ConvertOpDef::trunc_u.with_width(5), sub.outputs()) + .add_dataflow_op(ConvertOpDef::trunc_u.with_log_width(5), sub.outputs()) .unwrap(); let reg = ExtensionRegistry::try_new([ @@ -362,19 +363,60 @@ mod test { Ok(()) } - fn assert_fully_folded(h: &Hugr, expected_value: &Value) { - // check the hugr just loads and returns a single const - let mut node_count = 0; + #[test] + fn test_fold_and() { + // pseudocode: + // x0, x1 := bool(true), bool(true) + // x2 := and(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_load_const(Value::true_val()); + let x2 = build + .add_dataflow_op(NaryLogic::And.with_n_inputs(2), [x0, x1]) + .unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); + } - for node in h.children(h.root()) { - let op = h.get_optype(node); - match op { - OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, - OpType::Const(c) if c.value() == expected_value => node_count += 1, - _ => panic!("unexpected op: {:?}", op), - } - } + #[test] + fn test_fold_or() { + // pseudocode: + // x0, x1 := bool(true), bool(false) + // x2 := or(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_load_const(Value::false_val()); + let x2 = build + .add_dataflow_op(NaryLogic::Or.with_n_inputs(2), [x0, x1]) + .unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); + } - assert_eq!(node_count, 4); + #[test] + fn test_fold_not() { + // pseudocode: + // x0 := bool(true) + // x1 := not(x0) + // output x1 == false; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::true_val()); + let x1 = build.add_dataflow_op(NotOp, [x0]).unwrap(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::false_val(); + assert_fully_folded(&h, &expected); } } diff --git a/hugr/src/extension.rs b/hugr/src/extension.rs index 13794406a..b504487ed 100644 --- a/hugr/src/extension.rs +++ b/hugr/src/extension.rs @@ -37,7 +37,7 @@ mod const_fold; pub mod prelude; pub mod simple_op; pub mod validate; -pub use const_fold::{ConstFold, ConstFoldResult}; +pub use const_fold::{ConstFold, ConstFoldResult, Folder}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; pub mod declarative; diff --git a/hugr/src/extension/const_fold.rs b/hugr/src/extension/const_fold.rs index bfa4540f9..a3aae93eb 100644 --- a/hugr/src/extension/const_fold.rs +++ b/hugr/src/extension/const_fold.rs @@ -2,8 +2,10 @@ use std::fmt::Formatter; use std::fmt::Debug; +use crate::ops::Value; use crate::types::TypeArg; +use crate::IncomingPort; use crate::OutgoingPort; use crate::ops; @@ -45,3 +47,17 @@ where self(consts) } } + +type FoldFn = dyn Fn(&[TypeArg], &[(IncomingPort, Value)]) -> ConstFoldResult + Send + Sync; + +/// Type holding a boxed const-folding function. +pub struct Folder { + /// Const-folding function. + pub folder: Box, +} + +impl ConstFold for Folder { + fn fold(&self, type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult { + (self.folder)(type_args, consts) + } +} diff --git a/hugr/src/extension/prelude.rs b/hugr/src/extension/prelude.rs index 6ab4c7247..3a6b2637b 100644 --- a/hugr/src/extension/prelude.rs +++ b/hugr/src/extension/prelude.rs @@ -130,10 +130,10 @@ pub const USIZE_T: Type = Type::new_extension(USIZE_CUSTOM_T); pub const BOOL_T: Type = Type::new_unit_sum(2); /// Initialize a new array of element type `element_ty` of length `size` -pub fn array_type(size: TypeArg, element_ty: Type) -> Type { +pub fn array_type(size: impl Into, element_ty: Type) -> Type { let array_def = PRELUDE.get_type("array").unwrap(); let custom_t = array_def - .instantiate(vec![size, TypeArg::Type { ty: element_ty }]) + .instantiate(vec![size.into(), element_ty.into()]) .unwrap(); Type::new_extension(custom_t) } diff --git a/hugr/src/hugr.rs b/hugr/src/hugr.rs index 02b5a25b5..770d56a5f 100644 --- a/hugr/src/hugr.rs +++ b/hugr/src/hugr.rs @@ -8,6 +8,8 @@ pub mod serialize; pub mod validate; pub mod views; +#[cfg(feature = "extension_inference")] +use std::collections::HashMap; use std::collections::VecDeque; use std::iter; @@ -196,8 +198,12 @@ impl Hugr { extension_registry: &ExtensionRegistry, ) -> Result<(), ValidationError> { resolve_extension_ops(self, extension_registry)?; - self.infer_extensions()?; - self.validate(extension_registry)?; + self.validate_no_extensions(extension_registry)?; + #[cfg(feature = "extension_inference")] + { + self.infer_extensions()?; + self.validate_extensions(HashMap::new())?; + } Ok(()) } diff --git a/hugr/src/hugr/rewrite/inline_dfg.rs b/hugr/src/hugr/rewrite/inline_dfg.rs index e4c1e5e37..adba1170f 100644 --- a/hugr/src/hugr/rewrite/inline_dfg.rs +++ b/hugr/src/hugr/rewrite/inline_dfg.rs @@ -206,12 +206,12 @@ mod test { )?; let [a] = inner.input_wires_arr(); let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?; - let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_width(6), [a, c1])?; + let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?; inner.finish_with_outputs(a1.outputs())? }; let [a1] = inner.outputs_arr(); - let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_width(6), [a1, b])?; + let a1_sub_b = outer.add_dataflow_op(IntOpDef::isub.with_log_width(6), [a1, b])?; let mut outer = outer.finish_hugr_with_outputs(a1_sub_b.outputs(), ®)?; // Sanity checks diff --git a/hugr/src/hugr/validate.rs b/hugr/src/hugr/validate.rs index 26857495a..40de5d805 100644 --- a/hugr/src/hugr/validate.rs +++ b/hugr/src/hugr/validate.rs @@ -35,9 +35,6 @@ struct ValidationContext<'a, 'b> { hugr: &'a Hugr, /// Dominator tree for each CFG region, using the container node as index. dominators: HashMap>, - /// Context for the extension validation. - #[allow(dead_code)] - extension_validator: ExtensionValidator, /// Registry of available Extensions extension_registry: &'b ExtensionRegistry, } @@ -48,7 +45,51 @@ impl Hugr { /// TODO: Add a version of validation which allows for open extension /// variables (see github issue #457) pub fn validate(&self, extension_registry: &ExtensionRegistry) -> Result<(), ValidationError> { - self.validate_with_extension_closure(HashMap::new(), extension_registry) + #[cfg(feature = "extension_inference")] + self.validate_with_extension_closure(HashMap::new(), extension_registry)?; + #[cfg(not(feature = "extension_inference"))] + self.validate_no_extensions(extension_registry)?; + Ok(()) + } + + /// Check the validity of the HUGR, but don't check consistency of extension + /// requirements between connected nodes or between parents and children. + pub fn validate_no_extensions( + &self, + extension_registry: &ExtensionRegistry, + ) -> Result<(), ValidationError> { + let mut validator = ValidationContext::new(self, extension_registry); + validator.validate() + } + + /// Validate extensions on the input and output edges of nodes. Check that + /// the target ends of edges require the extensions from the sources, and + /// check extension deltas from parent nodes are reflected in their children + pub fn validate_extensions(&self, closure: ExtensionSolution) -> Result<(), ValidationError> { + let validator = ExtensionValidator::new(self, closure); + for src_node in self.nodes() { + let node_type = self.get_nodetype(src_node); + + // FuncDefns have no resources since they're static nodes, but the + // functions they define can have any extension delta. + if node_type.tag() != OpTag::FuncDefn { + // If this is a container with I/O nodes, check that the extension they + // define match the extensions of the container. + if let Some([input, output]) = self.get_io(src_node) { + validator.validate_io_extensions(src_node, input, output)?; + } + } + + for src_port in self.node_outputs(src_node) { + for (tgt_node, tgt_port) in self.linked_inputs(src_node, src_port) { + validator.check_extensions_compatible( + &(src_node, src_port.into()), + &(tgt_node, tgt_port.into()), + )?; + } + } + } + Ok(()) } /// Check the validity of a hugr, taking an argument of a closure for the @@ -58,8 +99,10 @@ impl Hugr { closure: ExtensionSolution, extension_registry: &ExtensionRegistry, ) -> Result<(), ValidationError> { - let mut validator = ValidationContext::new(self, closure, extension_registry); - validator.validate() + let mut validator = ValidationContext::new(self, extension_registry); + validator.validate()?; + self.validate_extensions(closure)?; + Ok(()) } } @@ -68,15 +111,10 @@ impl<'a, 'b> ValidationContext<'a, 'b> { // Allow unused "extension_closure" variable for when // the "extension_inference" feature is disabled. #[allow(unused_variables)] - pub fn new( - hugr: &'a Hugr, - extension_closure: ExtensionSolution, - extension_registry: &'b ExtensionRegistry, - ) -> Self { + pub fn new(hugr: &'a Hugr, extension_registry: &'b ExtensionRegistry) -> Self { Self { hugr, dominators: HashMap::new(), - extension_validator: ExtensionValidator::new(hugr, extension_closure), extension_registry, } } @@ -176,18 +214,6 @@ impl<'a, 'b> ValidationContext<'a, 'b> { // Secondly that the node has correct children self.validate_children(node, node_type)?; - // FuncDefns have no resources since they're static nodes, but the - // functions they define can have any extension delta. - #[cfg(feature = "extension_inference")] - if node_type.tag() != OpTag::FuncDefn { - // If this is a container with I/O nodes, check that the extension they - // define match the extensions of the container. - if let Some([input, output]) = self.hugr.get_io(node) { - self.extension_validator - .validate_io_extensions(node, input, output)?; - } - } - Ok(()) } @@ -247,10 +273,6 @@ impl<'a, 'b> ValidationContext<'a, 'b> { let other_node: Node = self.hugr.graph.port_node(link).unwrap().into(); let other_offset = self.hugr.graph.port_offset(link).unwrap().into(); - #[cfg(feature = "extension_inference")] - self.extension_validator - .check_extensions_compatible(&(node, port), &(other_node, other_offset))?; - let other_op = self.hugr.get_optype(other_node); let Some(other_kind) = other_op.port_kind(other_offset) else { panic!("The number of ports in {other_node} does not match the operation definition. This should have been caught by `validate_node`."); diff --git a/hugr/src/hugr/validate/test.rs b/hugr/src/hugr/validate/test.rs index 424c86513..a1518fcb2 100644 --- a/hugr/src/hugr/validate/test.rs +++ b/hugr/src/hugr/validate/test.rs @@ -3,20 +3,22 @@ use cool_asserts::assert_matches; use super::*; use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{ - BuildError, Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, - HugrBuilder, ModuleBuilder, + BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, }; -use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T}; -use crate::extension::{Extension, ExtensionId, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY}; +use crate::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, USIZE_T}; +use crate::extension::{Extension, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::HugrMut; use crate::ops::dataflow::IOTrait; -use crate::ops::{self, Noop, Value}; +use crate::ops::handle::NodeHandle; +use crate::ops::leaf::MakeTuple; +use crate::ops::{self, Noop, OpType, Value}; use crate::std_extensions::logic::test::{and_op, or_op}; use crate::std_extensions::logic::{self, NotOp}; use crate::types::type_param::{TypeArg, TypeArgError}; use crate::types::{CustomType, FunctionType, PolyFuncType, Type, TypeBound, TypeRow}; -use crate::{type_row, IncomingPort}; +use crate::{const_extension_ids, type_row, Direction, IncomingPort, Node}; const NAT: Type = crate::extension::prelude::USIZE_T; @@ -336,10 +338,12 @@ fn unregistered_extension() { h.update_validate(&PRELUDE_REGISTRY).unwrap(); } +const_extension_ids! { + const EXT_ID: ExtensionId = "MyExt"; +} #[test] fn invalid_types() { - let name: ExtensionId = "MyExt".try_into().unwrap(); - let mut e = Extension::new(name.clone()); + let mut e = Extension::new(EXT_ID); e.add_type( "MyContainer".into(), vec![TypeBound::Copyable.into()], @@ -360,7 +364,7 @@ fn invalid_types() { let valid = Type::new_extension(CustomType::new( "MyContainer", vec![TypeArg::Type { ty: USIZE_T }], - name.clone(), + EXT_ID, TypeBound::Any, )); assert_eq!( @@ -374,7 +378,7 @@ fn invalid_types() { let element_outside_bound = CustomType::new( "MyContainer", vec![TypeArg::Type { ty: valid.clone() }], - name.clone(), + EXT_ID, TypeBound::Any, ); assert_eq!( @@ -388,7 +392,7 @@ fn invalid_types() { let bad_bound = CustomType::new( "MyContainer", vec![TypeArg::Type { ty: USIZE_T }], - name.clone(), + EXT_ID, TypeBound::Copyable, ); assert_eq!( @@ -405,7 +409,7 @@ fn invalid_types() { vec![TypeArg::Type { ty: Type::new_extension(bad_bound), }], - name.clone(), + EXT_ID, TypeBound::Any, ); assert_eq!( @@ -419,7 +423,7 @@ fn invalid_types() { let too_many_type_args = CustomType::new( "MyContainer", vec![TypeArg::Type { ty: USIZE_T }, TypeArg::BoundedNat { n: 3 }], - name.clone(), + EXT_ID, TypeBound::Any, ); assert_eq!( @@ -544,18 +548,101 @@ fn no_polymorphic_consts() -> Result<(), Box> { #[test] fn test_polymorphic_call() -> Result<(), Box> { - let mut m = ModuleBuilder::new(); - let id = m.declare( - "id", + let mut e = Extension::new(EXT_ID); + + let params: Vec = vec![ + TypeBound::Any.into(), + TypeParam::Extensions, + TypeBound::Any.into(), + ]; + let evaled_fn = Type::new_function( + FunctionType::new( + Type::new_var_use(0, TypeBound::Any), + Type::new_var_use(2, TypeBound::Any), + ) + .with_extension_delta(ExtensionSet::type_var(1)), + ); + // The higher-order "eval" operation - takes a function and its argument. + // Note the extension-delta of the eval node includes that of the input function. + e.add_op( + "eval".into(), + "".into(), PolyFuncType::new( - vec![TypeBound::Any.into()], - FunctionType::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + params.clone(), + FunctionType::new( + vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], + Type::new_var_use(2, TypeBound::Any), + ) + .with_extension_delta(ExtensionSet::type_var(1)), ), )?; - let mut f = m.define_function("main", FunctionType::new_endo(vec![USIZE_T]).into())?; - let c = f.call(&id, &[USIZE_T.into()], f.input_wires(), &PRELUDE_REGISTRY)?; - f.finish_with_outputs(c.outputs())?; - let _ = m.finish_prelude_hugr()?; + + fn utou(e: impl Into) -> Type { + Type::new_function(FunctionType::new_endo(USIZE_T).with_extension_delta(e.into())) + } + + let int_pair = Type::new_tuple(type_row![USIZE_T; 2]); + // Root DFG: applies a function int--PRELUDE-->int to each element of a pair of two ints + let mut d = DFGBuilder::new( + FunctionType::new( + vec![utou(PRELUDE_ID), int_pair.clone()], + vec![int_pair.clone()], + ) + .with_extension_delta(PRELUDE_ID), + )?; + // ....by calling a function parametrized (int--e-->int, int_pair) -> int_pair + let f = { + let es = ExtensionSet::type_var(0); + let mut f = d.define_function( + "two_ints", + PolyFuncType::new( + vec![TypeParam::Extensions], + FunctionType::new(vec![utou(es.clone()), int_pair.clone()], int_pair.clone()) + .with_extension_delta(es.clone()), + ), + )?; + let [func, tup] = f.input_wires_arr(); + let mut c = f.conditional_builder( + (vec![type_row![USIZE_T; 2]], tup), + vec![], + type_row![USIZE_T;2], + es.clone(), + )?; + let mut cc = c.case_builder(0)?; + let [i1, i2] = cc.input_wires_arr(); + let op = e.instantiate_extension_op( + "eval", + vec![USIZE_T.into(), TypeArg::Extensions { es }, USIZE_T.into()], + &PRELUDE_REGISTRY, + )?; + let [f1] = cc.add_dataflow_op(op.clone(), [func, i1])?.outputs_arr(); + let [f2] = cc.add_dataflow_op(op, [func, i2])?.outputs_arr(); + cc.finish_with_outputs([f1, f2])?; + let res = c.finish_sub_container()?.outputs(); + let tup = f.add_dataflow_op( + MakeTuple { + tys: type_row![USIZE_T; 2], + }, + res, + )?; + f.finish_with_outputs(tup.outputs())? + }; + + let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()])?; + let [func, tup] = d.input_wires_arr(); + let call = d.call( + f.handle(), + &[TypeArg::Extensions { + es: ExtensionSet::singleton(&PRELUDE_ID), + }], + [func, tup], + ®, + )?; + let h = d.finish_hugr_with_outputs(call.outputs(), ®)?; + let call_ty = h.get_optype(call.node()).dataflow_signature().unwrap(); + let exp_fun_ty = FunctionType::new(vec![utou(PRELUDE_ID), int_pair.clone()], int_pair) + .with_extension_delta(PRELUDE_ID); + assert_eq!(call_ty, exp_fun_ty); Ok(()) } @@ -829,19 +916,21 @@ mod extension_tests { let all_rs = ExtensionSet::from_iter([XA, XB]); - let main_sig = FunctionType::new(type_row![], type_row![NAT]) + let main_sig = FunctionType::new(type_row![NAT], type_row![NAT]) .with_extension_delta(all_rs.clone()) .into(); let mut main = module_builder.define_function("main", main_sig)?; + let [inp_wire] = main.input_wires_arr(); + let [left_wire] = main .dfg_builder( FunctionType::new(type_row![], type_row![NAT]), Some(XA.into()), [], )? - .finish_with_outputs([])? + .finish_with_outputs([inp_wire])? .outputs_arr(); let [right_wire] = main @@ -850,7 +939,7 @@ mod extension_tests { Some(XB.into()), [], )? - .finish_with_outputs([])? + .finish_with_outputs([inp_wire])? .outputs_arr(); let builder = main.dfg_builder( @@ -858,8 +947,8 @@ mod extension_tests { Some(all_rs), [left_wire, right_wire], )?; - let [_left, _right] = builder.input_wires_arr(); - let [output] = builder.finish_with_outputs([])?.outputs_arr(); + let [left, _] = builder.input_wires_arr(); + let [output] = builder.finish_with_outputs([left])?.outputs_arr(); main.finish_with_outputs([output])?; let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); diff --git a/hugr/src/ops/constant.rs b/hugr/src/ops/constant.rs index caae40ada..86ca0c441 100644 --- a/hugr/src/ops/constant.rs +++ b/hugr/src/ops/constant.rs @@ -14,7 +14,10 @@ use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use thiserror::Error; -pub use custom::{downcast_equal_consts, CustomConst, CustomSerialized}; +pub use custom::{ + downcast_equal_consts, get_pair_of_input_values, get_single_input_value, CustomConst, + CustomSerialized, +}; #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] /// An operation returning a constant value. diff --git a/hugr/src/ops/constant/custom.rs b/hugr/src/ops/constant/custom.rs index 81f04c0b2..54d58cf5d 100644 --- a/hugr/src/ops/constant/custom.rs +++ b/hugr/src/ops/constant/custom.rs @@ -13,6 +13,9 @@ use crate::extension::ExtensionSet; use crate::macros::impl_box_clone; use crate::types::{CustomCheckFailure, Type}; +use crate::IncomingPort; + +use super::Value; use super::ValueName; @@ -452,3 +455,21 @@ mod test { ); } } + +/// Given a singleton list of constant operations, return the value. +pub fn get_single_input_value(consts: &[(IncomingPort, Value)]) -> Option<&T> { + let [(_, c)] = consts else { + return None; + }; + c.get_custom_value() +} + +/// Given a list of two constant operations, return the values. +pub fn get_pair_of_input_values( + consts: &[(IncomingPort, Value)], +) -> Option<(&T, &T)> { + let [(_, c0), (_, c1)] = consts else { + return None; + }; + Some((c0.get_custom_value()?, c1.get_custom_value()?)) +} diff --git a/hugr/src/std_extensions/arithmetic/conversions.rs b/hugr/src/std_extensions/arithmetic/conversions.rs index 817d421ca..b6dcab87b 100644 --- a/hugr/src/std_extensions/arithmetic/conversions.rs +++ b/hugr/src/std_extensions/arithmetic/conversions.rs @@ -76,18 +76,18 @@ impl MakeOpDef for ConvertOpDef { impl ConvertOpDef { /// Initialise a conversion op with an integer log width type argument. - pub fn with_width(self, log_width: u8) -> ConvertOpType { + pub fn with_log_width(self, log_width: u8) -> ConvertOpType { ConvertOpType { def: self, - log_width: log_width as u64, + log_width, } } } -/// Concrete convert operation with integer width set. +/// Concrete convert operation with integer log width set. #[derive(Debug, Clone, PartialEq)] pub struct ConvertOpType { def: ConvertOpDef, - log_width: u64, + log_width: u8, } impl NamedOp for ConvertOpType { @@ -99,18 +99,20 @@ impl NamedOp for ConvertOpType { impl MakeExtensionOp for ConvertOpType { fn from_extension_op(ext_op: &ExtensionOp) -> Result { let def = ConvertOpDef::from_def(ext_op.def())?; - let width = match *ext_op.args() { + let log_width: u64 = match *ext_op.args() { [TypeArg::BoundedNat { n }] => n, _ => return Err(SignatureError::InvalidTypeArgs.into()), }; Ok(Self { def, - log_width: width, + log_width: u8::try_from(log_width).unwrap(), }) } fn type_args(&self) -> Vec { - vec![TypeArg::BoundedNat { n: self.log_width }] + vec![TypeArg::BoundedNat { + n: self.log_width as u64, + }] } } diff --git a/hugr/src/std_extensions/arithmetic/conversions/const_fold.rs b/hugr/src/std_extensions/arithmetic/conversions/const_fold.rs index 44fd95840..69dd724ae 100644 --- a/hugr/src/std_extensions/arithmetic/conversions/const_fold.rs +++ b/hugr/src/std_extensions/arithmetic/conversions/const_fold.rs @@ -1,14 +1,15 @@ +use crate::ops::constant::get_single_input_value; use crate::ops::Value; +use crate::std_extensions::arithmetic::int_types::INT_TYPES; use crate::{ extension::{ prelude::{sum_with_error, ConstError}, ConstFold, ConstFoldResult, OpDef, }, ops, - ops::constant::CustomConst, std_extensions::arithmetic::{ float_types::ConstF64, - int_types::{get_log_width, ConstInt, INT_TYPES}, + int_types::{get_log_width, ConstInt}, }, types::ConstTypeError, IncomingPort, @@ -27,19 +28,12 @@ pub(super) fn set_fold(op: &ConvertOpDef, def: &mut OpDef) { } } -fn get_input(consts: &[(IncomingPort, ops::Value)]) -> Option<&T> { - let [(_, c)] = consts else { - return None; - }; - c.get_custom_value() -} - fn fold_trunc( type_args: &[crate::types::TypeArg], consts: &[(IncomingPort, Value)], convert: impl Fn(f64, u8) -> Result, ) -> ConstFoldResult { - let f: &ConstF64 = get_input(consts)?; + let f: &ConstF64 = get_single_input_value(consts)?; let f = f.value(); let [arg] = type_args else { return None; @@ -105,7 +99,7 @@ impl ConstFold for ConvertU { _type_args: &[crate::types::TypeArg], consts: &[(IncomingPort, ops::Value)], ) -> ConstFoldResult { - let u: &ConstInt = get_input(consts)?; + let u: &ConstInt = crate::ops::constant::get_single_input_value(consts)?; let f = u.value_u() as f64; Some(vec![(0.into(), ConstF64::new(f).into())]) } @@ -119,7 +113,7 @@ impl ConstFold for ConvertS { _type_args: &[crate::types::TypeArg], consts: &[(IncomingPort, ops::Value)], ) -> ConstFoldResult { - let u: &ConstInt = get_input(consts)?; + let u: &ConstInt = get_single_input_value(consts)?; let f = u.value_s() as f64; Some(vec![(0.into(), ConstF64::new(f).into())]) } diff --git a/hugr/src/std_extensions/arithmetic/int_ops.rs b/hugr/src/std_extensions/arithmetic/int_ops.rs index e7d2efa7b..1219084e9 100644 --- a/hugr/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr/src/std_extensions/arithmetic/int_ops.rs @@ -8,6 +8,7 @@ use crate::extension::{ }; use crate::ops::custom::ExtensionOp; use crate::ops::{NamedOp, OpName}; +use crate::std_extensions::arithmetic::int_types::int_type; use crate::type_row; use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; @@ -21,6 +22,8 @@ use crate::{ use lazy_static::lazy_static; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; +mod const_fold; + /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int"); @@ -113,8 +116,8 @@ impl MakeOpDef for IntOpDef { IOValidator { f_ge_s: true }, ) .into(), - itobool => int_polytype(1, vec![int_tv(0)], type_row![BOOL_T]).into(), - ifrombool => int_polytype(1, type_row![BOOL_T], vec![int_tv(0)]).into(), + itobool => int_polytype(0, vec![int_type(0)], type_row![BOOL_T]).into(), + ifrombool => int_polytype(0, type_row![BOOL_T], vec![int_type(0)]).into(), ieq | ine | ilt_u | ilt_s | igt_u | igt_s | ile_u | ile_s | ige_u | ige_s => { int_polytype(1, vec![int_tv(0); 2], type_row![BOOL_T]).into() } @@ -134,7 +137,7 @@ impl MakeOpDef for IntOpDef { .into(), idivmod_u | idivmod_s => { let intpair: TypeRow = vec![int_tv(0), int_tv(1)].into(); - int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)]) + int_polytype(2, intpair.clone(), intpair.clone()) } .into(), idiv_u | idiv_s => int_polytype(2, vec![int_tv(0), int_tv(1)], vec![int_tv(0)]).into(), @@ -225,6 +228,10 @@ impl MakeOpDef for IntOpDef { itostring_u => "convert an unsigned integer to its string representation", }.into() } + + fn post_opdef(&self, def: &mut OpDef) { + const_fold::set_fold(self, def) + } } fn int_polytype( n_vars: usize, @@ -270,12 +277,11 @@ lazy_static! { .unwrap(); } -/// Concrete integer operation with either one or two integer widths set. +/// Concrete integer operation with integer widths set. #[derive(Debug, Clone, PartialEq)] pub struct IntOpType { def: IntOpDef, - first_width: u64, - second_width: Option, + log_widths: Vec, } impl NamedOp for IntOpType { @@ -286,24 +292,16 @@ impl NamedOp for IntOpType { impl MakeExtensionOp for IntOpType { fn from_extension_op(ext_op: &ExtensionOp) -> Result { let def = IntOpDef::from_def(ext_op.def())?; - let (first_width, second_width) = match *ext_op.args() { - [TypeArg::BoundedNat { n }] => (n, None), - [TypeArg::BoundedNat { n }, TypeArg::BoundedNat { n: n2 }] => (n, Some(n2)), - _ => return Err(SignatureError::InvalidTypeArgs.into()), - }; - Ok(Self { - def, - first_width, - second_width, - }) + let args = ext_op.args(); + let log_widths: Vec = args + .iter() + .map(|a| get_log_width(a).map_err(|_| SignatureError::InvalidTypeArgs)) + .collect::>()?; + Ok(Self { def, log_widths }) } fn type_args(&self) -> Vec { - [Some(self.first_width), self.second_width] - .iter() - .flatten() - .map(|&n| TypeArg::BoundedNat { n }) - .collect() + self.log_widths.iter().map(|&n| (n as u64).into()).collect() } } @@ -318,22 +316,28 @@ impl MakeRegisteredOp for IntOpType { } impl IntOpDef { + /// Initialize a concrete [IntOpType] from a [IntOpDef] which requires no + /// integer widths set. + pub fn without_log_width(self) -> IntOpType { + IntOpType { + def: self, + log_widths: vec![], + } + } /// Initialize a concrete [IntOpType] from a [IntOpDef] which requires one /// integer width set. - pub fn with_width(self, width: u64) -> IntOpType { + pub fn with_log_width(self, log_width: u8) -> IntOpType { IntOpType { def: self, - first_width: width, - second_width: None, + log_widths: vec![log_width], } } /// Initialize a concrete [IntOpType] from a [IntOpDef] which requires two /// integer widths set. - pub fn with_two_widths(self, first_width: u64, second_width: u64) -> IntOpType { + pub fn with_two_log_widths(self, first_log_width: u8, second_log_width: u8) -> IntOpType { IntOpType { def: self, - first_width, - second_width: Some(second_width), + log_widths: vec![first_log_width, second_log_width], } } } @@ -354,41 +358,35 @@ mod test { } } - const fn ta(n: u64) -> TypeArg { - TypeArg::BoundedNat { n } - } #[test] fn test_binary_signatures() { assert_eq!( IntOpDef::iwiden_s - .with_two_widths(3, 4) + .with_two_log_widths(3, 4) .to_extension_op() .unwrap() .signature(), - FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(4))],) + FunctionType::new(vec![int_type(3)], vec![int_type(4)],) ); assert_eq!( IntOpDef::iwiden_s - .with_two_widths(3, 3) + .with_two_log_widths(3, 3) .to_extension_op() .unwrap() .signature(), - FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(3))],) + FunctionType::new(vec![int_type(3)], vec![int_type(3)],) ); assert_eq!( IntOpDef::inarrow_s - .with_two_widths(3, 3) + .with_two_log_widths(3, 3) .to_extension_op() .unwrap() .signature(), - FunctionType::new( - vec![int_type(ta(3))], - vec![sum_with_error(int_type(ta(3))).into()], - ) + FunctionType::new(vec![int_type(3)], vec![sum_with_error(int_type(3)).into()],) ); assert!( IntOpDef::iwiden_u - .with_two_widths(4, 3) + .with_two_log_widths(4, 3) .to_extension_op() .is_none(), "type arguments invalid" @@ -396,28 +394,25 @@ mod test { assert_eq!( IntOpDef::inarrow_s - .with_two_widths(2, 1) + .with_two_log_widths(2, 1) .to_extension_op() .unwrap() .signature(), - FunctionType::new( - vec![int_type(ta(2))], - vec![sum_with_error(int_type(ta(1))).into()], - ) + FunctionType::new(vec![int_type(2)], vec![sum_with_error(int_type(1)).into()],) ); assert!(IntOpDef::inarrow_u - .with_two_widths(1, 2) + .with_two_log_widths(1, 2) .to_extension_op() .is_none()); } #[test] fn test_conversions() { - let o = IntOpDef::itobool.with_width(5); + let o = IntOpDef::itobool.without_log_width(); assert!( IntOpDef::itobool - .with_two_widths(1, 2) + .with_two_log_widths(1, 2) .to_extension_op() .is_none(), "type arguments invalid" diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs new file mode 100644 index 000000000..0915a4737 --- /dev/null +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold.rs @@ -0,0 +1,1204 @@ +use std::cmp::{max, min}; + +use crate::{ + extension::{ + prelude::{sum_with_error, ConstError, ConstString}, + ConstFoldResult, Folder, OpDef, + }, + ops::{ + constant::{get_pair_of_input_values, get_single_input_value}, + Value, + }, + std_extensions::arithmetic::int_types::{get_log_width, ConstInt, INT_TYPES}, + types::{SumType, Type, TypeArg}, + IncomingPort, +}; + +use super::IntOpDef; + +fn bitmask_from_width(width: u64) -> u64 { + debug_assert!(width <= 64); + if width == 64 { + u64::MAX + } else { + (1u64 << width) - 1 + } +} + +fn bitmask_from_logwidth(logwidth: u8) -> u64 { + bitmask_from_width(1u64 << logwidth) +} + +// return q, r s.t. n = qm + r, 0 <= r < m +fn divmod_s(n: i64, m: u64) -> (i64, u64) { + // This is quite hairy. + if n >= 0 { + let n_u = n as u64; + ((n_u / m) as i64, n_u % m) + } else if n != i64::MIN { + // -2^63 < n < 0 + let n_u = (-n) as u64; + let q = (n_u / m) as i64; + let r = n_u % m; + if r == 0 { + (-q, 0) + } else { + (-q - 1, m - r) + } + } else if m == 1 { + // n = -2^63, m = 1 + (n, 0) + } else if m < (1u64 << 63) { + // n = -2^63, 1 < m < 2^63 + let m_s = m as i64; + let q = n / m_s; + let r = n % m_s; + if r == 0 { + (q, 0) + } else { + (q - 1, (m_s - r) as u64) + } + } else { + // n = -2^63, m >= 2^63 + (-1, m - (1u64 << 63)) + } +} + +pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { + def.set_constant_folder(match op { + IntOpDef::iwiden_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let n0: &ConstInt = get_single_input_value(consts)?; + if logwidth0 > logwidth1 || n0.log_width() != logwidth0 { + None + } else { + let n1 = ConstInt::new_u(logwidth1, n0.value_u()).ok()?; + Some(vec![(0.into(), n1.into())]) + } + }, + ), + }, + IntOpDef::iwiden_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let n0: &ConstInt = get_single_input_value(consts)?; + if logwidth0 > logwidth1 || n0.log_width() != logwidth0 { + None + } else { + let n1 = ConstInt::new_s(logwidth1, n0.value_s()).ok()?; + Some(vec![(0.into(), n1.into())]) + } + }, + ), + }, + IntOpDef::inarrow_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let n0: &ConstInt = get_single_input_value(consts)?; + + let int_out_type = INT_TYPES[logwidth1 as usize].to_owned(); + let sum_type = sum_with_error(int_out_type.clone()); + let err_value = || { + let err_val = ConstError { + signal: 0, + message: "Integer too large to narrow".to_string(), + }; + Value::sum(1, [err_val.into()], sum_type.clone()) + .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) + }; + let n0val: u64 = n0.value_u(); + let out_const: Value = if n0val >> (1 << logwidth1) != 0 { + err_value() + } else { + Value::extension(ConstInt::new_u(logwidth1, n0val).unwrap()) + }; + if logwidth0 < logwidth1 || n0.log_width() != logwidth0 { + None + } else { + Some(vec![(0.into(), out_const)]) + } + }, + ), + }, + IntOpDef::inarrow_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let n0: &ConstInt = get_single_input_value(consts)?; + + let int_out_type = INT_TYPES[logwidth1 as usize].to_owned(); + let sum_type = sum_with_error(int_out_type.clone()); + let err_value = || { + let err_val = ConstError { + signal: 0, + message: "Integer too large to narrow".to_string(), + }; + Value::sum(1, [err_val.into()], sum_type.clone()) + .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) + }; + let n0val: i64 = n0.value_s(); + let ub = 1i64 << ((1 << logwidth1) - 1); + let out_const: Value = if n0val >= ub || n0val < -ub { + err_value() + } else { + Value::extension(ConstInt::new_s(logwidth1, n0val).unwrap()) + }; + if logwidth0 < logwidth1 || n0.log_width() != logwidth0 { + None + } else { + Some(vec![(0.into(), out_const)]) + } + }, + ), + }, + IntOpDef::itobool => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + if !type_args.is_empty() { + return None; + } + let n0: &ConstInt = get_single_input_value(consts)?; + if n0.log_width() != 0 { + None + } else { + Some(vec![(0.into(), Value::from_bool(n0.value_u() == 1))]) + } + }, + ), + }, + IntOpDef::ifrombool => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + if !type_args.is_empty() { + return None; + } + let [(_, b0)] = consts else { + return None; + }; + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u( + 0, + if b0.clone() == Value::true_val() { + 1 + } else { + 0 + }, + ) + .unwrap(), + ), + )]) + }, + ), + }, + IntOpDef::ieq => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::from_bool(n0.value_u() == n1.value_u()), + )]) + } + }, + ), + }, + IntOpDef::ine => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::from_bool(n0.value_u() != n1.value_u()), + )]) + } + }, + ), + }, + IntOpDef::ilt_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::from_bool(n0.value_u() < n1.value_u()), + )]) + } + }, + ), + }, + IntOpDef::ilt_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::from_bool(n0.value_s() < n1.value_s()), + )]) + } + }, + ), + }, + IntOpDef::igt_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::from_bool(n0.value_u() > n1.value_u()), + )]) + } + }, + ), + }, + IntOpDef::igt_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::from_bool(n0.value_s() > n1.value_s()), + )]) + } + }, + ), + }, + IntOpDef::ile_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::from_bool(n0.value_u() <= n1.value_u()), + )]) + } + }, + ), + }, + IntOpDef::ile_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::from_bool(n0.value_s() <= n1.value_s()), + )]) + } + }, + ), + }, + IntOpDef::ige_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::from_bool(n0.value_u() >= n1.value_u()), + )]) + } + }, + ), + }, + IntOpDef::ige_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::from_bool(n0.value_s() >= n1.value_s()), + )]) + } + }, + ), + }, + IntOpDef::imax_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u(logwidth, max(n0.value_u(), n1.value_u())).unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::imax_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_s(logwidth, max(n0.value_s(), n1.value_s())).unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::imin_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u(logwidth, min(n0.value_u(), n1.value_u())).unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::imin_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_s(logwidth, min(n0.value_s(), n1.value_s())).unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::iadd => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u( + logwidth, + n0.value_u().overflowing_add(n1.value_u()).0 + & bitmask_from_logwidth(logwidth), + ) + .unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::isub => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u( + logwidth, + n0.value_u().overflowing_sub(n1.value_u()).0 + & bitmask_from_logwidth(logwidth), + ) + .unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::ineg => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let n0: &ConstInt = get_single_input_value(consts)?; + if n0.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u( + logwidth, + n0.value_u().overflowing_neg().0 + & bitmask_from_logwidth(logwidth), + ) + .unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::imul => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u( + logwidth, + n0.value_u().overflowing_mul(n1.value_u()).0 + & bitmask_from_logwidth(logwidth), + ) + .unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::idivmod_checked_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n, m): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n.log_width() != logwidth0 || m.log_width() != logwidth1 { + None + } else { + let q_type = INT_TYPES[logwidth0 as usize].to_owned(); + let r_type = INT_TYPES[logwidth1 as usize].to_owned(); + let qr_type: Type = Type::new_tuple(vec![q_type, r_type]); + let sum_type: SumType = sum_with_error(qr_type); + let err_value = || { + let err_val = ConstError { + signal: 0, + message: "Division by zero".to_string(), + }; + Value::sum(1, [err_val.into()], sum_type.clone()) + .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) + }; + let nval = n.value_u(); + let mval = m.value_u(); + let out_const: Value = if mval == 0 { + err_value() + } else { + let qval = nval / mval; + let rval = nval % mval; + Value::tuple(vec![ + Value::extension(ConstInt::new_u(logwidth0, qval).unwrap()), + Value::extension(ConstInt::new_u(logwidth1, rval).unwrap()), + ]) + }; + Some(vec![(0.into(), out_const)]) + } + }, + ), + }, + IntOpDef::idivmod_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n, m): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + let nval = n.value_u(); + let mval = m.value_u(); + if n.log_width() != logwidth0 || m.log_width() != logwidth1 || mval == 0 { + None + } else { + let qval = nval / mval; + let rval = nval % mval; + let q = Value::extension(ConstInt::new_u(logwidth0, qval).unwrap()); + let r = Value::extension(ConstInt::new_u(logwidth1, rval).unwrap()); + Some(vec![(0.into(), q), (1.into(), r)]) + } + }, + ), + }, + IntOpDef::idivmod_checked_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n, m): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n.log_width() != logwidth0 || m.log_width() != logwidth1 { + None + } else { + let q_type = INT_TYPES[logwidth0 as usize].to_owned(); + let r_type = INT_TYPES[logwidth1 as usize].to_owned(); + let qr_type: Type = Type::new_tuple(vec![q_type, r_type]); + let sum_type: SumType = sum_with_error(qr_type); + let err_value = || { + let err_val = ConstError { + signal: 0, + message: "Division by zero".to_string(), + }; + Value::sum(1, [err_val.into()], sum_type.clone()) + .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) + }; + let nval = n.value_s(); + let mval = m.value_u(); + let out_const: Value = if mval == 0 { + err_value() + } else { + let (qval, rval) = divmod_s(nval, mval); + Value::tuple(vec![ + Value::extension(ConstInt::new_s(logwidth0, qval).unwrap()), + Value::extension(ConstInt::new_u(logwidth1, rval).unwrap()), + ]) + }; + Some(vec![(0.into(), out_const)]) + } + }, + ), + }, + IntOpDef::idivmod_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n, m): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + let nval = n.value_s(); + let mval = m.value_u(); + if n.log_width() != logwidth0 || m.log_width() != logwidth1 || mval == 0 { + None + } else { + let (qval, rval) = divmod_s(nval, mval); + let q = Value::extension(ConstInt::new_s(logwidth0, qval).unwrap()); + let r = Value::extension(ConstInt::new_u(logwidth1, rval).unwrap()); + Some(vec![(0.into(), q), (1.into(), r)]) + } + }, + ), + }, + IntOpDef::idiv_checked_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n, m): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n.log_width() != logwidth0 || m.log_width() != logwidth1 { + None + } else { + let int_out_type = INT_TYPES[logwidth0 as usize].to_owned(); + let sum_type = sum_with_error(int_out_type.clone()); + let err_value = || { + let err_val = ConstError { + signal: 0, + message: "Division by zero".to_string(), + }; + Value::sum(1, [err_val.into()], sum_type.clone()) + .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) + }; + let nval = n.value_u(); + let mval = m.value_u(); + let out_const: Value = if mval == 0 { + err_value() + } else { + Value::extension(ConstInt::new_u(logwidth0, nval / mval).unwrap()) + }; + Some(vec![(0.into(), out_const)]) + } + }, + ), + }, + IntOpDef::idiv_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n, m): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + let nval = n.value_u(); + let mval = m.value_u(); + if n.log_width() != logwidth0 || m.log_width() != logwidth1 || mval == 0 { + None + } else { + let q = Value::extension(ConstInt::new_u(logwidth0, nval / mval).unwrap()); + Some(vec![(0.into(), q)]) + } + }, + ), + }, + IntOpDef::imod_checked_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n, m): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n.log_width() != logwidth0 || m.log_width() != logwidth1 { + None + } else { + let int_out_type = INT_TYPES[logwidth1 as usize].to_owned(); + let sum_type = sum_with_error(int_out_type.clone()); + let err_value = || { + let err_val = ConstError { + signal: 0, + message: "Division by zero".to_string(), + }; + Value::sum(1, [err_val.into()], sum_type.clone()) + .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) + }; + let nval = n.value_u(); + let mval = m.value_u(); + let out_const: Value = if mval == 0 { + err_value() + } else { + Value::extension(ConstInt::new_u(logwidth1, nval % mval).unwrap()) + }; + Some(vec![(0.into(), out_const)]) + } + }, + ), + }, + IntOpDef::imod_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n, m): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + let nval = n.value_u(); + let mval = m.value_u(); + if n.log_width() != logwidth0 || m.log_width() != logwidth1 || mval == 0 { + None + } else { + let r = Value::extension(ConstInt::new_u(logwidth1, nval % mval).unwrap()); + Some(vec![(0.into(), r)]) + } + }, + ), + }, + IntOpDef::idiv_checked_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n, m): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n.log_width() != logwidth0 || m.log_width() != logwidth1 { + None + } else { + let int_out_type = INT_TYPES[logwidth0 as usize].to_owned(); + let sum_type = sum_with_error(int_out_type.clone()); + let err_value = || { + let err_val = ConstError { + signal: 0, + message: "Division by zero".to_string(), + }; + Value::sum(1, [err_val.into()], sum_type.clone()) + .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) + }; + let nval = n.value_s(); + let mval = m.value_u(); + let out_const: Value = if mval == 0 { + err_value() + } else { + let (qval, _) = divmod_s(nval, mval); + Value::extension(ConstInt::new_s(logwidth1, qval).unwrap()) + }; + Some(vec![(0.into(), out_const)]) + } + }, + ), + }, + IntOpDef::idiv_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n, m): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + let nval = n.value_s(); + let mval = m.value_u(); + if n.log_width() != logwidth0 || m.log_width() != logwidth1 || mval == 0 { + None + } else { + let (qval, _) = divmod_s(nval, mval); + let q = Value::extension(ConstInt::new_s(logwidth0, qval).unwrap()); + Some(vec![(0.into(), q)]) + } + }, + ), + }, + IntOpDef::imod_checked_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n, m): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n.log_width() != logwidth0 || m.log_width() != logwidth1 { + None + } else { + let int_out_type = INT_TYPES[logwidth1 as usize].to_owned(); + let sum_type = sum_with_error(int_out_type.clone()); + let err_value = || { + let err_val = ConstError { + signal: 0, + message: "Division by zero".to_string(), + }; + Value::sum(1, [err_val.into()], sum_type.clone()) + .unwrap_or_else(|e| panic!("Invalid computed sum, {}", e)) + }; + let nval = n.value_s(); + let mval = m.value_u(); + let out_const: Value = if mval == 0 { + err_value() + } else { + let (_, rval) = divmod_s(nval, mval); + Value::extension(ConstInt::new_u(logwidth1, rval).unwrap()) + }; + Some(vec![(0.into(), out_const)]) + } + }, + ), + }, + IntOpDef::imod_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n, m): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + let nval = n.value_s(); + let mval = m.value_u(); + if n.log_width() != logwidth0 || m.log_width() != logwidth1 || mval == 0 { + None + } else { + let (_, rval) = divmod_s(nval, mval); + let r = Value::extension(ConstInt::new_u(logwidth1, rval).unwrap()); + Some(vec![(0.into(), r)]) + } + }, + ), + }, + IntOpDef::iabs => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let n0: &ConstInt = get_single_input_value(consts)?; + let n0val = n0.value_s(); + if n0.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + if n0val == i64::MIN { + debug_assert!(logwidth == 6); + ConstInt::new_u(6, 1u64 << 63) + } else { + ConstInt::new_s(logwidth, n0val.abs()) + } + .unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::iand => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u(logwidth, n0.value_u() & n1.value_u()).unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::ior => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u(logwidth, n0.value_u() | n1.value_u()).unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::ixor => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u(logwidth, n0.value_u() ^ n1.value_u()).unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::inot => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let n0: &ConstInt = get_single_input_value(consts)?; + if n0.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u( + logwidth, + bitmask_from_logwidth(logwidth) & !n0.value_u(), + ) + .unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::ishl => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth0 || n1.log_width() != logwidth1 { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u( + logwidth0, + (n0.value_u() << n1.value_u()) + & bitmask_from_logwidth(logwidth0), + ) + .unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::ishr => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth0 || n1.log_width() != logwidth1 { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u(logwidth0, n0.value_u() >> n1.value_u()).unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::irotl => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth0 || n1.log_width() != logwidth1 { + None + } else { + let n = n0.value_u(); + let w = 1 << logwidth0; + let k = n1.value_u() % w; // equivalent rotation amount + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u( + logwidth0, + ((n << k) & bitmask_from_width(w)) | (n >> (w - k)), + ) + .unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::irotr => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg0, arg1] = type_args else { + return None; + }; + let logwidth0: u8 = get_log_width(arg0).ok()?; + let logwidth1: u8 = get_log_width(arg1).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth0 || n1.log_width() != logwidth1 { + None + } else { + let n = n0.value_u(); + let w = 1 << logwidth0; + let k = n1.value_u() % w; // equivalent rotation amount + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u( + logwidth0, + ((n << (w - k)) & bitmask_from_width(w)) | (n >> k), + ) + .unwrap(), + ), + )]) + } + }, + ), + }, + IntOpDef::itostring_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let n0: &ConstInt = get_single_input_value(consts)?; + if n0.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension(ConstString::new(n0.value_u().to_string())), + )]) + } + }, + ), + }, + IntOpDef::itostring_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let n0: &ConstInt = get_single_input_value(consts)?; + if n0.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension(ConstString::new(n0.value_s().to_string())), + )]) + } + }, + ), + }, + }); +} + +#[cfg(test)] +mod test; diff --git a/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs new file mode 100644 index 000000000..e7daf7000 --- /dev/null +++ b/hugr/src/std_extensions/arithmetic/int_ops/const_fold/test.rs @@ -0,0 +1,1287 @@ +use crate::algorithm::const_fold::constant_fold_pass; +use crate::builder::{DFGBuilder, Dataflow, DataflowHugr}; +use crate::extension::prelude::{sum_with_error, ConstError, ConstString, BOOL_T, STRING_TYPE}; +use crate::extension::{ExtensionRegistry, PRELUDE}; +use crate::ops::Value; +use crate::std_extensions::arithmetic; +use crate::std_extensions::arithmetic::int_ops::IntOpDef; +use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; +use crate::std_extensions::logic::{self, NaryLogic}; +use crate::type_row; +use crate::types::{FunctionType, Type, TypeRow}; +use crate::utils::test::assert_fully_folded; + +use rstest::rstest; + +#[test] +fn test_fold_iwiden_u() { + // pseudocode: + // + // x0 := int_u<4>(13); + // x1 := iwiden_u<4, 5>(x0); + // output x1 == int_u<5>(13); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(4, 13).unwrap())); + let x1 = build + .add_dataflow_op(IntOpDef::iwiden_u.with_two_log_widths(4, 5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 13).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_iwiden_s() { + // pseudocode: + // + // x0 := int_u<4>(-3); + // x1 := iwiden_u<4, 5>(x0); + // output x1 == int_s<5>(-3); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(4, -3).unwrap())); + let x1 = build + .add_dataflow_op(IntOpDef::iwiden_s.with_two_log_widths(4, 5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, -3).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_inarrow_u() { + // pseudocode: + // + // x0 := int_u<5>(13); + // x1 := inarrow_u<5, 4>(x0); + // output x1 == int_u<4>(13); + let sum_type = sum_with_error(INT_TYPES[4].to_owned()); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![sum_type.clone().into()], + )) + .unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 13).unwrap())); + let x1 = build + .add_dataflow_op(IntOpDef::inarrow_u.with_two_log_widths(5, 4), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(4, 13).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_inarrow_s() { + // pseudocode: + // + // x0 := int_s<5>(-3); + // x1 := inarrow_s<5, 4>(x0); + // output x1 == int_s<4>(-3); + let sum_type = sum_with_error(INT_TYPES[4].to_owned()); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![sum_type.clone().into()], + )) + .unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -3).unwrap())); + let x1 = build + .add_dataflow_op(IntOpDef::inarrow_s.with_two_log_widths(5, 4), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(4, -3).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_itobool() { + // pseudocode: + // + // x0 := int_u<0>(1); + // x1 := itobool(x0); + // output x1 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(0, 1).unwrap())); + let x1 = build + .add_dataflow_op(IntOpDef::itobool.without_log_width(), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ifrombool() { + // pseudocode: + // + // x0 := false + // x1 := ifrombool(x0); + // output x1 == int_u<0>(0); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[0].clone()])).unwrap(); + let x0 = build.add_load_const(Value::false_val()); + let x1 = build + .add_dataflow_op(IntOpDef::ifrombool.without_log_width(), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(0, 0).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ieq() { + // pseudocode: + // x0, x1 := int_s<3>(-1), int_u<3>(255) + // x2 := ieq(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(3, -1).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 255).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ieq.with_log_width(3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ine() { + // pseudocode: + // x0, x1 := int_u<5>(3), int_u<5>(4) + // x2 := ine(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ine.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ilt_u() { + // pseudocode: + // x0, x1 := int_u<5>(3), int_u<5>(4) + // x2 := ilt_u(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ilt_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ilt_s() { + // pseudocode: + // x0, x1 := int_s<5>(3), int_s<5>(-4) + // x2 := ilt_s(x0, x1) + // output x2 == false; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ilt_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::false_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_igt_u() { + // pseudocode: + // x0, x1 := int_u<5>(3), int_u<5>(4) + // x2 := ilt_u(x0, x1) + // output x2 == false; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::igt_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::false_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_igt_s() { + // pseudocode: + // x0, x1 := int_s<5>(3), int_s<5>(-4) + // x2 := ilt_s(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::igt_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ile_u() { + // pseudocode: + // x0, x1 := int_u<5>(3), int_u<5>(3) + // x2 := ile_u(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ile_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ile_s() { + // pseudocode: + // x0, x1 := int_s<5>(-4), int_s<5>(-4) + // x2 := ile_s(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ile_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ige_u() { + // pseudocode: + // x0, x1 := int_u<5>(3), int_u<5>(4) + // x2 := ilt_u(x0, x1) + // output x2 == false; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ige_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::false_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ige_s() { + // pseudocode: + // x0, x1 := int_s<5>(3), int_s<5>(-4) + // x2 := ilt_s(x0, x1) + // output x2 == true; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, -4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ige_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imax_u() { + // pseudocode: + // x0, x1 := int_u<5>(7), int_u<5>(11); + // x2 := imax_u(x0, x1); + // output x2 == int_u<5>(11); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 7).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 11).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imax_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 11).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imax_s() { + // pseudocode: + // x0, x1 := int_s<5>(-2), int_s<5>(1); + // x2 := imax_u(x0, x1); + // output x2 == int_s<5>(1); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imax_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, 1).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imin_u() { + // pseudocode: + // x0, x1 := int_u<5>(7), int_u<5>(11); + // x2 := imin_u(x0, x1); + // output x2 == int_u<5>(7); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 7).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 11).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imin_u.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 7).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imin_s() { + // pseudocode: + // x0, x1 := int_s<5>(-2), int_s<5>(1); + // x2 := imin_u(x0, x1); + // output x2 == int_s<5>(-2); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imin_s.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, -2).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_iadd() { + // pseudocode: + // x0, x1 := int_s<5>(-2), int_s<5>(1); + // x2 := iadd(x0, x1); + // output x2 == int_s<5>(-1); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::iadd.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, -1).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_isub() { + // pseudocode: + // x0, x1 := int_s<5>(-2), int_s<5>(1); + // x2 := isub(x0, x1); + // output x2 == int_s<5>(-3); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 1).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::isub.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, -3).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ineg() { + // pseudocode: + // x0 := int_s<5>(-2); + // x1 := ineg(x0); + // output x1 == int_s<5>(2); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ineg.with_log_width(5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, 2).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imul() { + // pseudocode: + // x0, x1 := int_s<5>(-2), int_s<5>(7); + // x2 := imul(x0, x1); + // output x2 == int_s<5>(-14); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_s(5, 7).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imul.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, -14).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idivmod_checked_u() { + // pseudocode: + // x0, x1 := int_u<5>(20), int_u<3>(0) + // x2 := idivmod_checked_u(x0, x1) + // output x2 == error + let intpair: TypeRow = vec![INT_TYPES[5].clone(), INT_TYPES[3].clone()].into(); + let sum_type = sum_with_error(Type::new_tuple(intpair)); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![sum_type.clone().into()], + )) + .unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 0).unwrap())); + let x2 = build + .add_dataflow_op( + IntOpDef::idivmod_checked_u.with_two_log_widths(5, 3), + [x0, x1], + ) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::sum( + 1, + [ConstError { + signal: 0, + message: "Division by zero".to_string(), + } + .into()], + sum_type.clone(), + ) + .unwrap(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idivmod_u() { + // pseudocode: + // x0, x1 := int_u<5>(20), int_u<3>(3); + // x2, x3 := idivmod_u(x0, x1); // 6, 2 + // x4 := iwiden_u<3,5>(x3); // 2 + // x5 := iadd<5>(x2, x4); // 8 + // output x5 == int_u<5>(8); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); + let [x2, x3] = build + .add_dataflow_op(IntOpDef::idivmod_u.with_two_log_widths(5, 3), [x0, x1]) + .unwrap() + .outputs_arr(); + let [x4] = build + .add_dataflow_op(IntOpDef::iwiden_u.with_two_log_widths(3, 5), [x3]) + .unwrap() + .outputs_arr(); + let x5 = build + .add_dataflow_op(IntOpDef::iadd.with_log_width(5), [x2, x4]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x5.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 8).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idivmod_checked_s() { + // pseudocode: + // x0, x1 := int_s<5>(-20), int_u<3>(0) + // x2 := idivmod_checked_s(x0, x1) + // output x2 == error + let intpair: TypeRow = vec![INT_TYPES[5].clone(), INT_TYPES[3].clone()].into(); + let sum_type = sum_with_error(Type::new_tuple(intpair)); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![sum_type.clone().into()], + )) + .unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 0).unwrap())); + let x2 = build + .add_dataflow_op( + IntOpDef::idivmod_checked_s.with_two_log_widths(5, 3), + [x0, x1], + ) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::sum( + 1, + [ConstError { + signal: 0, + message: "Division by zero".to_string(), + } + .into()], + sum_type.clone(), + ) + .unwrap(); + assert_fully_folded(&h, &expected); +} + +#[rstest] +#[case(20, 3, 8)] +#[case(-20, 3, -6)] +#[case(-20, 4, -5)] +#[case(i64::MIN, 1, i64::MIN)] +#[case(i64::MIN, 2, -(1i64 << 62))] +#[case(i64::MIN, 1u64 << 63, -1)] +// c = a/b + a%b +fn test_fold_idivmod_s(#[case] a: i64, #[case] b: u64, #[case] c: i64) { + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[6].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(6, a).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(6, b).unwrap())); + let [x2, x3] = build + .add_dataflow_op(IntOpDef::idivmod_s.with_two_log_widths(6, 6), [x0, x1]) + .unwrap() + .outputs_arr(); + let x4 = build + .add_dataflow_op(IntOpDef::iadd.with_log_width(6), [x2, x3]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x4.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(6, c).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idiv_checked_u() { + // pseudocode: + // x0, x1 := int_u<5>(20), int_u<3>(0) + // x2 := idiv_checked_u(x0, x1) + // output x2 == error + let sum_type = sum_with_error(INT_TYPES[5].to_owned()); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![sum_type.clone().into()], + )) + .unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 0).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::idiv_checked_u.with_two_log_widths(5, 3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::sum( + 1, + [ConstError { + signal: 0, + message: "Division by zero".to_string(), + } + .into()], + sum_type.clone(), + ) + .unwrap(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idiv_u() { + // pseudocode: + // x0, x1 := int_u<5>(20), int_u<3>(3); + // x2 := idiv_u(x0, x1); + // output x2 == int_u<5>(6); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::idiv_u.with_two_log_widths(5, 3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 6).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imod_checked_u() { + // pseudocode: + // x0, x1 := int_u<5>(20), int_u<3>(0) + // x2 := imod_checked_u(x0, x1) + // output x2 == error + let sum_type = sum_with_error(INT_TYPES[3].to_owned()); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![sum_type.clone().into()], + )) + .unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 0).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imod_checked_u.with_two_log_widths(5, 3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::sum( + 1, + [ConstError { + signal: 0, + message: "Division by zero".to_string(), + } + .into()], + sum_type.clone(), + ) + .unwrap(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imod_u() { + // pseudocode: + // x0, x1 := int_u<5>(20), int_u<3>(3); + // x2 := imod_u(x0, x1); + // output x2 == int_u<3>(2); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[3].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imod_u.with_two_log_widths(5, 3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(3, 2).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idiv_checked_s() { + // pseudocode: + // x0, x1 := int_s<5>(-20), int_u<3>(0) + // x2 := idiv_checked_s(x0, x1) + // output x2 == error + let sum_type = sum_with_error(INT_TYPES[5].to_owned()); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![sum_type.clone().into()], + )) + .unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 0).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::idiv_checked_s.with_two_log_widths(5, 3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::sum( + 1, + [ConstError { + signal: 0, + message: "Division by zero".to_string(), + } + .into()], + sum_type.clone(), + ) + .unwrap(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_idiv_s() { + // pseudocode: + // x0, x1 := int_s<5>(-20), int_u<3>(3); + // x2 := idiv_s(x0, x1); + // output x2 == int_s<5>(-7); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::idiv_s.with_two_log_widths(5, 3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_s(5, -7).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imod_checked_s() { + // pseudocode: + // x0, x1 := int_s<5>(-20), int_u<3>(0) + // x2 := imod_checked_u(x0, x1) + // output x2 == error + let sum_type = sum_with_error(INT_TYPES[3].to_owned()); + let mut build = DFGBuilder::new(FunctionType::new( + type_row![], + vec![sum_type.clone().into()], + )) + .unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 0).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imod_checked_s.with_two_log_widths(5, 3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::sum( + 1, + [ConstError { + signal: 0, + message: "Division by zero".to_string(), + } + .into()], + sum_type.clone(), + ) + .unwrap(); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_imod_s() { + // pseudocode: + // x0, x1 := int_s<5>(-20), int_u<3>(3); + // x2 := imod_s(x0, x1); + // output x2 == int_u<3>(1); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[3].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -20).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::imod_s.with_two_log_widths(5, 3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(3, 1).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_iabs() { + // pseudocode: + // x0 := int_s<5>(-2); + // x1 := iabs(x0); + // output x1 == int_s<5>(2); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -2).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::iabs.with_log_width(5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 2).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_iand() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<5>(20); + // x2 := iand(x0, x1); + // output x2 == int_u<5>(4); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::iand.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 4).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ior() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<5>(20); + // x2 := ior(x0, x1); + // output x2 == int_u<5>(30); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ior.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 30).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ixor() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<5>(20); + // x2 := ixor(x0, x1); + // output x2 == int_u<5>(26); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 20).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ixor.with_log_width(5), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 26).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_inot() { + // pseudocode: + // x0 := int_u<5>(14); + // x1 := inot(x0); + // output x1 == int_u<5>(17); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 14).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::inot.with_log_width(5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, (1u64 << 32) - 15).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ishl() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<3>(3); + // x2 := ishl(x0, x1); + // output x2 == int_u<5>(112); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ishl.with_two_log_widths(5, 3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 112).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_ishr() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<3>(3); + // x2 := ishr(x0, x1); + // output x2 == int_u<5>(1); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ishr.with_two_log_widths(5, 3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 1).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_irotl() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<3>(61); + // x2 := irotl(x0, x1); + // output x2 == int_u<5>(2^30 + 2^31 + 1); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 61).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::irotl.with_two_log_widths(5, 3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_irotr() { + // pseudocode: + // x0, x1 := int_u<5>(14), int_u<3>(3); + // x2 := irotr(x0, x1); + // output x2 == int_u<5>(2^30 + 2^31 + 1); + let mut build = + DFGBuilder::new(FunctionType::new(type_row![], vec![INT_TYPES[5].clone()])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, 14).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(3, 3).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::irotr.with_two_log_widths(5, 3), [x0, x1]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstInt::new_u(5, 3 * (1u64 << 30) + 1).unwrap()); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_itostring_u() { + // pseudocode: + // x0 := int_u<5>(17); + // x1 := itostring_u(x0); + // output x2 := "17"; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![STRING_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 17).unwrap())); + let x1 = build + .add_dataflow_op(IntOpDef::itostring_u.with_log_width(5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstString::new("17".into())); + assert_fully_folded(&h, &expected); +} + +#[test] +fn test_fold_itostring_s() { + // pseudocode: + // x0 := int_s<5>(-17); + // x1 := itostring_s(x0); + // output x2 := "-17"; + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![STRING_TYPE])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_s(5, -17).unwrap())); + let x1 = build + .add_dataflow_op(IntOpDef::itostring_s.with_log_width(5), [x0]) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::extension(ConstString::new("-17".into())); + assert_fully_folded(&h, &expected); +} + +#[test] +#[should_panic] +// FIXME: https://github.com/CQCL/hugr/issues/996 +fn test_fold_int_ops() { + // pseudocode: + // + // x0 := int_u<5>(3); // 3 + // x1 := int_u<5>(4); // 4 + // x2 := ine(x0, x1); // true + // x3 := ilt_u(x0, x1); // true + // x4 := and(x2, x3); // true + // x5 := int_s<5>(-10) // -10 + // x6 := ilt_s(x0, x5) // false + // x7 := or(x4, x6) // true + // output x7 + let mut build = DFGBuilder::new(FunctionType::new(type_row![], vec![BOOL_T])).unwrap(); + let x0 = build.add_load_const(Value::extension(ConstInt::new_u(5, 3).unwrap())); + let x1 = build.add_load_const(Value::extension(ConstInt::new_u(5, 4).unwrap())); + let x2 = build + .add_dataflow_op(IntOpDef::ine.with_log_width(5), [x0, x1]) + .unwrap(); + let x3 = build + .add_dataflow_op(IntOpDef::ilt_u.with_log_width(5), [x0, x1]) + .unwrap(); + let x4 = build + .add_dataflow_op( + NaryLogic::And.with_n_inputs(2), + x2.outputs().chain(x3.outputs()), + ) + .unwrap(); + let x5 = build.add_load_const(Value::extension(ConstInt::new_s(5, -10).unwrap())); + let x6 = build + .add_dataflow_op(IntOpDef::ilt_s.with_log_width(5), [x0, x5]) + .unwrap(); + let x7 = build + .add_dataflow_op( + NaryLogic::Or.with_n_inputs(2), + x4.outputs().chain(x6.outputs()), + ) + .unwrap(); + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + logic::EXTENSION.to_owned(), + arithmetic::int_types::EXTENSION.to_owned(), + ]) + .unwrap(); + let mut h = build.finish_hugr_with_outputs(x7.outputs(), ®).unwrap(); + constant_fold_pass(&mut h, ®); + let expected = Value::true_val(); + assert_fully_folded(&h, &expected); +} diff --git a/hugr/src/std_extensions/arithmetic/int_types.rs b/hugr/src/std_extensions/arithmetic/int_types.rs index 89aa66e6b..7b92f82a6 100644 --- a/hugr/src/std_extensions/arithmetic/int_types.rs +++ b/hugr/src/std_extensions/arithmetic/int_types.rs @@ -20,15 +20,15 @@ pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int /// Identifier for the integer type. pub const INT_TYPE_ID: TypeName = TypeName::new_inline("int"); -pub(crate) fn int_custom_type(width_arg: TypeArg) -> CustomType { - CustomType::new(INT_TYPE_ID, [width_arg], EXTENSION_ID, TypeBound::Eq) +pub(crate) fn int_custom_type(width_arg: impl Into) -> CustomType { + CustomType::new(INT_TYPE_ID, [width_arg.into()], EXTENSION_ID, TypeBound::Eq) } /// Integer type of a given bit width (specified by the TypeArg). /// Depending on the operation, the semantic interpretation may be unsigned integer, signed integer /// or bit string. -pub(super) fn int_type(width_arg: TypeArg) -> Type { - Type::new_extension(int_custom_type(width_arg)) +pub(super) fn int_type(width_arg: impl Into) -> Type { + Type::new_extension(int_custom_type(width_arg.into())) } lazy_static! { diff --git a/hugr/src/std_extensions/logic.rs b/hugr/src/std_extensions/logic.rs index 86de41b28..50e6921ce 100644 --- a/hugr/src/std_extensions/logic.rs +++ b/hugr/src/std_extensions/logic.rs @@ -120,6 +120,17 @@ impl MakeOpDef for NotOp { fn description(&self) -> String { "logical 'not'".into() } + + fn post_opdef(&self, def: &mut OpDef) { + def.set_constant_folder(|consts: &_| { + let inps = read_inputs(consts)?; + if inps.len() != 1 { + None + } else { + Some(vec![(0.into(), ops::Value::from_bool(!inps[0]))]) + } + }) + } } /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("logic"); diff --git a/hugr/src/types/type_param.rs b/hugr/src/types/type_param.rs index 66053a7dc..3e57fcab7 100644 --- a/hugr/src/types/type_param.rs +++ b/hugr/src/types/type_param.rs @@ -130,12 +130,6 @@ impl From for TypeParam { } } -impl From for TypeArg { - fn from(ty: Type) -> Self { - Self::Type { ty } - } -} - /// A statically-known argument value to an operation. #[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[non_exhaustive] @@ -177,6 +171,36 @@ pub enum TypeArg { }, } +impl From for TypeArg { + fn from(ty: Type) -> Self { + Self::Type { ty } + } +} + +impl From for TypeArg { + fn from(n: u64) -> Self { + Self::BoundedNat { n } + } +} + +impl From for TypeArg { + fn from(arg: CustomTypeArg) -> Self { + Self::Opaque { arg } + } +} + +impl From> for TypeArg { + fn from(elems: Vec) -> Self { + Self::Sequence { elems } + } +} + +impl From for TypeArg { + fn from(es: ExtensionSet) -> Self { + Self::Extensions { es } + } +} + /// Variable in a TypeArg, that is not a [TypeArg::Type] or [TypeArg::Extensions], #[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct TypeArgVariable { diff --git a/hugr/src/utils.rs b/hugr/src/utils.rs index ef5a778a9..15c2a0386 100644 --- a/hugr/src/utils.rs +++ b/hugr/src/utils.rs @@ -211,6 +211,10 @@ pub(crate) mod test_quantum_extension { pub(crate) mod test { #[allow(unused_imports)] use crate::HugrView; + use crate::{ + ops::{OpType, Value}, + Hugr, + }; /// Open a browser page to render a dot string graph. /// @@ -227,4 +231,20 @@ pub(crate) mod test { pub(crate) fn viz_hugr(hugr: &impl HugrView) { viz_dotstr(hugr.dot_string()); } + + /// Check that a hugr just loads and returns a single expected constant. + pub(crate) fn assert_fully_folded(h: &Hugr, expected_value: &Value) { + let mut node_count = 0; + + for node in h.children(h.root()) { + let op = h.get_optype(node); + match op { + OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, + OpType::Const(c) if c.value() == expected_value => node_count += 1, + _ => panic!("unexpected op: {:?}", op), + } + } + + assert_eq!(node_count, 4); + } }