Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Have CustomTypes reference their Extension definition #1723

Merged
merged 8 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions hugr-cli/tests/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ use hugr::builder::{DFGBuilder, DataflowSubContainer, ModuleBuilder};
use hugr::types::Type;
use hugr::{
builder::{Container, Dataflow},
extension::prelude::{BOOL_T, QB_T},
std_extensions::arithmetic::float_types::FLOAT64_TYPE,
type_row,
extension::prelude::{bool_t, qb_t},
std_extensions::arithmetic::float_types::float64_type,
types::Signature,
Hugr,
};
Expand All @@ -41,7 +40,7 @@ const FLOAT_EXT_FILE: &str = concat!(

/// A test package, containing a module-rooted HUGR.
#[fixture]
fn test_package(#[default(BOOL_T)] id_type: Type) -> Package {
fn test_package(#[default(bool_t())] id_type: Type) -> Package {
let mut module = ModuleBuilder::new();
let df = module
.define_function("test", Signature::new_endo(id_type))
Expand All @@ -57,7 +56,7 @@ fn test_package(#[default(BOOL_T)] id_type: Type) -> Package {

/// A DFG-rooted HUGR.
#[fixture]
fn test_hugr(#[default(BOOL_T)] id_type: Type) -> Hugr {
fn test_hugr(#[default(bool_t())] id_type: Type) -> Hugr {
let mut df = DFGBuilder::new(Signature::new_endo(id_type)).unwrap();
let [i] = df.input_wires_arr();
df.set_outputs([i]).unwrap();
Expand Down Expand Up @@ -120,7 +119,7 @@ fn test_mermaid(test_hugr_file: NamedTempFile, mut cmd: Command) {

#[fixture]
fn bad_hugr_string() -> String {
let df = DFGBuilder::new(Signature::new_endo(type_row![QB_T])).unwrap();
let df = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap();
let bad_hugr = df.hugr().clone();

serde_json::to_string(&bad_hugr).unwrap()
Expand Down Expand Up @@ -178,7 +177,7 @@ fn test_no_std(test_hugr_string: String, mut val_cmd: Command) {
}

#[fixture]
fn float_hugr_string(#[with(FLOAT64_TYPE)] test_hugr: Hugr) -> String {
fn float_hugr_string(#[with(float64_type())] test_hugr: Hugr) -> String {
serde_json::to_string(&test_hugr).unwrap()
}

Expand All @@ -205,7 +204,7 @@ fn test_float_extension(float_hugr_string: String, mut val_cmd: Command) {
val_cmd.assert().success().stderr(contains(VALID_PRINT));
}
#[fixture]
fn package_string(#[with(FLOAT64_TYPE)] test_package: Package) -> String {
fn package_string(#[with(float64_type())] test_package: Package) -> String {
serde_json::to_string(&test_package).unwrap()
}

Expand Down
23 changes: 10 additions & 13 deletions hugr-core/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
//! ```rust
//! # use hugr::Hugr;
//! # use hugr::builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, ModuleBuilder, DataflowSubContainer, HugrBuilder};
//! use hugr::extension::prelude::BOOL_T;
//! use hugr::extension::prelude::bool_t;
//! use hugr::std_extensions::logic::{EXTENSION_ID, LOGIC_REG, LogicOp};
//! use hugr::types::Signature;
//!
Expand All @@ -42,7 +42,7 @@
//! let _dfg_handle = {
//! let mut dfg = module_builder.define_function(
//! "main",
//! Signature::new_endo(BOOL_T).with_extension_delta(EXTENSION_ID),
//! Signature::new_endo(bool_t()).with_extension_delta(EXTENSION_ID),
//! )?;
//!
//! // Get the wires from the function inputs.
Expand All @@ -59,7 +59,7 @@
//! let _circuit_handle = {
//! let mut dfg = module_builder.define_function(
//! "circuit",
//! Signature::new_endo(vec![BOOL_T, BOOL_T])
//! Signature::new_endo(vec![bool_t(), bool_t()])
//! .with_extension_delta(EXTENSION_ID),
//! )?;
//! let mut circuit = dfg.as_circuit(dfg.input_wires());
Expand Down Expand Up @@ -238,11 +238,12 @@ pub enum BuilderWiringError {
pub(crate) mod test {
use rstest::fixture;

use crate::extension::prelude::{bool_t, usize_t};
use crate::hugr::{views::HugrView, HugrMut};
use crate::ops;
use crate::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY;
use crate::types::{PolyFuncType, Signature, Type};
use crate::{type_row, Hugr};
use crate::types::{PolyFuncType, Signature};
use crate::Hugr;

use super::handle::BuildHandle;
use super::{
Expand All @@ -251,10 +252,6 @@ pub(crate) mod test {
};
use super::{DataflowSubContainer, HugrBuilder};

pub(super) const NAT: Type = crate::extension::prelude::USIZE_T;
pub(super) const BIT: Type = crate::extension::prelude::BOOL_T;
pub(super) const QB: Type = crate::extension::prelude::QB_T;

/// Wire up inputs of a Dataflow container to the outputs.
pub(crate) fn n_identity<T: DataflowSubContainer>(
dataflow_builder: T,
Expand All @@ -277,31 +274,31 @@ pub(crate) mod test {

#[fixture]
pub(crate) fn simple_dfg_hugr() -> Hugr {
let dfg_builder = DFGBuilder::new(Signature::new(type_row![BIT], type_row![BIT])).unwrap();
let dfg_builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()])).unwrap();
let [i1] = dfg_builder.input_wires_arr();
dfg_builder.finish_prelude_hugr_with_outputs([i1]).unwrap()
}

#[fixture]
pub(crate) fn simple_funcdef_hugr() -> Hugr {
let fn_builder =
FunctionBuilder::new("test", Signature::new(type_row![BIT], type_row![BIT])).unwrap();
FunctionBuilder::new("test", Signature::new(vec![bool_t()], vec![bool_t()])).unwrap();
let [i1] = fn_builder.input_wires_arr();
fn_builder.finish_prelude_hugr_with_outputs([i1]).unwrap()
}

#[fixture]
pub(crate) fn simple_module_hugr() -> Hugr {
let mut builder = ModuleBuilder::new();
let sig = Signature::new(type_row![BIT], type_row![BIT]);
let sig = Signature::new(vec![bool_t()], vec![bool_t()]);
builder.declare("test", sig.into()).unwrap();
builder.finish_prelude_hugr().unwrap()
}

#[fixture]
pub(crate) fn simple_cfg_hugr() -> Hugr {
let mut cfg_builder =
CFGBuilder::new(Signature::new(type_row![NAT], type_row![NAT])).unwrap();
CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap();
super::cfg::test::build_basic_cfg(&mut cfg_builder).unwrap();
cfg_builder.finish_prelude_hugr().unwrap()
}
Expand Down
46 changes: 24 additions & 22 deletions hugr-core/src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,20 @@ use crate::{hugr::HugrMut, type_row, Hugr};
/// ops, type_row,
/// types::{Signature, SumType, Type},
/// Hugr,
/// extension::prelude::usize_t,
/// };
///
/// const NAT: Type = prelude::USIZE_T;
///
/// fn make_cfg() -> Result<Hugr, BuildError> {
/// let mut cfg_builder = CFGBuilder::new(Signature::new_endo(NAT))?;
/// let mut cfg_builder = CFGBuilder::new(Signature::new_endo(usize_t()))?;
///
/// // Outputs from basic blocks must be packed in a sum which corresponds to
/// // which successor to pick. We'll either choose the first branch and pass
/// // it a NAT, or the second branch and pass it nothing.
/// let sum_variants = vec![type_row![NAT], type_row![]];
/// // it a usize, or the second branch and pass it nothing.
/// let sum_variants = vec![vec![usize_t()].into(), type_row![]];
///
/// // The second argument says what types will be passed through to every
/// // successor, in addition to the appropriate `sum_variants` type.
/// let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?;
/// let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), vec![usize_t()].into())?;
///
/// let [inw] = entry_b.input_wires_arr();
/// let entry = {
Expand All @@ -81,10 +80,10 @@ use crate::{hugr::HugrMut, type_row, Hugr};
/// };
///
/// // This block will be the first successor of the entry node. It takes two
/// // `NAT` arguments: one from the `sum_variants` type, and another from the
/// // `usize` arguments: one from the `sum_variants` type, and another from the
/// // entry node's `other_outputs`.
/// let mut successor_builder = cfg_builder.simple_block_builder(
/// inout_sig(type_row![NAT, NAT], NAT),
/// inout_sig(vec![usize_t(), usize_t()], usize_t()),
/// 1, // only one successor to this block
/// )?;
/// let successor_a = {
Expand All @@ -98,7 +97,7 @@ use crate::{hugr::HugrMut, type_row, Hugr};
/// };
///
/// // The only argument to this block is the entry node's `other_outputs`.
/// let mut successor_builder = cfg_builder.simple_block_builder(endo_sig(NAT), 1)?;
/// let mut successor_builder = cfg_builder.simple_block_builder(endo_sig(usize_t()), 1)?;
/// let successor_b = {
/// let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum());
/// let [in_wire] = successor_builder.input_wires_arr();
Expand Down Expand Up @@ -464,9 +463,10 @@ impl BlockBuilder<Hugr> {
pub(crate) mod test {
use crate::builder::{DataflowSubContainer, ModuleBuilder};

use crate::extension::prelude::usize_t;
use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::ValidationError;
use crate::{builder::test::NAT, type_row};
use crate::type_row;
use cool_asserts::assert_matches;

use super::*;
Expand All @@ -475,13 +475,13 @@ pub(crate) mod test {
let build_result = {
let mut module_builder = ModuleBuilder::new();
let mut func_builder = module_builder
.define_function("main", Signature::new(vec![NAT], type_row![NAT]))?;
.define_function("main", Signature::new(vec![usize_t()], vec![usize_t()]))?;
let _f_id = {
let [int] = func_builder.input_wires_arr();

let cfg_id = {
let mut cfg_builder =
func_builder.cfg_builder(vec![(NAT, int)], type_row![NAT])?;
func_builder.cfg_builder(vec![(usize_t(), int)], vec![usize_t()].into())?;
build_basic_cfg(&mut cfg_builder)?;

cfg_builder.finish_sub_container()?
Expand All @@ -498,7 +498,7 @@ pub(crate) mod test {
}
#[test]
fn basic_cfg_hugr() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(Signature::new(type_row![NAT], type_row![NAT]))?;
let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
build_basic_cfg(&mut cfg_builder)?;
assert_matches!(cfg_builder.finish_prelude_hugr(), Ok(_));

Expand All @@ -508,7 +508,8 @@ pub(crate) mod test {
pub(crate) fn build_basic_cfg<T: AsMut<Hugr> + AsRef<Hugr>>(
cfg_builder: &mut CFGBuilder<T>,
) -> Result<(), BuildError> {
let sum2_variants = vec![type_row![NAT], type_row![NAT]];
let usize_row: TypeRow = vec![usize_t()].into();
let sum2_variants = vec![usize_row.clone(), usize_row];
let mut entry_b = cfg_builder.entry_builder_exts(
sum2_variants.clone(),
type_row![],
Expand All @@ -520,8 +521,8 @@ pub(crate) mod test {
let sum = entry_b.make_sum(1, sum2_variants, [inw])?;
entry_b.finish_with_outputs(sum, [])?
};
let mut middle_b =
cfg_builder.simple_block_builder(Signature::new(type_row![NAT], type_row![NAT]), 1)?;
let mut middle_b = cfg_builder
.simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?;
let middle = {
let c = middle_b.add_load_const(ops::Value::unary_unit_sum());
let [inw] = middle_b.input_wires_arr();
Expand All @@ -535,7 +536,7 @@ pub(crate) mod test {
}
#[test]
fn test_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(Signature::new(type_row![NAT], type_row![NAT]))?;
let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
let sum_variants = vec![type_row![]];

Expand All @@ -551,7 +552,7 @@ pub(crate) mod test {
entry_b.finish_with_outputs(sum, [])?
};
let mut middle_b =
cfg_builder.simple_block_builder(Signature::new(type_row![], type_row![NAT]), 1)?;
cfg_builder.simple_block_builder(Signature::new(type_row![], vec![usize_t()]), 1)?;
let middle = {
let c = middle_b.load_const(&sum_tuple_const);
middle_b.finish_with_outputs(c, [inw])?
Expand All @@ -566,18 +567,19 @@ pub(crate) mod test {

#[test]
fn test_non_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(Signature::new(type_row![NAT], type_row![NAT]))?;
let mut cfg_builder = CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()]))?;
let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum());
let sum_variants = vec![type_row![]];
let mut middle_b =
cfg_builder.simple_block_builder(Signature::new(type_row![NAT], type_row![NAT]), 1)?;
let mut middle_b = cfg_builder
.simple_block_builder(Signature::new(vec![usize_t()], vec![usize_t()]), 1)?;
let [inw] = middle_b.input_wires_arr();
let middle = {
let c = middle_b.load_const(&sum_tuple_const);
middle_b.finish_with_outputs(c, [inw])?
};

let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT])?;
let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), vec![usize_t()].into())?;
let entry = {
let sum = entry_b.load_const(&sum_tuple_const);
// entry block uses wire from middle block even though middle block
Expand Down
32 changes: 16 additions & 16 deletions hugr-core/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,26 +243,23 @@ mod test {
use super::*;
use cool_asserts::assert_matches;

use crate::extension::prelude::{qb_t, usize_t};
use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY};
use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
use crate::utils::test_quantum_extension::{
self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
};
use crate::Extension;
use crate::{
builder::{
test::{build_main, NAT, QB},
DataflowSubContainer,
},
extension::prelude::BOOL_T,
type_row,
builder::{test::build_main, DataflowSubContainer},
extension::prelude::bool_t,
types::Signature,
};

#[test]
fn simple_linear() {
let build_res = build_main(
Signature::new(type_row![QB, QB], type_row![QB, QB])
Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()])
.with_extension_delta(test_quantum_extension::EXTENSION_ID)
.with_extension_delta(float_types::EXTENSION_ID)
.into(),
Expand Down Expand Up @@ -302,7 +299,7 @@ mod test {
ext.add_op(
"MyOp".into(),
"".to_string(),
Signature::new(vec![QB, NAT], vec![QB]),
Signature::new(vec![qb_t(), usize_t()], vec![qb_t()]),
extension_ref,
)
.unwrap();
Expand All @@ -312,12 +309,15 @@ mod test {
.unwrap();

let build_res = build_main(
Signature::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T])
.with_extension_delta(ExtensionSet::from_iter([
test_quantum_extension::EXTENSION_ID,
my_ext_name,
]))
.into(),
Signature::new(
vec![qb_t(), qb_t(), usize_t()],
vec![qb_t(), qb_t(), bool_t()],
)
.with_extension_delta(ExtensionSet::from_iter([
test_quantum_extension::EXTENSION_ID,
my_ext_name,
]))
.into(),
|mut f_build| {
let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();

Expand All @@ -342,7 +342,7 @@ mod test {
#[test]
fn ancillae() {
let build_res = build_main(
Signature::new_endo(QB)
Signature::new_endo(qb_t())
.with_extension_delta(test_quantum_extension::EXTENSION_ID)
.into(),
|mut f_build| {
Expand Down Expand Up @@ -380,7 +380,7 @@ mod test {
#[test]
fn circuit_builder_errors() {
let _build_res = build_main(
Signature::new_endo(type_row![QB, QB]).into(),
Signature::new_endo(vec![qb_t(), qb_t()]).into(),
|mut f_build| {
let mut circ = f_build.as_circuit(f_build.input_wires());
let [q0, q1] = circ.tracked_units_arr();
Expand Down
Loading
Loading