Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Update serialisation schema, implement CustomConst serialisation #1005

Merged
merged 18 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 26 additions & 45 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,52 +70,41 @@ class FuncDecl(BaseOp):
signature: PolyFuncType


CustomConst = Any # TODO
class CustomConst(ConfiguredBaseModel):
c: str
v: Any


class ExtensionValue(ConfiguredBaseModel):
"""An extension constant value, that can check it is of a given [CustomType]."""

c: Literal["Extension"] = Field("Extension", title="ValueTag")
e: CustomConst = Field(title="CustomConst")

class Config:
json_schema_extra = {
"required": ["c", "e"],
}
mark-koch marked this conversation as resolved.
Show resolved Hide resolved
v: Literal["Extension"] = Field("Extension", title="ValueTag")
extensions: ExtensionSet
typ: Type
value: CustomConst


class FunctionValue(ConfiguredBaseModel):
"""A higher-order function value."""

c: Literal["Function"] = Field("Function", title="ValueTag")
v: Literal["Function"] = Field("Function", title="ValueTag")
hugr: Any # TODO

class Config:
json_schema_extra = {
"required": ["c", "hugr"],
}


class TupleValue(ConfiguredBaseModel):
"""A constant tuple value."""

c: Literal["Tuple"] = Field("Tuple", title="ValueTag")
v: Literal["Tuple"] = Field("Tuple", title="ValueTag")
vs: list["Value"]

class Config:
json_schema_extra = {
"required": ["c", "vs"],
}


class SumValue(ConfiguredBaseModel):
"""A Sum variant

For any Sum type where this value meets the type of the variant indicated by the tag
"""

c: Literal["Sum"] = Field("Sum", title="ValueTag")
v: Literal["Sum"] = Field("Sum", title="ValueTag")
tag: int
typ: SumType
vs: list["Value"]
Expand All @@ -127,29 +116,26 @@ class Config:
"A Sum variant For any Sum type where this value meets the type "
"of the variant indicated by the tag."
),
"required": ["c", "tag", "typ", "vs"],
}


class Value(RootModel):
"""A constant Value."""

root: ExtensionValue | FunctionValue | TupleValue | SumValue = Field(
discriminator="c"
discriminator="v"
)

class Config:
json_schema_extra = {"required": ["v"]}


class Const(BaseOp):
"""A Const operation definition."""

op: Literal["Const"] = "Const"
v: Value = Field()

class Config:
json_schema_extra = {
"required": ["op", "parent", "v"],
}


# -----------------------------------------------
# --------------- BasicBlock types ------------------
Expand All @@ -163,7 +149,7 @@ class DataflowBlock(BaseOp):
op: Literal["DataflowBlock"] = "DataflowBlock"
inputs: TypeRow = Field(default_factory=list)
other_outputs: TypeRow = Field(default_factory=list)
sum_rows: list[TypeRow] = Field(default_factory=list)
sum_rows: list[TypeRow]
extension_delta: ExtensionSet = Field(default_factory=list)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
Expand All @@ -173,26 +159,18 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:
self.inputs = inputs
pred = outputs[0].root
assert isinstance(pred, tys.TaggedSumType)
if isinstance(pred.st.root, tys.UnitSum):
self.sum_rows = [[] for _ in range(pred.st.root.size)]
assert isinstance(pred, tys.SumType)
if isinstance(pred.root, tys.UnitSum):
self.sum_rows = [[] for _ in range(pred.root.size)]
else:
self.sum_rows = []
for variant in pred.st.root.rows:
for variant in pred.root.rows:
self.sum_rows.append(variant)
self.other_outputs = outputs[1:]

class Config:
# Needed to avoid random '\n's in the pydantic description
json_schema_extra = {
"required": [
"parent",
"op",
"inputs",
"other_outputs",
"sum_rows",
"extension_delta",
],
"description": "A CFG basic block node. The signature is that of the internal Dataflow graph.",
}

Expand All @@ -205,9 +183,9 @@ class ExitBlock(BaseOp):
cfg_outputs: TypeRow

class Config:
# Needed to avoid random '\n's in the pydantic description
json_schema_extra = {
"description": "The single exit node of the CFG, has no children, stores the types of the CFG node output."
# Needed to avoid random '\n's in the pydantic description
"description": "The single exit node of the CFG, has no children, stores the types of the CFG node output.",
}


