Skip to content

Commit

Permalink
Merge branch 'main' into just-holes
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor authored Dec 23, 2024
2 parents f783923 + 17736fa commit 6f9d28a
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 42 deletions.
2 changes: 1 addition & 1 deletion brat/examples/karlheinz.brat
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ answer = energy(results)
evaluate(obs :: Observable
,q :: Quantity
,a :: Ansatz
,rs :: List Real
,rs :: List(Real)
) -> Real
evaluate = ?eval

Expand Down
10 changes: 10 additions & 0 deletions brat/examples/let.brat
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,13 @@ nums' = let xs = map(inc, [0,2,3]) in xs

nums'' :: List(Int)
nums'' = let i2 = {inc; inc} in map(i2, xs)

dyad :: Int, Bool
dyad = 42, true

bind2 :: Bool
bind2 = let i, b = dyad in b

-- It shouldn't matter if we put brackets in the binding sites
bind2' :: Bool
bind2' = let (i, b) = dyad in b
1 change: 1 addition & 0 deletions brat/test/Test/Checking.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import Test.Tasty.HUnit
import Test.Tasty.Silver

expectedCheckingFails = map ("examples" </>) ["nested-abstractors.brat"
,"karlheinz.brat"
,"karlheinz_alias.brat"
,"hea.brat"
]
Expand Down
4 changes: 1 addition & 3 deletions brat/test/Test/Parsing.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ testParse file = testCase (show file) $ do
Left err -> assertFailure (show err)
Right _ -> return () -- OK

expectedParsingFails = map ("examples" </>) [
"karlheinz.brat",
"thin.brat"]
expectedParsingFails = ["examples" </> "thin.brat"]

parseXF = expectFailForPaths expectedParsingFails testParse

Expand Down
4 changes: 2 additions & 2 deletions hugr_extension/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "brat_extension"
version = "0.3.0"
version = "0.4.0"
edition = "2021"

[lib]
Expand All @@ -9,7 +9,7 @@ bench = false
path = "src/lib.rs"

[dependencies]
hugr = "0.8.0"
hugr = "0.9.0"
serde = "1.0"
serde_json = "1.0.97"

Expand Down
12 changes: 5 additions & 7 deletions hugr_extension/src/ctor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ use enum_iterator::Sequence;
use hugr::{
ops::NamedOp,
std_extensions::{arithmetic::int_types, collections},
types::{
type_param::TypeParam, CustomType, FunctionType, PolyFuncType, Type, TypeArg, TypeBound,
},
types::{type_param::TypeParam, CustomType, PolyFuncType, Signature, Type, TypeArg, TypeBound},
};
use smol_str::{format_smolstr, SmolStr};
use std::str::FromStr;
Expand Down Expand Up @@ -82,8 +80,8 @@ impl Ctor for BratCtor {
impl Ctor for NatCtor {
fn signature(self) -> PolyFuncType {
match self {
NatCtor::zero => FunctionType::new(vec![], vec![nat_type()]).into(),
NatCtor::succ => FunctionType::new(vec![nat_type()], vec![nat_type()]).into(),
NatCtor::zero => Signature::new(vec![], vec![nat_type()]).into(),
NatCtor::succ => Signature::new(vec![nat_type()], vec![nat_type()]).into(),
}
}
}
Expand All @@ -94,11 +92,11 @@ impl Ctor for VecCtor {
let ta = Type::new_var_use(0, TypeBound::Any);
match self {
VecCtor::nil => {
PolyFuncType::new(vec![tp], FunctionType::new(vec![], vec![vec_type(&ta)]))
PolyFuncType::new(vec![tp], Signature::new(vec![], vec![vec_type(&ta)]))
}
VecCtor::cons => PolyFuncType::new(
vec![tp],
FunctionType::new(vec![ta.clone(), vec_type(&ta)], vec![vec_type(&ta)]),
Signature::new(vec![ta.clone(), vec_type(&ta)], vec![vec_type(&ta)]),
),
}
}
Expand Down
35 changes: 20 additions & 15 deletions hugr_extension/src/defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use hugr::{
},
ops::NamedOp,
std_extensions::collections::list_type,
types::{type_param::TypeParam, FunctionType, PolyFuncType, Type, TypeArg, TypeBound, TypeEnum},
types::{
type_param::TypeParam, FuncValueType, PolyFuncTypeRV, Type, TypeArg, TypeBound, TypeEnum,
},
};

