Skip to content

Commit

Permalink
Update required fields in serialisation schema, removing TaggedSumTyp…
Browse files Browse the repository at this point in the history
…e, TaggedOpaqueType
  • Loading branch information
doug-q committed May 8, 2024
1 parent caadff2 commit c8fdc48
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 271 deletions.
31 changes: 15 additions & 16 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,17 @@ class Value(RootModel):
discriminator="v"
)

class Config:
# Needed to avoid random '\n's in the pydantic description
json_schema_extra = {"required": ["v"]}


class Const(BaseOp):
"""A Const operation definition."""

op: Literal["Const"] = "Const"
v: Value = Field()

class Config:
json_schema_extra = {
"required": ["op", "parent", "v"],
}


# -----------------------------------------------
# --------------- BasicBlock types ------------------
Expand All @@ -151,7 +150,7 @@ class DataflowBlock(BaseOp):
op: Literal["DataflowBlock"] = "DataflowBlock"
inputs: TypeRow = Field(default_factory=list)
other_outputs: TypeRow = Field(default_factory=list)
sum_rows: list[TypeRow] = Field(default_factory=list)
sum_rows: list[TypeRow]
extension_delta: ExtensionSet = Field(default_factory=list)

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
Expand All @@ -173,14 +172,6 @@ def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None:
class Config:
# Needed to avoid random '\n's in the pydantic description
json_schema_extra = {
"required": [
"parent",
"op",
"inputs",
"other_outputs",
"sum_rows",
"extension_delta",
],
"description": "A CFG basic block node. The signature is that of the internal Dataflow graph.",
}

Expand All @@ -193,9 +184,9 @@ class ExitBlock(BaseOp):
cfg_outputs: TypeRow

class Config:
# Needed to avoid random '\n's in the pydantic description
json_schema_extra = {
"description": "The single exit node of the CFG, has no children, stores the types of the CFG node output."
# Needed to avoid random '\n's in the pydantic description
"description": "The single exit node of the CFG, has no children, stores the types of the CFG node output.",
}


Expand All @@ -218,6 +209,11 @@ def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
assert len(in_types) == 0
self.types = list(out_types)

# class Config:
# json_schema_extra = {
# "required": ["op", "parent", "types"]
# }


class Output(DataflowOp):
"""An output node. The inputs are the outputs of the function."""
Expand Down Expand Up @@ -501,6 +497,9 @@ class OpType(RootModel):
| AliasDefn
) = Field(discriminator="op")

class Config:
json_schema_extra = {"required": ["parent", "op"]}


# --------------------------------------
# --------------- OpDef ----------------
Expand Down
40 changes: 18 additions & 22 deletions hugr-py/src/hugr/serialization/tys.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class ConfiguredBaseModel(BaseModel):
model_config = default_model_config

@classmethod
def set_model_config(cls, config: ConfigDict):
cls.model_config = config
def update_model_config(cls, config: ConfigDict):
cls.model_config.update(config)


# --------------------------------------------
Expand Down Expand Up @@ -99,6 +99,9 @@ class TypeParam(RootModel):
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="tp")

class Config:
json_schema_extra = {"required": ["tp"]}


# ------------------------------------------
# --------------- TypeArg ------------------
Expand Down Expand Up @@ -150,6 +153,9 @@ class TypeArg(RootModel):
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="tya")

class Config:
json_schema_extra = {"required": ["tya"]}


# --------------------------------------------
# --------------- Container ------------------
Expand All @@ -170,24 +176,24 @@ class Array(MultiContainer):
class UnitSum(ConfiguredBaseModel):
"""Simple sum type where all variants are empty tuples."""

t: Literal["Sum"] = "Sum"
s: Literal["Unit"] = "Unit"
size: int


class GeneralSum(ConfiguredBaseModel):
"""General sum type that explicitly stores the types of the variants."""

t: Literal["Sum"] = "Sum"
s: Literal["General"] = "General"
rows: list["TypeRow"]