Expand Down Expand Up @@ -334,8 +312,8 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
# First port is a predicate, i.e. a sum of tuple types. We need to unpack
# those into a list of type rows
pred = in_types[0]
assert isinstance(pred.root, tys.TaggedSumType)
sum = pred.root.st.root
assert isinstance(pred.root, tys.SumType)
sum = pred.root.root
if isinstance(sum, tys.UnitSum):
self.sum_rows = [[] for _ in range(sum.size)]
else:
Expand Down Expand Up @@ -513,6 +491,9 @@ class OpType(RootModel):
| AliasDefn
) = Field(discriminator="op")

class Config:
json_schema_extra = {"required": ["parent", "op"]}


# --------------------------------------
# --------------- OpDef ----------------
Expand Down
3 changes: 0 additions & 3 deletions hugr-py/src/hugr/serialization/testing_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,3 @@ def _pydantic_rebuild(cls, config: ConfigDict = ConfigDict(), **kwargs):
my_classes = dict(ops_classes)
my_classes[cls.__name__] = cls
model_rebuild(my_classes, config=config, **kwargs)

class Config:
title = "HugrTesting"
48 changes: 25 additions & 23 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class ConfiguredBaseModel(BaseModel):
model_config = default_model_config

@classmethod
def set_model_config(cls, config: ConfigDict):
cls.model_config = config
def update_model_config(cls, config: ConfigDict):
cls.model_config.update(config)
Comment on lines -51 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍



# --------------------------------------------
Expand Down Expand Up @@ -99,6 +99,9 @@ class TypeParam(RootModel):
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="tp")

class Config:
json_schema_extra = {"required": ["tp"]}


# ------------------------------------------
# --------------- TypeArg ------------------
Expand Down Expand Up @@ -150,6 +153,9 @@ class TypeArg(RootModel):
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="tya")

class Config:
json_schema_extra = {"required": ["tya"]}


# --------------------------------------------
# --------------- Container ------------------
Expand All @@ -170,24 +176,29 @@ class Array(MultiContainer):
class UnitSum(ConfiguredBaseModel):
"""Simple sum type where all variants are empty tuples."""

t: Literal["Sum"] = "Sum"
s: Literal["Unit"] = "Unit"
size: int


class GeneralSum(ConfiguredBaseModel):
"""General sum type that explicitly stores the types of the variants."""

t: Literal["Sum"] = "Sum"
s: Literal["General"] = "General"
rows: list["TypeRow"]


class SumType(RootModel):
root: Union[UnitSum, GeneralSum] = Field(discriminator="s")
root: Annotated[Union[UnitSum, GeneralSum], Field(discriminator="s")]

# This seems to be required for nested discriminated unions to work
@property
def t(self) -> str:
return self.root.t

class TaggedSumType(ConfiguredBaseModel):
t: Literal["Sum"] = "Sum"
st: SumType
class Config:
json_schema_extra = {"required": ["s"]}


# ----------------------------------------------
Expand Down Expand Up @@ -280,17 +291,13 @@ def join(*bs: "TypeBound") -> "TypeBound":
class Opaque(ConfiguredBaseModel):
"""An opaque Type that can be downcasted by the extensions that define it."""

t: Literal["Opaque"] = "Opaque"
extension: ExtensionId
id: str # Unique identifier of the opaque type.
args: list[TypeArg]
bound: TypeBound


class TaggedOpaque(ConfiguredBaseModel):
t: Literal["Opaque"] = "Opaque"
o: Opaque


class Alias(ConfiguredBaseModel):
"""An Alias Type"""

Expand All @@ -314,16 +321,13 @@ class Type(RootModel):
"""A HUGR type."""

root: Annotated[
Qubit
| Variable
| USize
| FunctionType
| Array
| TaggedSumType
| TaggedOpaque
| Alias,
Qubit | Variable | USize | FunctionType | Array | SumType | Opaque | Alias,
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="t")
Field(discriminator="t"),
]

class Config:
json_schema_extra = {"required": ["t"]}


# -------------------------------------------
Expand Down Expand Up @@ -365,11 +369,9 @@ def model_rebuild(
config: ConfigDict = ConfigDict(),
**kwargs,
):
new_config = default_model_config.copy()
new_config.update(config)
for c in classes.values():
if issubclass(c, ConfiguredBaseModel):
c.set_model_config(new_config)
c.update_model_config(config)
c.model_rebuild(**kwargs)


