diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 14f68e7..368ace9 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -1,7 +1,7 @@ import json from fractions import Fraction -from typing import TYPE_CHECKING, Callable, Optional, Sequence, Dict, Union +from typing import TYPE_CHECKING, Callable, Optional, Sequence, Dict, Union, Any import networkx as nx import numpy as np @@ -175,8 +175,11 @@ def to_json(self) -> str: }) @classmethod - def from_json(cls, json_str: str) -> "CustomRule": - d = json.loads(json_str) + def from_json(cls, json_str: Union[str,Dict[str,Any]]) -> "CustomRule": + if isinstance(json_str, str): + d = json.loads(json_str) + else: + d = json_str lhs_graph = GraphT.from_json(d['lhs_graph']) rhs_graph = GraphT.from_json(d['rhs_graph']) # Mypy issue: https://github.com/python/mypy/issues/11673 diff --git a/zxlive/proof.py b/zxlive/proof.py index c4e9e05..4bd79c8 100644 --- a/zxlive/proof.py +++ b/zxlive/proof.py @@ -1,5 +1,5 @@ import json -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, Dict if TYPE_CHECKING: from .proof_panel import ProofPanel @@ -24,19 +24,26 @@ class Rewrite(NamedTuple): graph: GraphT # New graph after applying the rewrite grouped_rewrites: Optional[list['Rewrite']] = None # Optional field to store the grouped rewrites - def to_json(self) -> str: - """Serializes the rewrite to JSON.""" - return json.dumps({ + def to_dict(self) -> Dict[str, Any]: + """Serializes the rewrite to Python dictionary.""" + return { "display_name": self.display_name, "rule": self.rule, - "graph": self.graph.to_json(), - "grouped_rewrites": [r.to_json() for r in self.grouped_rewrites] if self.grouped_rewrites else None - }) + "graph": self.graph.to_dict(), + "grouped_rewrites": [r.to_dict() for r in self.grouped_rewrites] if self.grouped_rewrites else None + } + + def to_json(self) -> str: + """Serializes the rewrite to JSON.""" + return json.dumps(self.to_dict()) @staticmethod - def from_json(json_str: str) -> "Rewrite": - """Deserializes the rewrite from JSON.""" - d = json.loads(json_str) + def from_json(json_str: Union[str,Dict[str,Any]]) -> "Rewrite": + """Deserializes the rewrite from JSON or Python dict.""" + if isinstance(json_str, str): + d = json.loads(json_str) + else: + d = json_str grouped_rewrites = d.get("grouped_rewrites") graph = GraphT.from_json(d["graph"]) assert isinstance(graph, GraphT) @@ -183,20 +190,27 @@ def ungroup_steps(self, index: int) -> None: self.createIndex(index + len(individual_steps), 0), []) - def to_json(self) -> str: - """Serializes the model to JSON.""" - initial_graph = self.initial_graph.to_json() - proof_steps = [step.to_json() for step in self.steps] + def to_dict(self) -> Dict[str,Any]: + """Serializes the model to Python dict.""" + initial_graph = self.initial_graph.to_dict() + proof_steps = [step.to_dict() for step in self.steps] - return json.dumps({ + return { "initial_graph": initial_graph, "proof_steps": proof_steps - }) + } + + def to_json(self) -> str: + """Serializes the model to JSON.""" + return json.dumps(self.to_dict()) @staticmethod - def from_json(json_str: str) -> "ProofModel": - """Deserializes the model from JSON.""" - d = json.loads(json_str) + def from_json(json_str: Union[str,Dict[str,Any]]) -> "ProofModel": + """Deserializes the model from JSON or Python dict.""" + if isinstance(json_str, str): + d = json.loads(json_str) + else: + d = json_str initial_graph = GraphT.from_json(d["initial_graph"]) # Mypy issue: https://github.com/python/mypy/issues/11673 assert isinstance(initial_graph, GraphT) # type: ignore