Skip to content

Commit

Permalink
Automatically add resource req for custom ops
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Nov 9, 2023
1 parent ef8f4d3 commit 443ccf7
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
11 changes: 9 additions & 2 deletions src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ mod test {
use super::*;
use cool_asserts::assert_matches;

use crate::extension::ExtensionSet;
use crate::{
builder::{
test::{build_main, NAT, QB},
Expand Down Expand Up @@ -179,12 +180,18 @@ mod test {
"MyOp",
"unknown op".to_string(),
vec![],
Some(FunctionType::new(vec![QB, NAT], vec![QB])),
Some(
FunctionType::new(vec![QB, NAT], vec![QB]).with_extension_delta(
&ExtensionSet::singleton(&"MissingRsrc".try_into().unwrap()),
),
),
))
.into(),
);
let build_res = build_main(
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]).pure(),
FunctionType::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T])
.with_extension_delta(&ExtensionSet::singleton(&"MissingRsrc".try_into().unwrap()))
.pure(),
|mut f_build| {
let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();

Expand Down
28 changes: 26 additions & 2 deletions src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use smol_str::SmolStr;
use std::sync::Arc;
use thiserror::Error;

use crate::extension::{ExtensionId, ExtensionRegistry, OpDef, SignatureError};
use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError};
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::{HugrView, NodeType};
use crate::types::{type_param::TypeArg, FunctionType};
Expand Down Expand Up @@ -103,13 +103,17 @@ pub struct ExtensionOp {

impl ExtensionOp {
/// Create a new ExtensionOp given the type arguments and specified input extensions
///
/// Automatically adds the OpDef's extensions to the signature.
pub fn new(
def: Arc<OpDef>,
args: impl Into<Vec<TypeArg>>,
exts: &ExtensionRegistry,
) -> Result<Self, SignatureError> {
let args = args.into();
let signature = def.compute_signature(&args, exts)?;
let signature = def
.compute_signature(&args, exts)?
.with_extension_delta(&ExtensionSet::singleton(def.extension()));
Ok(Self {
def,
args,
Expand Down Expand Up @@ -180,13 +184,17 @@ fn qualify_name(res_id: &ExtensionId, op_name: &SmolStr) -> SmolStr {

impl OpaqueOp {
/// Creates a new OpaqueOp from all the fields we'd expect to serialize.
///
/// Automatically includes `extension` in the signature if `signature` is provided.
pub fn new(
extension: ExtensionId,
op_name: impl Into<SmolStr>,
description: String,
args: impl Into<Vec<TypeArg>>,
signature: Option<FunctionType>,
) -> Self {
let signature =
signature.map(|s| s.with_extension_delta(&ExtensionSet::singleton(&extension)));
Self {
extension,
op_name: op_name.into(),
Expand Down Expand Up @@ -327,4 +335,20 @@ mod test {
assert_eq!(op.description(), "desc");
assert_eq!(op.args(), &[TypeArg::Type { ty: USIZE_T }]);
}

#[test]
fn new_opaque_op_with_signature() {
let op = OpaqueOp::new(
"res".try_into().unwrap(),
"op",
"desc".into(),
vec![TypeArg::Type { ty: USIZE_T }],
Some(FunctionType::new_linear(vec![])),
);
let op: ExternalOp = op.into();
assert_eq!(
op.signature().extension_reqs,
ExtensionSet::singleton(&"res".try_into().unwrap())
);
}
}

0 comments on commit 443ccf7

Please sign in to comment.