From 31b3193ae3666fc688d84fc2f7d2eddf7d9f6fef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 9 Jul 2024 14:16:31 +0100 Subject: [PATCH] feat!: Classical op params (#56) Closes #53 drive-by: Add missing `data` field to Operation ([schema](https://github.com/CQCL/tket/blob/a47578ec5d79bb5caf23ef2edee3f587bc3c7d14/schemas/circuit_v1.json#L283-L286)). drive-by: Remove `created_qubits`/`discarded_qubits` from the test files. Those fields are not in the schema. drive-by: Mark more structs as `non_exhaustive`, so in the future a change like this does not have to be breaking. BREAKING CHANGE: Added `data` and `classical` fields to `Operation`. Marked some structs/enums as non_exhaustive. --- src/circuit_json.rs | 86 +++++++++++++++++++-- src/optype.rs | 3 +- tests/data/classical.json | 144 +++++++++++++++++++++++++++++++++++ tests/data/diagonal-box.json | 2 - tests/data/simple.json | 3 +- tests/roundtrip.rs | 7 +- 6 files changed, 233 insertions(+), 12 deletions(-) create mode 100644 tests/data/classical.json diff --git a/src/circuit_json.rs b/src/circuit_json.rs index 60d44a7..35b48fb 100644 --- a/src/circuit_json.rs +++ b/src/circuit_json.rs @@ -48,6 +48,7 @@ pub struct Matrix { /// The units used in a [`ClassicalExp`]. #[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq, Hash)] #[serde(untagged)] +#[non_exhaustive] pub enum ClassicalExpUnit { /// Unsigned 32-bit integer. U32(u32), @@ -79,8 +80,66 @@ pub struct Conditional { pub value: u32, } +/// Additional fields for classical operations, +/// which only act on Bits classically. +// +// Note: The order of the variants here is important. +// Serde will return the first matching variant when deserializing, +// so CopyBits and SetBits must come after other variants that +// define `values` and `n_i`. +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] +#[serde(untagged)] +#[non_exhaustive] +pub enum Classical { + /// Multi-bit operation. + MultiBit { + /// The inner operation. + op: Box, + /// Multiplier on underlying op for MultiBitOp. + n: u32, + }, + /// A range predicate. + RangePredicate { + /// Number of pure input wires to the RangePredicate. + n_i: u32, + /// The inclusive minimum of the RangePredicate. + lower: u64, + /// The inclusive maximum of the RangePredicate. + upper: u64, + }, + /// ExplicitModifierOp/ExplicitPredicateOp. + Explicit { + /// Number of pure input wires to the ExplicitModifierOp/ExplicitPredicateOp. + n_i: u32, + /// Name of classical ExplicitModifierOp/ExplicitPredicateOp (e.g. AND). + name: String, + /// Truth table of ExplicitModifierOp/ExplicitPredicateOp. + values: Vec, + }, + /// ClassicalTransformOp + ClassicalTransform { + /// Number of input/output wires. + n_io: u32, + /// Name of classical ClassicalTransformOp (e.g. ClassicalCX). + name: String, + /// Truth table of ClassicalTransformOp. + values: Vec, + }, + /// CopyBitsOp. + CopyBits { + /// Number of input wires to the CopyBitsOp. + n_i: u32, + }, + /// SetBitsOp. + SetBits { + /// List of bools that SetBitsOp sets bits to. + values: Vec, + }, +} + /// Serializable operation descriptor. #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] +#[non_exhaustive] pub struct Operation

{ /// The type of operation. #[serde(rename = "type")] @@ -88,6 +147,9 @@ pub struct Operation

{ /// Number of input and output qubits. #[serde(skip_serializing_if = "Option::is_none")] pub n_qb: Option, + /// Additional string stored in the op + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, /// Expressions for the parameters of the operation. #[serde(skip_serializing_if = "Option::is_none")] pub params: Option>, @@ -101,6 +163,9 @@ pub struct Operation

{ /// A QASM-style classical condition for the operation. #[serde(skip_serializing_if = "Option::is_none")] pub conditional: Option, + /// Data for commands which only act on Bits classically. + #[serde(skip_serializing_if = "Option::is_none")] + pub classical: Option>, } /// Operation applied in a circuit, with defined arguments. @@ -137,6 +202,21 @@ pub struct SerialCircuit

{ pub implicit_permutation: Vec, } +impl

Default for Operation

{ + fn default() -> Self { + Self { + op_type: Default::default(), + n_qb: None, + data: None, + params: None, + op_box: None, + signature: None, + conditional: None, + classical: None, + } + } +} + impl

Operation

{ /// Returns a default-initialized Operation with the given type. /// @@ -145,11 +225,7 @@ impl

Operation

{ pub fn from_optype(op_type: OpType) -> Self { Self { op_type, - n_qb: None, - params: None, - op_box: None, - signature: None, - conditional: None, + ..Operation::default() } } } diff --git a/src/optype.rs b/src/optype.rs index 70cdaca..5b66da5 100644 --- a/src/optype.rs +++ b/src/optype.rs @@ -10,7 +10,7 @@ use strum::EnumString; /// Operation types in a quantum circuit. #[cfg_attr(feature = "pyo3", pyclass(name = "RsOpType"))] -#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq, Hash, EnumString)] +#[derive(Deserialize, Serialize, Clone, Debug, Default, PartialEq, Eq, Hash, EnumString)] #[non_exhaustive] pub enum OpType { /// Quantum input node of the circuit @@ -280,6 +280,7 @@ pub enum OpType { /// Identity #[allow(non_camel_case_types)] + #[default] noop, /// Measure a qubit, producing a classical output diff --git a/tests/data/classical.json b/tests/data/classical.json new file mode 100644 index 0000000..200ccfc --- /dev/null +++ b/tests/data/classical.json @@ -0,0 +1,144 @@ +{ + "bits": [ + [ + "c", + [ + 0 + ] + ], + [ + "c", + [ + 1 + ] + ] + ], + "commands": [ + { + "args": [ + [ + "c", + [ + 0 + ] + ], + [ + "c", + [ + 1 + ] + ] + ], + "op": { + "type": "MultiBit", + "classical": { + "op": { + "type": "SetBits", + "classical": { + "values": [ + true + ] + } + }, + "n": 2 + } + } + }, + { + "args": [ + [ + "c", + [ + 0 + ] + ], + [ + "c", + [ + 1 + ] + ] + ], + "op": { + "type": "ClassicalTransform", + "classical": { + "name": "ClassicalCX", + "n_io": 2, + "values": [ + 0, + 1, + 3, + 2 + ] + } + } + }, + { + "args": [ + [ + "c", + [ + 0 + ] + ], + [ + "c", + [ + 1 + ] + ] + ], + "op": { + "type": "CopyBits", + "classical": { + "n_i": 2 + } + } + } + ], + "implicit_permutation": [ + [ + [ + "q", + [ + 0 + ] + ], + [ + "q", + [ + 0 + ] + ] + ], + [ + [ + "q", + [ + 1 + ] + ], + [ + "q", + [ + 1 + ] + ] + ] + ], + "phase": "0.0", + "qubits": [ + [ + "q", + [ + 0 + ] + ], + [ + "q", + [ + 1 + ] + ] + ] +} \ No newline at end of file diff --git a/tests/data/diagonal-box.json b/tests/data/diagonal-box.json index b6e1c2b..f5f84ea 100644 --- a/tests/data/diagonal-box.json +++ b/tests/data/diagonal-box.json @@ -67,8 +67,6 @@ } } ], - "created_qubits": [], - "discarded_qubits": [], "implicit_permutation": [ [ [ diff --git a/tests/data/simple.json b/tests/data/simple.json index 855dbaa..6320c5d 100644 --- a/tests/data/simple.json +++ b/tests/data/simple.json @@ -24,7 +24,8 @@ ] ], "op": { - "type": "H" + "type": "H", + "data": "Custom data" } }, { diff --git a/tests/roundtrip.rs b/tests/roundtrip.rs index d7d85cf..e3c1466 100644 --- a/tests/roundtrip.rs +++ b/tests/roundtrip.rs @@ -1,14 +1,16 @@ //! Roundtrip tests -use assert_json_diff::assert_json_include; +use assert_json_diff::assert_json_eq; use rstest::rstest; use serde_json::Value; use tket_json_rs::SerialCircuit; const SIMPLE: &str = include_str!("data/simple.json"); +const CLASSICAL: &str = include_str!("data/classical.json"); const DIAGONAL: &str = include_str!("data/diagonal-box.json"); #[rstest] #[case::simple(SIMPLE, 4)] +#[case::classical(CLASSICAL, 3)] #[case::diagonal_box(DIAGONAL, 1)] fn roundtrip(#[case] json: &str, #[case] num_commands: usize) { let initial_json: Value = serde_json::from_str(json).unwrap(); @@ -17,8 +19,7 @@ fn roundtrip(#[case] json: &str, #[case] num_commands: usize) { assert_eq!(ser.commands.len(), num_commands); let reencoded_json = serde_json::to_value(&ser).unwrap(); - // Do a partial comparison. The re-encoded circuit does not include "created_qubits" nor "discarded_qubits". - assert_json_include!(actual: initial_json, expected: reencoded_json); + assert_json_eq!(reencoded_json, initial_json); let reser: SerialCircuit = serde_json::from_value(reencoded_json).unwrap();