Expand Down
4 changes: 2 additions & 2 deletions hugr/src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR
/// Generate a graph that loads and outputs `consts` in order, validating
/// against `reg`.
fn const_graph(consts: Vec<Value>, reg: &ExtensionRegistry) -> Hugr {
let const_types = consts.iter().map(Value::const_type).collect_vec();
let const_types = consts.iter().map(Value::get_type).collect_vec();
let mut b = DFGBuilder::new(FunctionType::new(type_row![], const_types)).unwrap();

let outputs = consts
Expand Down Expand Up @@ -338,7 +338,7 @@ mod test {
let list: Value = ListValue::new(BOOL_T, [Value::unit_sum(0, 1).unwrap()]).into();
let mut build = DFGBuilder::new(FunctionType::new(
type_row![],
vec![list.const_type().clone()],
vec![list.get_type().clone()],
))
.unwrap();

Expand Down
1 change: 0 additions & 1 deletion hugr/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@
//! # }
//! # doctest().unwrap();
//! ```
//!
use thiserror::Error;

use crate::extension::SignatureError;
Expand Down
2 changes: 1 addition & 1 deletion hugr/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ pub trait Dataflow: Container {
let load_n = self
.add_dataflow_op(
ops::LoadConstant {
datatype: op.const_type().clone(),
datatype: op.get_type().clone(),
},
// Constant wire from the constant value node
vec![Wire::new(const_node, OutgoingPort::from(0))],
Expand Down
38 changes: 18 additions & 20 deletions hugr/src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ use crate::{
/// +------------+
/// */
/// use hugr::{
/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder},
/// Hugr,
/// extension::{ExtensionSet, prelude},
/// types::{FunctionType, Type, SumType},
/// ops,
/// type_row,
/// builder::{BuildError, CFGBuilder, Container, Dataflow, HugrBuilder},
/// extension::{prelude, ExtensionSet},
/// ops, type_row,
/// types::{FunctionType, SumType, Type},
/// Hugr,
/// };
///
/// const NAT: Type = prelude::USIZE_T;
Expand All @@ -75,7 +74,7 @@ use crate::{
/// let left_42 = ops::Value::sum(
/// 0,
/// [prelude::ConstUsize::new(42).into()],
/// SumType::new(sum_variants.clone())
/// SumType::new(sum_variants.clone()),
/// )?;
/// let sum = entry_b.add_load_value(left_42);
///
Expand All @@ -85,11 +84,10 @@ use crate::{
/// // This block will be the first successor of the entry node. It takes two
/// // `NAT` arguments: one from the `sum_variants` type, and another from the
/// // entry node's `other_outputs`.
/// let mut successor_builder =
/// cfg_builder.simple_block_builder(
/// FunctionType::new(type_row![NAT, NAT], type_row![NAT]),
/// 1 // only one successor to this block
/// )?;
/// let mut successor_builder = cfg_builder.simple_block_builder(
/// FunctionType::new(type_row![NAT, NAT], type_row![NAT]),
/// 1, // only one successor to this block
/// )?;
/// let successor_a = {
/// // This block has one successor. The choice is denoted by a unary sum.
/// let sum_unary = successor_builder.add_load_const(ops::Value::unary_unit_sum());
Expand All @@ -100,14 +98,14 @@ use crate::{
/// successor_builder.finish_with_outputs(sum_unary, [in_wire])?
/// };
///
/// // The only argument to this block is the entry node's `other_outputs`.
/// let mut successor_builder =
/// cfg_builder.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
/// let successor_b = {
/// let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum());
/// let [in_wire] = successor_builder.input_wires_arr();
/// successor_builder.finish_with_outputs(sum_unary, [in_wire])?
/// };
/// // The only argument to this block is the entry node's `other_outputs`.
/// let mut successor_builder = cfg_builder
/// .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
/// let successor_b = {
/// let sum_unary = successor_builder.add_load_value(ops::Value::unary_unit_sum());
/// let [in_wire] = successor_builder.input_wires_arr();
/// successor_builder.finish_with_outputs(sum_unary, [in_wire])?
/// };
/// let exit = cfg_builder.exit_block();
/// cfg_builder.branch(&entry, 0, &successor_a)?; // branch 0 goes to successor_a
/// cfg_builder.branch(&entry, 1, &successor_b)?; // branch 1 goes to successor_b
Expand Down
Loading
Loading