From 6c5c8ea4a13091d517d6f79a4c00b3ffa1aae3f8 Mon Sep 17 00:00:00 2001 From: Boldi Date: Wed, 12 Jun 2024 17:52:28 +0100 Subject: [PATCH] Improve typing --- zxlive/rewrite_action.py | 12 +++--- zxlive/rewrite_data.py | 85 +++++++++++++++++++++++++--------------- 2 files changed, 60 insertions(+), 37 deletions(-) diff --git a/zxlive/rewrite_action.py b/zxlive/rewrite_action.py index 4c0f945f..fa85fbd3 100644 --- a/zxlive/rewrite_action.py +++ b/zxlive/rewrite_action.py @@ -67,7 +67,7 @@ def do_rewrite(self, panel: ProofPanel) -> None: panel.undo_stack.push(cmd, anim_before=anim_before, anim_after=anim_after) # TODO: Narrow down the type of the first return value. - def apply_rewrite(self, g: GraphT, matches: list) -> tuple[Any, Optional[Iterable[VT]]]: + def apply_rewrite(self, g: GraphT, matches: list) -> tuple[Any, Optional[list[VT]]]: if self.returns_new_graph: return self.rule(g, matches), None @@ -136,7 +136,7 @@ def from_dict(cls, d: dict, header: str = "", parent: RewriteActionTree | None = ret.append_child(cls.from_dict(actions, group, ret)) return ret - def update_on_selection(self, g, selection, edges) -> None: + def update_on_selection(self, g: GraphT, selection: list[VT], edges: list[ET]) -> None: for child in self.child_items: child.update_on_selection(g, selection, edges) if self.rewrite is not None: @@ -146,13 +146,13 @@ def update_on_selection(self, g, selection, edges) -> None: class RewriteActionTreeModel(QAbstractItemModel): root_item: RewriteActionTree - def __init__(self, data: RewriteActionTree, proof_panel: ProofPanel): + def __init__(self, data: RewriteActionTree, proof_panel: ProofPanel) -> None: super().__init__(proof_panel) self.proof_panel = proof_panel self.root_item = data @classmethod - def from_dict(cls, d: dict, proof_panel: ProofPanel): + def from_dict(cls, d: dict, proof_panel: ProofPanel) -> RewriteActionTreeModel: return RewriteActionTreeModel( RewriteActionTree.from_dict(d), proof_panel @@ -170,7 +170,7 @@ def index(self, row: int, column: int, parent: Union[QModelIndex, QPersistentMod return QModelIndex() def parent(self, index: QModelIndex = None) -> QModelIndex: - if not index.isValid(): + if index is None or not index.isValid(): return QModelIndex() parent_item = index.internalPointer().parent @@ -181,7 +181,7 @@ def parent(self, index: QModelIndex = None) -> QModelIndex: return self.createIndex(parent_item.row(), 0, parent_item) def rowCount(self, parent: QModelIndex = None) -> int: - if parent.column() > 0: + if parent is None or parent.column() > 0: return 0 parent_item = parent.internalPointer() if parent.isValid() else self.root_item return parent_item.child_count() diff --git a/zxlive/rewrite_data.py b/zxlive/rewrite_data.py index 1b5f584f..25ac54e9 100644 --- a/zxlive/rewrite_data.py +++ b/zxlive/rewrite_data.py @@ -51,32 +51,54 @@ def read_custom_rules() -> list[RewriteData]: # So we add them to operations rewrites_graph_theoretic: dict[str, RewriteData] = { - "lcomp": operations["lcomp"], - "pivot": operations["pivot"], - "pivot_boundary": {"text": "boundary pivot", - "tooltip": "Performs a pivot between a Pauli spider and a spider on the boundary.", - "matcher": pyzx.rules.match_pivot_boundary, - "rule": pyzx.rules.pivot, - "type": MATCHES_EDGES, - "copy_first": True}, - "pivot_gadget": {"text": "gadget pivot", - "tooltip": "Performs a pivot between a Pauli spider and a spider with an arbitrary phase, creating a phase gadget.", - "matcher": pyzx.rules.match_pivot_gadget, - "rule": pyzx.rules.pivot, - "type": MATCHES_EDGES, - "copy_first": True}, - "phase_gadget_fuse": {"text": "Fuse phase gadgets", - "tooltip": "Fuses two phase gadgets with the same connectivity.", - "matcher": pyzx.rules.match_phase_gadgets, - "rule": pyzx.rules.merge_phase_gadgets, - "type": MATCHES_VERTICES, - "copy_first": True}, - "supplementarity": {"text": "Supplementarity", - "tooltip": "Looks for a pair of internal spiders with the same connectivity and supplementary angles and removes them.", - "matcher": pyzx.rules.match_supplementarity, - "rule": pyzx.rules.apply_supplementarity, - "type": MATCHES_VERTICES, - "copy_first": False}, + "lcomp": { + "text": "local complementation", + "tooltip": "Deletes a spider with a pi/2 phase by performing a local complementation on its neighbors", + "matcher": pyzx.rules.match_lcomp_parallel, + "rule": pyzx.rules.lcomp, + "type": MATCHES_VERTICES, + "copy_first": True + }, + "pivot": { + "text": "pivot", + "tooltip": "Deletes a pair of spiders with 0/pi phases by performing a pivot", + "matcher": lambda g, matchf: pyzx.rules.match_pivot_parallel(g, matchf, check_edge_types=True), + "rule": pyzx.rules.pivot, + "type": MATCHES_EDGES, + "copy_first": True + }, + "pivot_boundary": { + "text": "boundary pivot", + "tooltip": "Performs a pivot between a Pauli spider and a spider on the boundary.", + "matcher": pyzx.rules.match_pivot_boundary, + "rule": pyzx.rules.pivot, + "type": MATCHES_EDGES, + "copy_first": True + }, + "pivot_gadget": { + "text": "gadget pivot", + "tooltip": "Performs a pivot between a Pauli spider and a spider with an arbitrary phase, creating a phase gadget.", + "matcher": pyzx.rules.match_pivot_gadget, + "rule": pyzx.rules.pivot, + "type": MATCHES_EDGES, + "copy_first": True + }, + "phase_gadget_fuse": { + "text": "Fuse phase gadgets", + "tooltip": "Fuses two phase gadgets with the same connectivity.", + "matcher": pyzx.rules.match_phase_gadgets, + "rule": pyzx.rules.merge_phase_gadgets, + "type": MATCHES_VERTICES, + "copy_first": True + }, + "supplementarity": { + "text": "Supplementarity", + "tooltip": "Looks for a pair of internal spiders with the same connectivity and supplementary angles and removes them.", + "matcher": pyzx.rules.match_supplementarity, + "rule": pyzx.rules.apply_supplementarity, + "type": MATCHES_VERTICES, + "copy_first": False + }, } const_true = lambda graph, matches: matches @@ -102,12 +124,13 @@ def _extract_circuit(graph: GraphT, matches: list) -> GraphT: def ocm_rule(_graph: GraphT, _matches: list) -> pyzx.rules.RewriteOutputType[ET, VT]: return ({}, [], [], True) + ocm_action: RewriteData = { - "text": "OCM", - "tooltip": "Saves the graph with the current vertex positions", - "matcher": const_true, - "rule": ocm_rule, - "type": MATCHES_VERTICES, + "text": "OCM", + "tooltip": "Saves the graph with the current vertex positions", + "matcher": const_true, + "rule": ocm_rule, + "type": MATCHES_VERTICES, } simplifications: dict[str, RewriteData] = {