class SumType(RootModel):
root: Union[UnitSum, GeneralSum] = Field(discriminator="s")


class TaggedSumType(ConfiguredBaseModel):
t: Literal["Sum"] = "Sum"
st: SumType
class Config:
json_schema_extra = {"required": ["s"]}


# ----------------------------------------------
Expand Down Expand Up @@ -280,17 +286,13 @@ def join(*bs: "TypeBound") -> "TypeBound":
class Opaque(ConfiguredBaseModel):
"""An opaque Type that can be downcasted by the extensions that define it."""

t: Literal["Opaque"] = "Opaque"
extension: ExtensionId
id: str # Unique identifier of the opaque type.
args: list[TypeArg]
bound: TypeBound


class TaggedOpaque(ConfiguredBaseModel):
t: Literal["Opaque"] = "Opaque"
o: Opaque


class Alias(ConfiguredBaseModel):
"""An Alias Type"""

Expand All @@ -314,17 +316,13 @@ class Type(RootModel):
"""A HUGR type."""

root: Annotated[
Qubit
| Variable
| USize
| FunctionType
| Array
| TaggedSumType
| TaggedOpaque
| Alias,
Qubit | Variable | USize | FunctionType | Array | SumType | Opaque | Alias,
WrapValidator(_json_custom_error_validator),
] = Field(discriminator="t")

class Config:
json_schema_extra = {"required": ["t"]}


# -------------------------------------------
# --------------- TypeRow -------------------
Expand Down Expand Up @@ -365,11 +363,9 @@ def model_rebuild(
config: ConfigDict = ConfigDict(),
**kwargs,
):
new_config = default_model_config.copy()
new_config.update(config)
for c in classes.values():
if issubclass(c, ConfiguredBaseModel):
c.set_model_config(new_config)
c.update_model_config(config)
c.model_rebuild(**kwargs)


Expand Down
12 changes: 6 additions & 6 deletions hugr/src/types/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ pub(super) enum SerSimpleType {
Q,
I,
G(Box<FunctionType>),
Sum { st: SumType },
Sum(SumType),
Array { inner: Box<SerSimpleType>, len: u64 },
Opaque { o: CustomType },
Opaque(CustomType),
Alias(AliasDecl),
V { i: usize, b: TypeBound },
}
Expand All @@ -29,11 +29,11 @@ impl From<Type> for SerSimpleType {
// TODO short circuiting for array.
let Type(value, _) = value;
match value {
TypeEnum::Extension(o) => SerSimpleType::Opaque { o },
TypeEnum::Extension(o) => SerSimpleType::Opaque(o),
TypeEnum::Alias(a) => SerSimpleType::Alias(a),
TypeEnum::Function(sig) => SerSimpleType::G(sig),
TypeEnum::Variable(i, b) => SerSimpleType::V { i, b },
TypeEnum::Sum(st) => SerSimpleType::Sum { st },
TypeEnum::Sum(st) => SerSimpleType::Sum(st),
}
}
}
Expand All @@ -44,11 +44,11 @@ impl From<SerSimpleType> for Type {
SerSimpleType::Q => QB_T,
SerSimpleType::I => USIZE_T,
SerSimpleType::G(sig) => Type::new_function(*sig),
SerSimpleType::Sum { st } => st.into(),
SerSimpleType::Sum(st) => st.into(),
SerSimpleType::Array { inner, len } => {
array_type(TypeArg::BoundedNat { n: len }, (*inner).into())
}
SerSimpleType::Opaque { o } => Type::new_extension(o),
SerSimpleType::Opaque(o) => Type::new_extension(o),
SerSimpleType::Alias(a) => Type::new_alias(a),
SerSimpleType::V { i, b } => Type::new_var_use(i, b),
}
Expand Down
2 changes: 1 addition & 1 deletion specification/schema/.gitattributes
Original file line number Diff line number Diff line change
@@ -1 +1 @@
*schema_v*.json -diff
*schema*.json -diff
Loading

0 comments on commit c8fdc48

Please sign in to comment.