-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Utilities for loading compiled guppy circuits (#393)
Adds a series of `load_guppy_json*` methods similar to `load_tk1_json`. Given a hugr json and a function name it returns a `Circuit` for the function. Currently this only supports guppy functions with no control flow. Adds a test adapted from #382, with the following hugr: ```mermaid graph LR subgraph 0 ["(0) Module"] direction LR subgraph 7 ["(7) FuncDefn"] direction LR 3["(3) Input"] 3--"0:0<br>qubit"-->8 3--"1:1<br>qubit"-->8 6["(6) Output"] subgraph 8 ["(8) CFG"] direction LR subgraph 1 ["(1) DataflowBlock"] direction LR 4["(4) Input"] 4--"0:0<br>qubit"-->13 4--"1:0<br>qubit"-->21 5["(5) Output"] 9["(9) const:custom:f64(1.5707963267948966)"] 9--"0:0<br>float64"-->10 10["(10) LoadConstant"] 10--"0:1<br>float64"-->13 11["(11) const:custom:f64(-1.5707963267948966)"] 11--"0:0<br>float64"-->12 12["(12) LoadConstant"] 12--"0:2<br>float64"-->13 13["(13) quantum.tket2.PhasedX"] 13--"0:0<br>qubit"-->16 14["(14) const:custom:f64(3.141592653589793)"] 14--"0:0<br>float64"-->15 15["(15) LoadConstant"] 15--"0:1<br>float64"-->16 16["(16) quantum.tket2.RzF64"] 16--"0:0<br>qubit"-->25 17["(17) const:custom:f64(1.5707963267948966)"] 17--"0:0<br>float64"-->18 18["(18) LoadConstant"] 18--"0:1<br>float64"-->21 19["(19) const:custom:f64(-1.5707963267948966)"] 19--"0:0<br>float64"-->20 20["(20) LoadConstant"] 20--"0:2<br>float64"-->21 21["(21) quantum.tket2.PhasedX"] 21--"0:0<br>qubit"-->24 22["(22) const:custom:f64(3.141592653589793)"] 22--"0:0<br>float64"-->23 23["(23) LoadConstant"] 23--"0:1<br>float64"-->24 24["(24) quantum.tket2.RzF64"] 24--"0:1<br>qubit"-->25 25["(25) quantum.tket2.ZZMax"] 25--"0:0<br>qubit"-->26 25--"1:1<br>qubit"-->26 26["(26) MakeTuple"] 26--"0:0<br>[qubit, qubit]"-->27 27["(27) UnpackTuple"] 27--"0:0<br>qubit"-->28 27--"1:0<br>qubit"-->30 28["(28) quantum.tket2.Measure"] 28--"0:0<br>qubit"-->29 29["(29) quantum.tket2.QFree"] 30["(30) quantum.tket2.Measure"] 30--"0:0<br>qubit"-->31 30--"1:0<br>[]+[]"-->32 31["(31) quantum.tket2.QFree"] 32["(32) MakeTuple"] 32--"0:0<br>[[]+[]]"-->33 33["(33) UnpackTuple"] 33--"0:1<br>[]+[]"-->5 34["(34) Tag"] 34--"0:0<br>[]"-->5 end 1-."0:0".->2 2["(2) ExitBlock"] end 8--"0:0<br>[]+[]"-->6 end end ``` drive-by: Drop deprecated `stringreader` dependency drive-by: Bind `Tk2Circuit.num_operations`, used in the python test
- Loading branch information
Showing
12 changed files
with
447 additions
and
92 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import no_type_check | ||
from tket2.circuit import Tk2Circuit | ||
|
||
import math | ||
|
||
from guppylang import guppy | ||
from guppylang.module import GuppyModule | ||
from guppylang.prelude import quantum | ||
from guppylang.prelude.builtins import py | ||
from guppylang.prelude.quantum import measure, phased_x, qubit, rz, zz_max | ||
|
||
|
||
def test_load_compiled_module(): | ||
module = GuppyModule("test") | ||
module.load(quantum) | ||
|
||
@guppy(module) | ||
@no_type_check | ||
def my_func( | ||
q0: qubit, | ||
q1: qubit, | ||
) -> tuple[bool,]: | ||
q0 = phased_x(q0, py(math.pi / 2), py(-math.pi / 2)) | ||
q0 = rz(q0, py(math.pi)) | ||
q1 = phased_x(q1, py(math.pi / 2), py(-math.pi / 2)) | ||
q1 = rz(q1, py(math.pi)) | ||
q0, q1 = zz_max(q0, q1) | ||
_ = measure(q0) | ||
return (measure(q1),) | ||
|
||
# Compile the module, and convert it to a JSON string | ||
hugr = module.compile() | ||
json = hugr.to_raw().to_json() | ||
|
||
# Load the module from the JSON string | ||
circ = Tk2Circuit.from_guppy_json(json, "my_func") | ||
|
||
# The 7 operations in the function, plus two implicit QFree | ||
assert circ.num_operations() == 9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
//! Load pre-compiled guppy functions. | ||
use std::path::Path; | ||
use std::{fs, io}; | ||
|
||
use hugr::ops::{NamedOp, OpTag, OpTrait, OpType}; | ||
use hugr::{Hugr, HugrView}; | ||
use itertools::Itertools; | ||
use thiserror::Error; | ||
|
||
use crate::{Circuit, CircuitError}; | ||
|
||
/// Loads a pre-compiled guppy file. | ||
pub fn load_guppy_json_file( | ||
path: impl AsRef<Path>, | ||
function: &str, | ||
) -> Result<Circuit, CircuitLoadError> { | ||
let file = fs::File::open(path)?; | ||
let reader = io::BufReader::new(file); | ||
load_guppy_json_reader(reader, function) | ||
} | ||
|
||
/// Loads a pre-compiled guppy file from a json string. | ||
pub fn load_guppy_json_str(json: &str, function: &str) -> Result<Circuit, CircuitLoadError> { | ||
let reader = json.as_bytes(); | ||
load_guppy_json_reader(reader, function) | ||
} | ||
|
||
/// Loads a pre-compiled guppy file from a reader. | ||
pub fn load_guppy_json_reader( | ||
reader: impl io::Read, | ||
function: &str, | ||
) -> Result<Circuit, CircuitLoadError> { | ||
let hugr: Hugr = serde_json::from_reader(reader)?; | ||
find_function(hugr, function) | ||
} | ||
|
||
/// Looks for the required function in a HUGR compiled from a guppy module. | ||
/// | ||
/// Guppy functions are compiled into a root module, with each function as a `FuncDecl` child. | ||
/// Each `FuncDecl` contains a `CFG` operation that defines the function. | ||
/// | ||
/// Currently we only support functions where the CFG operation has a single `DataflowBlock` child, | ||
/// which we use as the root of the circuit. We (currently) do not support control flow primitives. | ||
/// | ||
/// # Errors | ||
/// | ||
/// - If the root of the HUGR is not a module operation. | ||
/// - If the function is not found in the module. | ||
/// - If the function has control flow primitives. | ||
fn find_function(hugr: Hugr, function_name: &str) -> Result<Circuit, CircuitLoadError> { | ||
// Find the root module. | ||
let module = hugr.root(); | ||
if !OpTag::ModuleRoot.is_superset(hugr.get_optype(module).tag()) { | ||
return Err(CircuitLoadError::NonModuleRoot { | ||
root_op: hugr.get_optype(module).clone(), | ||
}); | ||
} | ||
|
||
// Find the function declaration. | ||
fn func_name(op: &OpType) -> &str { | ||
match op { | ||
OpType::FuncDefn(decl) => &decl.name, | ||
_ => "", | ||
} | ||
} | ||
|
||
let Some(function) = hugr | ||
.children(module) | ||
.find(|&n| func_name(hugr.get_optype(n)) == function_name) | ||
else { | ||
let available_functions = hugr | ||
.children(module) | ||
.map(|n| func_name(hugr.get_optype(n)).to_string()) | ||
.collect(); | ||
return Err(CircuitLoadError::FunctionNotFound { | ||
function: function_name.to_string(), | ||
available_functions, | ||
}); | ||
}; | ||
|
||
// Find the CFG operation. | ||
let invalid_cfg = CircuitLoadError::InvalidControlFlow { | ||
function: function_name.to_string(), | ||
}; | ||
let Ok(cfg) = hugr.children(function).skip(2).exactly_one() else { | ||
return Err(invalid_cfg); | ||
}; | ||
|
||
// Find the single dataflow block to use as the root of the circuit. | ||
// The cfg node should only have the dataflow block and an exit node as children. | ||
let mut cfg_children = hugr.children(cfg); | ||
let Some(dataflow) = cfg_children.next() else { | ||
return Err(invalid_cfg); | ||
}; | ||
if cfg_children.nth(1).is_some() { | ||
return Err(invalid_cfg); | ||
} | ||
|
||
let circ = Circuit::try_new(hugr, dataflow)?; | ||
Ok(circ) | ||
} | ||
|
||
/// Error type for conversion between `Op` and `OpType`. | ||
#[derive(Debug, Error)] | ||
pub enum CircuitLoadError { | ||
/// Cannot load the circuit file. | ||
#[error("Cannot load the circuit file: {0}")] | ||
InvalidFile(#[from] io::Error), | ||
/// Invalid JSON | ||
#[error("Invalid JSON. {0}")] | ||
InvalidJson(#[from] serde_json::Error), | ||
/// The root node is not a module operation. | ||
#[error( | ||
"Expected a HUGR with a module at the root, but found a {} instead.", | ||
root_op.name() | ||
)] | ||
NonModuleRoot { | ||
/// The root operation. | ||
root_op: OpType, | ||
}, | ||
/// The function is not found in the module. | ||
#[error( | ||
"Function '{function}' not found in the loaded module. Available functions: [{}]", | ||
available_functions.join(", ") | ||
)] | ||
FunctionNotFound { | ||
/// The function name. | ||
function: String, | ||
/// The available functions. | ||
available_functions: Vec<String>, | ||
}, | ||
/// The function has an invalid control flow structure. | ||
#[error("Function '{function}' has an invalid control flow structure. Currently only flat functions with no control flow primitives are supported.")] | ||
InvalidControlFlow { | ||
/// The function name. | ||
function: String, | ||
}, | ||
/// Error loading the circuit. | ||
#[error("Error loading the circuit: {0}")] | ||
CircuitLoadError(#[from] CircuitError), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters