diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f895bc87c..5514dedee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,13 +56,13 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - rust: ['1.70', stable, beta, nightly] + rust: ['1.75', stable, beta, nightly] # workaround to ignore non-stable tests when running the merge queue checks # see: https://github.community/t/how-to-conditionally-include-exclude-items-in-matrix-eg-based-on-branch/16853/6 isMerge: - ${{ github.event_name == 'merge_group' }} exclude: - - rust: '1.70' + - rust: '1.75' isMerge: true - rust: beta isMerge: true diff --git a/.github/workflows/release-plz.yml b/.github/workflows/release-plz.yml new file mode 100644 index 000000000..436bb2f4d --- /dev/null +++ b/.github/workflows/release-plz.yml @@ -0,0 +1,28 @@ +name: Release-plz + +permissions: + pull-requests: write + contents: write + +on: + push: + branches: + - main + +jobs: + release-plz: + name: Release-plz + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + - name: Run release-plz + uses: MarcoIeni/release-plz-action@v0.5 + with: + command: release-pr + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/Cargo.toml b/Cargo.toml index b1acc0bd4..c633a1222 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ description = "Hierarchical Unified Graph Representation" #categories = [] # TODO edition = "2021" -rust-version = "1.70" +rust-version = "1.75" [lib] # Using different names for the lib and for the package is supported, but may be confusing. @@ -46,8 +46,7 @@ lazy_static = "1.4.0" petgraph = { version = "0.6.3", default-features = false } context-iterators = "0.2.0" serde_json = "1.0.97" -delegate = "0.11.0" -rustversion = "1.0.14" +delegate = "0.12.0" paste = "1.0" strum = "0.25.0" strum_macros = "0.25.3" @@ -68,4 +67,3 @@ harness = false [profile.dev.package] insta.opt-level = 3 -similar.opt-level = 3 diff --git a/README.md b/README.md index 91ab2c617..4428eac1b 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ quantinuum-hugr =============== [![build_status][]](https://github.com/CQCL/hugr/actions) +[![crates][]](https://crates.io/crates/quantinuum-hugr) [![msrv][]](https://github.com/CQCL/hugr) [![codecov][]](https://codecov.io/gh/CQCL/hugr) @@ -16,15 +17,21 @@ The HUGR specification is [here](specification/hugr.md). ## Usage -Add this to your `Cargo.toml`: +Add the dependency to your project: -```toml -[dependencies] -quantinuum-hugr = "0.1" +```bash +cargo add quantinuum-hugr ``` The library crate is called `hugr`. +Please read the [API documentation here][]. + +## Recent Changes + +See [CHANGELOG][] for a list of changes. The minimum supported rust +version will only change on major releases. + ## Development See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up the development environment. @@ -33,7 +40,9 @@ See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up the developm This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). + [API documentation here]: https://docs.rs/quantinuum-hugr/ [build_status]: https://github.com/CQCL/hugr/workflows/Continuous%20integration/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.70.0%2B-blue.svg + [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: LICENCE + [CHANGELOG]: CHANGELOG.md diff --git a/cliff.toml b/cliff.toml new file mode 100644 index 000000000..8c3dba1ca --- /dev/null +++ b/cliff.toml @@ -0,0 +1,73 @@ +# git-cliff ~ default configuration file +# https://git-cliff.org/docs/configuration +# +# Lines starting with "#" are comments. +# Configuration options are organized into tables and keys. +# See documentation for more information on available options. + +[changelog] +# changelog header +header = """ +# Changelog\n +""" +# template for the changelog body +# https://tera.netlify.app/docs +body = """ +{% if version %}\ + ## {{ version }} ({{ timestamp | date(format="%Y-%m-%d") }}) +{% else %}\ + ## Unreleased (XXXX-XX-XX) +{% endif %}\ +{% for group, commits in commits | group_by(attribute="group") %} + ### {{ group | upper_first }} + {% for commit in commits %} + - {% if commit.breaking %}[**breaking**] {% endif %}{{ commit.message | upper_first }}\ + {% endfor %} +{% endfor %}\n +""" +# remove the leading and trailing whitespace from the template +trim = true +# changelog footer +footer = "" + +[git] +# parse the commits based on https://www.conventionalcommits.org +conventional_commits = true +# filter out the commits that are not conventional +filter_unconventional = true +# process each line of a commit as an individual commit +split_commits = false +# regex for preprocessing the commit messages +commit_preprocessors = [ + { pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](https://github.com/CQCL/portgraph/issues/${2}))"}, # replace issue numbers +] +# regex for parsing and grouping commits +commit_parsers = [ + { message = "^feat", group = "Features" }, + { message = "^fix", group = "Bug Fixes" }, + { message = "^docs", group = "Documentation" }, + { message = "^style", group = "Styling" }, + { message = "^refactor", group = "Refactor" }, + { message = "^perf", group = "Performance" }, + { message = "^test", group = "Testing" }, + { message = "^chore\\(release\\): prepare for", skip = true }, + { message = "^chore", group = "Miscellaneous Tasks", skip = true }, + { message = "^revert", group = "Reverted changes", skip = true }, + { message = "^ci", group = "CI", skip = true }, +] +# protect breaking changes from being skipped due to matching a skipping commit_parser +protect_breaking_commits = true +# filter out the commits that are not matched by commit parsers +filter_commits = false +# glob pattern for matching git tags +tag_pattern = "v[0-9.]*" +# regex for skipping tags +skip_tags = "v0.1.0-beta.1" +# regex for ignoring tags +ignore_tags = "" +# sort the tags topologically +topo_order = false +# sort the commits inside sections by oldest-first/newest-first +sort_commits = "oldest" +# limit the number of commits included in the changelog. +# limit_commits = 42 diff --git a/devenv.lock b/devenv.lock index 3d6e2de22..7750dfcfa 100644 --- a/devenv.lock +++ b/devenv.lock @@ -3,11 +3,11 @@ "devenv": { "locked": { "dir": "src/modules", - "lastModified": 1700140236, - "narHash": "sha256-OpukFO0rRG2hJzD+pCQq+nSWuT9dBL6DSvADQaUlmFg=", + "lastModified": 1703939110, + "narHash": "sha256-GgjYWkkHQ8pUBwXX++ah+4d07DqOeCDaaQL6Ab86C50=", "owner": "cachix", "repo": "devenv", - "rev": "525d60c44de848a6b2dd468f6efddff078eb2af2", + "rev": "7354096fc026f79645fdac73e9aeea71a09412c3", "type": "github" }, "original": { @@ -25,11 +25,11 @@ "rust-analyzer-src": "rust-analyzer-src" }, "locked": { - "lastModified": 1700461394, - "narHash": "sha256-lBpjEshdBxeuJwc4+vh4jbO3AmhXbiFrkdWy2pABAAc=", + "lastModified": 1704262971, + "narHash": "sha256-3HB1yaMBBox3z9oXEiQuZzQhXegOc9P3FR6/XNsJGn0=", "owner": "nix-community", "repo": "fenix", - "rev": "5ad1b10123ca40c9d983fb0863403fd97a06c0f8", + "rev": "38aaea4e54dc3874a6355c10861bd8316a6f09f3", "type": "github" }, "original": { @@ -95,11 +95,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1700444282, - "narHash": "sha256-s/+tgT+Iz0LZO+nBvSms+xsMqvHt2LqYniG9r+CYyJc=", + "lastModified": 1704008649, + "narHash": "sha256-rGPSWjXTXTurQN9beuHdyJhB8O761w1Zc5BqSSmHvoM=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "3f21a22b5aafefa1845dec6f4a378a8f53d8681c", + "rev": "d44d59d2b5bd694cd9d996fd8c51d03e3e9ba7f7", "type": "github" }, "original": { @@ -111,11 +111,11 @@ }, "nixpkgs-stable": { "locked": { - "lastModified": 1700403855, - "narHash": "sha256-Q0Uzjik9kUTN9pd/kp52XJi5kletBhy29ctBlAG+III=", + "lastModified": 1704018918, + "narHash": "sha256-erjg/HrpC9liEfm7oLqb8GXCqsxaFwIIPqCsknW5aFY=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "0c5678df521e1407884205fe3ce3cf1d7df297db", + "rev": "2c9c58e98243930f8cb70387934daa4bc8b00373", "type": "github" }, "original": { @@ -152,11 +152,11 @@ "nixpkgs-stable": "nixpkgs-stable_2" }, "locked": { - "lastModified": 1700064067, - "narHash": "sha256-1ZWNDzhu8UlVCK7+DUN9dVQfiHX1bv6OQP9VxstY/gs=", + "lastModified": 1703939133, + "narHash": "sha256-Gxe+mfOT6bL7wLC/tuT2F+V+Sb44jNr8YsJ3cyIl4Mo=", "owner": "cachix", "repo": "pre-commit-hooks.nix", - "rev": "e558068cba67b23b4fbc5537173dbb43748a17e8", + "rev": "9d3d7e18c6bc4473d7520200d4ddab12f8402d38", "type": "github" }, "original": { @@ -177,11 +177,11 @@ "rust-analyzer-src": { "flake": false, "locked": { - "lastModified": 1700247620, - "narHash": "sha256-+Xg0qZLbC9dZx0Z6JbaVHR/BklAr2I83dzKLB8r41c8=", + "lastModified": 1704207973, + "narHash": "sha256-VEWsjIKtdinx5iyhfxuTHRijYBKSbO/8Gw1HPoWD9mQ=", "owner": "rust-lang", "repo": "rust-analyzer", - "rev": "255eed40c45fcf108ba844b4ad126bdc4e7a18df", + "rev": "426d2842c1f0e5cc5e34bb37c7ac3ee0945f9746", "type": "github" }, "original": { diff --git a/devenv.nix b/devenv.nix index f4a85aca5..bc9fdba1c 100644 --- a/devenv.nix +++ b/devenv.nix @@ -31,7 +31,7 @@ in # https://devenv.sh/languages/ # https://devenv.sh/reference/options/#languagesrustversion languages.rust = { - channel = "beta"; + channel = "stable"; enable = true; components = [ "rustc" "cargo" "clippy" "rustfmt" "rust-analyzer" ]; }; diff --git a/release-plz.toml b/release-plz.toml new file mode 100644 index 000000000..0d7dd7e3e --- /dev/null +++ b/release-plz.toml @@ -0,0 +1,2 @@ +[workspace] +changelog_config = "cliff.toml" # use a custom git-cliff configuration diff --git a/specification/hugr.md b/specification/hugr.md index 3c6517dcf..fc5f6f65b 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -267,19 +267,13 @@ the following basic dataflow operations are available (in addition to the - `Input/Output`: input/output nodes, the outputs of `Input` node are the inputs to the function, and the inputs to `Output` are the - outputs of the function. In a data dependency subgraph, a valid - ordering of operations can be achieved by topologically sorting the - nodes starting from `Input` with respect to the Value and Order - edges. + outputs of the function. - `Call`: Call a statically defined function. There is an incoming `Static<Function>` edge to specify the graph being called. The signature of the node (defined by its incoming and outgoing `Value` edges) matches the function being called. - `LoadConstant<T>`: has an incoming `Static<T>` edge, where `T` is a `CopyableType`, and a `Value<Local,T>` output, used to load a static constant into the local - dataflow graph. They also have an incoming `Order` edge connecting - them to the `Input` node, as should all operations that - take no dataflow input, to ensure they lie in the causal cone of the - `Input` node when traversing. + dataflow graph. - `identity<T>`: pass-through, no operation is performed. - `DFG`: A nested dataflow graph. These nodes are parents in the hierarchy. @@ -515,10 +509,11 @@ graph: cycles. The common parent is a CFG-node. **Dataflow Sibling Graph (DSG)**: nodes are operations, `CFG`, -`Conditional`, `TailLoop` and `DFG` nodes; edges are `Value`, `Order` and `Static`; -and must be acyclic. There is a unique Input node and Output node. All nodes must be -reachable from the Input node, and must reach the Output node. The common parent -may be a `FuncDefn`, `TailLoop`, `DFG`, `Case` or `DFB` node. +`Conditional`, `TailLoop` and `DFG` nodes; edges are `Value`, `Order` and `Static`, and must be acyclic. +(Thus a valid ordering of operations can be achieved by topologically sorting the +nodes.) +There is a unique Input node and Output node. +The common parent may be a `FuncDefn`, `TailLoop`, `DFG`, `Case` or `DFB` node. | **Edge Kind** | **Locality** | | -------------- | ------------ | @@ -1191,7 +1186,19 @@ The new hugr is then derived as follows: ###### `Replace` -This is the general subgraph-replacement method. +This is the general subgraph-replacement method. Intuitively, it takes a set of +sibling nodes to remove and replace with a new set of nodes. The new set of +nodes is itself a HUGR with some "holes" (edges and nodes that get "filled in" +by the `Replace` operation). To fully specify the operation, some further data +are needed: + + - The replacement may include container nodes with no children, which adopt + the children of removed container nodes and prevent those children being + removed. + - All new incoming edges from the retained nodes to the new nodes, all + outgoing edges from the new nodes to the retained nodes, and any new edges + that bypass the replacement (going between retained nodes) must be + specified. Given a set $S$ of nodes in a hugr, let $S^\*$ be the set of all nodes descended from nodes in $S$ (i.e. reachable from $S$ by following hierarchy edges), @@ -1234,7 +1241,9 @@ Note that considering all three $\mu$ lists together, - the `TgtNode` + `TgtPos`s of all `NewEdgeSpec`s with `EdgeKind` == `Value` will be unique - and similarly for `EdgeKind` == `Static` -The well-formedness requirements of Hugr imply that $\mu\_\textrm{inp}$ and $\mu\_\textrm{out}$ may only contain `NewEdgeSpec`s with certain `EdgeKind`s, depending on $P$: +The well-formedness requirements of Hugr imply that $\mu\_\textrm{inp}$, +$\mu\_\textrm{out}$ and $\mu\_\textrm{new}$ may only contain `NewEdgeSpec`s with +certain `EdgeKind`s, depending on $P$: - if $P$ is a dataflow container, `EdgeKind`s may be `Order`, `Value` or `Static` only (no `ControlFlow`) - if $P$ is a CFG node, `EdgeKind`s may be `ControlFlow`, `Value`, or `Static` only (no `Order`) - if $P$ is a Module node, there may be `Value` or `Static` only (no `Order`). @@ -1262,7 +1271,8 @@ The new hugr is then derived as follows: 6. For each node $(n, b = B(n))$ and for each child $m$ of $b$, replace the hierarchy edge from $b$ to $m$ with a hierarchy edge from the new copy of $n$ to $m$ (preserving the order). -7. Remove all nodes in $R$ and edges adjoining them. +7. Remove all nodes in $R$ and edges adjoining them. (Reindexing may be + necessary after this step.) ##### Outlining methods @@ -1325,8 +1335,8 @@ successor. Insert an Order edge from `n0` to `n1` where `n0` and `n1` are distinct siblings in a DSG such that there is no path in the DSG from `n1` to -`n0`. If there is already an order edge from `n0` to `n1` this does -nothing (but is not an error). +`n0`. (Thus acyclicity is preserved.) If there is already an order edge from +`n0` to `n1` this does nothing (but is not an error). ###### `RemoveOrder` @@ -1340,8 +1350,7 @@ remove it. (If there is an non-local edge from `n0` to a descendent of Given a `Const<T>` node `c`, and optionally `P`, a parent of a DSG, add a new `LoadConstant<T>` node `n` as a child of `P` with a `Static<T>` edge -from `c` to `n` and no outgoing edges from `n`. Also add an Order edge -from the Input node under `P` to `n`. Return the ID of `n`. If `P` is +from `c` to `n` and no outgoing edges from `n`. Return the ID of `n`. If `P` is omitted it defaults to the parent of `c` (in this case said `c` will have to be in a DSG or CSG rather than under the Module Root.) If `P` is provided, it must be a descendent of the parent of `c`. @@ -1349,7 +1358,7 @@ provided, it must be a descendent of the parent of `c`. ###### `RemoveConstIgnore` Given a `LoadConstant<T>` node `n` that has no outgoing edges, remove -it (and its incoming value and Order edges) from the hugr. +it (and its incoming Static edge and any Order edges) from the hugr. ##### Insertion and removal of const nodes @@ -1371,7 +1380,7 @@ nodes. The most basic case – replacing a convex set of Op nodes in a DSG with another graph of Op nodes having the same signature – is implemented by -having T map everything to the parent node, and bot(G) is empty. +`SimpleReplace`. If one of the nodes in the region is a complex container node that we wish to preserve in the replacement without doing a deep copy, we can diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 641ef1ae2..c85a02eb0 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -70,8 +70,8 @@ pub trait Container { /// /// This function will return an error if there is an error in adding the /// [`OpType::Const`] node. - fn add_constant(&mut self, constant: ops::Const) -> Result<ConstID, BuildError> { - let const_n = self.add_child_node(NodeType::new(constant, ExtensionSet::new()))?; + fn add_constant(&mut self, constant: impl Into<ops::Const>) -> Result<ConstID, BuildError> { + let const_n = self.add_child_node(NodeType::new(constant.into(), ExtensionSet::new()))?; Ok(const_n.into()) } @@ -374,7 +374,7 @@ pub trait Dataflow: Container { /// # Errors /// /// This function will return an error if there is an error when adding the node. - fn add_load_const(&mut self, constant: ops::Const) -> Result<Wire, BuildError> { + fn add_load_const(&mut self, constant: impl Into<ops::Const>) -> Result<Wire, BuildError> { let cid = self.add_constant(constant)?; self.load_const(&cid) } diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index bbcddade7..8f99ee512 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -109,7 +109,7 @@ mod test { let build_result: Result<Hugr, ValidationError> = { let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?; let [i1] = loop_b.input_wires_arr(); - let const_wire = loop_b.add_load_const(ConstUsize::new(1).into())?; + let const_wire = loop_b.add_load_const(ConstUsize::new(1))?; let break_wire = loop_b.make_break(loop_b.loop_signature()?.clone(), [const_wire])?; loop_b.set_outputs(break_wire, [i1])?; @@ -173,7 +173,7 @@ mod test { let mut branch_1 = conditional_b.case_builder(1)?; let [_b1] = branch_1.input_wires_arr(); - let wire = branch_1.add_load_const(ConstUsize::new(2).into())?; + let wire = branch_1.add_load_const(ConstUsize::new(2))?; let break_wire = branch_1.make_break(signature, [wire])?; branch_1.finish_with_outputs([break_wire])?; diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index f96046ba8..b411c3667 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -137,12 +137,11 @@ pub fn new_array_op(element_ty: Type, size: u64) -> LeafOp { .into() } +/// The custom type for Errors. +pub const ERROR_CUSTOM_TYPE: CustomType = + CustomType::new_simple(ERROR_TYPE_NAME, PRELUDE_ID, TypeBound::Eq); /// Unspecified opaque error type. -pub const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple( - ERROR_TYPE_NAME, - PRELUDE_ID, - TypeBound::Eq, -)); +pub const ERROR_TYPE: Type = Type::new_extension(ERROR_CUSTOM_TYPE); /// The string name of the error type. pub const ERROR_TYPE_NAME: SmolStr = SmolStr::new_inline("error"); @@ -191,6 +190,48 @@ impl KnownTypeConst for ConstUsize { const TYPE: CustomType = USIZE_CUSTOM_T; } +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +/// Structure for holding constant usize values. +pub struct ConstError { + /// Integer tag/signal for the error. + pub signal: u32, + /// Error message. + pub message: String, +} + +impl ConstError { + /// Define a new error value. + pub fn new(signal: u32, message: impl ToString) -> Self { + Self { + signal, + message: message.to_string(), + } + } +} + +#[typetag::serde] +impl CustomConst for ConstError { + fn name(&self) -> SmolStr { + format!("ConstError({:?}, {:?})", self.signal, self.message).into() + } + + fn check_custom_type(&self, typ: &CustomType) -> Result<(), CustomCheckFailure> { + self.check_known_type(typ) + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::values::downcast_equal_consts(self, other) + } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&PRELUDE_ID) + } +} + +impl KnownTypeConst for ConstError { + const TYPE: CustomType = ERROR_CUSTOM_TYPE; +} + #[cfg(test)] mod test { use crate::{ @@ -219,7 +260,7 @@ mod test { } #[test] - /// Test building a HUGR involving a new_array operation. + /// test the prelude error type. fn test_error_type() { let ext_def = PRELUDE .get_type(&ERROR_TYPE_NAME) @@ -229,5 +270,18 @@ mod test { let ext_type = Type::new_extension(ext_def); assert_eq!(ext_type, ERROR_TYPE); + + let error_val = ConstError::new(2, "my message"); + + assert_eq!(error_val.name(), "ConstError(2, \"my message\")"); + + assert!(error_val.check_custom_type(&ERROR_CUSTOM_TYPE).is_ok()); + + assert_eq!( + error_val.extension_reqs(), + ExtensionSet::singleton(&PRELUDE_ID) + ); + assert!(error_val.equal_consts(&ConstError::new(2, "my message"))); + assert!(!error_val.equal_consts(&ConstError::new(3, "my message"))); } } diff --git a/src/hugr/rewrite.rs b/src/hugr/rewrite.rs index 3f524e2db..05f1f48d9 100644 --- a/src/hugr/rewrite.rs +++ b/src/hugr/rewrite.rs @@ -1,5 +1,6 @@ //! Rewrite operations on the HUGR - replacement, outlining, etc. +pub mod consts; pub mod insert_identity; pub mod outline_cfg; pub mod replace; diff --git a/src/hugr/rewrite/consts.rs b/src/hugr/rewrite/consts.rs new file mode 100644 index 000000000..61be178d5 --- /dev/null +++ b/src/hugr/rewrite/consts.rs @@ -0,0 +1,214 @@ +//! Rewrite operations involving Const and LoadConst operations + +use std::iter; + +use crate::{ + hugr::{HugrError, HugrMut}, + HugrView, Node, +}; + +use itertools::Itertools; +use thiserror::Error; + +use super::Rewrite; + +/// Remove a [`crate::ops::LoadConstant`] node with no consumers. +#[derive(Debug, Clone)] +pub struct RemoveConstIgnore(pub Node); + +/// Error from an [`RemoveConst`] or [`RemoveConstIgnore`] operation. +#[derive(Debug, Clone, Error, PartialEq, Eq)] +pub enum RemoveError { + /// Invalid node. + #[error("Node is invalid (either not in HUGR or not correct operation).")] + InvalidNode(Node), + /// Node in use. + #[error("Node: {0:?} has non-zero outgoing connections.")] + ValueUsed(Node), + /// Removal error + #[error("Removing node caused error: {0:?}.")] + RemoveFail(#[from] HugrError), +} + +impl Rewrite for RemoveConstIgnore { + type Error = RemoveError; + + // The Const node the LoadConstant was connected to. + type ApplyResult = Node; + + type InvalidationSet<'a> = iter::Once<Node>; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + let node = self.0; + + if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) { + return Err(RemoveError::InvalidNode(node)); + } + + if h.out_value_types(node) + .next() + .is_some_and(|(p, _)| h.linked_inputs(node, p).next().is_some()) + { + return Err(RemoveError::ValueUsed(node)); + } + + Ok(()) + } + + fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> { + self.verify(h)?; + let node = self.0; + let source = h + .input_neighbours(node) + .exactly_one() + .ok() + .expect("Validation should check a Const is connected to LoadConstant."); + h.remove_node(node)?; + + Ok(source) + } + + fn invalidation_set(&self) -> Self::InvalidationSet<'_> { + iter::once(self.0) + } +} + +/// Remove a [`crate::ops::Const`] node with no outputs. +#[derive(Debug, Clone)] +pub struct RemoveConst(pub Node); + +impl Rewrite for RemoveConst { + type Error = RemoveError; + + // The parent of the Const node. + type ApplyResult = Node; + + type InvalidationSet<'a> = iter::Once<Node>; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + let node = self.0; + + if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) { + return Err(RemoveError::InvalidNode(node)); + } + + if h.output_neighbours(node).next().is_some() { + return Err(RemoveError::ValueUsed(node)); + } + + Ok(()) + } + + fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> { + self.verify(h)?; + let node = self.0; + let parent = h + .get_parent(node) + .expect("Const node without a parent shouldn't happen."); + h.remove_node(node)?; + + Ok(parent) + } + + fn invalidation_set(&self) -> Self::InvalidationSet<'_> { + iter::once(self.0) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}, + extension::{ + prelude::{ConstUsize, USIZE_T}, + PRELUDE_REGISTRY, + }, + hugr::HugrMut, + ops::{handle::NodeHandle, LeafOp}, + type_row, + types::FunctionType, + }; + #[test] + fn test_const_remove() -> Result<(), Box<dyn std::error::Error>> { + let mut build = ModuleBuilder::new(); + let con_node = build.add_constant(ConstUsize::new(2))?; + + let mut dfg_build = + build.define_function("main", FunctionType::new_endo(type_row![]).into())?; + let load_1 = dfg_build.load_const(&con_node)?; + let load_2 = dfg_build.load_const(&con_node)?; + let tup = dfg_build.add_dataflow_op( + LeafOp::MakeTuple { + tys: type_row![USIZE_T, USIZE_T], + }, + [load_1, load_2], + )?; + dfg_build.finish_sub_container()?; + + let mut h = build.finish_prelude_hugr()?; + // nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple + assert_eq!(h.node_count(), 8); + let tup_node = tup.node(); + // can't remove invalid node + assert_eq!( + h.apply_rewrite(RemoveConst(tup_node)), + Err(RemoveError::InvalidNode(tup_node)) + ); + + assert_eq!( + h.apply_rewrite(RemoveConstIgnore(tup_node)), + Err(RemoveError::InvalidNode(tup_node)) + ); + let load_1_node = load_1.node(); + let load_2_node = load_2.node(); + let con_node = con_node.node(); + + let remove_1 = RemoveConstIgnore(load_1_node); + assert_eq!( + remove_1.invalidation_set().exactly_one().ok(), + Some(load_1_node) + ); + + let remove_2 = RemoveConstIgnore(load_2_node); + + let remove_con = RemoveConst(con_node); + assert_eq!( + remove_con.invalidation_set().exactly_one().ok(), + Some(con_node) + ); + + // can't remove nodes in use + assert_eq!( + h.apply_rewrite(remove_1.clone()), + Err(RemoveError::ValueUsed(load_1_node)) + ); + + // remove the use + h.remove_node(tup_node)?; + + // remove first load + let reported_con_node = h.apply_rewrite(remove_1)?; + assert_eq!(reported_con_node, con_node); + + // still can't remove const, in use by second load + assert_eq!( + h.apply_rewrite(remove_con.clone()), + Err(RemoveError::ValueUsed(con_node)) + ); + + // remove second use + let reported_con_node = h.apply_rewrite(remove_2)?; + assert_eq!(reported_con_node, con_node); + // remove const + assert_eq!(h.apply_rewrite(remove_con)?, h.root()); + + assert_eq!(h.node_count(), 4); + assert!(h.validate(&PRELUDE_REGISTRY).is_ok()); + Ok(()) + } +} diff --git a/src/hugr/views.rs b/src/hugr/views.rs index 769432cc0..b74e186cb 100644 --- a/src/hugr/views.rs +++ b/src/hugr/views.rs @@ -26,11 +26,11 @@ use super::{Hugr, HugrError, NodeMetadata, NodeMetadataMap, NodeType, DEFAULT_NO use crate::ops::dataflow::DataflowParent; use crate::ops::handle::NodeHandle; use crate::ops::{OpName, OpTag, OpTrait, OpType}; -#[rustversion::since(1.75)] // uses impl in return position + use crate::types::Type; use crate::types::{EdgeKind, FunctionType}; use crate::{Direction, IncomingPort, Node, OutgoingPort, Port}; -#[rustversion::since(1.75)] // uses impl in return position + use itertools::Either; /// A trait for inspecting HUGRs. @@ -184,7 +184,6 @@ pub trait HugrView: sealed::HugrInternals { /// Iterator over the nodes and ports connected to a port. fn linked_ports(&self, node: Node, port: impl Into<Port>) -> Self::PortLinks<'_>; - #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all the nodes and ports connected to a node in a given direction. fn all_linked_ports( &self, @@ -206,7 +205,6 @@ pub trait HugrView: sealed::HugrInternals { } } - #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all the nodes and ports connected to a node's inputs. fn all_linked_outputs(&self, node: Node) -> impl Iterator<Item = (Node, OutgoingPort)> { self.all_linked_ports(node, Direction::Incoming) @@ -214,7 +212,6 @@ pub trait HugrView: sealed::HugrInternals { .unwrap() } - #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all the nodes and ports connected to a node's outputs. fn all_linked_inputs(&self, node: Node) -> impl Iterator<Item = (Node, IncomingPort)> { self.all_linked_ports(node, Direction::Outgoing) @@ -414,7 +411,6 @@ pub trait HugrView: sealed::HugrInternals { .map(|(n, _)| n) } - #[rustversion::since(1.75)] // uses impl in return position /// If a node has a static output, return the targets. fn static_targets(&self, node: Node) -> Option<impl Iterator<Item = (Node, IncomingPort)>> { Some(self.linked_inputs(node, self.get_optype(node).static_output_port()?)) @@ -426,7 +422,6 @@ pub trait HugrView: sealed::HugrInternals { self.get_optype(node).dataflow_signature() } - #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all outgoing ports that have Value type, along /// with corresponding types. fn value_types(&self, node: Node, dir: Direction) -> impl Iterator<Item = (Port, Type)> { @@ -435,7 +430,6 @@ pub trait HugrView: sealed::HugrInternals { .flat_map(move |port| sig.port_type(port).map(|typ| (port, typ.clone()))) } - #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all incoming ports that have Value type, along /// with corresponding types. fn in_value_types(&self, node: Node) -> impl Iterator<Item = (IncomingPort, Type)> { @@ -443,7 +437,6 @@ pub trait HugrView: sealed::HugrInternals { .map(|(p, t)| (p.as_incoming().unwrap(), t)) } - #[rustversion::since(1.75)] // uses impl in return position /// Iterator over all incoming ports that have Value type, along /// with corresponding types. fn out_value_types(&self, node: Node) -> impl Iterator<Item = (OutgoingPort, Type)> { @@ -621,7 +614,6 @@ impl<T: AsRef<Hugr>> HugrView for T { } } -#[rustversion::since(1.75)] // uses impl in return position /// Trait implementing methods on port iterators. pub trait PortIterator<P>: Iterator<Item = (Node, P)> where @@ -639,7 +631,7 @@ where }) } } -#[rustversion::since(1.75)] // uses impl in return position + impl<I, P> PortIterator<P> for I where I: Iterator<Item = (Node, P)>, diff --git a/src/hugr/views/tests.rs b/src/hugr/views/tests.rs index 97fb50861..ce0353d48 100644 --- a/src/hugr/views/tests.rs +++ b/src/hugr/views/tests.rs @@ -69,7 +69,6 @@ fn dot_string(sample_hugr: (Hugr, BuildHandle<DataflowOpID>, BuildHandle<Dataflo insta::assert_yaml_snapshot!(h.dot_string()); } -#[rustversion::since(1.75)] // uses impl in return position #[rstest] fn all_ports(sample_hugr: (Hugr, BuildHandle<DataflowOpID>, BuildHandle<DataflowOpID>)) { use itertools::Itertools; @@ -97,7 +96,6 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle<DataflowOpID>, BuildHandle<Dataflow ); } -#[rustversion::since(1.75)] // uses impl in return position #[test] fn value_types() { use crate::builder::Container; @@ -129,7 +127,6 @@ fn value_types() { assert_eq!(&out_types[..], &[(0.into(), BOOL_T), (1.into(), QB_T)]); } -#[rustversion::since(1.75)] // uses impl in return position #[test] fn static_targets() { use crate::extension::{ @@ -143,7 +140,7 @@ fn static_targets() { ) .unwrap(); - let c = dfg.add_constant(ConstUsize::new(1).into()).unwrap(); + let c = dfg.add_constant(ConstUsize::new(1)).unwrap(); let load = dfg.load_const(&c).unwrap(); @@ -157,7 +154,6 @@ fn static_targets() { ) } -#[rustversion::since(1.75)] // uses impl in return position #[test] fn test_dataflow_ports_only() { use crate::builder::DataflowSubContainer; diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 6db0006b3..c66ab6f76 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -148,7 +148,7 @@ mod test { use super::*; fn test_registry() -> ExtensionRegistry { - ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::extension()]).unwrap() + ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]).unwrap() } #[test] diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 4ae262b77..98e5df887 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,63 +1,131 @@ //! Conversions between integer and floating-point values. +use smol_str::SmolStr; +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; + use crate::{ - extension::{prelude::sum_with_error, ExtensionId, ExtensionSet}, + extension::{ + prelude::sum_with_error, + simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}, + ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc, + }, + ops::{custom::ExtensionOp, OpName}, type_row, - types::{FunctionType, PolyFuncType}, + types::{FunctionType, PolyFuncType, TypeArg}, Extension, }; use super::int_types::int_tv; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; +use lazy_static::lazy_static; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); -/// Extension for basic arithmetic operations. -pub fn extension() -> Extension { - let ftoi_sig = PolyFuncType::new( - vec![LOG_WIDTH_TYPE_PARAM], - FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]), - ); - - let itof_sig = PolyFuncType::new( - vec![LOG_WIDTH_TYPE_PARAM], - FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]), - ); - - let mut extension = Extension::new_with_reqs( - EXTENSION_ID, - ExtensionSet::from_iter(vec![ - super::int_types::EXTENSION_ID, - super::float_types::EXTENSION_ID, - ]), - ); - extension - .add_op( - "trunc_u".into(), - "float to unsigned int".to_owned(), - ftoi_sig.clone(), - ) - .unwrap(); - extension - .add_op("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig) - .unwrap(); - extension - .add_op( - "convert_u".into(), - "unsigned int to float".to_owned(), - itof_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "convert_s".into(), - "signed int to float".to_owned(), - itof_sig, - ) - .unwrap(); - - extension +/// Extensiop for conversions between floats and integers. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +#[allow(missing_docs, non_camel_case_types)] +pub enum ConvertOpDef { + trunc_u, + trunc_s, + convert_u, + convert_s, +} + +impl MakeOpDef for ConvertOpDef { + fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> { + crate::extension::simple_op::try_from_name(op_def.name()) + } + + fn signature(&self) -> SignatureFunc { + use ConvertOpDef::*; + match self { + trunc_s | trunc_u => PolyFuncType::new( + vec![LOG_WIDTH_TYPE_PARAM], + FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]), + ), + + convert_s | convert_u => PolyFuncType::new( + vec![LOG_WIDTH_TYPE_PARAM], + FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]), + ), + } + .into() + } + + fn description(&self) -> String { + use ConvertOpDef::*; + match self { + trunc_u => "float to unsigned int", + trunc_s => "float to signed int", + convert_u => "unsigned int to float", + convert_s => "signed int to float", + } + .to_string() + } +} + +/// Concrete convert operation with integer width set. +#[derive(Debug, Clone, PartialEq)] +pub struct ConvertOpType { + def: ConvertOpDef, + width: u64, +} + +impl OpName for ConvertOpType { + fn name(&self) -> SmolStr { + self.def.name() + } +} + +impl MakeExtensionOp for ConvertOpType { + fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> { + let def = ConvertOpDef::from_def(ext_op.def())?; + let width = match *ext_op.args() { + [TypeArg::BoundedNat { n }] => n, + _ => return Err(SignatureError::InvalidTypeArgs.into()), + }; + Ok(Self { def, width }) + } + + fn type_args(&self) -> Vec<crate::types::TypeArg> { + vec![TypeArg::BoundedNat { n: self.width }] + } +} + +lazy_static! { + /// Extension for conversions between integers and floats. + pub static ref EXTENSION: Extension = { + let mut extension = Extension::new_with_reqs( + EXTENSION_ID, + ExtensionSet::from_iter(vec![ + super::int_types::EXTENSION_ID, + super::float_types::EXTENSION_ID, + ]), + ); + + ConvertOpDef::load_all_ops(&mut extension).unwrap(); + + extension + }; + + /// Registry of extensions required to validate integer operations. + pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + super::int_types::EXTENSION.to_owned(), + super::float_types::EXTENSION.to_owned(), + EXTENSION.to_owned(), + ]) + .unwrap(); +} + +impl MakeRegisteredOp for ConvertOpType { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &CONVERT_OPS_REGISTRY + } } #[cfg(test)] @@ -66,7 +134,7 @@ mod test { #[test] fn test_conversions_extension() { - let r = extension(); + let r = &EXTENSION; assert_eq!(r.name() as &str, "arithmetic.conversions"); assert_eq!(r.types().count(), 0); for (name, _) in r.operations() { diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 5cef5d19a..87c87751b 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -1,103 +1,119 @@ //! Basic floating-point operations. +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; + +use super::float_types::FLOAT64_TYPE; use crate::{ - extension::{ExtensionId, ExtensionSet}, + extension::{ + prelude::BOOL_T, + simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError}, + ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureFunc, PRELUDE, + }, type_row, - types::{FunctionType, PolyFuncType}, + types::FunctionType, Extension, }; - -use super::float_types::FLOAT64_TYPE; +use lazy_static::lazy_static; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float"); -/// Extension for basic arithmetic operations. -pub fn extension() -> Extension { - let mut extension = Extension::new_with_reqs( - EXTENSION_ID, - ExtensionSet::singleton(&super::float_types::EXTENSION_ID), - ); - - let fcmp_sig: PolyFuncType = FunctionType::new( - type_row![FLOAT64_TYPE; 2], - type_row![crate::extension::prelude::BOOL_T], - ) - .into(); - let fbinop_sig: PolyFuncType = - FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]).into(); - let funop_sig: PolyFuncType = - FunctionType::new(type_row![FLOAT64_TYPE], type_row![FLOAT64_TYPE]).into(); - extension - .add_op("feq".into(), "equality test".to_owned(), fcmp_sig.clone()) - .unwrap(); - extension - .add_op("fne".into(), "inequality test".to_owned(), fcmp_sig.clone()) - .unwrap(); - extension - .add_op("flt".into(), "\"less than\"".to_owned(), fcmp_sig.clone()) - .unwrap(); - extension - .add_op( - "fgt".into(), - "\"greater than\"".to_owned(), - fcmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "fle".into(), - "\"less than or equal\"".to_owned(), - fcmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "fge".into(), - "\"greater than or equal\"".to_owned(), - fcmp_sig, - ) - .unwrap(); - extension - .add_op("fmax".into(), "maximum".to_owned(), fbinop_sig.clone()) - .unwrap(); - extension - .add_op("fmin".into(), "minimum".to_owned(), fbinop_sig.clone()) - .unwrap(); - extension - .add_op("fadd".into(), "addition".to_owned(), fbinop_sig.clone()) - .unwrap(); - extension - .add_op("fsub".into(), "subtraction".to_owned(), fbinop_sig.clone()) - .unwrap(); - extension - .add_op("fneg".into(), "negation".to_owned(), funop_sig.clone()) - .unwrap(); - extension - .add_op( - "fabs".into(), - "absolute value".to_owned(), - funop_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "fmul".into(), - "multiplication".to_owned(), - fbinop_sig.clone(), - ) - .unwrap(); - extension - .add_op("fdiv".into(), "division".to_owned(), fbinop_sig) - .unwrap(); - extension - .add_op("ffloor".into(), "floor".to_owned(), funop_sig.clone()) - .unwrap(); - extension - .add_op("fceil".into(), "ceiling".to_owned(), funop_sig) - .unwrap(); - - extension +/// Integer extension operation definitions. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +#[allow(missing_docs, non_camel_case_types)] +pub enum FloatOps { + feq, + fne, + flt, + fgt, + fle, + fge, + fmax, + fmin, + fadd, + fsub, + fneg, + fabs, + fmul, + fdiv, + ffloor, + fceil, +} + +impl MakeOpDef for FloatOps { + fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> { + crate::extension::simple_op::try_from_name(op_def.name()) + } + + fn signature(&self) -> SignatureFunc { + use FloatOps::*; + + match self { + feq | fne | flt | fgt | fle | fge => { + FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![BOOL_T]) + } + fmax | fmin | fadd | fsub | fmul | fdiv => { + FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]) + } + fneg | fabs | ffloor | fceil => FunctionType::new_endo(type_row![FLOAT64_TYPE]), + } + .into() + } + + fn description(&self) -> String { + use FloatOps::*; + match self { + feq => "equality test", + fne => "inequality test", + flt => "\"less than\"", + fgt => "\"greater than\"", + fle => "\"less than or equal\"", + fge => "\"greater than or equal\"", + fmax => "maximum", + fmin => "minimum", + fadd => "addition", + fsub => "subtraction", + fneg => "negation", + fabs => "absolute value", + fmul => "multiplication", + fdiv => "division", + ffloor => "floor", + fceil => "ceiling", + } + .to_string() + } +} + +lazy_static! { + /// Extension for basic float operations. + pub static ref EXTENSION: Extension = { + let mut extension = Extension::new_with_reqs( + EXTENSION_ID, + ExtensionSet::singleton(&super::int_types::EXTENSION_ID), + ); + + FloatOps::load_all_ops(&mut extension).unwrap(); + + extension + }; + + /// Registry of extensions required to validate float operations. + pub static ref FLOAT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + super::float_types::EXTENSION.to_owned(), + EXTENSION.to_owned(), + ]) + .unwrap(); +} + +impl MakeRegisteredOp for FloatOps { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &FLOAT_OPS_REGISTRY + } } #[cfg(test)] @@ -106,7 +122,7 @@ mod test { #[test] fn test_float_ops_extension() { - let r = extension(); + let r = &EXTENSION; assert_eq!(r.name() as &str, "arithmetic.float"); assert_eq!(r.types().count(), 0); for (name, _) in r.operations() { diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index 582c0eee7..71f91bf87 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -8,6 +8,7 @@ use crate::{ values::{CustomConst, KnownTypeConst}, Extension, }; +use lazy_static::lazy_static; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float.types"); @@ -72,29 +73,30 @@ impl CustomConst for ConstF64 { } } -/// Extension for basic floating-point types. -pub fn extension() -> Extension { - let mut extension = Extension::new(EXTENSION_ID); - - extension - .add_type( - FLOAT_TYPE_ID, - vec![], - "64-bit IEEE 754-2019 floating-point value".to_owned(), - TypeBound::Copyable.into(), - ) - .unwrap(); - - extension +lazy_static! { + /// Extension defining the float type. + pub static ref EXTENSION: Extension = { + let mut extension = Extension::new(EXTENSION_ID); + + extension + .add_type( + FLOAT_TYPE_ID, + vec![], + "64-bit IEEE 754-2019 floating-point value".to_owned(), + TypeBound::Copyable.into(), + ) + .unwrap(); + + extension + }; } - #[cfg(test)] mod test { use super::*; #[test] fn test_float_types_extension() { - let r = extension(); + let r = &EXTENSION; assert_eq!(r.name() as &str, "arithmetic.float.types"); assert_eq!(r.types().count(), 1); assert_eq!(r.operations().count(), 0); diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 267fde902..ae5160ffd 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -41,7 +41,7 @@ impl ValidateJustArgs for IOValidator { Ok(()) } } -/// Logic extension operation definitions. +/// Integer extension operation definitions. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] #[allow(missing_docs, non_camel_case_types)] pub enum IntOpDef { diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index a78f4793a..ebec9bda7 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -177,7 +177,7 @@ mod test { let reg = ExtensionRegistry::try_new([ EXTENSION.to_owned(), PRELUDE.to_owned(), - float_types::extension(), + float_types::EXTENSION.to_owned(), ]) .unwrap(); let pop_sig = get_op(&POP_NAME) diff --git a/src/utils.rs b/src/utils.rs index 6c297b4bd..f62cee50f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -86,7 +86,7 @@ pub(crate) mod test_quantum_extension { lazy_static! { /// Quantum extension definition. pub static ref EXTENSION: Extension = extension(); - static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([EXTENSION.to_owned(), PRELUDE.to_owned(), float_types::extension()]).unwrap(); + static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([EXTENSION.to_owned(), PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]).unwrap(); } fn get_gate(gate_name: &str) -> LeafOp {