diff --git a/src/json.rs b/src/json.rs index c08afba5..c27fdc9b 100644 --- a/src/json.rs +++ b/src/json.rs @@ -7,6 +7,7 @@ pub mod op; #[cfg(test)] mod tests; +use hugr::hugr::CircuitUnit; #[cfg(feature = "pyo3")] use pyo3::{create_exception, exceptions::PyException, PyErr}; @@ -14,7 +15,7 @@ use std::path::Path; use std::{fs, io}; use hugr::ops::OpType; -use hugr::std_extensions::arithmetic::float_types::ConstF64; +use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; use hugr::values::Value; use hugr::Hugr; @@ -72,6 +73,15 @@ impl TKETDecode for SerialCircuit { fn encode(circ: &impl Circuit) -> Result { let mut encoder = JsonEncoder::new(circ); + let f64_inputs = circ.units().filter_map(|(wire, _, t)| match (wire, t) { + (CircuitUnit::Wire(wire), t) if t == FLOAT64_TYPE => Some(wire), + (CircuitUnit::Linear(_), _) => None, + _ => unimplemented!("Non-float64 input wires not supported"), + }); + for (i, wire) in f64_inputs.enumerate() { + let param = format!("f{i}"); + encoder.add_parameter(wire, param); + } for com in circ.commands() { let optype = com.optype(); encoder.add_command(com.clone(), optype)?; diff --git a/src/json/encoder.rs b/src/json/encoder.rs index cd026005..e6a0d136 100644 --- a/src/json/encoder.rs +++ b/src/json/encoder.rs @@ -13,7 +13,8 @@ use tket_json_rs::circuit_json::{self, Permutation, Register, SerialCircuit}; use crate::circuit::command::{CircuitUnit, Command}; use crate::circuit::Circuit; use crate::extension::LINEAR_BIT; -use crate::ops::match_symb_const_op; +use crate::ops::{match_symb_const_op, op_matches}; +use crate::T2Op; use super::op::JsonOp; use super::{ @@ -207,6 +208,9 @@ impl JsonEncoder { // Re-use the parameter from the input. inputs[0].clone() } + op if op_matches(op, T2Op::AngleAdd) => { + format!("{} + {}", inputs[0], inputs[1]) + } OpType::LeafOp(_) => { if let Some(s) = match_symb_const_op(optype) { s.to_string() @@ -215,16 +219,13 @@ impl JsonEncoder { } } _ => { - // In the future we may want to support arithmetic operations. - // (Just concatenating the inputs and the operation symbol, no - // need for evaluation). return false; } }; for (unit, _, _) in command.outputs() { if let CircuitUnit::Wire(wire) = unit { - self.parameters.insert(wire, param.clone()); + self.add_parameter(wire, param.clone()); } } true @@ -237,4 +238,8 @@ impl JsonEncoder { .or_else(|| self.bit_to_reg.get(&unit)) .cloned() } + + pub(super) fn add_parameter(&mut self, wire: Wire, param: String) { + self.parameters.insert(wire, param); + } } diff --git a/src/json/op.rs b/src/json/op.rs index 5ffcf3e7..86e6832d 100644 --- a/src/json/op.rs +++ b/src/json/op.rs @@ -225,7 +225,9 @@ impl TryFrom<&OpType> for JsonOp { T2Op::RzF64 => JsonOpType::Rz, T2Op::RxF64 => JsonOpType::Rx, // TODO: Use a TK2 opaque op once we update the tket-json-rs dependency. - T2Op::AngleAdd => JsonOpType::AngleAdd, + T2Op::AngleAdd => { + unimplemented!("Serialising AngleAdd not supported. Are all constant folded?") + } T2Op::TK1 => JsonOpType::TK1, T2Op::PhasedX => JsonOpType::PhasedX, T2Op::ZZPhase => JsonOpType::ZZPhase, diff --git a/src/json/tests.rs b/src/json/tests.rs index 90fb3253..195ee444 100644 --- a/src/json/tests.rs +++ b/src/json/tests.rs @@ -2,12 +2,20 @@ use std::io::BufReader; +use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr}; +use hugr::extension::prelude::QB_T; +use hugr::extension::ExtensionSet; +use hugr::std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}; +use hugr::types::FunctionType; use hugr::Hugr; -use rstest::rstest; +use rstest::{fixture, rstest}; use tket_json_rs::circuit_json::{self, SerialCircuit}; +use tket_json_rs::optype; use crate::circuit::Circuit; +use crate::extension::REGISTRY; use crate::json::TKETDecode; +use crate::T2Op; const SIMPLE_JSON: &str = r#"{ "phase": "0", @@ -75,6 +83,68 @@ fn json_file_roundtrip(#[case] circ: impl AsRef) { compare_serial_circs(&ser, &reser); } +#[fixture] +fn circ_add_angles_symbolic() -> Hugr { + let input_t = vec![QB_T, FLOAT64_TYPE, FLOAT64_TYPE]; + let output_t = vec![QB_T]; + let mut h = DFGBuilder::new(FunctionType::new(input_t, output_t)).unwrap(); + + let mut inps = h.input_wires(); + let qb = inps.next().unwrap(); + let f1 = inps.next().unwrap(); + let f2 = inps.next().unwrap(); + + let res = h.add_dataflow_op(T2Op::AngleAdd, [f1, f2]).unwrap(); + let f12 = res.outputs().next().unwrap(); + let res = h.add_dataflow_op(T2Op::RxF64, [qb, f12]).unwrap(); + let qb = res.outputs().next().unwrap(); + + h.finish_hugr_with_outputs([qb], ®ISTRY).unwrap() +} + +#[fixture] +fn circ_add_angles_constants() -> Hugr { + let qb_row = vec![QB_T]; + let mut h = DFGBuilder::new(FunctionType::new(qb_row.clone(), qb_row)).unwrap(); + + let qb = h.input_wires().next().unwrap(); + let f64_ext = hugr::std_extensions::arithmetic::float_types::EXTENSION_ID; + + let point2 = h + .add_load_const(ConstF64::new(0.2).into(), ExtensionSet::singleton(&f64_ext)) + .unwrap(); + let point3 = h + .add_load_const(ConstF64::new(0.3).into(), ExtensionSet::singleton(&f64_ext)) + .unwrap(); + let point5 = h + .add_dataflow_op(T2Op::AngleAdd, [point2, point3]) + .unwrap() + .out_wire(0); + + let qbs = h + .add_dataflow_op(T2Op::RxF64, [qb, point5]) + .unwrap() + .outputs(); + h.finish_hugr_with_outputs(qbs, ®ISTRY).unwrap() +} + +#[rstest] +#[case::symbolic(circ_add_angles_symbolic(), "f0 + f1")] +#[case::constants(circ_add_angles_constants(), "0.2 + 0.3")] +fn test_add_angle_serialise(#[case] circ_add_angles: Hugr, #[case] param_str: &str) { + let ser: SerialCircuit = SerialCircuit::encode(&circ_add_angles).unwrap(); + assert_eq!(ser.commands.len(), 1); + assert_eq!(ser.commands[0].op.op_type, optype::OpType::Rx); + assert_eq!(ser.commands[0].op.params, Some(vec![param_str.into()])); + + // Note: this is not a proper roundtrip as the symbols f0 and f1 are not + // converted back to circuit inputs. This would require parsing symbolic + // expressions. + let deser: Hugr = ser.clone().decode().unwrap(); + let reser = SerialCircuit::encode(&deser).unwrap(); + compare_serial_circs(&ser, &reser); +} + fn compare_serial_circs(a: &SerialCircuit, b: &SerialCircuit) { assert_eq!(a.name, b.name); assert_eq!(a.phase, b.phase);