Skip to content

Commit

Permalink
Merge pull request #243 from Quantomatic/improve-typing
Browse files Browse the repository at this point in the history
Improve typing
  • Loading branch information
RazinShaikh authored Jun 12, 2024
2 parents 0afbc3c + 6c5c8ea commit 70ccc3b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 37 deletions.
12 changes: 6 additions & 6 deletions zxlive/rewrite_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def do_rewrite(self, panel: ProofPanel) -> None:
panel.undo_stack.push(cmd, anim_before=anim_before, anim_after=anim_after)

# TODO: Narrow down the type of the first return value.
def apply_rewrite(self, g: GraphT, matches: list) -> tuple[Any, Optional[Iterable[VT]]]:
def apply_rewrite(self, g: GraphT, matches: list) -> tuple[Any, Optional[list[VT]]]:
if self.returns_new_graph:
return self.rule(g, matches), None

Expand Down Expand Up @@ -136,7 +136,7 @@ def from_dict(cls, d: dict, header: str = "", parent: RewriteActionTree | None =
ret.append_child(cls.from_dict(actions, group, ret))
return ret

def update_on_selection(self, g, selection, edges) -> None:
def update_on_selection(self, g: GraphT, selection: list[VT], edges: list[ET]) -> None:
for child in self.child_items:
child.update_on_selection(g, selection, edges)
if self.rewrite is not None:
Expand All @@ -146,13 +146,13 @@ def update_on_selection(self, g, selection, edges) -> None:
class RewriteActionTreeModel(QAbstractItemModel):
root_item: RewriteActionTree

def __init__(self, data: RewriteActionTree, proof_panel: ProofPanel):
def __init__(self, data: RewriteActionTree, proof_panel: ProofPanel) -> None:
super().__init__(proof_panel)
self.proof_panel = proof_panel
self.root_item = data

@classmethod
def from_dict(cls, d: dict, proof_panel: ProofPanel):
def from_dict(cls, d: dict, proof_panel: ProofPanel) -> RewriteActionTreeModel:
return RewriteActionTreeModel(
RewriteActionTree.from_dict(d),
proof_panel
Expand All @@ -170,7 +170,7 @@ def index(self, row: int, column: int, parent: Union[QModelIndex, QPersistentMod
return QModelIndex()

def parent(self, index: QModelIndex = None) -> QModelIndex:
if not index.isValid():
if index is None or not index.isValid():
return QModelIndex()

parent_item = index.internalPointer().parent
Expand All @@ -181,7 +181,7 @@ def parent(self, index: QModelIndex = None) -> QModelIndex:
return self.createIndex(parent_item.row(), 0, parent_item)

def rowCount(self, parent: QModelIndex = None) -> int:
if parent.column() > 0:
if parent is None or parent.column() > 0:
return 0
parent_item = parent.internalPointer() if parent.isValid() else self.root_item
return parent_item.child_count()
Expand Down
85 changes: 54 additions & 31 deletions zxlive/rewrite_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,54 @@ def read_custom_rules() -> list[RewriteData]:
# So we add them to operations

rewrites_graph_theoretic: dict[str, RewriteData] = {
"lcomp": operations["lcomp"],
"pivot": operations["pivot"],
"pivot_boundary": {"text": "boundary pivot",
"tooltip": "Performs a pivot between a Pauli spider and a spider on the boundary.",
"matcher": pyzx.rules.match_pivot_boundary,
"rule": pyzx.rules.pivot,
"type": MATCHES_EDGES,
"copy_first": True},
"pivot_gadget": {"text": "gadget pivot",
"tooltip": "Performs a pivot between a Pauli spider and a spider with an arbitrary phase, creating a phase gadget.",
"matcher": pyzx.rules.match_pivot_gadget,
"rule": pyzx.rules.pivot,
"type": MATCHES_EDGES,
"copy_first": True},
"phase_gadget_fuse": {"text": "Fuse phase gadgets",
"tooltip": "Fuses two phase gadgets with the same connectivity.",
"matcher": pyzx.rules.match_phase_gadgets,
"rule": pyzx.rules.merge_phase_gadgets,
"type": MATCHES_VERTICES,
"copy_first": True},
"supplementarity": {"text": "Supplementarity",
"tooltip": "Looks for a pair of internal spiders with the same connectivity and supplementary angles and removes them.",
"matcher": pyzx.rules.match_supplementarity,
"rule": pyzx.rules.apply_supplementarity,
"type": MATCHES_VERTICES,
"copy_first": False},
"lcomp": {
"text": "local complementation",
"tooltip": "Deletes a spider with a pi/2 phase by performing a local complementation on its neighbors",
"matcher": pyzx.rules.match_lcomp_parallel,
"rule": pyzx.rules.lcomp,
"type": MATCHES_VERTICES,
"copy_first": True
},
"pivot": {
"text": "pivot",
"tooltip": "Deletes a pair of spiders with 0/pi phases by performing a pivot",
"matcher": lambda g, matchf: pyzx.rules.match_pivot_parallel(g, matchf, check_edge_types=True),
"rule": pyzx.rules.pivot,
"type": MATCHES_EDGES,
"copy_first": True
},
"pivot_boundary": {
"text": "boundary pivot",
"tooltip": "Performs a pivot between a Pauli spider and a spider on the boundary.",
"matcher": pyzx.rules.match_pivot_boundary,
"rule": pyzx.rules.pivot,
"type": MATCHES_EDGES,
"copy_first": True
},
"pivot_gadget": {
"text": "gadget pivot",
"tooltip": "Performs a pivot between a Pauli spider and a spider with an arbitrary phase, creating a phase gadget.",
"matcher": pyzx.rules.match_pivot_gadget,
"rule": pyzx.rules.pivot,
"type": MATCHES_EDGES,
"copy_first": True
},
"phase_gadget_fuse": {
"text": "Fuse phase gadgets",
"tooltip": "Fuses two phase gadgets with the same connectivity.",
"matcher": pyzx.rules.match_phase_gadgets,
"rule": pyzx.rules.merge_phase_gadgets,
"type": MATCHES_VERTICES,
"copy_first": True
},
"supplementarity": {
"text": "Supplementarity",
"tooltip": "Looks for a pair of internal spiders with the same connectivity and supplementary angles and removes them.",
"matcher": pyzx.rules.match_supplementarity,
"rule": pyzx.rules.apply_supplementarity,
"type": MATCHES_VERTICES,
"copy_first": False
},
}

const_true = lambda graph, matches: matches
Expand All @@ -102,12 +124,13 @@ def _extract_circuit(graph: GraphT, matches: list) -> GraphT:
def ocm_rule(_graph: GraphT, _matches: list) -> pyzx.rules.RewriteOutputType[ET, VT]:
return ({}, [], [], True)


ocm_action: RewriteData = {
"text": "OCM",
"tooltip": "Saves the graph with the current vertex positions",
"matcher": const_true,
"rule": ocm_rule,
"type": MATCHES_VERTICES,
"text": "OCM",
"tooltip": "Saves the graph with the current vertex positions",
"matcher": const_true,
"rule": ocm_rule,
"type": MATCHES_VERTICES,
}

simplifications: dict[str, RewriteData] = {
Expand Down

0 comments on commit 70ccc3b

Please sign in to comment.