use lazy_static::lazy_static;
Expand Down Expand Up @@ -83,14 +85,17 @@ impl MakeOpDef for BratOpDef {
let sig = ctor.signature();
let input = sig.body().output(); // Ctor output is input for the test
let output = Type::new_sum(vec![input.clone(), sig.body().input().clone()]);
PolyFuncType::new(sig.params(), FunctionType::new(input.clone(), vec![output]))
.into()
PolyFuncTypeRV::new(
sig.params(),
FuncValueType::new(input.clone(), vec![output]),
)
.into()
}
Replicate => PolyFuncType::new(
Replicate => PolyFuncTypeRV::new(
[TypeParam::Type {
b: TypeBound::Copyable,
}],
FunctionType::new(
FuncValueType::new(
vec![USIZE_T, Type::new_var_use(0, TypeBound::Copyable)],
vec![list_type(Type::new_var_use(0, TypeBound::Copyable))],
),
Expand All @@ -107,15 +112,15 @@ impl MakeOpDef for BratOpDef {
/// Binary compute_signature function for the `Hole` op
struct HoleSigFun();
impl SignatureFromArgs for HoleSigFun {
fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncType, SignatureError> {
fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncTypeRV, SignatureError> {
// The Hole op expects a nat identifier and two type sequences specifiying
// the signature of the hole
match arg_values {
[TypeArg::BoundedNat { n: _ }, TypeArg::Type { ty: fun_ty }] => {
let TypeEnum::Function(sig) = fun_ty.as_type_enum().clone() else {
return Err(SignatureError::InvalidTypeArgs);
};
Ok(PolyFuncType::new([], *sig))
Ok(PolyFuncTypeRV::new([], *sig))
}
_ => Err(SignatureError::InvalidTypeArgs),
}
Expand All @@ -133,7 +138,7 @@ impl SignatureFromArgs for HoleSigFun {
/// Binary compute_signature function for the `Substitute` op
struct SubstituteSigFun();
impl SignatureFromArgs for SubstituteSigFun {
fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncType, SignatureError> {
fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncTypeRV, SignatureError> {
// The Substitute op expects a function signature and a list of hole signatures
match arg_values {
[TypeArg::Type { ty: outer_fun_ty }, TypeArg::Sequence { elems: hole_sigs }] => {
Expand All @@ -144,7 +149,7 @@ impl SignatureFromArgs for SubstituteSigFun {
};
inputs.push(inner_fun_ty.clone())
}
Ok(FunctionType::new(inputs, vec![outer_fun_ty.clone()]).into())
Ok(FuncValueType::new(inputs, vec![outer_fun_ty.clone()]).into())
}
_ => Err(SignatureError::InvalidTypeArgs),
}
Expand All @@ -168,7 +173,7 @@ impl SignatureFromArgs for SubstituteSigFun {
/// Binary compute_signature function for the `Partial` op
struct PartialSigFun();
impl SignatureFromArgs for PartialSigFun {
fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncType, SignatureError> {
fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncTypeRV, SignatureError> {
// The Partial op expects a type sequence specifying the supplied partial inputs, a type
// sequence specifiying the remaining inputs and a type sequence for the function outputs.
match arg_values {
Expand All @@ -177,13 +182,13 @@ impl SignatureFromArgs for PartialSigFun {
let other_inputs = row_from_arg(other_inputs)?;
let outputs = row_from_arg(outputs)?;
let res_func =
Type::new_function(FunctionType::new(other_inputs.clone(), outputs.clone()));
let mut inputs = vec![Type::new_function(FunctionType::new(
Type::new_function(FuncValueType::new(other_inputs.clone(), outputs.clone()));
let mut inputs = vec![Type::new_function(FuncValueType::new(
[partial_inputs.clone(), other_inputs].concat(),
outputs,
))];
inputs.extend(partial_inputs);
Ok(FunctionType::new(inputs, vec![res_func]).into())
Ok(FuncValueType::new(inputs, vec![res_func]).into())
}
_ => Err(SignatureError::InvalidTypeArgs),
}
Expand All @@ -200,11 +205,11 @@ impl SignatureFromArgs for PartialSigFun {
/// Binary compute_signature function for the `Panic` op
struct PanicSigFun();
impl SignatureFromArgs for PanicSigFun {
fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncType, SignatureError> {
fn compute_signature(&self, arg_values: &[TypeArg]) -> Result<PolyFuncTypeRV, SignatureError> {
// The Panic op expects two type sequences specifiying the signature of the op
match arg_values {
[input, output] => {
Ok(FunctionType::new(row_from_arg(input)?, row_from_arg(output)?).into())
Ok(FuncValueType::new(row_from_arg(input)?, row_from_arg(output)?).into())
}
_ => Err(SignatureError::InvalidTypeArgs),
}
Expand Down
38 changes: 26 additions & 12 deletions hugr_extension/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use hugr::{
SignatureError,
},
ops::{custom::ExtensionOp, NamedOp, OpTrait},
types::{FunctionType, TypeArg, TypeEnum, TypeRow},
types::{Signature, TypeArg, TypeEnum, TypeRow},
};
use smol_str::{format_smolstr, SmolStr};

Expand All @@ -16,18 +16,18 @@ use crate::{ctor::BratCtor, defs::BratOpDef};
pub enum BratOp {
Hole {
idx: u64,
sig: FunctionType,
sig: Signature,
},
Substitute {
func_sig: FunctionType,
hole_sigs: Vec<FunctionType>,
func_sig: Signature,
hole_sigs: Vec<Signature>,
},
Partial {
inputs: TypeRow,
output_sig: FunctionType,
output_sig: Signature,
},
Panic {
sig: FunctionType,
sig: Signature,
},
Ctor {
ctor: BratCtor,
Expand Down Expand Up @@ -78,9 +78,20 @@ impl MakeExtensionOp for BratOp {
_ => Err(SignatureError::InvalidTypeArgs.into()),
})
.collect();
let closed_sig = Signature::try_from(*func_sig.clone())
.map_err(|_| SignatureError::InvalidTypeArgs)?;

let closed_hole_sigs: Result<Vec<Signature>, SignatureError> = hole_sigs?
.iter()
.map(|a| {
Signature::try_from(a.clone())
.map_err(|_| SignatureError::InvalidTypeArgs)
})
.collect();

Ok(BratOp::Substitute {
func_sig: *func_sig.clone(),
hole_sigs: hole_sigs?,
func_sig: closed_sig,
hole_sigs: closed_hole_sigs?,
})
}
_ => Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs)),
Expand All @@ -92,7 +103,8 @@ impl MakeExtensionOp for BratOp {
};
Ok(BratOp::Partial {
inputs: partial_inputs.to_vec().into(),
output_sig: *output_sig.clone(),
output_sig: Signature::try_from(*output_sig.clone())
.expect("Invalid type arg to Partial"),
})
}
_ => Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs)),
Expand Down Expand Up @@ -138,10 +150,12 @@ impl MakeExtensionOp for BratOp {
],
BratOp::Partial { inputs, output_sig } => vec![
arg_from_row(inputs),
arg_from_row(output_sig.input()),
arg_from_row(output_sig.output()),
arg_from_row(output_sig.input().into()),
arg_from_row(output_sig.output().into()),
],
BratOp::Panic { sig } => vec![arg_from_row(sig.input()), arg_from_row(sig.output())],
BratOp::Panic { sig } => {
vec![arg_from_row(sig.input().into()), arg_from_row(sig.output())]
}
BratOp::Ctor { args, .. } => args.clone(),
BratOp::PrimCtorTest { args, .. } => args.clone(),
BratOp::Replicate(arg) => vec![arg.clone()],
Expand Down
4 changes: 2 additions & 2 deletions hugr_validator/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[package]
name = "hugr_validator"
version = "0.3.0"
version = "0.4.0"
edition = "2021"

[dependencies]
hugr = "0.8.0"
hugr = "0.9.0"
serde_json = "*"
brat_extension = { path = "../hugr_extension" }

0 comments on commit 6f9d28a

Please sign in to comment.