From f7a4cf7614381cc17a66dec1c6a214e786978e99 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 3 Jan 2024 11:09:38 +0000 Subject: [PATCH] Revert "feat: constant folding implemented for core and float extension (#758)" This reverts commit 664fe89bce99ff5d02678c8ad0c57f4e4e1d3fb9. --- .github/workflows/ci.yml | 4 +- Cargo.toml | 4 +- README.md | 2 +- src/algorithm.rs | 1 - src/algorithm/const_fold.rs | 302 ------------------ src/extension.rs | 2 - src/extension/const_fold.rs | 53 --- src/extension/op_def.rs | 23 +- src/extension/prelude.rs | 66 +--- src/ops/custom.rs | 9 +- src/std_extensions/arithmetic/conversions.rs | 12 +- src/std_extensions/arithmetic/float_ops.rs | 6 +- .../arithmetic/float_ops/fold.rs | 124 ------- src/std_extensions/arithmetic/float_types.rs | 2 +- src/std_extensions/arithmetic/int_ops.rs | 1 - src/values.rs | 1 - 16 files changed, 20 insertions(+), 592 deletions(-) delete mode 100644 src/algorithm/const_fold.rs delete mode 100644 src/extension/const_fold.rs delete mode 100644 src/std_extensions/arithmetic/float_ops/fold.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5514dedee..f895bc87c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,13 +56,13 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - rust: ['1.75', stable, beta, nightly] + rust: ['1.70', stable, beta, nightly] # workaround to ignore non-stable tests when running the merge queue checks # see: https://github.community/t/how-to-conditionally-include-exclude-items-in-matrix-eg-based-on-branch/16853/6 isMerge: - ${{ github.event_name == 'merge_group' }} exclude: - - rust: '1.75' + - rust: '1.70' isMerge: true - rust: beta isMerge: true diff --git a/Cargo.toml b/Cargo.toml index ea4861062..b1acc0bd4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ description = "Hierarchical Unified Graph Representation" #categories = [] # TODO edition = "2021" -rust-version = "1.75" +rust-version = "1.70" [lib] # Using different names for the lib and for the package is supported, but may be confusing. @@ -46,7 +46,7 @@ lazy_static = "1.4.0" petgraph = { version = "0.6.3", default-features = false } context-iterators = "0.2.0" serde_json = "1.0.97" -delegate = "0.12.0" +delegate = "0.11.0" rustversion = "1.0.14" paste = "1.0" strum = "0.25.0" diff --git a/README.md b/README.md index 14adf32ed..91ab2c617 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,6 @@ See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up the developm This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). [build_status]: https://github.com/CQCL/hugr/workflows/Continuous%20integration/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/badge/rust-1.70.0%2B-blue.svg [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: LICENCE diff --git a/src/algorithm.rs b/src/algorithm.rs index 633231504..0023b5916 100644 --- a/src/algorithm.rs +++ b/src/algorithm.rs @@ -1,5 +1,4 @@ //! Algorithms using the Hugr. -pub mod const_fold; mod half_node; pub mod nest_cfgs; diff --git a/src/algorithm/const_fold.rs b/src/algorithm/const_fold.rs deleted file mode 100644 index 49474c43e..000000000 --- a/src/algorithm/const_fold.rs +++ /dev/null @@ -1,302 +0,0 @@ -//! Constant folding routines. - -use std::collections::{BTreeSet, HashMap}; - -use itertools::Itertools; - -use crate::{ - builder::{DFGBuilder, Dataflow, DataflowHugr}, - extension::{ConstFoldResult, ExtensionRegistry}, - hugr::{ - rewrite::consts::{RemoveConst, RemoveConstIgnore}, - views::SiblingSubgraph, - HugrMut, - }, - ops::{Const, LeafOp, OpType}, - type_row, - types::{FunctionType, Type, TypeEnum}, - values::Value, - Hugr, HugrView, IncomingPort, Node, SimpleReplacement, -}; - -/// Tag some output constants with [`OutgoingPort`] inferred from the ordering. -fn out_row(consts: impl IntoIterator) -> ConstFoldResult { - let vec = consts - .into_iter() - .enumerate() - .map(|(i, c)| (i.into(), c)) - .collect(); - Some(vec) -} - -/// Sort folding inputs with [`IncomingPort`] as key -fn sort_by_in_port(consts: &[(IncomingPort, Const)]) -> Vec<&(IncomingPort, Const)> { - let mut v: Vec<_> = consts.iter().collect(); - v.sort_by_key(|(i, _)| i); - v -} - -/// Sort some input constants by port and just return the constants. -pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> { - sort_by_in_port(consts) - .into_iter() - .map(|(_, c)| c) - .collect() -} -/// For a given op and consts, attempt to evaluate the op. -pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult { - let op = op.as_leaf_op()?; - - match op { - LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]), - LeafOp::MakeTuple { .. } => { - out_row([Const::new_tuple(sorted_consts(consts).into_iter().cloned())]) - } - LeafOp::UnpackTuple { .. } => { - let c = &consts.first()?.1; - - if let Value::Tuple { vs } = c.value() { - if let TypeEnum::Tuple(tys) = c.const_type().as_type_enum() { - return out_row(tys.iter().zip(vs.iter()).map(|(t, v)| { - Const::new(v.clone(), t.clone()) - .expect("types should already have been checked") - })); - } - } - None // could panic - } - - LeafOp::Tag { tag, variants } => out_row([Const::new( - Value::sum(*tag, consts.first()?.1.value().clone()), - Type::new_sum(variants.clone()), - ) - .unwrap()]), - LeafOp::CustomOp(_) => { - let ext_op = op.as_extension_op()?; - - ext_op.constant_fold(consts) - } - _ => None, - } -} - -/// Generate a graph that loads and outputs `consts` in order, validating -/// against `reg`. -fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { - let const_types = consts.iter().map(Const::const_type).cloned().collect_vec(); - let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap(); - - let outputs = consts - .into_iter() - .map(|c| b.add_load_const(c).unwrap()) - .collect_vec(); - - b.finish_hugr_with_outputs(outputs, reg).unwrap() -} - -/// Given some `candidate_nodes` to search for LoadConstant operations in `hugr`, -/// return an iterator of possible constant folding rewrites. The -/// [`SimpleReplacement`] replaces an operation with constants that result from -/// evaluating it, the extension registry `reg` is used to validate the -/// replacement HUGR. The vector of [`RemoveConstIgnore`] refer to the -/// LoadConstant nodes that could be removed. -pub fn find_consts<'a, 'r: 'a>( - hugr: &'a impl HugrView, - candidate_nodes: impl IntoIterator + 'a, - reg: &'r ExtensionRegistry, -) -> impl Iterator)> + 'a { - // track nodes for operations that have already been considered for folding - let mut used_neighbours = BTreeSet::new(); - - candidate_nodes - .into_iter() - .filter_map(move |n| { - // only look at LoadConstant - hugr.get_optype(n).is_load_constant().then_some(())?; - - let (out_p, _) = hugr.out_value_types(n).exactly_one().ok()?; - let neighbours = hugr - .linked_inputs(n, out_p) - .filter(|(n, _)| used_neighbours.insert(*n)) - .collect_vec(); - if neighbours.is_empty() { - // no uses of LoadConstant that haven't already been considered. - return None; - } - let fold_iter = neighbours - .into_iter() - .filter_map(|(neighbour, _)| fold_op(hugr, neighbour, reg)); - Some(fold_iter) - }) - .flatten() -} - -/// Attempt to evaluate and generate rewrites for the operation at `op_node` -fn fold_op( - hugr: &impl HugrView, - op_node: Node, - reg: &ExtensionRegistry, -) -> Option<(SimpleReplacement, Vec)> { - let (in_consts, removals): (Vec<_>, Vec<_>) = hugr - .node_inputs(op_node) - .filter_map(|in_p| { - let (con_op, load_n) = get_const(hugr, op_node, in_p)?; - Some(((in_p, con_op), RemoveConstIgnore(load_n))) - }) - .unzip(); - let neighbour_op = hugr.get_optype(op_node); - // attempt to evaluate op - let folded = fold_const(neighbour_op, &in_consts)?; - let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip(); - let nu_out = op_outs - .into_iter() - .enumerate() - .filter_map(|(i, out)| { - // map from the ports the op was linked to, to the output ports of - // the replacement. - hugr.single_linked_input(op_node, out) - .map(|np| (np, i.into())) - }) - .collect(); - let replacement = const_graph(consts, reg); - let sibling_graph = SiblingSubgraph::try_from_nodes([op_node], hugr) - .expect("Operation should form valid subgraph."); - - let simple_replace = SimpleReplacement::new( - sibling_graph, - replacement, - // no inputs to replacement - HashMap::new(), - nu_out, - ); - Some((simple_replace, removals)) -} - -/// If `op_node` is connected to a LoadConstant at `in_p`, return the constant -/// and the LoadConstant node -fn get_const(hugr: &impl HugrView, op_node: Node, in_p: IncomingPort) -> Option<(Const, Node)> { - let (load_n, _) = hugr.single_linked_output(op_node, in_p)?; - let load_op = hugr.get_optype(load_n).as_load_constant()?; - let const_node = hugr - .linked_outputs(load_n, load_op.constant_port()) - .exactly_one() - .ok()? - .0; - - let const_op = hugr.get_optype(const_node).as_const()?; - - // TODO avoid const clone here - Some((const_op.clone(), load_n)) -} - -/// Exhaustively apply constant folding to a HUGR. -pub fn constant_fold_pass(h: &mut impl HugrMut, reg: &ExtensionRegistry) { - loop { - // would be preferable if the candidates were updated to be just the - // neighbouring nodes of those added. - let rewrites = find_consts(h, h.nodes(), reg).collect_vec(); - if rewrites.is_empty() { - break; - } - for (replace, removes) in rewrites { - h.apply_rewrite(replace).unwrap(); - for rem in removes { - if let Ok(const_node) = h.apply_rewrite(rem) { - // if the LoadConst was removed, try removing the Const too. - if h.apply_rewrite(RemoveConst(const_node)).is_err() { - // const cannot be removed - no problem - continue; - } - } - } - } - } -} - -#[cfg(test)] -mod test { - - use crate::extension::{ExtensionRegistry, PRELUDE}; - use crate::std_extensions::arithmetic; - - use crate::std_extensions::arithmetic::float_ops::FloatOps; - use crate::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; - - use rstest::rstest; - - use super::*; - - /// float to constant - fn f2c(f: f64) -> Const { - ConstF64::new(f).into() - } - - #[rstest] - #[case(0.0, 0.0, 0.0)] - #[case(0.0, 1.0, 1.0)] - #[case(23.5, 435.5, 459.0)] - // c = a + b - fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { - let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))]; - let add_op: OpType = FloatOps::fadd.into(); - let out = fold_const(&add_op, &consts).unwrap(); - - assert_eq!(&out[..], &[(0.into(), f2c(c))]); - } - - #[test] - fn test_big() { - /* - Test hugr approximately calculates - let x = (5.5, 3.25); - x.0 - x.1 == 2.25 - */ - let mut build = - DFGBuilder::new(FunctionType::new(type_row![], type_row![FLOAT64_TYPE])).unwrap(); - - let tup = build - .add_load_const(Const::new_tuple([f2c(5.5), f2c(3.25)])) - .unwrap(); - - let unpack = build - .add_dataflow_op( - LeafOp::UnpackTuple { - tys: type_row![FLOAT64_TYPE, FLOAT64_TYPE], - }, - [tup], - ) - .unwrap(); - - let sub = build - .add_dataflow_op(FloatOps::fsub, unpack.outputs()) - .unwrap(); - - let reg = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - arithmetic::float_ops::EXTENSION.to_owned(), - ]) - .unwrap(); - let mut h = build.finish_hugr_with_outputs(sub.outputs(), ®).unwrap(); - assert_eq!(h.node_count(), 7); - - constant_fold_pass(&mut h, ®); - - assert_fully_folded(&h, &f2c(2.25)); - } - fn assert_fully_folded(h: &Hugr, expected_const: &Const) { - // check the hugr just loads and returns a single const - let mut node_count = 0; - - for node in h.children(h.root()) { - let op = h.get_optype(node); - match op { - OpType::Input(_) | OpType::Output(_) | OpType::LoadConstant(_) => node_count += 1, - OpType::Const(c) if c == expected_const => node_count += 1, - _ => panic!("unexpected op: {:?}", op), - } - } - - assert_eq!(node_count, 4); - } -} diff --git a/src/extension.rs b/src/extension.rs index 5519e456b..95b0474ea 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -28,11 +28,9 @@ pub use op_def::{ }; mod type_def; pub use type_def::{TypeDef, TypeDefBound}; -mod const_fold; pub mod prelude; pub mod simple_op; pub mod validate; -pub use const_fold::{ConstFold, ConstFoldResult}; pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; /// Extension Registries store extensions to be looked up e.g. during validation. diff --git a/src/extension/const_fold.rs b/src/extension/const_fold.rs deleted file mode 100644 index 29cf4b5a9..000000000 --- a/src/extension/const_fold.rs +++ /dev/null @@ -1,53 +0,0 @@ -use std::fmt::Formatter; - -use std::fmt::Debug; - -use crate::types::TypeArg; - -use crate::OutgoingPort; - -use crate::ops; - -/// Output of constant folding an operation, None indicates folding was either -/// not possible or unsuccessful. An empty vector indicates folding was -/// successful and no values are output. -pub type ConstFoldResult = Option>; - -/// Trait implemented by extension operations that can perform constant folding. -pub trait ConstFold: Send + Sync { - /// Given type arguments `type_args` and - /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s, - /// try to evaluate the operation. - fn fold( - &self, - type_args: &[TypeArg], - consts: &[(crate::IncomingPort, crate::ops::Const)], - ) -> ConstFoldResult; -} - -impl Debug for Box { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "") - } -} - -impl Default for Box { - fn default() -> Self { - Box::new(|&_: &_| None) - } -} - -/// Blanket implementation for functions that only require the constants to -/// evaluate - type arguments are not relevant. -impl ConstFold for T -where - T: Fn(&[(crate::IncomingPort, crate::ops::Const)]) -> ConstFoldResult + Send + Sync, -{ - fn fold( - &self, - _type_args: &[TypeArg], - consts: &[(crate::IncomingPort, crate::ops::Const)], - ) -> ConstFoldResult { - self(consts) - } -} diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 2ea686ab4..143426123 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -7,8 +7,7 @@ use std::sync::Arc; use smol_str::SmolStr; use super::{ - ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, - ExtensionSet, SignatureError, + Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, }; use crate::types::type_param::{check_type_args, TypeArg, TypeParam}; @@ -308,9 +307,6 @@ pub struct OpDef { // can only treat them as opaque/black-box ops. #[serde(flatten)] lower_funcs: Vec, - - #[serde(skip)] - constant_folder: Box, } impl OpDef { @@ -416,22 +412,6 @@ impl OpDef { ) -> Option { self.misc.insert(k.to_string(), v) } - - /// Set the constant folding function for this Op, which can evaluate it - /// given constant inputs. - pub fn set_constant_folder(&mut self, fold: impl ConstFold + 'static) { - self.constant_folder = Box::new(fold) - } - - /// Evaluate an instance of this [`OpDef`] defined by the `type_args`, given - /// [`crate::ops::Const`] values for inputs at [`crate::IncomingPort`]s. - pub fn constant_fold( - &self, - type_args: &[TypeArg], - consts: &[(crate::IncomingPort, crate::ops::Const)], - ) -> ConstFoldResult { - self.constant_folder.fold(type_args, consts) - } } impl Extension { @@ -452,7 +432,6 @@ impl Extension { signature_func: signature_func.into(), misc: Default::default(), lower_funcs: Default::default(), - constant_folder: Default::default(), }; match self.operations.entry(op.name.clone()) { diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index b411c3667..f96046ba8 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -137,11 +137,12 @@ pub fn new_array_op(element_ty: Type, size: u64) -> LeafOp { .into() } -/// The custom type for Errors. -pub const ERROR_CUSTOM_TYPE: CustomType = - CustomType::new_simple(ERROR_TYPE_NAME, PRELUDE_ID, TypeBound::Eq); /// Unspecified opaque error type. -pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE); +pub const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple( + ERROR_TYPE_NAME, + PRELUDE_ID, + TypeBound::Eq, +)); /// The string name of the error type. pub const ERROR_TYPE_NAME: SmolStr = SmolStr::new_inline("error"); @@ -190,48 +191,6 @@ impl KnownTypeConst for ConstUsize { const TYPE: CustomType = USIZE_CUSTOM_T; } -#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] -/// Structure for holding constant usize values. -pub struct ConstError { - /// Integer tag/signal for the error. - pub signal: u32, - /// Error message. - pub message: String, -} - -impl ConstError { - /// Define a new error value. - pub fn new(signal: u32, message: impl ToString) -> Self { - Self { - signal, - message: message.to_string(), - } - } -} - -#[typetag::serde] -impl CustomConst for ConstError { - fn name(&self) -> SmolStr { - format!("ConstError({:?}, {:?})", self.signal, self.message).into() - } - - fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { - self.check_known_type(typ) - } - - fn equal_consts(&self, other: &dyn CustomConst) -> bool { - crate::values::downcast_equal_consts(self, other) - } - - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(&PRELUDE_ID) - } -} - -impl KnownTypeConst for ConstError { - const TYPE: CustomType = ERROR_CUSTOM_TYPE; -} - #[cfg(test)] mod test { use crate::{ @@ -260,7 +219,7 @@ mod test { } #[test] - /// test the prelude error type. + /// Test building a HUGR involving a new_array operation. fn test_error_type() { let ext_def = PRELUDE .get_type(&ERROR_TYPE_NAME) @@ -270,18 +229,5 @@ mod test { let ext_type = Type::new_extension(ext_def); assert_eq!(ext_type, ERROR_TYPE); - - let error_val = ConstError::new(2, "my message"); - - assert_eq!(error_val.name(), "ConstError(2, \"my message\")"); - - assert!(error_val.check_custom_type(&ERROR_CUSTOM_TYPE).is_ok()); - - assert_eq!( - error_val.extension_reqs(), - ExtensionSet::singleton(&PRELUDE_ID) - ); - assert!(error_val.equal_consts(&ConstError::new(2, "my message"))); - assert!(!error_val.equal_consts(&ConstError::new(3, "my message"))); } } diff --git a/src/ops/custom.rs b/src/ops/custom.rs index d4e4c88a6..9d2c0ab00 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -4,11 +4,11 @@ use smol_str::SmolStr; use std::sync::Arc; use thiserror::Error; -use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}; +use crate::extension::{ExtensionId, ExtensionRegistry, OpDef, SignatureError}; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrView, NodeType}; use crate::types::{type_param::TypeArg, FunctionType}; -use crate::{ops, Hugr, IncomingPort, Node}; +use crate::{Hugr, Node}; use super::dataflow::DataflowOpTrait; use super::tag::OpTag; @@ -137,11 +137,6 @@ impl ExtensionOp { pub fn def(&self) -> &OpDef { self.def.as_ref() } - - /// Attempt to evaluate this operation. See [`OpDef::constant_fold`]. - pub fn constant_fold(&self, consts: &[(IncomingPort, ops::Const)]) -> ConstFoldResult { - self.def().constant_fold(self.args(), consts) - } } impl From for OpaqueOp { diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 23b457f7c..98e5df887 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -8,7 +8,6 @@ use crate::{ prelude::sum_with_error, simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}, ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc, - PRELUDE, }, ops::{custom::ExtensionOp, OpName}, type_row, @@ -19,6 +18,7 @@ use crate::{ use super::int_types::int_tv; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; use lazy_static::lazy_static; + /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); @@ -69,7 +69,7 @@ impl MakeOpDef for ConvertOpDef { #[derive(Debug, Clone, PartialEq)] pub struct ConvertOpType { def: ConvertOpDef, - log_width: u64, + width: u64, } impl OpName for ConvertOpType { @@ -85,14 +85,11 @@ impl MakeExtensionOp for ConvertOpType { [TypeArg::BoundedNat { n }] => n, _ => return Err(SignatureError::InvalidTypeArgs.into()), }; - Ok(Self { - def, - log_width: width, - }) + Ok(Self { def, width }) } fn type_args(&self) -> Vec { - vec![TypeArg::BoundedNat { n: self.log_width }] + vec![TypeArg::BoundedNat { n: self.width }] } } @@ -114,7 +111,6 @@ lazy_static! { /// Registry of extensions required to validate integer operations. pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ - PRELUDE.to_owned(), super::int_types::EXTENSION.to_owned(), super::float_types::EXTENSION.to_owned(), EXTENSION.to_owned(), diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 672b8a1ed..87c87751b 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -14,7 +14,7 @@ use crate::{ Extension, }; use lazy_static::lazy_static; -mod fold; + /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float"); @@ -82,10 +82,6 @@ impl MakeOpDef for FloatOps { } .to_string() } - - fn post_opdef(&self, def: &mut OpDef) { - fold::set_fold(self, def) - } } lazy_static! { diff --git a/src/std_extensions/arithmetic/float_ops/fold.rs b/src/std_extensions/arithmetic/float_ops/fold.rs deleted file mode 100644 index 34d162f4d..000000000 --- a/src/std_extensions/arithmetic/float_ops/fold.rs +++ /dev/null @@ -1,124 +0,0 @@ -use crate::{ - algorithm::const_fold::sorted_consts, - extension::{ConstFold, ConstFoldResult, OpDef}, - ops, - std_extensions::arithmetic::float_types::ConstF64, - IncomingPort, -}; - -use super::FloatOps; - -pub(super) fn set_fold(op: &FloatOps, def: &mut OpDef) { - use FloatOps::*; - - match op { - fmax | fmin | fadd | fsub | fmul | fdiv => def.set_constant_folder(BinaryFold::from_op(op)), - feq | fne | flt | fgt | fle | fge => def.set_constant_folder(CmpFold::from_op(*op)), - fneg | fabs | ffloor | fceil => def.set_constant_folder(UnaryFold::from_op(op)), - } -} - -/// Extract float values from constants in port order. -fn get_floats(consts: &[(IncomingPort, ops::Const)]) -> Option<[f64; N]> { - let consts: [&ops::Const; N] = sorted_consts(consts).try_into().ok()?; - - Some(consts.map(|c| { - let const_f64: &ConstF64 = c - .get_custom_value() - .expect("This function assumes all incoming constants are floats."); - const_f64.value() - })) -} - -/// Fold binary operations -struct BinaryFold(Box f64 + Send + Sync>); -impl BinaryFold { - fn from_op(op: &FloatOps) -> Self { - use FloatOps::*; - Self(Box::new(match op { - fmax => f64::max, - fmin => f64::min, - fadd => std::ops::Add::add, - fsub => std::ops::Sub::sub, - fmul => std::ops::Mul::mul, - fdiv => std::ops::Div::div, - _ => panic!("not binary op"), - })) - } -} -impl ConstFold for BinaryFold { - fn fold( - &self, - _type_args: &[crate::types::TypeArg], - consts: &[(IncomingPort, ops::Const)], - ) -> ConstFoldResult { - let [f1, f2] = get_floats(consts)?; - - let res = ConstF64::new((self.0)(f1, f2)); - Some(vec![(0.into(), res.into())]) - } -} - -/// Fold comparisons. -struct CmpFold(Box bool + Send + Sync>); -impl CmpFold { - fn from_op(op: FloatOps) -> Self { - use FloatOps::*; - Self(Box::new(move |x, y| { - (match op { - feq => f64::eq, - fne => f64::lt, - flt => f64::lt, - fgt => f64::gt, - fle => f64::le, - fge => f64::ge, - _ => panic!("not cmp op"), - })(&x, &y) - })) - } -} - -impl ConstFold for CmpFold { - fn fold( - &self, - _type_args: &[crate::types::TypeArg], - consts: &[(IncomingPort, ops::Const)], - ) -> ConstFoldResult { - let [f1, f2] = get_floats(consts)?; - - let res = if (self.0)(f1, f2) { - ops::Const::true_val() - } else { - ops::Const::false_val() - }; - - Some(vec![(0.into(), res)]) - } -} - -/// Fold unary operations -struct UnaryFold(Box f64 + Send + Sync>); -impl UnaryFold { - fn from_op(op: &FloatOps) -> Self { - use FloatOps::*; - Self(Box::new(match op { - fneg => std::ops::Neg::neg, - fabs => f64::abs, - ffloor => f64::floor, - fceil => f64::ceil, - _ => panic!("not unary op."), - })) - } -} - -impl ConstFold for UnaryFold { - fn fold( - &self, - _type_args: &[crate::types::TypeArg], - consts: &[(IncomingPort, ops::Const)], - ) -> ConstFoldResult { - let [f1] = get_floats(consts)?; - let res = ConstF64::new((self.0)(f1)); - Some(vec![(0.into(), res.into())]) - } -} diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index ba5fe0956..71f91bf87 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -40,7 +40,7 @@ impl std::ops::Deref for ConstF64 { impl ConstF64 { /// Create a new [`ConstF64`] - pub const fn new(value: f64) -> Self { + pub fn new(value: f64) -> Self { Self { value } } diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index e1cea6c49..ae5160ffd 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -11,7 +11,6 @@ use crate::ops::OpName; use crate::type_row; use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; - use crate::{ extension::{ExtensionId, ExtensionSet, SignatureError}, types::{type_param::TypeArg, Type, TypeRow}, diff --git a/src/values.rs b/src/values.rs index 46d2778ee..17d173a00 100644 --- a/src/values.rs +++ b/src/values.rs @@ -10,7 +10,6 @@ use smol_str::SmolStr; use crate::extension::ExtensionSet; use crate::macros::impl_box_clone; - use crate::{Hugr, HugrView}; use crate::types::{CustomCheckFailure, CustomType};