From aeb193067c84392f19189d68ad1eff803af1a572 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Agust=C3=ADn=20Borgna?=
<121866228+aborgna-q@users.noreply.github.com>
Date: Wed, 6 Nov 2024 16:58:48 +0000
Subject: [PATCH] feat: Support classical expressions (#86)
Closes #83
~~The json schema has not been updated yet (see
https://github.com/CQCL/tket/issues/1654), so the format may not be
completely correct.~~
Edit: json schema is now merged https://github.com/CQCL/tket/pull/1660
---------
Co-authored-by: Alec Edgington <54802828+cqc-alec@users.noreply.github.com>
---
Cargo.toml | 2 +
src/circuit_json.rs | 9 +
src/clexpr.rs | 87 ++++++++++
src/clexpr/op.rs | 73 ++++++++
src/clexpr/operator.rs | 78 +++++++++
src/lib.rs | 1 +
src/opbox.rs | 2 +
src/optype.rs | 11 ++
tests/data/qasm.json | 367 +++++++++++++++++++++++++++++++++++++++
tests/data/qasm.py | 32 ++++
tests/missing_optypes.rs | 60 +++++--
tests/roundtrip.rs | 28 ++-
12 files changed, 735 insertions(+), 15 deletions(-)
create mode 100644 src/clexpr.rs
create mode 100644 src/clexpr/op.rs
create mode 100644 src/clexpr/operator.rs
create mode 100644 tests/data/qasm.json
create mode 100644 tests/data/qasm.py
diff --git a/Cargo.toml b/Cargo.toml
index 0eea50d..e27bfce 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -25,6 +25,7 @@ pythonize = { workspace = true, optional = true }
strum = { workspace = true, features = ["derive"] }
[dev-dependencies]
+itertools = { workspace = true }
pyo3 = { workspace = true }
rstest = { workspace = true }
assert-json-diff = { workspace = true }
@@ -37,6 +38,7 @@ name = "integration"
path = "tests/lib.rs"
[workspace.dependencies]
+itertools = "0.13.0"
pyo3 = "0.22.2"
pythonize = "0.22.0"
rstest = "0.23.0"
diff --git a/src/circuit_json.rs b/src/circuit_json.rs
index d442abc..6f37412 100644
--- a/src/circuit_json.rs
+++ b/src/circuit_json.rs
@@ -1,6 +1,7 @@
//! Contains structs for serializing and deserializing TKET circuits to and from
//! JSON.
+use crate::clexpr::ClExpr;
use crate::opbox::OpBox;
use crate::optype::OpType;
use serde::{Deserialize, Serialize};
@@ -168,6 +169,12 @@ pub struct Operation
{
#[serde(rename = "box")]
#[serde(skip_serializing_if = "Option::is_none")]
pub op_box: Option,
+ /// Classical expression.
+ ///
+ /// Required if the operation is of type [`OpType::ClExpr`].
+ #[serde(skip_serializing_if = "Option::is_none")]
+ #[serde(rename = "expr")]
+ pub classical_expr: Option,
/// The pre-computed signature.
#[serde(skip_serializing_if = "Option::is_none")]
pub signature: Option>,
@@ -240,6 +247,7 @@ impl Default for Operation
{
data: None,
params: None,
op_box: None,
+ classical_expr: None,
signature: None,
conditional: None,
classical: None,
@@ -289,6 +297,7 @@ impl
Operation
{
.params
.map(|params| params.into_iter().map(f).collect()),
op_box: self.op_box,
+ classical_expr: self.classical_expr,
signature: self.signature,
conditional: self.conditional,
classical: self.classical,
diff --git a/src/clexpr.rs b/src/clexpr.rs
new file mode 100644
index 0000000..0e70126
--- /dev/null
+++ b/src/clexpr.rs
@@ -0,0 +1,87 @@
+//! Classical expressions
+
+pub mod op;
+pub mod operator;
+
+use operator::ClOperator;
+use serde::de::SeqAccess;
+use serde::ser::SerializeSeq;
+use serde::{Deserialize, Serialize};
+
+/// Data encoding a classical expression.
+///
+/// A classical expression operates over multi-bit registers and/or individual bits,
+/// which are identified here by their individual bit positions.
+///
+/// This is included in a [`Operation`] when the operation is a [`OpType::ClExpr`].
+///
+/// [`Operation`]: crate::circuit_json::Operation
+/// [`OpType::ClExpr`]: crate::optype::OpType::ClExpr
+#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)]
+#[non_exhaustive]
+pub struct ClExpr {
+ /// Mapping between bit variables in the expression and the position of the
+ /// corresponding bit in the `args` list.
+ pub bit_posn: Vec<(u32, u32)>,
+ /// The encoded expression.
+ pub expr: ClOperator,
+ /// The input bits of the expression.
+ pub reg_posn: Vec,
+ /// The output bits of the expression.
+ pub output_posn: ClRegisterBits,
+}
+
+/// An input register for a classical expression.
+///
+/// Contains the input index as well as the bits that are part of the register.
+///
+/// Serialized as a list with two elements: the index and the bits.
+#[derive(Debug, Default, PartialEq, Clone)]
+pub struct InputClRegister {
+ /// The index of the register variable in the expression.
+ pub index: u32,
+ /// The sequence of positions of bits comprising the register variable.
+ pub bits: ClRegisterBits,
+}
+
+/// The sequence of positions of bits in the output.
+///
+/// Registers are little-endian, so the first bit is the least significant.
+#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)]
+#[serde(transparent)]
+pub struct ClRegisterBits(pub Vec);
+
+impl Serialize for InputClRegister {
+ fn serialize(&self, serializer: S) -> Result {
+ let mut seq = serializer.serialize_seq(Some(2))?;
+ seq.serialize_element(&self.index)?;
+ seq.serialize_element(&self.bits)?;
+ seq.end()
+ }
+}
+
+impl<'de> Deserialize<'de> for InputClRegister {
+ fn deserialize>(deserializer: D) -> Result {
+ struct Visitor;
+
+ impl<'de_vis> serde::de::Visitor<'de_vis> for Visitor {
+ type Value = InputClRegister;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
+ formatter.write_str("a list of two elements: the index and the bits")
+ }
+
+ fn visit_seq>(self, mut seq: A) -> Result {
+ let index = seq
+ .next_element::()?
+ .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
+ let bits = seq
+ .next_element::()?
+ .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
+ Ok(InputClRegister { index, bits })
+ }
+ }
+
+ deserializer.deserialize_seq(Visitor)
+ }
+}
diff --git a/src/clexpr/op.rs b/src/clexpr/op.rs
new file mode 100644
index 0000000..1843810
--- /dev/null
+++ b/src/clexpr/op.rs
@@ -0,0 +1,73 @@
+//! Classical expression operations.
+
+use serde::{Deserialize, Serialize};
+use strum::EnumString;
+
+/// List of supported classical expressions.
+///
+/// Corresponds to `pytket.circuit.ClOp`.
+#[derive(Deserialize, Serialize, Clone, Debug, Default, PartialEq, Eq, Hash, EnumString)]
+#[non_exhaustive]
+pub enum ClOp {
+ /// Invalid operation
+ #[default]
+ INVALID,
+
+ /// Bitwise AND
+ BitAnd,
+ /// Bitwise OR
+ BitOr,
+ /// Bitwise XOR
+ BitXor,
+ /// Bitwise equality
+ BitEq,
+ /// Bitwise inequality
+ BitNeq,
+ /// Bitwise NOT
+ BitNot,
+ /// Constant zero bit
+ BitZero,
+ /// Constant one bit
+ BitOne,
+
+ /// Registerwise AND
+ RegAnd,
+ /// Registerwise OR
+ RegOr,
+ /// Registerwise XOR
+ RegXor,
+ /// Registerwise equality
+ RegEq,
+ /// Registerwise inequality
+ RegNeq,
+ /// Registerwise NOT
+ RegNot,
+ /// Constant all-zeros register
+ RegZero,
+ /// Constant all-ones register
+ RegOne,
+ /// Integer less-than comparison
+ RegLt,
+ /// Integer greater-than comparison
+ RegGt,
+ /// Integer less-than-or-equal comparison
+ RegLeq,
+ /// Integer greater-than-or-equal comparison
+ RegGeq,
+ /// Integer addition
+ RegAdd,
+ /// Integer subtraction
+ RegSub,
+ /// Integer multiplication
+ RegMul,
+ /// Integer division
+ RegDiv,
+ /// Integer exponentiation
+ RegPow,
+ /// Left shift
+ RegLsh,
+ /// Right shift
+ RegRsh,
+ /// Integer negation
+ RegNeg,
+}
diff --git a/src/clexpr/operator.rs b/src/clexpr/operator.rs
new file mode 100644
index 0000000..9d7cc52
--- /dev/null
+++ b/src/clexpr/operator.rs
@@ -0,0 +1,78 @@
+//! A tree of operators forming a classical expression.
+
+use serde::{Deserialize, Serialize};
+
+use super::op::ClOp;
+
+/// A node in a classical expression tree.
+#[derive(Debug, Default, PartialEq, Clone, Serialize, Deserialize)]
+#[non_exhaustive]
+pub struct ClOperator {
+ /// The operation to be performed.
+ pub op: ClOp,
+ /// The arguments to the operation.
+ pub args: Vec,
+}
+
+/// An argument to a classical expression operation.
+#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
+#[non_exhaustive]
+#[serde(tag = "type", content = "input")]
+pub enum ClArgument {
+ /// A terminal argument.
+ #[serde(rename = "term")]
+ Terminal(ClTerminal),
+ /// A sub-expression.
+ #[serde(rename = "expr")]
+ Expression(Box),
+}
+
+/// A terminal argument in a classical expression operation.
+#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
+#[non_exhaustive]
+#[serde(tag = "type", content = "term")]
+pub enum ClTerminal {
+ /// A terminal argument.
+ #[serde(rename = "var")]
+ Variable(ClVariable),
+ /// A constant integer.
+ #[serde(rename = "int")]
+ Int(u64),
+}
+
+/// A variable terminal argument in a classical expression operation.
+#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, Hash)]
+#[non_exhaustive]
+#[serde(tag = "type", content = "var")]
+pub enum ClVariable {
+ /// A register variable.
+ #[serde(rename = "reg")]
+ Register {
+ /// The register index.
+ index: u32,
+ },
+ /// A constant bit.
+ #[serde(rename = "bit")]
+ Bit {
+ /// The bit index.
+ index: u32,
+ },
+}
+
+impl Default for ClArgument {
+ fn default() -> Self {
+ ClArgument::Terminal(ClTerminal::default())
+ }
+}
+
+impl Default for ClTerminal {
+ fn default() -> Self {
+ ClTerminal::Int(0)
+ }
+}
+
+impl Default for ClVariable {
+ fn default() -> Self {
+ ClVariable::Register { index: 0 }
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
index 5410829..3238f1f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -3,6 +3,7 @@
//! [TKET](https://github.com/CQCL/tket) quantum compiler.
pub mod circuit_json;
+pub mod clexpr;
pub mod opbox;
pub mod optype;
#[cfg(feature = "pyo3")]
diff --git a/src/opbox.rs b/src/opbox.rs
index 794e5a2..9ad7522 100644
--- a/src/opbox.rs
+++ b/src/opbox.rs
@@ -147,6 +147,8 @@ pub enum OpBox {
control_state: u32,
},
/// Holding box for abstract expressions on Bits.
+ ///
+ /// Deprecated in favour of [`OpType::ClExpr`].
ClassicalExpBox {
id: BoxID,
n_i: u32,
diff --git a/src/optype.rs b/src/optype.rs
index 851c98c..ed5577e 100644
--- a/src/optype.rs
+++ b/src/optype.rs
@@ -500,6 +500,8 @@ pub enum OpType {
/// See [`ClassicalExpBox`]
///
+ /// Deprecated. Use [`OpType::ClExpr`] instead.
+ ///
/// [`ClassicalExpBox`]: crate::opbox::OpBox::ClassicalExpBox
ClassicalExpBox,
@@ -547,4 +549,13 @@ pub enum OpType {
///
/// [`DiagonalBox`]: crate::opbox::OpBox::DiagonalBox
DiagonalBox,
+
+ /// Classical expression.
+ ///
+ /// An operation of this type is accompanied by a [`ClExpr`] object.
+ ///
+ /// This is a replacement of the deprecated [`OpType::ClassicalExpBox`].
+ ///
+ /// [`ClExpr`]: crate::clexpr::ClExpr
+ ClExpr,
}
diff --git a/tests/data/qasm.json b/tests/data/qasm.json
new file mode 100644
index 0000000..3ea4f58
--- /dev/null
+++ b/tests/data/qasm.json
@@ -0,0 +1,367 @@
+{
+ "bits": [
+ [
+ "a",
+ [
+ 0
+ ]
+ ],
+ [
+ "a",
+ [
+ 1
+ ]
+ ],
+ [
+ "a",
+ [
+ 2
+ ]
+ ],
+ [
+ "b",
+ [
+ 0
+ ]
+ ],
+ [
+ "b",
+ [
+ 1
+ ]
+ ],
+ [
+ "b",
+ [
+ 2
+ ]
+ ],
+ [
+ "c",
+ [
+ 0
+ ]
+ ],
+ [
+ "c",
+ [
+ 1
+ ]
+ ],
+ [
+ "c",
+ [
+ 2
+ ]
+ ],
+ [
+ "d",
+ [
+ 0
+ ]
+ ],
+ [
+ "d",
+ [
+ 1
+ ]
+ ],
+ [
+ "d",
+ [
+ 2
+ ]
+ ]
+ ],
+ "commands": [
+ {
+ "args": [
+ [
+ "a",
+ [
+ 0
+ ]
+ ],
+ [
+ "a",
+ [
+ 1
+ ]
+ ],
+ [
+ "a",
+ [
+ 2
+ ]
+ ],
+ [
+ "b",
+ [
+ 0
+ ]
+ ],
+ [
+ "b",
+ [
+ 1
+ ]
+ ],
+ [
+ "b",
+ [
+ 2
+ ]
+ ],
+ [
+ "c",
+ [
+ 0
+ ]
+ ],
+ [
+ "c",
+ [
+ 1
+ ]
+ ],
+ [
+ "c",
+ [
+ 2
+ ]
+ ],
+ [
+ "d",
+ [
+ 0
+ ]
+ ],
+ [
+ "d",
+ [
+ 1
+ ]
+ ],
+ [
+ "d",
+ [
+ 2
+ ]
+ ]
+ ],
+ "op": {
+ "expr": {
+ "bit_posn": [],
+ "expr": {
+ "args": [
+ {
+ "input": {
+ "args": [
+ {
+ "input": {
+ "args": [
+ {
+ "input": {
+ "term": {
+ "type": "reg",
+ "var": {
+ "index": 0
+ }
+ },
+ "type": "var"
+ },
+ "type": "term"
+ },
+ {
+ "input": {
+ "term": {
+ "type": "reg",
+ "var": {
+ "index": 1
+ }
+ },
+ "type": "var"
+ },
+ "type": "term"
+ }
+ ],
+ "op": "RegAdd"
+ },
+ "type": "expr"
+ },
+ {
+ "input": {
+ "term": 2,
+ "type": "int"
+ },
+ "type": "term"
+ }
+ ],
+ "op": "RegDiv"
+ },
+ "type": "expr"
+ },
+ {
+ "input": {
+ "term": {
+ "type": "reg",
+ "var": {
+ "index": 2
+ }
+ },
+ "type": "var"
+ },
+ "type": "term"
+ }
+ ],
+ "op": "RegSub"
+ },
+ "output_posn": [
+ 9,
+ 10,
+ 11
+ ],
+ "reg_posn": [
+ [
+ 0,
+ [
+ 0,
+ 1,
+ 2
+ ]
+ ],
+ [
+ 1,
+ [
+ 3,
+ 4,
+ 5
+ ]
+ ],
+ [
+ 2,
+ [
+ 6,
+ 7,
+ 8
+ ]
+ ]
+ ]
+ },
+ "type": "ClExpr"
+ }
+ },
+ {
+ "args": [
+ [
+ "q",
+ [
+ 0
+ ]
+ ]
+ ],
+ "op": {
+ "type": "H"
+ }
+ },
+ {
+ "args": [
+ [
+ "q",
+ [
+ 2
+ ]
+ ]
+ ],
+ "op": {
+ "type": "Z"
+ }
+ },
+ {
+ "args": [
+ [
+ "q",
+ [
+ 2
+ ]
+ ],
+ [
+ "q",
+ [
+ 1
+ ]
+ ]
+ ],
+ "op": {
+ "type": "CX"
+ }
+ }
+ ],
+ "created_qubits": [],
+ "discarded_qubits": [],
+ "implicit_permutation": [
+ [
+ [
+ "q",
+ [
+ 0
+ ]
+ ],
+ [
+ "q",
+ [
+ 0
+ ]
+ ]
+ ],
+ [
+ [
+ "q",
+ [
+ 1
+ ]
+ ],
+ [
+ "q",
+ [
+ 1
+ ]
+ ]
+ ],
+ [
+ [
+ "q",
+ [
+ 2
+ ]
+ ],
+ [
+ "q",
+ [
+ 2
+ ]
+ ]
+ ]
+ ],
+ "phase": "0.0",
+ "qubits": [
+ [
+ "q",
+ [
+ 0
+ ]
+ ],
+ [
+ "q",
+ [
+ 1
+ ]
+ ],
+ [
+ "q",
+ [
+ 2
+ ]
+ ]
+ ]
+}
diff --git a/tests/data/qasm.py b/tests/data/qasm.py
new file mode 100644
index 0000000..205a022
--- /dev/null
+++ b/tests/data/qasm.py
@@ -0,0 +1,32 @@
+# /// script
+# requires-python = ">=3.13"
+# dependencies = [
+# "pytket>=1.34",
+# ]
+# ///
+
+import json
+
+from pytket import Circuit
+from pytket.qasm import circuit_from_qasm_str
+
+
+def qasm_circuit() -> Circuit:
+ qasm = """OPENQASM 2.0;
+ include "hqslib1.inc";
+ qreg q[3];
+ creg a[3];
+ creg b[3];
+ creg c[3];
+ creg d[3];
+ d = (((a + b) / 2) - c);
+
+ h q[0];
+ z q[2];
+ cx q[2], q[1];
+ """
+ return circuit_from_qasm_str(qasm, use_clexpr=True)
+
+
+if __name__ == "__main__":
+ print(json.dumps(qasm_circuit().to_dict(), indent=2))
diff --git a/tests/missing_optypes.rs b/tests/missing_optypes.rs
index 80a274d..1da1242 100644
--- a/tests/missing_optypes.rs
+++ b/tests/missing_optypes.rs
@@ -4,10 +4,29 @@
use std::str::FromStr;
+use itertools::Itertools;
use pyo3::prelude::*;
use pyo3::types::PyDict;
+use tket_json_rs::clexpr::op::ClOp;
use tket_json_rs::OpType;
+/// Given a python enum, lists the enum variants that cannot be converted into a `T` using `FromStr`.
+fn find_missing_variants<'py, T>(py_enum: &Bound<'py, PyAny>) -> impl Iterator- + 'py
+where
+ T: FromStr,
+{
+ let py_members = py_enum.getattr("__members__").unwrap();
+ let py_members = py_members.downcast::().unwrap();
+
+ py_members.into_iter().filter_map(|(name, _class)| {
+ let name = name.extract::().unwrap();
+ match T::from_str(&name) {
+ Err(_) => Some(name),
+ Ok(_) => None,
+ }
+ })
+}
+
#[test]
#[ignore = "Requires a python environment with `pytket` installed."]
fn missing_optypes() -> PyResult<()> {
@@ -19,19 +38,7 @@ fn missing_optypes() -> PyResult<()> {
panic!("Failed to import `pytket`. Make sure the python library is installed.");
};
let py_enum = pytket.getattr("OpType")?;
- let py_members = py_enum.getattr("__members__")?;
- let py_members = py_members.downcast::()?;
-
- let missing: Vec = py_members
- .into_iter()
- .filter_map(|(name, _class)| {
- let name = name.extract::().unwrap();
- match OpType::from_str(&name) {
- Err(_) => Some(name),
- Ok(_) => None,
- }
- })
- .collect();
+ let missing = find_missing_variants::(&py_enum).collect_vec();
if !missing.is_empty() {
let msg = "\nMissing optypes in `tket_json_rs`:\n".to_string();
@@ -46,3 +53,30 @@ fn missing_optypes() -> PyResult<()> {
Ok(())
})
}
+
+#[test]
+#[ignore = "Requires a python environment with `pytket` installed."]
+fn missing_classical_optypes() -> PyResult<()> {
+ println!("Checking missing classical ops");
+
+ pyo3::prepare_freethreaded_python();
+ Python::with_gil(|py| {
+ let Ok(pytket) = PyModule::import_bound(py, "pytket") else {
+ panic!("Failed to import `pytket`. Make sure the python library is installed.");
+ };
+ let py_enum = pytket.getattr("circuit")?.getattr("ClOp")?;
+ let missing = find_missing_variants::(&py_enum).collect_vec();
+
+ if !missing.is_empty() {
+ let msg = "\nMissing classical ops in `tket_json_rs`:\n".to_string();
+ let msg = missing
+ .into_iter()
+ .fold(msg, |msg, s| msg + " - " + &s + "\n");
+ let msg =
+ msg + "Please add them to the `ClOp` enum in `tket_json_rs/src/clexpr/op.rs`.\n";
+ panic!("{msg}");
+ }
+
+ Ok(())
+ })
+}
diff --git a/tests/roundtrip.rs b/tests/roundtrip.rs
index 41aa4ea..1ca5713 100644
--- a/tests/roundtrip.rs
+++ b/tests/roundtrip.rs
@@ -7,20 +7,44 @@ use tket_json_rs::SerialCircuit;
const SIMPLE: &str = include_str!("data/simple.json");
const CLASSICAL: &str = include_str!("data/classical.json");
const DIAGONAL: &str = include_str!("data/diagonal-box.json");
+const QASM: &str = include_str!("data/qasm.json");
const WASM: &str = include_str!("data/wasm.json");
+/// Cleanup some fields in the JSON so that we can compare them.
+fn normalize_json(json: &mut Value) {
+ if let Value::Object(obj) = json {
+ // Some versions of pytket include the `created_qubits` and `discarded_qubits` fields
+ // even if they are empty. Some other versions do not include them at all.
+ //
+ // We remove them here.
+ if let Some(Value::Array(registers)) = obj.get_mut("created_qubits") {
+ if registers.is_empty() {
+ obj.remove("created_qubits");
+ }
+ }
+ if let Some(Value::Array(registers)) = obj.get_mut("discarded_qubits") {
+ if registers.is_empty() {
+ obj.remove("discarded_qubits");
+ }
+ }
+ }
+}
+
#[rstest]
#[case::simple(SIMPLE, 4)]
#[case::classical(CLASSICAL, 3)]
#[case::diagonal_box(DIAGONAL, 1)]
+#[case::qasm_box(QASM, 4)]
#[case::wasm_box(WASM, 1)]
fn roundtrip(#[case] json: &str, #[case] num_commands: usize) {
- let initial_json: Value = serde_json::from_str(json).unwrap();
+ let mut initial_json: Value = serde_json::from_str(json).unwrap();
+ normalize_json(&mut initial_json);
let ser: SerialCircuit = serde_json::from_value(initial_json.clone()).unwrap();
assert_eq!(ser.commands.len(), num_commands);
- let reencoded_json = serde_json::to_value(&ser).unwrap();
+ let mut reencoded_json = serde_json::to_value(&ser).unwrap();
+ normalize_json(&mut reencoded_json);
assert_json_eq!(reencoded_json, initial_json);
let reser: SerialCircuit = serde_json::from_value(reencoded_json).unwrap();