From d44348760b18d1200436789073e86e0b0907fc82 Mon Sep 17 00:00:00 2001 From: David Yonge-Mallo Date: Wed, 1 Nov 2023 19:10:12 +0100 Subject: [PATCH 01/37] Enable ZX Live to be run from inside of a notebook. --- .gitignore | 1 + embed_zxlive_demo.ipynb | 79 +++++++++++++++++++++++++++++++++++++++++ zxlive/app.py | 24 +++++++++++-- zxlive/mainwindow.py | 40 +++++++++++++++------ 4 files changed, 131 insertions(+), 13 deletions(-) create mode 100644 embed_zxlive_demo.ipynb diff --git a/.gitignore b/.gitignore index 2c093ca3..dfd41b2e 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ dist *.qasm .env *.zxr +.ipynb_checkpoints diff --git a/embed_zxlive_demo.ipynb b/embed_zxlive_demo.ipynb new file mode 100644 index 00000000..5ad2c3b4 --- /dev/null +++ b/embed_zxlive_demo.ipynb @@ -0,0 +1,79 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "dc0b153f-dc2f-447e-82de-6e71b9d389d2", + "metadata": {}, + "source": [ + "# Demo of embedded ZX Live running inside Jupyter Notebook\n", + "\n", + "First, run the cell below. An instance of ZX Live will open, with two identical graphs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e45e8da-f33b-4d13-8327-e394ae096ed9", + "metadata": {}, + "outputs": [], + "source": [ + "%gui qt6\n", + "from zxlive import app\n", + "\n", + "import pyzx as zx\n", + "\n", + "g1 = zx.Graph()\n", + "g1.add_vertex(zx.VertexType.Z, 0, 0)\n", + "g1.add_vertex(zx.VertexType.X, 0, 1)\n", + "g1.add_edge((0, 1))\n", + "g2 = g1.clone()\n", + "zx.draw(g1)\n", + "zx.draw(g2)\n", + "\n", + "zxl = app.get_embedded_app()\n", + "zxl.edit_graph(g1, 'g1')\n", + "zxl.edit_graph(g2, 'g2')" + ] + }, + { + "cell_type": "markdown", + "id": "88f425a0-d50d-40e0-86cf-eecbc3a27277", + "metadata": {}, + "source": [ + "After making some edits and saving them from within ZX Live, run the following cell to see the changes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a0a8a17-1795-4b13-aedf-30b346388e23", + "metadata": {}, + "outputs": [], + "source": [ + "zx.draw(g1)\n", + "zx.draw(g2)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/zxlive/app.py b/zxlive/app.py index b586c052..9a33ae0e 100644 --- a/zxlive/app.py +++ b/zxlive/app.py @@ -19,6 +19,8 @@ from PySide6.QtCore import QCommandLineParser import sys from .mainwindow import MainWindow +from .common import GraphT +from typing import Optional sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx @@ -29,6 +31,9 @@ class ZXLive(QApplication): ... """ + main_window: Optional[MainWindow] = None + is_embedded: bool = False + def __init__(self) -> None: super().__init__(sys.argv) self.setApplicationName('ZX Live') @@ -47,9 +52,24 @@ def __init__(self) -> None: for f in parser.positionalArguments(): self.main_window.open_file_from_path(f) + def edit_graph(self, g: GraphT, name: Optional[str] = "Embedded Graph") -> None: + """Opens a ZX Live window from within a notebook to edit a graph.""" + assert self.is_embedded + if not self.main_window: + self.main_window = MainWindow(True) + self.main_window.show() + self.main_window.new_graph(g, name) -def main() -> None: - """Main entry point for ZX Live""" +def get_embedded_app() -> ZXLive: + """Main entry point for ZX Live as an embedded app inside a jupyter notebook.""" + app = QApplication.instance() or ZXLive() + app.__class__ = ZXLive + app.is_embedded = True + return app + + +def main() -> None: + """Main entry point for ZX Live as a standalone app.""" zxl = ZXLive() zxl.exec_() diff --git a/zxlive/mainwindow.py b/zxlive/mainwindow.py index 17ec7940..aaa12d11 100644 --- a/zxlive/mainwindow.py +++ b/zxlive/mainwindow.py @@ -55,12 +55,14 @@ class MainWindow(QMainWindow): rewrite_form: QFormLayout left_graph: Optional[GraphT] right_graph: Optional[GraphT] + embedded_graph: Optional[GraphT] - def __init__(self) -> None: + def __init__(self, is_embedded: bool = False) -> None: super().__init__() self.settings = QSettings("zxlive", "zxlive") self.setWindowTitle("zxlive") + self.is_embedded = is_embedded w = QWidget(self) w.setLayout(QVBoxLayout()) @@ -89,15 +91,21 @@ def __init__(self) -> None: menu = self.menuBar() new_graph = self._new_action("&New", self.new_graph, QKeySequence.StandardKey.New, - "Create a new tab with an empty graph", alt_shortcut = QKeySequence.StandardKey.AddTab) + "Create a new tab with an empty graph", alt_shortcut=QKeySequence.StandardKey.AddTab) open_file = self._new_action("&Open...", self.open_file, QKeySequence.StandardKey.Open, "Open a file-picker dialog to choose a new diagram") self.close_action = self._new_action("Close", self.handle_close_action, QKeySequence.StandardKey.Close, - "Closes the window", alt_shortcut = QKeySequence("Ctrl+W")) - # TODO: We should remember if we have saved the diagram before, - # and give an open to overwrite this file with a Save action - self.save_file = self._new_action("&Save", self.handle_save_file_action, QKeySequence.StandardKey.Save, - "Save the diagram by overwriting the previous loaded file.") + "Closes the window", alt_shortcut=QKeySequence("Ctrl+W")) + if not self.is_embedded: + # TODO: We should remember if we have saved the diagram before, + # and give an option to overwrite this file with a Save action. + self.save_diagram = self._new_action("&Save", self.handle_save_file_action, + QKeySequence.StandardKey.Save, + "Save the diagram by overwriting the previous loaded file.") + else: + self.save_diagram = self._new_action("&Save to notebook", self.handle_save_embedded_graph_action, + QKeySequence.StandardKey.Save, + "Save the diagram back to the notebook.") self.save_as = self._new_action("Save &as...", self.handle_save_as_action, QKeySequence.StandardKey.SaveAs, "Opens a file-picker dialog to save the diagram in a chosen file format") @@ -106,7 +114,7 @@ def __init__(self) -> None: file_menu.addAction(open_file) file_menu.addSeparator() file_menu.addAction(self.close_action) - file_menu.addAction(self.save_file) + file_menu.addAction(self.save_diagram) file_menu.addAction(self.save_as) self.undo_action = self._new_action("Undo", self.undo, QKeySequence.StandardKey.Undo, @@ -170,11 +178,12 @@ def __init__(self) -> None: self.simplify_menu.addAction(action) self.simplify_menu.menuAction().setVisible(False) - graph = construct_circuit() - self.new_graph(graph) + if not self.is_embedded: + graph = construct_circuit() + self.new_graph(graph) def _reset_menus(self, has_active_tab: bool) -> None: - self.save_file.setEnabled(has_active_tab) + self.save_diagram.setEnabled(has_active_tab) self.save_as.setEnabled(has_active_tab) self.cut_action.setEnabled(has_active_tab) self.copy_action.setEnabled(has_active_tab) @@ -376,6 +385,12 @@ def handle_save_as_action(self) -> bool: self.tab_widget.setTabText(i,name) return True + def handle_save_embedded_graph_action(self) -> bool: + assert self.is_embedded and self.embedded_graph is not None + assert self.active_panel is not None and isinstance(self.active_panel, GraphEditPanel) + self.embedded_graph.__dict__.update(self.active_panel.graph.__dict__) + self.active_panel.undo_stack.setClean() + return False def cut_graph(self) -> None: assert self.active_panel is not None @@ -415,6 +430,9 @@ def new_graph(self, graph: Optional[GraphT] = None, name: Optional[str] = None) panel = GraphEditPanel(_graph, self.undo_action, self.redo_action) panel.start_derivation_signal.connect(self.new_deriv) if name is None: name = "New Graph" + if self.is_embedded: + assert graph is not None + self.embedded_graph = graph self._new_panel(panel, name) def new_rule_editor(self, rule: Optional[CustomRule] = None, name: Optional[str] = None) -> None: From f6ae727887ecc1031a2780f388e5c1f0eb1fbfee Mon Sep 17 00:00:00 2001 From: David Yonge-Mallo Date: Fri, 10 Nov 2023 13:09:46 +0100 Subject: [PATCH 02/37] Allow embedded ZXLive to pass a copy of a graph back to the notebook. --- ...e_demo.ipynb => embedded_zxlive_demo.ipynb | 26 ++++---- zxlive/app.py | 17 +++--- zxlive/base_panel.py | 4 ++ zxlive/mainwindow.py | 60 +++++++++---------- 4 files changed, 54 insertions(+), 53 deletions(-) rename embed_zxlive_demo.ipynb => embedded_zxlive_demo.ipynb (66%) diff --git a/embed_zxlive_demo.ipynb b/embedded_zxlive_demo.ipynb similarity index 66% rename from embed_zxlive_demo.ipynb rename to embedded_zxlive_demo.ipynb index 5ad2c3b4..510024cb 100644 --- a/embed_zxlive_demo.ipynb +++ b/embedded_zxlive_demo.ipynb @@ -5,9 +5,9 @@ "id": "dc0b153f-dc2f-447e-82de-6e71b9d389d2", "metadata": {}, "source": [ - "# Demo of embedded ZX Live running inside Jupyter Notebook\n", + "# Demo of embedded ZXLive running inside Jupyter Notebook\n", "\n", - "First, run the cell below. An instance of ZX Live will open, with two identical graphs." + "First, run the cell below. An instance of ZXLive will open, with two identical graphs." ] }, { @@ -22,17 +22,15 @@ "\n", "import pyzx as zx\n", "\n", - "g1 = zx.Graph()\n", - "g1.add_vertex(zx.VertexType.Z, 0, 0)\n", - "g1.add_vertex(zx.VertexType.X, 0, 1)\n", - "g1.add_edge((0, 1))\n", - "g2 = g1.clone()\n", - "zx.draw(g1)\n", - "zx.draw(g2)\n", + "g = zx.Graph()\n", + "g.add_vertex(zx.VertexType.Z, 0, 0)\n", + "g.add_vertex(zx.VertexType.X, 0, 1)\n", + "g.add_edge((0, 1))\n", + "zx.draw(g)\n", "\n", "zxl = app.get_embedded_app()\n", - "zxl.edit_graph(g1, 'g1')\n", - "zxl.edit_graph(g2, 'g2')" + "zxl.edit_graph(g, 'g1')\n", + "zxl.edit_graph(g, 'g2')" ] }, { @@ -40,7 +38,7 @@ "id": "88f425a0-d50d-40e0-86cf-eecbc3a27277", "metadata": {}, "source": [ - "After making some edits and saving them from within ZX Live, run the following cell to see the changes." + "After making some edits and saving them from within ZXLive, run the following cell to see the changes." ] }, { @@ -50,8 +48,8 @@ "metadata": {}, "outputs": [], "source": [ - "zx.draw(g1)\n", - "zx.draw(g2)" + "zx.draw(zxl.get_copy_of_graph('g1'))\n", + "zx.draw(zxl.get_copy_of_graph('g2'))" ] } ], diff --git a/zxlive/app.py b/zxlive/app.py index ecb360ef..5841ee1b 100644 --- a/zxlive/app.py +++ b/zxlive/app.py @@ -32,7 +32,6 @@ class ZXLive(QApplication): """ main_window: Optional[MainWindow] = None - is_embedded: bool = False def __init__(self) -> None: super().__init__(sys.argv) @@ -52,20 +51,22 @@ def __init__(self) -> None: for f in parser.positionalArguments(): self.main_window.open_file_from_path(f) - def edit_graph(self, g: GraphT, name: Optional[str] = "Embedded Graph") -> None: - """Opens a ZX Live window from within a notebook to edit a graph.""" - assert self.is_embedded + def edit_graph(self, g: GraphT, name: str) -> None: + """Opens a ZXLive window from within a notebook to edit a graph.""" if not self.main_window: - self.main_window = MainWindow(True) + self.main_window = MainWindow() self.main_window.show() - self.main_window.new_graph(g, name) + self.main_window.open_graph_from_notebook(g, name) + + def get_copy_of_graph(self, name: str) -> GraphT: + """Returns a copy of the graph which has the given name.""" + return self.main_window.get_copy_of_graph(name) def get_embedded_app() -> ZXLive: - """Main entry point for ZX Live as an embedded app inside a jupyter notebook.""" + """Main entry point for ZXLive as an embedded app inside a jupyter notebook.""" app = QApplication.instance() or ZXLive() app.__class__ = ZXLive - app.is_embedded = True return app diff --git a/zxlive/base_panel.py b/zxlive/base_panel.py index 1d882271..93035a7d 100644 --- a/zxlive/base_panel.py +++ b/zxlive/base_panel.py @@ -87,6 +87,10 @@ def clear_graph(self) -> None: cmd = SetGraph(self.graph_view, empty_graph) self.undo_stack.push(cmd) + def replace_graph(self, graph: GraphT) -> None: + cmd = SetGraph(self.graph_view, graph) + self.undo_stack.push(cmd) + def select_all(self) -> None: self.graph_scene.select_all() diff --git a/zxlive/mainwindow.py b/zxlive/mainwindow.py index 956e0610..3391c430 100644 --- a/zxlive/mainwindow.py +++ b/zxlive/mainwindow.py @@ -47,24 +47,19 @@ class MainWindow(QMainWindow): - """A simple window containing a single `GraphView` - This is just an example, and should be replaced with - something more sophisticated. - """ + """The main window of the ZXLive application.""" edit_panel: GraphEditPanel proof_panel: ProofPanel rewrite_form: QFormLayout left_graph: Optional[GraphT] right_graph: Optional[GraphT] - embedded_graph: Optional[GraphT] - def __init__(self, is_embedded: bool = False) -> None: + def __init__(self) -> None: super().__init__() self.settings = QSettings("zxlive", "zxlive") self.setWindowTitle("zxlive") - self.is_embedded = is_embedded w = QWidget(self) w.setLayout(QVBoxLayout()) @@ -98,16 +93,10 @@ def __init__(self, is_embedded: bool = False) -> None: "Open a file-picker dialog to choose a new diagram") self.close_action = self._new_action("Close", self.handle_close_action, QKeySequence.StandardKey.Close, "Closes the window", alt_shortcut=QKeySequence("Ctrl+W")) - if not self.is_embedded: - # TODO: We should remember if we have saved the diagram before, - # and give an option to overwrite this file with a Save action. - self.save_diagram = self._new_action("&Save", self.handle_save_file_action, - QKeySequence.StandardKey.Save, - "Save the diagram by overwriting the previous loaded file.") - else: - self.save_diagram = self._new_action("&Save to notebook", self.handle_save_embedded_graph_action, - QKeySequence.StandardKey.Save, - "Save the diagram back to the notebook.") + # TODO: We should remember if we have saved the diagram before, + # and give an option to overwrite this file with a Save action. + self.save_file = self._new_action("&Save", self.handle_save_file_action, QKeySequence.StandardKey.Save, + "Save the diagram by overwriting the previous loaded file.") self.save_as = self._new_action("Save &as...", self.handle_save_as_action, QKeySequence.StandardKey.SaveAs, "Opens a file-picker dialog to save the diagram in a chosen file format") @@ -116,7 +105,7 @@ def __init__(self, is_embedded: bool = False) -> None: file_menu.addAction(open_file) file_menu.addSeparator() file_menu.addAction(self.close_action) - file_menu.addAction(self.save_diagram) + file_menu.addAction(self.save_file) file_menu.addAction(self.save_as) self.undo_action = self._new_action("Undo", self.undo, QKeySequence.StandardKey.Undo, @@ -179,12 +168,11 @@ def __init__(self, is_embedded: bool = False) -> None: self._reset_menus(False) - if not self.is_embedded: - graph = construct_circuit() - self.new_graph(graph) + graph = construct_circuit() + self.new_graph(graph) def _reset_menus(self, has_active_tab: bool) -> None: - self.save_diagram.setEnabled(has_active_tab) + self.save_file.setEnabled(has_active_tab) self.save_as.setEnabled(has_active_tab) self.cut_action.setEnabled(has_active_tab) self.copy_action.setEnabled(has_active_tab) @@ -384,12 +372,6 @@ def handle_save_as_action(self) -> bool: self.tab_widget.setTabText(i,name) return True - def handle_save_embedded_graph_action(self) -> bool: - assert self.is_embedded and self.embedded_graph is not None - assert self.active_panel is not None and isinstance(self.active_panel, GraphEditPanel) - self.embedded_graph.__dict__.update(self.active_panel.graph.__dict__) - self.active_panel.undo_stack.setClean() - return False def cut_graph(self) -> None: assert self.active_panel is not None @@ -444,11 +426,27 @@ def new_graph(self, graph: Optional[GraphT] = None, name: Optional[str] = None) panel = GraphEditPanel(_graph, self.undo_action, self.redo_action) panel.start_derivation_signal.connect(self.new_deriv) if name is None: name = "New Graph" - if self.is_embedded: - assert graph is not None - self.embedded_graph = graph self._new_panel(panel, name) + def open_graph_from_notebook(self, graph: GraphT, name: str = None) -> None: + """Opens a ZXLive window from within a notebook to edit a graph. + + Replaces the graph in an existing tab if it has the same name.""" + # TODO: handle multiple tabs with the same name somehow + for i in range(self.tab_widget.count()): + if self.tab_widget.tabText(i) == name or self.tab_widget.tabText(i) == name + "*": + self.tab_widget.setCurrentIndex(i) + self.active_panel.replace_graph(graph) + return + self.new_graph(copy.deepcopy(graph), name) + + def get_copy_of_graph(self, name: str): + # TODO: handle multiple tabs with the same name somehow + for i in range(self.tab_widget.count()): + if self.tab_widget.tabText(i) == name or self.tab_widget.tabText(i) == name + "*": + return copy.deepcopy(self.tab_widget.widget(i).graph_scene.g) + return None + def new_rule_editor(self, rule: Optional[CustomRule] = None, name: Optional[str] = None) -> None: if rule is None: graph1 = GraphT() From 73da7b8ebbd3eaf193cc76af572c036a5f3bdec5 Mon Sep 17 00:00:00 2001 From: giodefelice Date: Mon, 13 Nov 2023 17:29:40 +0000 Subject: [PATCH 03/37] add button to refresh custom rules --- zxlive/proof_panel.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/zxlive/proof_panel.py b/zxlive/proof_panel.py index 08b436e1..d1bcb45e 100644 --- a/zxlive/proof_panel.py +++ b/zxlive/proof_panel.py @@ -98,11 +98,16 @@ def _toolbar_sections(self) -> Iterator[ToolbarSection]: self.identity_choice[1].setText("X") self.identity_choice[1].setCheckable(True) + self.refresh_rules = QToolButton(self) + self.refresh_rules.setText("Refresh rules") + self.refresh_rules.clicked.connect(self._refresh_rules) + yield ToolbarSection(*self.identity_choice, exclusive=True) yield ToolbarSection(*self.actions()) + yield ToolbarSection(self.refresh_rules) def init_action_groups(self) -> None: - self.action_groups = [group.copy() for group in proof_actions.action_groups] + # self.action_groups = [group.copy() for group in proof_actions.action_groups] custom_rules = [] for root, dirs, files in os.walk(get_custom_rules_path()): for file in files: @@ -327,6 +332,28 @@ def _proof_step_selected(self, selected: QItemSelection, deselected: QItemSelect cmd = GoToRewriteStep(self.graph_view, self.step_view, deselected.first().topLeft().row(), selected.first().topLeft().row()) self.undo_stack.push(cmd) + def _refresh_rules(self): + self.actions_bar.removeTab(self.actions_bar.count() - 1) + custom_rules = [] + for root, dirs, files in os.walk(get_custom_rules_path()): + for file in files: + if file.endswith(".zxr"): + zxr_file = os.path.join(root, file) + with open(zxr_file, "r") as f: + rule = CustomRule.from_json(f.read()).to_proof_action() + custom_rules.append(rule) + group = proof_actions.ProofActionGroup("Custom rules", *custom_rules).copy() + hlayout = QHBoxLayout() + group.init_buttons(self) + for action in group.actions: + assert action.button is not None + hlayout.addWidget(action.button) + hlayout.addStretch() + widget = QWidget() + widget.setLayout(hlayout) + widget.action_group = group + self.actions_bar.addTab(widget, group.name) + class ProofStepItemDelegate(QStyledItemDelegate): """This class controls the painting of items in the proof steps list view. From 138f5a1e2c88b1f8c69d83381e7a6d5aba3f94ad Mon Sep 17 00:00:00 2001 From: giodefelice Date: Mon, 13 Nov 2023 18:43:20 +0000 Subject: [PATCH 04/37] uncomment action_groups copying --- zxlive/proof_panel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zxlive/proof_panel.py b/zxlive/proof_panel.py index d1bcb45e..a78f3254 100644 --- a/zxlive/proof_panel.py +++ b/zxlive/proof_panel.py @@ -107,7 +107,7 @@ def _toolbar_sections(self) -> Iterator[ToolbarSection]: yield ToolbarSection(self.refresh_rules) def init_action_groups(self) -> None: - # self.action_groups = [group.copy() for group in proof_actions.action_groups] + self.action_groups = [group.copy() for group in proof_actions.action_groups] custom_rules = [] for root, dirs, files in os.walk(get_custom_rules_path()): for file in files: From d97cf96901b1ecea133ea305a368e791119a374b Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 13 Nov 2023 19:57:31 +0000 Subject: [PATCH 05/37] Add proof export to tikz --- zxlive/dialogs.py | 14 +++++--- zxlive/mainwindow.py | 28 ++++++++++++---- zxlive/settings_dialog.py | 67 +++++++++++++++++++++++++++++++++++++++ zxlive/tikz.py | 55 ++++++++++++++++++++++++++++++++ 4 files changed, 153 insertions(+), 11 deletions(-) create mode 100644 zxlive/tikz.py diff --git a/zxlive/dialogs.py b/zxlive/dialogs.py index 19c3de06..cd4efa96 100644 --- a/zxlive/dialogs.py +++ b/zxlive/dialogs.py @@ -191,7 +191,7 @@ def get_file_path_and_format(parent: QWidget, filter: str) -> Optional[tuple[str return file_path, selected_format -def export_diagram_dialog(graph: GraphT, parent: QWidget) -> Optional[tuple[str, FileFormat]]: +def save_diagram_dialog(graph: GraphT, parent: QWidget) -> Optional[tuple[str, FileFormat]]: file_path_and_format = get_file_path_and_format(parent, ";;".join([f.filter for f in FileFormat if f != FileFormat.ZXProof])) if file_path_and_format is None or not file_path_and_format[0]: return None @@ -215,7 +215,7 @@ def export_diagram_dialog(graph: GraphT, parent: QWidget) -> Optional[tuple[str, return file_path, selected_format -def export_proof_dialog(proof_model: ProofModel, parent: QWidget) -> Optional[tuple[str, FileFormat]]: +def safe_proof_dialog(proof_model: ProofModel, parent: QWidget) -> Optional[tuple[str, FileFormat]]: file_path_and_format = get_file_path_and_format(parent, FileFormat.ZXProof.filter) if file_path_and_format is None or not file_path_and_format[0]: return None @@ -225,7 +225,7 @@ def export_proof_dialog(proof_model: ProofModel, parent: QWidget) -> Optional[tu return None return file_path, selected_format -def export_rule_dialog(rule: CustomRule, parent: QWidget) -> Optional[tuple[str, FileFormat]]: +def safe_rule_dialog(rule: CustomRule, parent: QWidget) -> Optional[tuple[str, FileFormat]]: file_path_and_format = get_file_path_and_format(parent, FileFormat.ZXRule.filter) if file_path_and_format is None or not file_path_and_format[0]: return None @@ -235,6 +235,12 @@ def export_rule_dialog(rule: CustomRule, parent: QWidget) -> Optional[tuple[str, return None return file_path, selected_format +def export_proof_dialog(parent: QWidget) -> Optional[str]: + file_path_and_format = get_file_path_and_format(parent, FileFormat.TikZ.filter) + if file_path_and_format is None or not file_path_and_format[0]: + return None + return file_path_and_format[0] + def get_lemma_name_and_description(parent: MainWindow) -> tuple[Optional[str], Optional[str]]: dialog = QDialog() parent.rewrite_form = QFormLayout(dialog) @@ -283,7 +289,7 @@ def add_rewrite() -> None: return rule = CustomRule(parent.left_graph, parent.right_graph, name.text(), description.toPlainText()) check_rule(rule, show_error=True) - if export_rule_dialog(rule, parent): + if safe_rule_dialog(rule, parent): dialog.accept() button_box.accepted.connect(add_rewrite) button_box.rejected.connect(dialog.reject) diff --git a/zxlive/mainwindow.py b/zxlive/mainwindow.py index a55feb0c..7c5aa90f 100644 --- a/zxlive/mainwindow.py +++ b/zxlive/mainwindow.py @@ -36,15 +36,17 @@ from .custom_rule import CustomRule, check_rule from .dialogs import (FileFormat, ImportGraphOutput, ImportProofOutput, ImportRuleOutput, create_new_rewrite, - export_diagram_dialog, export_proof_dialog, - export_rule_dialog, get_lemma_name_and_description, - import_diagram_dialog, import_diagram_from_file, show_error_msg) + save_diagram_dialog, safe_proof_dialog, + safe_rule_dialog, get_lemma_name_and_description, + import_diagram_dialog, import_diagram_from_file, show_error_msg, + export_proof_dialog) from zxlive.settings_dialog import open_settings_dialog from .editor_base_panel import EditorBasePanel from .edit_panel import GraphEditPanel from .proof_panel import ProofPanel from .rule_panel import RulePanel +from .tikz import proof_to_tikz class MainWindow(QMainWindow): @@ -103,6 +105,8 @@ def __init__(self) -> None: "Save the diagram by overwriting the previous loaded file.") self.save_as = self._new_action("Save &as...", self.handle_save_as_action, QKeySequence.StandardKey.SaveAs, "Opens a file-picker dialog to save the diagram in a chosen file format") + self.export_tikz_proof = self._new_action("Export to tikz", self.handle_export_tikz_proof_action, None, + "Exports the proof to tikz") file_menu = menu.addMenu("&File") file_menu.addAction(new_graph) @@ -111,6 +115,7 @@ def __init__(self) -> None: file_menu.addAction(self.close_action) file_menu.addAction(self.save_file) file_menu.addAction(self.save_as) + file_menu.addAction(self.export_tikz_proof) self.undo_action = self._new_action("Undo", self.undo, QKeySequence.StandardKey.Undo, "Undoes the last action", "undo.svg") @@ -363,12 +368,12 @@ def handle_save_file_action(self) -> bool: def handle_save_as_action(self) -> bool: assert self.active_panel is not None if isinstance(self.active_panel, ProofPanel): - out = export_proof_dialog(self.active_panel.proof_model, self) + out = safe_proof_dialog(self.active_panel.proof_model, self) elif isinstance(self.active_panel, RulePanel): check_rule(self.active_panel.get_rule(), show_error=True) - out = export_rule_dialog(self.active_panel.get_rule(), self) + out = safe_rule_dialog(self.active_panel.get_rule(), self) else: - out = export_diagram_dialog(self.active_panel.graph_scene.g, self) + out = save_diagram_dialog(self.active_panel.graph_scene.g, self) if out is None: return False file_path, file_type = out self.active_panel.file_path = file_path @@ -379,6 +384,15 @@ def handle_save_as_action(self) -> bool: self.tab_widget.setTabText(i,name) return True + def handle_export_tikz_proof_action(self) -> bool: + assert isinstance(self.active_panel, ProofPanel) + path = export_proof_dialog(self) + if path is None: + show_error_msg("Export failed", "Invalid path") + return False + print(path) + with open(path, "w") as f: + f.write(proof_to_tikz(self.active_panel.proof_model)) def cut_graph(self) -> None: assert self.active_panel is not None @@ -518,7 +532,7 @@ def proof_as_lemma(self) -> None: lhs_graph = self.active_panel.proof_model.graphs[0] rhs_graph = self.active_panel.proof_model.graphs[-1] rule = CustomRule(lhs_graph, rhs_graph, name, description) - export_rule_dialog(rule, self) + safe_rule_dialog(rule, self) def update_colors(self) -> None: if self.active_panel is not None: diff --git a/zxlive/settings_dialog.py b/zxlive/settings_dialog.py index 20b7adc8..f45c23ab 100644 --- a/zxlive/settings_dialog.py +++ b/zxlive/settings_dialog.py @@ -59,6 +59,20 @@ "tikz/edge-import": ", ".join(pyzx.tikz.synonyms_edge), "tikz/edge-H-import": ", ".join(pyzx.tikz.synonyms_hedge), "tikz/edge-W-import": ", ".join(pyzx.tikz.synonyms_wedge), + + "tikz/layout/hspace": 2, + "tikz/layout/vspace": 2, + "tikz/layout/max-width": 10, + + "tikz/names/fuse spiders": "f", + "tikz/names/bialgebra": "b", + "tikz/names/change color to Z": "cc", + "tikz/names/change color to X": "cc", + "tikz/names/remove identity": "id", + "tikz/names/Add Z identity": "id", + "tikz/names/copy 0/pi spider": "cp", + "tikz/names/push Pauli": "pi", + "tikz/names/decompose hadamard": "eu", } color_schemes = { @@ -68,6 +82,14 @@ 'gidney': "Gidney's Black & White", } + +# Initialise settings +settings = QSettings("zxlive", "zxlive") +for key, value in defaults.items(): + if not settings.contains(key): + settings.setValue(key, value) + + class SettingsDialog(QDialog): def __init__(self, parent: MainWindow) -> None: super().__init__(parent) @@ -156,6 +178,51 @@ def __init__(self, parent: MainWindow) -> None: self.add_setting(form_import, "tikz/z-box-import", "Z box", 'str') self.add_setting(form_import, "tikz/edge-W-import", "W io edge", 'str') + ##### Tikz Layout settings ##### + panel_tikz_layout = QWidget() + vlayout = QVBoxLayout() + panel_tikz_layout.setLayout(vlayout) + tab_widget.addTab(panel_tikz_layout, "Tikz layout") + + vlayout.addWidget(QLabel("Tikz layout settings")) + + form_layout = QFormLayout() + w = QWidget() + w.setLayout(form_layout) + vlayout.addWidget(w) + vlayout.addStretch() + + self.add_setting(form_layout, "tikz/layout/hspace", "Horizontal spacing", "float") + self.add_setting(form_layout, "tikz/layout/vspace", "Vertical spacing", "float") + self.add_setting(form_layout, "tikz/layout/max-width", "Maximum width", 'float') + + + ##### Tikz rule name settings ##### + panel_tikz_names = QWidget() + vlayout = QVBoxLayout() + panel_tikz_names.setLayout(vlayout) + tab_widget.addTab(panel_tikz_names, "Tikz rule names") + + vlayout.addWidget(QLabel("Tikz rule name settings")) + vlayout.addWidget(QLabel("Mapping of pyzx rule names to tikz display strings")) + + form_names = QFormLayout() + w = QWidget() + w.setLayout(form_names) + vlayout.addWidget(w) + vlayout.addStretch() + + self.add_setting(form_names, "tikz/names/fuse spiders", "fuse spiders", "str") + self.add_setting(form_names, "tikz/names/bialgebra", "bialgebra", "str") + self.add_setting(form_names, "tikz/names/change color to Z", "change color to Z", "str") + self.add_setting(form_names, "tikz/names/change color to X", "change color to X", "str") + self.add_setting(form_names, "tikz/names/remove identity", "remove identity", "str") + self.add_setting(form_names, "tikz/names/Add Z identity", "add Z identity", "str") + self.add_setting(form_names, "tikz/names/copy 0/pi spider", "copy 0/pi spider", "str") + self.add_setting(form_names, "tikz/names/push Pauli", "push Pauli", "str") + self.add_setting(form_names, "tikz/names/decompose hadamard", "decompose hadamard", "str") + + ##### Okay/Cancel Buttons ##### w= QWidget() diff --git a/zxlive/tikz.py b/zxlive/tikz.py new file mode 100644 index 00000000..1b74391c --- /dev/null +++ b/zxlive/tikz.py @@ -0,0 +1,55 @@ +from typing import Union + +from PySide6.QtCore import QSettings +from pyzx.graph.graph_s import GraphS +from pyzx.tikz import TIKZ_BASE, _to_tikz + +from zxlive.proof import ProofModel + + + + + +def proof_to_tikz(proof: ProofModel) -> str: + settings = QSettings("zxlive", "zxlive") + vspace = settings.value("tikz/layout/vspace") + hspace = settings.value("tikz/layout/hspace") + max_width = settings.value("tikz/layout/max-width") + draw_scalar = False + + xoffset = -max_width + yoffset = -10 + idoffset = 0 + total_verts, total_edges = [], [] + for i, g in enumerate(proof.graphs): + # Compute graph dimensions + width = max(g.row(v) for v in g.vertices()) - min(g.row(v) for v in g.vertices()) + height = max(g.qubit(v) for v in g.vertices()) - min(g.qubit(v) for v in g.vertices()) + + # Translate graph so that the first vertex starts at 0 + min_x = min(g.row(v) for v in g.vertices()) + g = g.translate(-min_x, 0) + + if i > 0: + rewrite = proof.steps[i-1] + # Try to look up name in settings + name = settings.value(f"tikz/names/{rewrite.rule}") if settings.contains(f"tikz/names/{rewrite.rule}") else rewrite.rule + eq = f"\\node [style=none] ({idoffset}) at ({xoffset - hspace/2:.2f}, {-yoffset - height/2:.2f}) {{$\\overset{{\\mathit{{{name}}}}}{{=}}$}};" + total_verts.append(eq) + idoffset += 1 + + verts, edges = _to_tikz(g, draw_scalar, xoffset, yoffset, idoffset) + total_verts.extend(verts) + total_edges.extend(edges) + + if xoffset + hspace > max_width: + xoffset = -max_width + yoffset += height + vspace + else: + xoffset += width + hspace + + max_index = max(g.vertices()) + 2 * g.num_inputs() + 2 + idoffset += max_index + + return TIKZ_BASE.format(vertices="\n".join(total_verts), edges="\n".join(total_edges)) + From 66b6b4ef8405eecde2acca05bf5c54b7e8e8fb7f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 13 Nov 2023 20:24:23 +0000 Subject: [PATCH 06/37] Add circuit input dialog --- zxlive/dialogs.py | 7 ++++++- zxlive/edit_panel.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/zxlive/dialogs.py b/zxlive/dialogs.py index 19c3de06..63773a11 100644 --- a/zxlive/dialogs.py +++ b/zxlive/dialogs.py @@ -7,7 +7,7 @@ from PySide6.QtCore import QFile, QIODevice, QTextStream from PySide6.QtWidgets import (QDialog, QDialogButtonBox, QFileDialog, QFormLayout, QLineEdit, QMessageBox, - QPushButton, QTextEdit, QWidget) + QPushButton, QTextEdit, QWidget, QInputDialog) from pyzx import Circuit, extract_circuit from .common import GraphT @@ -102,6 +102,11 @@ def import_diagram_dialog(parent: QWidget) -> Optional[ImportGraphOutput | Impor return import_diagram_from_file(file_path, selected_filter) +def create_circuit_dialog(parent: QWidget) -> Optional[str]: + s, success = QInputDialog.getMultiLineText(parent, "Circuit input", "Write a circuit in QASM format", "qreg qs[3];\n") + return s if success else None + + def import_diagram_from_file(file_path: str, selected_filter: str = FileFormat.All.filter) -> \ Optional[ImportGraphOutput | ImportProofOutput | ImportRuleOutput]: """Imports a diagram from a given file path. diff --git a/zxlive/edit_panel.py b/zxlive/edit_panel.py index 4cb46930..70fbf5b7 100644 --- a/zxlive/edit_panel.py +++ b/zxlive/edit_panel.py @@ -6,11 +6,13 @@ from PySide6.QtCore import Signal from PySide6.QtGui import QAction from PySide6.QtWidgets import (QToolButton) -from pyzx import EdgeType, VertexType +from pyzx import EdgeType, VertexType, Circuit +from pyzx.circuit.qasmparser import QASMParser from .base_panel import ToolbarSection +from .commands import UpdateGraph from .common import GraphT -from .dialogs import show_error_msg +from .dialogs import show_error_msg, create_circuit_dialog from .editor_base_panel import EditorBasePanel from .graphscene import EditGraphScene from .graphview import GraphView @@ -48,11 +50,17 @@ def __init__(self, graph: GraphT, *actions: QAction) -> None: def _toolbar_sections(self) -> Iterator[ToolbarSection]: yield from super()._toolbar_sections() + self.input_circuit = QToolButton(self) + self.input_circuit.setText("Input Circuit") + self.input_circuit.clicked.connect(self._input_circuit) + yield ToolbarSection(self.input_circuit) + self.start_derivation = QToolButton(self) self.start_derivation.setText("Start Derivation") self.start_derivation.clicked.connect(self._start_derivation) yield ToolbarSection(self.start_derivation) + def _start_derivation(self) -> None: if not self.graph_scene.g.is_well_formed(): show_error_msg("Graph is not well-formed") @@ -63,3 +71,22 @@ def _start_derivation(self) -> None: if isinstance(phase, Poly): phase.freeze() self.start_derivation_signal.emit(new_g) + + def _input_circuit(self) -> None: + qasm = create_circuit_dialog(self) + if qasm is not None: + new_g = copy.deepcopy(self.graph_scene.g) + try: + circ = QASMParser().parse(qasm, strict=False).to_graph() + except TypeError as err: + show_error_msg("Invalid circuit", str(err)) + return + except Exception: + show_error_msg("Invalid circuit", "Couldn't parse QASM code") + return + + new_verts, new_edges = new_g.merge(circ) + cmd = UpdateGraph(self.graph_view, new_g) + self.undo_stack.push(cmd) + self.graph_scene.select_vertices(new_verts) + From 8f1a4d08817b8f41ffbd0bf8c142c115ec8de180 Mon Sep 17 00:00:00 2001 From: John van de Wetering Date: Mon, 13 Nov 2023 20:50:16 +0000 Subject: [PATCH 07/37] Use Poly class in PyZX --- zxlive/app.py | 4 +- zxlive/editor_base_panel.py | 2 +- zxlive/parse_poly.py | 58 ------------ zxlive/poly.py | 180 ------------------------------------ 4 files changed, 3 insertions(+), 241 deletions(-) delete mode 100644 zxlive/parse_poly.py delete mode 100644 zxlive/poly.py diff --git a/zxlive/app.py b/zxlive/app.py index 2b9fba83..590b8ab4 100644 --- a/zxlive/app.py +++ b/zxlive/app.py @@ -19,11 +19,11 @@ from PySide6.QtCore import QCommandLineParser from PySide6.QtGui import QIcon import sys +sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx + from .mainwindow import MainWindow from .common import get_data -#sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx - # The following hack is needed on windows in order to show the icon in the taskbar # See https://stackoverflow.com/questions/1551605/how-to-set-applications-taskbar-icon-in-windows-7/1552105#1552105 import os diff --git a/zxlive/editor_base_panel.py b/zxlive/editor_base_panel.py index 70d20882..6c78a50c 100644 --- a/zxlive/editor_base_panel.py +++ b/zxlive/editor_base_panel.py @@ -15,6 +15,7 @@ QSpacerItem, QSplitter, QToolButton, QWidget) from pyzx import EdgeType, VertexType from pyzx.utils import get_w_partner, vertex_is_w +from pyzx.symbolic import Poly from .base_panel import BasePanel, ToolbarSection @@ -26,7 +27,6 @@ from .eitem import HAD_EDGE_BLUE from .graphscene import EditGraphScene from .parse_poly import parse -from .poly import Poly, new_var from .vitem import BLACK diff --git a/zxlive/parse_poly.py b/zxlive/parse_poly.py deleted file mode 100644 index f70a33ab..00000000 --- a/zxlive/parse_poly.py +++ /dev/null @@ -1,58 +0,0 @@ -from .poly import Poly, new_const - -from typing import Any, Callable -from lark import Lark, Transformer -from functools import reduce -from operator import add, mul -from fractions import Fraction - -poly_grammar = Lark(""" - start : "(" start ")" | term ("+" term)* - term : (intf | frac)? factor ("*" factor)* - ?factor : intf | frac | pi | pifrac | var - var : CNAME - intf : INT - pi : "\\pi" | "pi" - frac : INT "/" INT - pifrac : [INT] pi "/" INT - - %import common.INT - %import common.CNAME - %import common.WS - %ignore WS - """, - parser='lalr', - maybe_placeholders=True) - -class PolyTransformer(Transformer): - def __init__(self, new_var: Callable[[str], Poly]): - super().__init__() - - self._new_var = new_var - - def start(self, items: list[Poly]) -> Poly: - return reduce(add, items) - - def term(self, items: list[Poly]) -> Poly: - return reduce(mul, items) - - def var(self, items: list[Any]) -> Poly: - v = str(items[0]) - return self._new_var(v) - - def pi(self, _: list[Any]) -> Poly: - return new_const(1) - - def intf(self, items: list[Any]) -> Poly: - return new_const(int(items[0])) - - def frac(self, items: list[Any]) -> Poly: - return new_const(Fraction(int(items[0]), int(items[1]))) - - def pifrac(self, items: list[Any]) -> Poly: - numerator = int(items[0]) if items[0] else 1 - return new_const(Fraction(numerator, int(items[2]))) - -def parse(expr: str, new_var: Callable[[str], Poly]) -> Poly: - tree = poly_grammar.parse(expr) - return PolyTransformer(new_var).transform(tree) diff --git a/zxlive/poly.py b/zxlive/poly.py deleted file mode 100644 index 0dfac718..00000000 --- a/zxlive/poly.py +++ /dev/null @@ -1,180 +0,0 @@ -from fractions import Fraction -from typing import Union, Optional - - -class Var: - name: str - _is_bool: bool - _types_dict: Optional[Union[bool, dict[str, bool]]] - - def __init__(self, name: str, data: Union[bool, dict[str, bool]]): - self.name = name - if isinstance(data, dict): - self._types_dict = data - self._frozen = False - self._is_bool = False - else: - self._types_dict = None - self._frozen = True - self._is_bool = data - - @property - def is_bool(self) -> bool: - if self._frozen: - return self._is_bool - else: - assert isinstance(self._types_dict, dict) - return self._types_dict[self.name] - - def __repr__(self) -> str: - return self.name - - def __lt__(self, other: 'Var') -> bool: - if int(self.is_bool) == int(other.is_bool): - return self.name < other.name - return int(self.is_bool) < int(other.is_bool) - - def __hash__(self) -> int: - # Variables with the same name map to the same type - # within the same graph, so no need to include is_bool - # in the hash. - return int(hash(self.name)) - - def __eq__(self, other: object) -> bool: - return self.__hash__() == other.__hash__() - - def freeze(self) -> None: - if not self._frozen: - assert isinstance(self._types_dict, dict) - self._is_bool = self._types_dict[self.name] - self._frozen = True - self._types_dict = None - - def __copy__(self) -> 'Var': - if self._frozen: - return Var(self.name, self.is_bool) - else: - assert isinstance(self._types_dict, dict) - return Var(self.name, self._types_dict) - - def __deepcopy__(self, _memo: object) -> 'Var': - return self.__copy__() - -class Term: - vars: list[tuple[Var, int]] - - def __init__(self, vars: list[tuple[Var,int]]) -> None: - self.vars = vars - - def freeze(self) -> None: - for var, _ in self.vars: - var.freeze() - - def free_vars(self) -> set[Var]: - return set(var for var, _ in self.vars) - - def __repr__(self) -> str: - vs = [] - for v, c in self.vars: - if c == 1: - vs.append(f'{v}') - else: - vs.append(f'{v}^{c}') - return '*'.join(vs) - - def __mul__(self, other: 'Term') -> 'Term': - vs = dict() - for v, c in self.vars + other.vars: - if v not in vs: vs[v] = c - else: vs[v] += c - # TODO deal with fractional / symbolic powers - if v.is_bool and c > 1: - vs[v] = 1 - return Term([(v, c) for v, c in vs.items()]) - - def __hash__(self) -> int: - return hash(tuple(sorted(self.vars))) - - def __eq__(self, other: object) -> bool: - return self.__hash__() == other.__hash__() - - -class Poly: - terms: list[tuple[Union[int, float, Fraction], Term]] - - def __init__(self, terms: list[tuple[Union[int, float, Fraction], Term]]) -> None: - self.terms = terms - - def freeze(self) -> None: - for _, term in self.terms: - term.freeze() - - def free_vars(self) -> set[Var]: - output = set() - for _, term in self.terms: - output.update(term.free_vars()) - return output - - def __add__(self, other: 'Poly') -> 'Poly': - if isinstance(other, (int, float, Fraction)): - other = Poly([(other, Term([]))]) - counter = dict() - for c, t in self.terms + other.terms: - if t not in counter: counter[t] = c - else: counter[t] += c - if all(tt[0].is_bool for tt in t.vars): - counter[t] = counter[t] % 2 - - # remove terms with coefficient 0 - for t in list(counter.keys()): - if counter[t] == 0: - del counter[t] - return Poly([(c, t) for t, c in counter.items()]) - - __radd__ = __add__ - - def __mul__(self, other: 'Poly') -> 'Poly': - if isinstance(other, (int, float)): - other = Poly([(other, Term([]))]) - p = Poly([]) - for c1, t1 in self.terms: - for c2, t2 in other.terms: - p += Poly([(c1 * c2, t1 * t2)]) - return p - - __rmul__ = __mul__ - - def __repr__(self) -> str: - ts = [] - for c, t in self.terms: - if t == Term([]): - ts.append(f'{c}') - elif c == 1: - ts.append(f'{t}') - else: - ts.append(f'{c}{t}') - return ' + '.join(ts) - - def __eq__(self, other: object) -> bool: - if isinstance(other, (int, float, Fraction)): - if other == 0: - other = Poly([]) - else: - other = Poly([(other, Term([]))]) - assert isinstance(other, Poly) - return set(self.terms) == set(other.terms) - - @property - def is_pauli(self) -> bool: - for c, t in self.terms: - if not all(v.is_bool for v, _ in t.vars): - return False - if c % 1 != 0: - return False - return True - -def new_var(name: str, types_dict: Union[bool, dict[str, bool]]) -> Poly: - return Poly([(1, Term([(Var(name, types_dict), 1)]))]) - -def new_const(coeff: Union[int, Fraction]) -> Poly: - return Poly([(coeff, Term([]))]) From 1fa641cabca97fdc32bf4d174964b841a0235931 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 13 Nov 2023 21:30:03 +0000 Subject: [PATCH 08/37] Mypy compliance --- zxlive/animations.py | 3 ++- zxlive/app.py | 2 +- zxlive/base_panel.py | 3 ++- zxlive/commands.py | 5 +++-- zxlive/common.py | 2 +- zxlive/custom_rule.py | 4 ++-- zxlive/editor_base_panel.py | 12 ++++++------ zxlive/eitem.py | 1 + zxlive/graphscene.py | 3 ++- zxlive/graphview.py | 6 ++++-- zxlive/mainwindow.py | 4 ++-- zxlive/parse_poly.py | 2 +- zxlive/proof.py | 9 ++++++--- zxlive/proof_actions.py | 2 +- zxlive/proof_panel.py | 14 ++++++++------ zxlive/settings_dialog.py | 27 ++++++++++++++------------- zxlive/vitem.py | 1 + 17 files changed, 57 insertions(+), 43 deletions(-) diff --git a/zxlive/animations.py b/zxlive/animations.py index f901a79c..6c3b776d 100644 --- a/zxlive/animations.py +++ b/zxlive/animations.py @@ -183,7 +183,8 @@ def fuse(dragged: VItem, target: VItem, meet_halfway: bool = False) -> QAbstract if not meet_halfway: group.addAnimation(move(dragged, target=target.pos(), duration=100, ease=QEasingCurve(QEasingCurve.Type.OutQuad))) else: - halfway_pos = (dragged.pos() + target.pos()) / 2 + sum_pos = dragged.pos() + target.pos() + halfway_pos = QPointF(sum_pos.x() / 2, sum_pos.y() / 2) group.addAnimation(move(dragged, target=halfway_pos, duration=100, ease=QEasingCurve(QEasingCurve.Type.OutQuad))) group.addAnimation(move(target, target=halfway_pos, duration=100, ease=QEasingCurve(QEasingCurve.Type.OutQuad))) group.addAnimation(scale(target, target=1, duration=100, ease=QEasingCurve(QEasingCurve.Type.InBack))) diff --git a/zxlive/app.py b/zxlive/app.py index 2b9fba83..e525f4fc 100644 --- a/zxlive/app.py +++ b/zxlive/app.py @@ -30,7 +30,7 @@ if os.name == 'nt': import ctypes myappid = 'quantomatic.zxlive.zxlive.1.0.0' # arbitrary string - ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) + ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(myappid) # type: ignore class ZXLive(QApplication): diff --git a/zxlive/base_panel.py b/zxlive/base_panel.py index f08c3044..e971b0ed 100644 --- a/zxlive/base_panel.py +++ b/zxlive/base_panel.py @@ -96,7 +96,8 @@ def deselect_all(self) -> None: def copy_selection(self) -> GraphT: selection = list(self.graph_scene.selected_vertices) copied_graph = self.graph.subgraph_from_vertices(selection) - assert isinstance(copied_graph, GraphT) + # Mypy issue: https://github.com/python/mypy/issues/11673 + assert isinstance(copied_graph, GraphT) # type: ignore return copied_graph def update_colors(self) -> None: diff --git a/zxlive/commands.py b/zxlive/commands.py index 0303e9ea..880142f3 100644 --- a/zxlive/commands.py +++ b/zxlive/commands.py @@ -15,6 +15,7 @@ from .common import ET, VT, W_INPUT_OFFSET, GraphT from .graphview import GraphView +from .poly import Poly from .proof import ProofModel, Rewrite @@ -304,9 +305,9 @@ def redo(self) -> None: class ChangePhase(BaseCommand): """Updates the phase of a spider.""" v: VT - new_phase: Union[Fraction, int] + new_phase: Union[Fraction, Poly, complex] - _old_phase: Optional[Union[Fraction, int]] = field(default=None, init=False) + _old_phase: Optional[Union[Fraction, Poly, complex]] = field(default=None, init=False) def undo(self) -> None: assert self._old_phase is not None diff --git a/zxlive/common.py b/zxlive/common.py index 6b858e88..a5d7ac91 100644 --- a/zxlive/common.py +++ b/zxlive/common.py @@ -158,7 +158,7 @@ def _get_synonyms(key: str, default: list[str]) -> list[str]: def to_tikz(g: GraphT) -> str: - return pyzx.tikz.to_tikz(g) + return pyzx.tikz.to_tikz(g) # type: ignore def from_tikz(s: str) -> GraphT: try: diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 9d3e5a23..e3dee9f0 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -93,8 +93,8 @@ def from_json(cls, json_str: str) -> "CustomRule": d = json.loads(json_str) lhs_graph = GraphT.from_json(d['lhs_graph']) rhs_graph = GraphT.from_json(d['rhs_graph']) - assert (isinstance(lhs_graph, GraphT) and - isinstance(rhs_graph, GraphT)) + # Mypy issue: https://github.com/python/mypy/issues/11673 + assert (isinstance(lhs_graph, GraphT) and isinstance(rhs_graph, GraphT)) # type: ignore return cls(lhs_graph, rhs_graph, d['name'], d['description']) def to_proof_action(self) -> "ProofAction": diff --git a/zxlive/editor_base_panel.py b/zxlive/editor_base_panel.py index 70d20882..9230491d 100644 --- a/zxlive/editor_base_panel.py +++ b/zxlive/editor_base_panel.py @@ -39,7 +39,7 @@ class ShapeType(Enum): class DrawPanelNodeType(TypedDict): text: str - icon: tuple[ShapeType, str] + icon: tuple[ShapeType, QColor] def vertices_data() -> dict[VertexType.Type, DrawPanelNodeType]: @@ -54,8 +54,8 @@ def vertices_data() -> dict[VertexType.Type, DrawPanelNodeType]: def edges_data() -> dict[EdgeType.Type, DrawPanelNodeType]: return { - EdgeType.SIMPLE: {"text": "Simple", "icon": (ShapeType.LINE, BLACK)}, - EdgeType.HADAMARD: {"text": "Hadamard", "icon": (ShapeType.DASHED_LINE, HAD_EDGE_BLUE)}, + EdgeType.SIMPLE: {"text": "Simple", "icon": (ShapeType.LINE, QColor(BLACK))}, + EdgeType.HADAMARD: {"text": "Hadamard", "icon": (ShapeType.DASHED_LINE, QColor(HAD_EDGE_BLUE))}, } @@ -211,7 +211,7 @@ def __init__(self, variable_types: dict[str, bool]) -> None: self._variable_types = variable_types self._widget = QWidget() - lpal = QApplication.palette("QListWidget") + lpal = QApplication.palette("QListWidget") # type: ignore palette = QPalette() palette.setBrush(QPalette.ColorRole.Window, lpal.base()) self._widget.setAutoFillBackground(True) @@ -354,14 +354,14 @@ def populate_list_widget(list_widget: QListWidget, list_widget.setCurrentRow(row) -def create_icon(shape: ShapeType, color: str) -> QIcon: +def create_icon(shape: ShapeType, color: QColor) -> QIcon: icon = QIcon() pixmap = QPixmap(64, 64) pixmap.fill(Qt.GlobalColor.transparent) painter = QPainter(pixmap) painter.setRenderHint(QPainter.RenderHint.Antialiasing) painter.setPen(QPen(QColor(BLACK), 6)) - painter.setBrush(QColor(color)) + painter.setBrush(color) if shape == ShapeType.CIRCLE: painter.drawEllipse(4, 4, 56, 56) elif shape == ShapeType.SQUARE: diff --git a/zxlive/eitem.py b/zxlive/eitem.py index 20d85ee7..d6b6a03e 100644 --- a/zxlive/eitem.py +++ b/zxlive/eitem.py @@ -93,6 +93,7 @@ def paint(self, painter: QPainter, option: QStyleOptionGraphicsItem, widget: Opt # By default, Qt draws a dashed rectangle around selected items. # We have our own implementation to draw selected vertices, so # we intercept the selected option here. + assert hasattr(option, "state") option.state &= ~QStyle.StateFlag.State_Selected super().paint(painter, option, widget) diff --git a/zxlive/graphscene.py b/zxlive/graphscene.py index 2fda49fc..e14b8250 100644 --- a/zxlive/graphscene.py +++ b/zxlive/graphscene.py @@ -121,7 +121,8 @@ def update_graph(self, new: GraphT, select_new: bool = False) -> None: self.removeItem(e_item) new_g = diff.apply_diff(self.g) - assert isinstance(new_g, GraphT) + # Mypy issue: https://github.com/python/mypy/issues/11673 + assert isinstance(new_g, GraphT) # type: ignore self.g = new_g # g now contains the new graph, # but we still need to update the scene diff --git a/zxlive/graphview.py b/zxlive/graphview.py index b9d13cb1..b4415956 100644 --- a/zxlive/graphview.py +++ b/zxlive/graphview.py @@ -119,7 +119,7 @@ def mousePressEvent(self, e: QMouseEvent) -> None: elif self.tool == GraphTool.MagicWand: pos = self.mapToScene(e.pos()) shift = e.modifiers() & Qt.KeyboardModifier.ShiftModifier - self.wand_trace = WandTrace(pos, shift) + self.wand_trace = WandTrace(pos, bool(shift)) self.wand_path = QGraphicsPathItem() self.graph_scene.addItem(self.wand_path) pen = QPen(QColor(WAND_COLOR), WAND_WIDTH) @@ -302,6 +302,7 @@ def mousePressEvent(self, e: QMouseEvent) -> None: class Sparkles(QObject): + def __init__(self, graph_scene: GraphScene) -> None: super().__init__() self.graph_scene = graph_scene @@ -314,7 +315,7 @@ def __init__(self, graph_scene: GraphScene) -> None: vx = speed * math.cos(angle) / SPARKLE_STEPS vy = speed * math.sin(angle) / SPARKLE_STEPS self.sparkle_deltas.append((vx, vy)) - self.timer_id = None + self.timer_id: Optional[int] = None def emit_sparkles(self, pos: QPointF, mult: int) -> None: if not self.timer_id: @@ -335,6 +336,7 @@ def timerEvent(self, event: QTimerEvent) -> None: sparkle.timer_step() def stop(self) -> None: + assert self.timer_id is not None self.killTimer(self.timer_id) self.timer_id = None for sparkle in reversed(self.sparkles): diff --git a/zxlive/mainwindow.py b/zxlive/mainwindow.py index a55feb0c..6f363a81 100644 --- a/zxlive/mainwindow.py +++ b/zxlive/mainwindow.py @@ -214,7 +214,7 @@ def _new_action(self, name: str, trigger: Callable, shortcut: QKeySequence | QKe if not action.shortcuts(): action.setShortcut(alt_shortcut) elif alt_shortcut not in action.shortcuts(): - action.setShortcuts([shortcut, alt_shortcut]) + action.setShortcuts([shortcut, alt_shortcut]) # type: ignore return action @property @@ -501,7 +501,7 @@ def format_str(c: complex) -> str: for i in range(matrix.shape[0]): for j in range(matrix.shape[1]): entry = QTableWidgetItem(format_str(matrix[i, j])) - entry.setFlags(entry.flags() & ~Qt.ItemIsEditable) + entry.setFlags(entry.flags() & ~Qt.ItemFlag.ItemIsEditable) table.setItem(i, j, entry) table.resizeColumnsToContents() table.resizeRowsToContents() diff --git a/zxlive/parse_poly.py b/zxlive/parse_poly.py index f70a33ab..42e9da97 100644 --- a/zxlive/parse_poly.py +++ b/zxlive/parse_poly.py @@ -24,7 +24,7 @@ parser='lalr', maybe_placeholders=True) -class PolyTransformer(Transformer): +class PolyTransformer(Transformer[Poly]): def __init__(self, new_var: Callable[[str], Poly]): super().__init__() diff --git a/zxlive/proof.py b/zxlive/proof.py index 93b74634..5649aa8d 100644 --- a/zxlive/proof.py +++ b/zxlive/proof.py @@ -113,7 +113,8 @@ def pop_rewrite(self) -> tuple[Rewrite, GraphT]: def get_graph(self, index: int) -> GraphT: """Returns the grap at a given position in the proof.""" copy = self.graphs[index].copy() - assert isinstance(copy, GraphT) + # Mypy issue: https://github.com/python/mypy/issues/11673 + assert isinstance(copy, GraphT) # type: ignore return copy def to_json(self) -> str: @@ -132,11 +133,13 @@ def from_json(json_str: str) -> "ProofModel": """Deserializes the model from JSON.""" d = json.loads(json_str) initial_graph = GraphT.from_tikz(d["initial_graph"]) - assert isinstance(initial_graph, GraphT) + # Mypy issue: https://github.com/python/mypy/issues/11673 + assert isinstance(initial_graph, GraphT) # type: ignore model = ProofModel(initial_graph) for step in d["proof_steps"]: rewrite = Rewrite.from_json(step) rewritten_graph = rewrite.diff.apply_diff(model.graphs[-1]) - assert isinstance(rewritten_graph, GraphT) + # Mypy issue: https://github.com/python/mypy/issues/11673 + assert isinstance(rewritten_graph, GraphT) # type: ignore model.add_rewrite(rewrite, rewritten_graph) return model diff --git a/zxlive/proof_actions.py b/zxlive/proof_actions.py index 004479dd..46df31bc 100644 --- a/zxlive/proof_actions.py +++ b/zxlive/proof_actions.py @@ -208,7 +208,7 @@ def rule(g: GraphT, matches: list) -> pyzx.rules.RewriteOutputType[ET,VT]: return ({}, [], [], True) return rule -def _extract_circuit(graph, matches): +def _extract_circuit(graph: GraphT, matches: list) -> GraphT: graph.auto_detect_io() simplify.full_reduce(graph) return extract_circuit(graph).to_graph() diff --git a/zxlive/proof_panel.py b/zxlive/proof_panel.py index a518ac8f..266c1c22 100644 --- a/zxlive/proof_panel.py +++ b/zxlive/proof_panel.py @@ -51,7 +51,7 @@ def __init__(self, graph: GraphT, *actions: QAction) -> None: self.graph_view.set_graph(graph) self.actions_bar = QTabWidget(self) - self.layout().insertWidget(1, self.actions_bar) + self.layout().insertWidget(1, self.actions_bar) # type: ignore self.init_action_groups() self.actions_bar.currentChanged.connect(self.update_on_selection) @@ -131,7 +131,7 @@ def init_action_groups(self) -> None: widget = QWidget() widget.setLayout(hlayout) - widget.action_group = group + setattr(widget, "action_group", group) self.actions_bar.addTab(widget, group.name) def parse_selection(self) -> tuple[list[VT], list[ET]]: @@ -148,7 +148,8 @@ def parse_selection(self) -> tuple[list[VT], list[ET]]: def update_on_selection(self) -> None: selection, edges = self.parse_selection() g = self.graph_scene.g - self.actions_bar.currentWidget().action_group.update_active(g, selection, edges) + action_group = getattr(self.actions_bar.currentWidget(), "action_group") + action_group.update_active(g, selection, edges) def _vert_moved(self, vs: list[tuple[VT, float, float]]) -> None: cmd = MoveNodeInStep(self.graph_view, vs, self.step_view) @@ -244,7 +245,7 @@ def cross(a: QPointF, b: QPointF) -> float: if not ok: return False try: - def new_var(_): + def new_var(_: str) -> Poly: raise ValueError() phase = string_to_complex(text) if phase_is_complex else string_to_fraction(text, new_var) except ValueError: @@ -367,7 +368,7 @@ def _proof_step_selected(self, selected: QItemSelection, deselected: QItemSelect cmd = GoToRewriteStep(self.graph_view, self.step_view, deselected.first().topLeft().row(), selected.first().topLeft().row()) self.undo_stack.push(cmd) - def _refresh_rules(self): + def _refresh_rules(self) -> None: self.actions_bar.removeTab(self.actions_bar.count() - 1) custom_rules = [] for root, dirs, files in os.walk(get_custom_rules_path()): @@ -386,7 +387,7 @@ def _refresh_rules(self): hlayout.addStretch() widget = QWidget() widget.setLayout(hlayout) - widget.action_group = group + setattr(widget, "action_group", group) self.actions_bar.addTab(widget, group.name) @@ -406,6 +407,7 @@ class ProofStepItemDelegate(QStyledItemDelegate): def paint(self, painter: QPainter, option: QStyleOptionViewItem, index: Union[QModelIndex, QPersistentModelIndex]) -> None: painter.save() + assert hasattr(option, "state") and hasattr(option, "rect") and hasattr(option, "font") # Draw background painter.setPen(Qt.GlobalColor.transparent) diff --git a/zxlive/settings_dialog.py b/zxlive/settings_dialog.py index 20b7adc8..28cd980f 100644 --- a/zxlive/settings_dialog.py +++ b/zxlive/settings_dialog.py @@ -69,9 +69,9 @@ } class SettingsDialog(QDialog): - def __init__(self, parent: MainWindow) -> None: - super().__init__(parent) - self.parent = parent + def __init__(self, main_window: MainWindow) -> None: + super().__init__(main_window) + self.main_window = main_window self.setWindowTitle("Settings") self.settings = QSettings("zxlive", "zxlive") self.value_dict: Dict[str,QWidget] = {} @@ -173,6 +173,7 @@ def __init__(self, parent: MainWindow) -> None: def add_setting(self,form:QFormLayout, name:str, label:str, ty:str, data:Any=None) -> None: val = self.settings.value(name) + widget: QWidget if val is None: val = defaults[name] if ty == 'str': widget = QLineEdit() @@ -180,12 +181,10 @@ def add_setting(self,form:QFormLayout, name:str, label:str, ty:str, data:Any=Non widget.setText(val) elif ty == 'int': widget = QSpinBox() - val = int(val) - widget.setValue(val) + widget.setValue(int(val)) # type: ignore elif ty == 'float': widget = QDoubleSpinBox() - val = float(val) - widget.setValue(val) + widget.setValue(float(val)) # type: ignore elif ty == 'folder': widget = QWidget() hlayout = QHBoxLayout() @@ -194,10 +193,10 @@ def add_setting(self,form:QFormLayout, name:str, label:str, ty:str, data:Any=Non val = str(val) widget_line.setText(val) def browse() -> None: - directory = QFileDialog.getExistingDirectory(self,"Pick folder",options=QFileDialog.ShowDirsOnly) + directory = QFileDialog.getExistingDirectory(self,"Pick folder",options=QFileDialog.Option.ShowDirsOnly) if directory: widget_line.setText(directory) - widget.text_value = directory + setattr(widget, "text_value", directory) hlayout.addWidget(widget_line) button = QPushButton("Browse") button.clicked.connect(browse) @@ -206,9 +205,9 @@ def browse() -> None: widget = QComboBox() val = str(val) assert isinstance(data, dict) - widget.addItems(data.values()) + widget.addItems(list(data.values())) widget.setCurrentText(data[val]) - widget.data = data + setattr(widget, "data", data) form.addRow(label, widget) @@ -231,8 +230,10 @@ def okay(self) -> None: self.settings.setValue(name, widget.text_value) set_pyzx_tikz_settings() if self.settings.value("color-scheme") != self.prev_color_scheme: - colors.set_color_scheme(self.settings.value("color-scheme")) - self.parent.update_colors() + theme = self.settings.value("color-scheme") + assert isinstance(theme, str) + colors.set_color_scheme(theme) + self.main_window.update_colors() self.accept() def cancel(self) -> None: diff --git a/zxlive/vitem.py b/zxlive/vitem.py index 9df61b5b..02c89f66 100644 --- a/zxlive/vitem.py +++ b/zxlive/vitem.py @@ -222,6 +222,7 @@ def paint(self, painter: QPainter, option: QStyleOptionGraphicsItem, widget: Opt # By default, Qt draws a dashed rectangle around selected items. # We have our own implementation to draw selected vertices, so # we intercept the selected option here. + assert hasattr(option, "state") option.state &= ~QStyle.StateFlag.State_Selected super().paint(painter, option, widget) From a9ff770179baf61d019a08ad133649759141c3ef Mon Sep 17 00:00:00 2001 From: John van de Wetering Date: Mon, 13 Nov 2023 21:31:50 +0000 Subject: [PATCH 09/37] Fixed bugs with saving parameters --- zxlive/edit_panel.py | 2 +- zxlive/editor_base_panel.py | 71 ++++++++++++++++--------------------- zxlive/rule_panel.py | 2 +- 3 files changed, 33 insertions(+), 42 deletions(-) diff --git a/zxlive/edit_panel.py b/zxlive/edit_panel.py index 4cb46930..78e7d32a 100644 --- a/zxlive/edit_panel.py +++ b/zxlive/edit_panel.py @@ -7,6 +7,7 @@ from PySide6.QtGui import QAction from PySide6.QtWidgets import (QToolButton) from pyzx import EdgeType, VertexType +from pyzx.symbolic import Poly from .base_panel import ToolbarSection from .common import GraphT @@ -14,7 +15,6 @@ from .editor_base_panel import EditorBasePanel from .graphscene import EditGraphScene from .graphview import GraphView -from .poly import Poly class GraphEditPanel(EditorBasePanel): diff --git a/zxlive/editor_base_panel.py b/zxlive/editor_base_panel.py index 6c78a50c..57386bf7 100644 --- a/zxlive/editor_base_panel.py +++ b/zxlive/editor_base_panel.py @@ -15,9 +15,9 @@ QSpacerItem, QSplitter, QToolButton, QWidget) from pyzx import EdgeType, VertexType from pyzx.utils import get_w_partner, vertex_is_w +from pyzx.graph.jsonparser import string_to_phase from pyzx.symbolic import Poly - from .base_panel import BasePanel, ToolbarSection from .commands import (AddEdge, AddNode, AddWNode, ChangeEdgeColor, ChangeNodeType, ChangePhase, MoveNode, SetGraph, @@ -26,7 +26,6 @@ from .dialogs import show_error_msg from .eitem import HAD_EDGE_BLUE from .graphscene import EditGraphScene -from .parse_poly import parse from .vitem import BLACK @@ -98,16 +97,8 @@ def update_colors(self) -> None: super().update_colors() self.update_side_bar() - def update_variable_viewer(self) -> None: - self.update_side_bar() - def _populate_variables(self) -> None: - self.variable_types = {} - for vert in self.graph.vertices(): - phase = self.graph.phase(vert) - if isinstance(phase, Poly): - for var in phase.free_vars(): - self.variable_types[var.name] = var.is_bool + self.variable_types = self.graph.variable_types.copy() def _tool_clicked(self, tool: ToolType) -> None: self.graph_scene.curr_tool = tool @@ -189,18 +180,18 @@ def vert_double_clicked(self, v: VT) -> None: if not ok: return None try: - new_phase = string_to_complex(input_) if phase_is_complex else string_to_fraction(input_, self._new_var) + new_phase = string_to_complex(input_) if phase_is_complex else string_to_phase(input_, graph) except ValueError: show_error_msg("Invalid Input", error_msg) return None cmd = ChangePhase(self.graph_view, v, new_phase) self.undo_stack.push(cmd) - - def _new_var(self, name: str) -> Poly: - if name not in self.variable_types: - self.variable_types[name] = False - self.variable_viewer.add_item(name) - return new_var(name, self.variable_types) + # For some reason it is important we first push to the stack before we do the following. + if len(graph.variable_types) != len(self.variable_types): + new_vars = graph.variable_types.keys() - self.variable_types.keys() + self.variable_types.update(graph.variable_types) + for v in new_vars: + self.variable_viewer.add_item(v) class VariableViewer(QScrollArea): @@ -378,28 +369,28 @@ def create_icon(shape: ShapeType, color: str) -> QIcon: return icon -def string_to_fraction(string: str, new_var_: Callable[[str], Poly]) -> Union[Fraction, Poly]: - if not string: - return Fraction(0) - try: - s = string.lower().replace(' ', '') - s = re.sub('\\*?(pi|\u04c0)\\*?', '', s) - if '.' in s or 'e' in s: - return Fraction(float(s)) - elif '/' in s: - a, b = s.split("/", 2) - if not a: - return Fraction(1, int(b)) - if a == '-': - a = '-1' - return Fraction(int(a), int(b)) - else: - return Fraction(int(s)) - except ValueError: - try: - return parse(string, new_var_) - except Exception as e: - raise ValueError(e) +#def string_to_fraction(string: str, new_var_: Callable[[str], Poly]) -> Union[Fraction, Poly]: +# if not string: +# return Fraction(0) +# try: +# s = string.lower().replace(' ', '') +# s = re.sub('\\*?(pi|\u04c0)\\*?', '', s) +# if '.' in s or 'e' in s: +# return Fraction(float(s)) +# elif '/' in s: +# a, b = s.split("/", 2) +# if not a: +# return Fraction(1, int(b)) +# if a == '-': +# a = '-1' +# return Fraction(int(a), int(b)) +# else: +# return Fraction(int(s)) +# except ValueError: +# try: +# return parse(string, new_var_) +# except Exception as e: +# raise ValueError(e) def string_to_complex(string: str) -> complex: diff --git a/zxlive/rule_panel.py b/zxlive/rule_panel.py index 61f4c567..179fca7e 100644 --- a/zxlive/rule_panel.py +++ b/zxlive/rule_panel.py @@ -6,6 +6,7 @@ from PySide6.QtGui import QAction from PySide6.QtWidgets import QLineEdit from pyzx import EdgeType, VertexType +from pyzx.symbolic import Poly from .base_panel import ToolbarSection @@ -14,7 +15,6 @@ from .editor_base_panel import EditorBasePanel from .graphscene import EditGraphScene from .graphview import RuleEditGraphView -from .poly import Poly class RulePanel(EditorBasePanel): From 9d87e77a419b9fb2d27d201d01d67f51cb43bd60 Mon Sep 17 00:00:00 2001 From: Tuomas Laakkonen Date: Mon, 13 Nov 2023 21:35:11 +0000 Subject: [PATCH 10/37] allow unfusing W nodes --- zxlive/proof_panel.py | 54 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/zxlive/proof_panel.py b/zxlive/proof_panel.py index 5c3975be..b8b3d97c 100644 --- a/zxlive/proof_panel.py +++ b/zxlive/proof_panel.py @@ -15,7 +15,7 @@ QStyleOptionViewItem, QToolButton, QWidget, QVBoxLayout, QTabWidget, QInputDialog) from pyzx import VertexType, basicrules -from pyzx.utils import get_z_box_label, set_z_box_label +from pyzx.utils import get_z_box_label, set_z_box_label, get_w_partner, EdgeType from . import animations as anims from . import proof_actions @@ -29,7 +29,7 @@ from .graphscene import GraphScene from .graphview import GraphTool, GraphView, WandTrace from .proof import ProofModel -from .vitem import DragState, VItem +from .vitem import DragState, VItem, get_w_partner_vitem, W_INPUT_OFFSET, SCALE from .editor_base_panel import string_to_complex, string_to_fraction from .poly import Poly @@ -220,14 +220,14 @@ def cross(a: QPointF, b: QPointF) -> float: return False item = filtered[0] vertex = item.v - if self.graph.type(vertex) not in (VertexType.Z, VertexType.X, VertexType.Z_BOX): + if self.graph.type(vertex) not in (VertexType.Z, VertexType.X, VertexType.Z_BOX, VertexType.W_OUTPUT): return False if not trace.shift and basicrules.check_remove_id(self.graph, vertex): self._remove_id(vertex) return True - if trace.shift: + if trace.shift and self.graph.type(vertex) != VertexType.W_OUTPUT: phase_is_complex = (self.graph.type(vertex) == VertexType.Z_BOX) if phase_is_complex: prompt = "Enter desired phase value (complex value):" @@ -245,7 +245,7 @@ def new_var(_): except ValueError: show_error_msg("Invalid Input", error_msg) return False - else: + elif self.graph.type(vertex) != VertexType.W_OUTPUT: if self.graph.type(vertex) == VertexType.Z_BOX: phase = get_z_box_label(self.graph, vertex) else: @@ -268,7 +268,11 @@ def new_var(_): else: right.append(neighbor) mouse_dir = ((start + end) * (1/2)) - pos - self._unfuse(vertex, left, mouse_dir, phase) + + if self.graph.type(vertex) == VertexType.W_OUTPUT: + self._unfuse_w(vertex, left, mouse_dir) + else: + self._unfuse(vertex, left, mouse_dir, phase) return True def _remove_id(self, v: VT) -> None: @@ -278,6 +282,44 @@ def _remove_id(self, v: VT) -> None: cmd = AddRewriteStep(self.graph_view, new_g, self.step_view, "id") self.undo_stack.push(cmd, anim_before=anim) + def _unfuse_w(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF) -> None: + new_g = copy.deepcopy(self.graph) + + vi = get_w_partner(self.graph, v) + par_dir = QVector2D( + self.graph.row(v) - self.graph.row(vi), + self.graph.qubit(v) - self.graph.qubit(vi) + ).normalized() + + perp_dir = QVector2D(mouse_dir - QPointF(self.graph.row(v)/SCALE, self.graph.qubit(v)/SCALE)).normalized() + perp_dir -= QVector2D.dotProduct(perp_dir, par_dir) * par_dir + perp_dir.normalize() + + out_offset_x = par_dir.x() * 0.5 + perp_dir.x() * 0.5 + out_offset_y = par_dir.y() * 0.5 + perp_dir.y() * 0.5 + + in_offset_x = out_offset_x - par_dir.x()*W_INPUT_OFFSET + in_offset_y = out_offset_y - par_dir.y()*W_INPUT_OFFSET + + left_vert = new_g.add_vertex(VertexType.W_OUTPUT, + qubit=self.graph.qubit(v) + out_offset_y, + row=self.graph.row(v) + out_offset_x) + left_vert_i = new_g.add_vertex(VertexType.W_INPUT, + qubit=self.graph.qubit(v) + in_offset_y, + row=self.graph.row(v) + in_offset_x) + new_g.add_edge((left_vert_i, left_vert), EdgeType.W_IO) + new_g.add_edge((v, left_vert_i)) + new_g.set_row(v, self.graph.row(v)) + new_g.set_qubit(v, self.graph.qubit(v)) + for neighbor in left_neighbours: + new_g.add_edge((neighbor, left_vert), + self.graph.edge_type((v, neighbor))) + new_g.remove_edge((v, neighbor)) + + anim = anims.unfuse(self.graph, new_g, v, self.graph_scene) + cmd = AddRewriteStep(self.graph_view, new_g, self.step_view, "unfuse") + self.undo_stack.push(cmd, anim_after=anim) + def _unfuse(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF, phase: Poly | complex | Fraction) -> None: def snap_vector(v: QVector2D) -> None: if abs(v.x()) > abs(v.y()): From 52ce7b4610c646302007b9d4d5140d7be620d9e1 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 13 Nov 2023 21:37:24 +0000 Subject: [PATCH 11/37] Cleanup --- zxlive/mainwindow.py | 1 - zxlive/tikz.py | 6 ------ 2 files changed, 7 deletions(-) diff --git a/zxlive/mainwindow.py b/zxlive/mainwindow.py index 7c5aa90f..32727f41 100644 --- a/zxlive/mainwindow.py +++ b/zxlive/mainwindow.py @@ -390,7 +390,6 @@ def handle_export_tikz_proof_action(self) -> bool: if path is None: show_error_msg("Export failed", "Invalid path") return False - print(path) with open(path, "w") as f: f.write(proof_to_tikz(self.active_panel.proof_model)) diff --git a/zxlive/tikz.py b/zxlive/tikz.py index 1b74391c..2dc12fc0 100644 --- a/zxlive/tikz.py +++ b/zxlive/tikz.py @@ -1,15 +1,9 @@ -from typing import Union - from PySide6.QtCore import QSettings -from pyzx.graph.graph_s import GraphS from pyzx.tikz import TIKZ_BASE, _to_tikz from zxlive.proof import ProofModel - - - def proof_to_tikz(proof: ProofModel) -> str: settings = QSettings("zxlive", "zxlive") vspace = settings.value("tikz/layout/vspace") From cfe78b422f9d56fd7875fa51e60810cc856a7daf Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 13 Nov 2023 21:43:47 +0000 Subject: [PATCH 12/37] Fix mypy --- zxlive/mainwindow.py | 1 + zxlive/tikz.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/zxlive/mainwindow.py b/zxlive/mainwindow.py index dceabe45..4180ec6b 100644 --- a/zxlive/mainwindow.py +++ b/zxlive/mainwindow.py @@ -392,6 +392,7 @@ def handle_export_tikz_proof_action(self) -> bool: return False with open(path, "w") as f: f.write(proof_to_tikz(self.active_panel.proof_model)) + return True def cut_graph(self) -> None: assert self.active_panel is not None diff --git a/zxlive/tikz.py b/zxlive/tikz.py index 2dc12fc0..237c1b4e 100644 --- a/zxlive/tikz.py +++ b/zxlive/tikz.py @@ -11,6 +11,8 @@ def proof_to_tikz(proof: ProofModel) -> str: max_width = settings.value("tikz/layout/max-width") draw_scalar = False + assert isinstance(vspace, float) and isinstance(hspace, float) and isinstance(max_width, float) + xoffset = -max_width yoffset = -10 idoffset = 0 From f62718d6594d85f31fb3efa7d42a76cde385ab5c Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:25:36 +0100 Subject: [PATCH 13/37] only check matrix for non-symbolic rules --- zxlive/custom_rule.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 9d3e5a23..0f76eaaf 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -154,13 +154,14 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: from .dialogs import show_error_msg show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different numbers of inputs or outputs.") return False - left_matrix, right_matrix = rule.lhs_graph.to_matrix(), rule.rhs_graph.to_matrix() - if not np.allclose(left_matrix, right_matrix): - if show_error: - from .dialogs import show_error_msg - if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)): - show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.") - else: - show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") - return False + if not rule.lhs_graph.variable_types and not rule.rhs_graph.variable_types: + left_matrix, right_matrix = rule.lhs_graph.to_matrix(), rule.rhs_graph.to_matrix() + if not np.allclose(left_matrix, right_matrix): + if show_error: + from .dialogs import show_error_msg + if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)): + show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.") + else: + show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") + return False return True From 057179997b593ef36740eaaaf265748b694742df Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:26:39 +0100 Subject: [PATCH 14/37] rewrite rule matching with symbolic parameters --- zxlive/custom_rule.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 0f76eaaf..144a23c9 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -11,6 +11,8 @@ from pyzx.utils import EdgeType, VertexType from shapely import Polygon +from pyzx.symbolic import Poly + from .common import ET, VT, GraphT if TYPE_CHECKING: @@ -75,8 +77,10 @@ def matcher(self, graph: GraphT, in_selection: Callable[[VT], bool]) -> list[VT] vertices = [v for v in graph.vertices() if in_selection(v)] subgraph_nx, _ = create_subgraph(graph, vertices) graph_matcher = GraphMatcher(self.lhs_graph_nx, subgraph_nx, - node_match=categorical_node_match(['type', 'phase'], default=[1, 0])) - if graph_matcher.is_isomorphic(): + node_match=categorical_node_match('type', 1)) + matchings = list(graph_matcher.match()) + matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx) + if len(matchings) > 0: return vertices return [] @@ -102,6 +106,32 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) +def match_symbolic_parameters(match, left, right): + params = {} + left_phase = left.nodes.data('phase', default=0) + right_phase = right.nodes.data('phase', default=0) + for v in left.nodes(): + if isinstance(left_phase[v], Poly): + if str(left_phase[v]) in params: + if params[str(left_phase)] != right_phase[match[v]]: + raise ValueError("Symbolic parameters do not match") + else: + params[str(left_phase[v])] = right_phase[match[v]] + elif left_phase[v] != right_phase[match[v]]: + raise ValueError("Parameters do not match") + return params + +def filter_matchings_if_symbolic_compatible(matchings, left, right): + new_matchings = [] + for matching in matchings: + try: + match_symbolic_parameters(matching, left, right) + new_matchings.append(matching) + except ValueError: + pass + return new_matchings + + def to_networkx(graph: GraphT) -> nx.Graph: G = nx.Graph() v_data = {v: {"type": graph.type(v), From 745d55e21491cde50f0718e8d8ba67be3352976d Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:28:14 +0100 Subject: [PATCH 15/37] applying custom rule with symbolic parameters --- zxlive/custom_rule.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 144a23c9..8e33da6c 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -33,8 +33,11 @@ def __init__(self, lhs_graph: GraphT, rhs_graph: GraphT, name: str, description: def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutputType[ET,VT]: subgraph_nx, boundary_mapping = create_subgraph(graph, vertices) graph_matcher = GraphMatcher(self.lhs_graph_nx, subgraph_nx, - node_match=categorical_node_match(['type', 'phase'], default=[1, 0])) - matching = list(graph_matcher.match())[0] + node_match=categorical_node_match('type', 1)) + matchings = graph_matcher.match() + matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx) + matching = matchings[0] + symbolic_params_map = match_symbolic_parameters(matching, self.lhs_graph_nx, subgraph_nx) vertices_to_remove = [] for v in matching: @@ -55,10 +58,15 @@ def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutpu vertex_map = boundary_vertex_map for v in self.rhs_graph_nx.nodes(): if self.rhs_graph_nx.nodes()[v]['type'] != VertexType.BOUNDARY: + phase = self.rhs_graph_nx.nodes()[v]['phase'] + if isinstance(phase, Poly): + phase = phase.substitute(symbolic_params_map) + if phase.free_vars() == set(): + phase = phase.terms[0][0] vertex_map[v] = graph.add_vertex(ty = self.rhs_graph_nx.nodes()[v]['type'], row = vertex_positions[v][0], qubit = vertex_positions[v][1], - phase = self.rhs_graph_nx.nodes()[v]['phase'],) + phase = phase,) # create etab to add edges etab = {} From 6517efcd438f747d17f0793c3bba4df7bdce801e Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:39:35 +0100 Subject: [PATCH 16/37] typo in match_symbolic_parameters --- zxlive/custom_rule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 8e33da6c..1f810132 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -121,7 +121,7 @@ def match_symbolic_parameters(match, left, right): for v in left.nodes(): if isinstance(left_phase[v], Poly): if str(left_phase[v]) in params: - if params[str(left_phase)] != right_phase[match[v]]: + if params[str(left_phase[v])] != right_phase[match[v]]: raise ValueError("Symbolic parameters do not match") else: params[str(left_phase[v])] = right_phase[match[v]] From 58ca1cc95e239712cbae3ef699025ca79664cec8 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 01:33:39 +0100 Subject: [PATCH 17/37] small bug fix --- zxlive/custom_rule.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 1f810132..82ed0e80 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -132,6 +132,8 @@ def match_symbolic_parameters(match, left, right): def filter_matchings_if_symbolic_compatible(matchings, left, right): new_matchings = [] for matching in matchings: + if len(matching) != len(left): + continue try: match_symbolic_parameters(matching, left, right) new_matchings.append(matching) From 726ad180078d69477a0fdcfded477ce0a485ab7f Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 01:34:02 +0100 Subject: [PATCH 18/37] get var method for symbolic parameters --- zxlive/custom_rule.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 82ed0e80..c838a40b 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -114,17 +114,26 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) +def get_var(v): + if not isinstance(v, Poly): + raise ValueError("Not a symbolic parameter") + if len(v.terms) != 1: + raise ValueError("Only single-term symbolic parameters are supported") + if len(v.terms[0][1].vars) != 1: + raise ValueError("Only single-variable symbolic parameters are supported") + return v.terms[0][1].vars[0][0] + def match_symbolic_parameters(match, left, right): params = {} left_phase = left.nodes.data('phase', default=0) right_phase = right.nodes.data('phase', default=0) for v in left.nodes(): if isinstance(left_phase[v], Poly): - if str(left_phase[v]) in params: - if params[str(left_phase[v])] != right_phase[match[v]]: + if get_var(left_phase[v]) in params: + if params[get_var(left_phase[v])] != right_phase[match[v]]: raise ValueError("Symbolic parameters do not match") else: - params[str(left_phase[v])] = right_phase[match[v]] + params[get_var(left_phase[v])] = right_phase[match[v]] elif left_phase[v] != right_phase[match[v]]: raise ValueError("Parameters do not match") return params From 3cf64733dbc31437e04940db47eb247c7146d034 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 17:35:43 +0100 Subject: [PATCH 19/37] add warnings for custom rules with symbolic parameters --- zxlive/custom_rule.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index c838a40b..a86fd0c9 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -213,4 +213,17 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: else: show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") return False + else: + if not (rule.rhs_graph.variable_types.items() <= rule.lhs_graph.variable_types.items()): + if show_error: + from .dialogs import show_error_msg + show_error_msg("Warning!", "The right-hand side has more free variables than the left-hand side.") + return False + for vertex in rule.lhs_graph.vertices(): + if isinstance(rule.lhs_graph.phase(vertex), Poly): + if len(rule.lhs_graph.phase(vertex).free_vars()) > 1: + if show_error: + from .dialogs import show_error_msg + show_error_msg("Warning!", "Only one symbolic parameter per vertex is supported on the left-hand side.") + return False return True From 223712562fc47d46ea138356ea6e3c8606200735 Mon Sep 17 00:00:00 2001 From: Tuomas Laakkonen Date: Tue, 14 Nov 2023 16:56:57 +0000 Subject: [PATCH 20/37] tikz layout settings should be floats by default --- zxlive/settings_dialog.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zxlive/settings_dialog.py b/zxlive/settings_dialog.py index 964f5374..41b570d2 100644 --- a/zxlive/settings_dialog.py +++ b/zxlive/settings_dialog.py @@ -60,9 +60,9 @@ "tikz/edge-H-import": ", ".join(pyzx.tikz.synonyms_hedge), "tikz/edge-W-import": ", ".join(pyzx.tikz.synonyms_wedge), - "tikz/layout/hspace": 2, - "tikz/layout/vspace": 2, - "tikz/layout/max-width": 10, + "tikz/layout/hspace": 2.0, + "tikz/layout/vspace": 2.0, + "tikz/layout/max-width": 10.0, "tikz/names/fuse spiders": "f", "tikz/names/bialgebra": "b", From 04169436da874d8ce71d9211f70047d29e29a4e9 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 18:47:17 +0100 Subject: [PATCH 21/37] symbolic rewrites support linear terms Co-authored-by: Tuomas Laakkonen --- zxlive/custom_rule.py | 64 +++++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index a86fd0c9..c3c81abf 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -114,30 +114,62 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) -def get_var(v): +def get_linear(v): if not isinstance(v, Poly): raise ValueError("Not a symbolic parameter") - if len(v.terms) != 1: - raise ValueError("Only single-term symbolic parameters are supported") - if len(v.terms[0][1].vars) != 1: - raise ValueError("Only single-variable symbolic parameters are supported") - return v.terms[0][1].vars[0][0] + if len(v.terms) > 2 or len(v.free_vars()) > 1: + raise ValueError("Only linear symbolic parameters are supported") + if len(v.terms) == 0: + return 1, None, 0 + elif len(v.terms) == 1: + if len(v.terms[0][1].vars) > 0: + var_term = v.terms[0] + const = 0 + else: + const = v.terms[0][0] + return 1, None, const + else: + if len(v.terms[0][1].vars) > 0: + var_term = v.terms[0] + const = v.terms[1][0] + else: + var_term = v.terms[1] + const = v.terms[0][0] + coeff = var_term[0] + var, power = var_term[1].vars[0] + if power != 1: + raise ValueError("Only linear symbolic parameters are supported") + return coeff, var, const + def match_symbolic_parameters(match, left, right): params = {} left_phase = left.nodes.data('phase', default=0) right_phase = right.nodes.data('phase', default=0) + + def check_phase_equality(v): + if left_phase[v] != right_phase[match[v]]: + raise ValueError("Parameters do not match") + + def update_params(v, var, coeff, const): + var_value = (right_phase[match[v]] - const) / coeff + if var in params and params[var] != var_value: + raise ValueError("Symbolic parameters do not match") + params[var] = var_value + for v in left.nodes(): if isinstance(left_phase[v], Poly): - if get_var(left_phase[v]) in params: - if params[get_var(left_phase[v])] != right_phase[match[v]]: - raise ValueError("Symbolic parameters do not match") - else: - params[get_var(left_phase[v])] = right_phase[match[v]] - elif left_phase[v] != right_phase[match[v]]: - raise ValueError("Parameters do not match") + coeff, var, const = get_linear(left_phase[v]) + if var is None: + check_phase_equality(v) + continue + update_params(v, var, coeff, const) + else: + check_phase_equality(v) + return params + def filter_matchings_if_symbolic_compatible(matchings, left, right): new_matchings = [] for matching in matchings: @@ -221,9 +253,11 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: return False for vertex in rule.lhs_graph.vertices(): if isinstance(rule.lhs_graph.phase(vertex), Poly): - if len(rule.lhs_graph.phase(vertex).free_vars()) > 1: + try: + get_linear(rule.lhs_graph.phase(vertex)) + except ValueError as e: if show_error: from .dialogs import show_error_msg - show_error_msg("Warning!", "Only one symbolic parameter per vertex is supported on the left-hand side.") + show_error_msg("Warning!", str(e)) return False return True From 8a1ab003f5fd781fd93a36fbfbab845cc06fa6ff Mon Sep 17 00:00:00 2001 From: John van de Wetering Date: Mon, 13 Nov 2023 20:50:16 +0000 Subject: [PATCH 22/37] Use Poly class in PyZX --- zxlive/app.py | 4 +- zxlive/editor_base_panel.py | 2 +- zxlive/parse_poly.py | 58 ------------ zxlive/poly.py | 180 ------------------------------------ 4 files changed, 3 insertions(+), 241 deletions(-) delete mode 100644 zxlive/parse_poly.py delete mode 100644 zxlive/poly.py diff --git a/zxlive/app.py b/zxlive/app.py index e525f4fc..56a2991c 100644 --- a/zxlive/app.py +++ b/zxlive/app.py @@ -19,11 +19,11 @@ from PySide6.QtCore import QCommandLineParser from PySide6.QtGui import QIcon import sys +sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx + from .mainwindow import MainWindow from .common import get_data -#sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx - # The following hack is needed on windows in order to show the icon in the taskbar # See https://stackoverflow.com/questions/1551605/how-to-set-applications-taskbar-icon-in-windows-7/1552105#1552105 import os diff --git a/zxlive/editor_base_panel.py b/zxlive/editor_base_panel.py index 9230491d..79cfeccd 100644 --- a/zxlive/editor_base_panel.py +++ b/zxlive/editor_base_panel.py @@ -15,6 +15,7 @@ QSpacerItem, QSplitter, QToolButton, QWidget) from pyzx import EdgeType, VertexType from pyzx.utils import get_w_partner, vertex_is_w +from pyzx.symbolic import Poly from .base_panel import BasePanel, ToolbarSection @@ -26,7 +27,6 @@ from .eitem import HAD_EDGE_BLUE from .graphscene import EditGraphScene from .parse_poly import parse -from .poly import Poly, new_var from .vitem import BLACK diff --git a/zxlive/parse_poly.py b/zxlive/parse_poly.py deleted file mode 100644 index 42e9da97..00000000 --- a/zxlive/parse_poly.py +++ /dev/null @@ -1,58 +0,0 @@ -from .poly import Poly, new_const - -from typing import Any, Callable -from lark import Lark, Transformer -from functools import reduce -from operator import add, mul -from fractions import Fraction - -poly_grammar = Lark(""" - start : "(" start ")" | term ("+" term)* - term : (intf | frac)? factor ("*" factor)* - ?factor : intf | frac | pi | pifrac | var - var : CNAME - intf : INT - pi : "\\pi" | "pi" - frac : INT "/" INT - pifrac : [INT] pi "/" INT - - %import common.INT - %import common.CNAME - %import common.WS - %ignore WS - """, - parser='lalr', - maybe_placeholders=True) - -class PolyTransformer(Transformer[Poly]): - def __init__(self, new_var: Callable[[str], Poly]): - super().__init__() - - self._new_var = new_var - - def start(self, items: list[Poly]) -> Poly: - return reduce(add, items) - - def term(self, items: list[Poly]) -> Poly: - return reduce(mul, items) - - def var(self, items: list[Any]) -> Poly: - v = str(items[0]) - return self._new_var(v) - - def pi(self, _: list[Any]) -> Poly: - return new_const(1) - - def intf(self, items: list[Any]) -> Poly: - return new_const(int(items[0])) - - def frac(self, items: list[Any]) -> Poly: - return new_const(Fraction(int(items[0]), int(items[1]))) - - def pifrac(self, items: list[Any]) -> Poly: - numerator = int(items[0]) if items[0] else 1 - return new_const(Fraction(numerator, int(items[2]))) - -def parse(expr: str, new_var: Callable[[str], Poly]) -> Poly: - tree = poly_grammar.parse(expr) - return PolyTransformer(new_var).transform(tree) diff --git a/zxlive/poly.py b/zxlive/poly.py deleted file mode 100644 index 0dfac718..00000000 --- a/zxlive/poly.py +++ /dev/null @@ -1,180 +0,0 @@ -from fractions import Fraction -from typing import Union, Optional - - -class Var: - name: str - _is_bool: bool - _types_dict: Optional[Union[bool, dict[str, bool]]] - - def __init__(self, name: str, data: Union[bool, dict[str, bool]]): - self.name = name - if isinstance(data, dict): - self._types_dict = data - self._frozen = False - self._is_bool = False - else: - self._types_dict = None - self._frozen = True - self._is_bool = data - - @property - def is_bool(self) -> bool: - if self._frozen: - return self._is_bool - else: - assert isinstance(self._types_dict, dict) - return self._types_dict[self.name] - - def __repr__(self) -> str: - return self.name - - def __lt__(self, other: 'Var') -> bool: - if int(self.is_bool) == int(other.is_bool): - return self.name < other.name - return int(self.is_bool) < int(other.is_bool) - - def __hash__(self) -> int: - # Variables with the same name map to the same type - # within the same graph, so no need to include is_bool - # in the hash. - return int(hash(self.name)) - - def __eq__(self, other: object) -> bool: - return self.__hash__() == other.__hash__() - - def freeze(self) -> None: - if not self._frozen: - assert isinstance(self._types_dict, dict) - self._is_bool = self._types_dict[self.name] - self._frozen = True - self._types_dict = None - - def __copy__(self) -> 'Var': - if self._frozen: - return Var(self.name, self.is_bool) - else: - assert isinstance(self._types_dict, dict) - return Var(self.name, self._types_dict) - - def __deepcopy__(self, _memo: object) -> 'Var': - return self.__copy__() - -class Term: - vars: list[tuple[Var, int]] - - def __init__(self, vars: list[tuple[Var,int]]) -> None: - self.vars = vars - - def freeze(self) -> None: - for var, _ in self.vars: - var.freeze() - - def free_vars(self) -> set[Var]: - return set(var for var, _ in self.vars) - - def __repr__(self) -> str: - vs = [] - for v, c in self.vars: - if c == 1: - vs.append(f'{v}') - else: - vs.append(f'{v}^{c}') - return '*'.join(vs) - - def __mul__(self, other: 'Term') -> 'Term': - vs = dict() - for v, c in self.vars + other.vars: - if v not in vs: vs[v] = c - else: vs[v] += c - # TODO deal with fractional / symbolic powers - if v.is_bool and c > 1: - vs[v] = 1 - return Term([(v, c) for v, c in vs.items()]) - - def __hash__(self) -> int: - return hash(tuple(sorted(self.vars))) - - def __eq__(self, other: object) -> bool: - return self.__hash__() == other.__hash__() - - -class Poly: - terms: list[tuple[Union[int, float, Fraction], Term]] - - def __init__(self, terms: list[tuple[Union[int, float, Fraction], Term]]) -> None: - self.terms = terms - - def freeze(self) -> None: - for _, term in self.terms: - term.freeze() - - def free_vars(self) -> set[Var]: - output = set() - for _, term in self.terms: - output.update(term.free_vars()) - return output - - def __add__(self, other: 'Poly') -> 'Poly': - if isinstance(other, (int, float, Fraction)): - other = Poly([(other, Term([]))]) - counter = dict() - for c, t in self.terms + other.terms: - if t not in counter: counter[t] = c - else: counter[t] += c - if all(tt[0].is_bool for tt in t.vars): - counter[t] = counter[t] % 2 - - # remove terms with coefficient 0 - for t in list(counter.keys()): - if counter[t] == 0: - del counter[t] - return Poly([(c, t) for t, c in counter.items()]) - - __radd__ = __add__ - - def __mul__(self, other: 'Poly') -> 'Poly': - if isinstance(other, (int, float)): - other = Poly([(other, Term([]))]) - p = Poly([]) - for c1, t1 in self.terms: - for c2, t2 in other.terms: - p += Poly([(c1 * c2, t1 * t2)]) - return p - - __rmul__ = __mul__ - - def __repr__(self) -> str: - ts = [] - for c, t in self.terms: - if t == Term([]): - ts.append(f'{c}') - elif c == 1: - ts.append(f'{t}') - else: - ts.append(f'{c}{t}') - return ' + '.join(ts) - - def __eq__(self, other: object) -> bool: - if isinstance(other, (int, float, Fraction)): - if other == 0: - other = Poly([]) - else: - other = Poly([(other, Term([]))]) - assert isinstance(other, Poly) - return set(self.terms) == set(other.terms) - - @property - def is_pauli(self) -> bool: - for c, t in self.terms: - if not all(v.is_bool for v, _ in t.vars): - return False - if c % 1 != 0: - return False - return True - -def new_var(name: str, types_dict: Union[bool, dict[str, bool]]) -> Poly: - return Poly([(1, Term([(Var(name, types_dict), 1)]))]) - -def new_const(coeff: Union[int, Fraction]) -> Poly: - return Poly([(coeff, Term([]))]) From dfaa9f62eb833354e6e882be92dc22bd67e8b116 Mon Sep 17 00:00:00 2001 From: John van de Wetering Date: Mon, 13 Nov 2023 21:31:50 +0000 Subject: [PATCH 23/37] Fixed bugs with saving parameters --- zxlive/edit_panel.py | 4 +-- zxlive/editor_base_panel.py | 71 ++++++++++++++++--------------------- zxlive/rule_panel.py | 2 +- 3 files changed, 34 insertions(+), 43 deletions(-) diff --git a/zxlive/edit_panel.py b/zxlive/edit_panel.py index 70fbf5b7..1bf64b54 100644 --- a/zxlive/edit_panel.py +++ b/zxlive/edit_panel.py @@ -6,8 +6,9 @@ from PySide6.QtCore import Signal from PySide6.QtGui import QAction from PySide6.QtWidgets import (QToolButton) -from pyzx import EdgeType, VertexType, Circuit +from pyzx import EdgeType, VertexType from pyzx.circuit.qasmparser import QASMParser +from pyzx.symbolic import Poly from .base_panel import ToolbarSection from .commands import UpdateGraph @@ -16,7 +17,6 @@ from .editor_base_panel import EditorBasePanel from .graphscene import EditGraphScene from .graphview import GraphView -from .poly import Poly class GraphEditPanel(EditorBasePanel): diff --git a/zxlive/editor_base_panel.py b/zxlive/editor_base_panel.py index 79cfeccd..d0881352 100644 --- a/zxlive/editor_base_panel.py +++ b/zxlive/editor_base_panel.py @@ -15,9 +15,9 @@ QSpacerItem, QSplitter, QToolButton, QWidget) from pyzx import EdgeType, VertexType from pyzx.utils import get_w_partner, vertex_is_w +from pyzx.graph.jsonparser import string_to_phase from pyzx.symbolic import Poly - from .base_panel import BasePanel, ToolbarSection from .commands import (AddEdge, AddNode, AddWNode, ChangeEdgeColor, ChangeNodeType, ChangePhase, MoveNode, SetGraph, @@ -26,7 +26,6 @@ from .dialogs import show_error_msg from .eitem import HAD_EDGE_BLUE from .graphscene import EditGraphScene -from .parse_poly import parse from .vitem import BLACK @@ -98,16 +97,8 @@ def update_colors(self) -> None: super().update_colors() self.update_side_bar() - def update_variable_viewer(self) -> None: - self.update_side_bar() - def _populate_variables(self) -> None: - self.variable_types = {} - for vert in self.graph.vertices(): - phase = self.graph.phase(vert) - if isinstance(phase, Poly): - for var in phase.free_vars(): - self.variable_types[var.name] = var.is_bool + self.variable_types = self.graph.variable_types.copy() def _tool_clicked(self, tool: ToolType) -> None: self.graph_scene.curr_tool = tool @@ -189,18 +180,18 @@ def vert_double_clicked(self, v: VT) -> None: if not ok: return None try: - new_phase = string_to_complex(input_) if phase_is_complex else string_to_fraction(input_, self._new_var) + new_phase = string_to_complex(input_) if phase_is_complex else string_to_phase(input_, graph) except ValueError: show_error_msg("Invalid Input", error_msg) return None cmd = ChangePhase(self.graph_view, v, new_phase) self.undo_stack.push(cmd) - - def _new_var(self, name: str) -> Poly: - if name not in self.variable_types: - self.variable_types[name] = False - self.variable_viewer.add_item(name) - return new_var(name, self.variable_types) + # For some reason it is important we first push to the stack before we do the following. + if len(graph.variable_types) != len(self.variable_types): + new_vars = graph.variable_types.keys() - self.variable_types.keys() + self.variable_types.update(graph.variable_types) + for v in new_vars: + self.variable_viewer.add_item(v) class VariableViewer(QScrollArea): @@ -378,28 +369,28 @@ def create_icon(shape: ShapeType, color: QColor) -> QIcon: return icon -def string_to_fraction(string: str, new_var_: Callable[[str], Poly]) -> Union[Fraction, Poly]: - if not string: - return Fraction(0) - try: - s = string.lower().replace(' ', '') - s = re.sub('\\*?(pi|\u04c0)\\*?', '', s) - if '.' in s or 'e' in s: - return Fraction(float(s)) - elif '/' in s: - a, b = s.split("/", 2) - if not a: - return Fraction(1, int(b)) - if a == '-': - a = '-1' - return Fraction(int(a), int(b)) - else: - return Fraction(int(s)) - except ValueError: - try: - return parse(string, new_var_) - except Exception as e: - raise ValueError(e) +#def string_to_fraction(string: str, new_var_: Callable[[str], Poly]) -> Union[Fraction, Poly]: +# if not string: +# return Fraction(0) +# try: +# s = string.lower().replace(' ', '') +# s = re.sub('\\*?(pi|\u04c0)\\*?', '', s) +# if '.' in s or 'e' in s: +# return Fraction(float(s)) +# elif '/' in s: +# a, b = s.split("/", 2) +# if not a: +# return Fraction(1, int(b)) +# if a == '-': +# a = '-1' +# return Fraction(int(a), int(b)) +# else: +# return Fraction(int(s)) +# except ValueError: +# try: +# return parse(string, new_var_) +# except Exception as e: +# raise ValueError(e) def string_to_complex(string: str) -> complex: diff --git a/zxlive/rule_panel.py b/zxlive/rule_panel.py index 61f4c567..179fca7e 100644 --- a/zxlive/rule_panel.py +++ b/zxlive/rule_panel.py @@ -6,6 +6,7 @@ from PySide6.QtGui import QAction from PySide6.QtWidgets import QLineEdit from pyzx import EdgeType, VertexType +from pyzx.symbolic import Poly from .base_panel import ToolbarSection @@ -14,7 +15,6 @@ from .editor_base_panel import EditorBasePanel from .graphscene import EditGraphScene from .graphview import RuleEditGraphView -from .poly import Poly class RulePanel(EditorBasePanel): From 7cc448f7449b9c76808279640a9df71020a09eca Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:25:36 +0100 Subject: [PATCH 24/37] only check matrix for non-symbolic rules --- zxlive/custom_rule.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index e3dee9f0..515e5288 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -154,13 +154,14 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: from .dialogs import show_error_msg show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different numbers of inputs or outputs.") return False - left_matrix, right_matrix = rule.lhs_graph.to_matrix(), rule.rhs_graph.to_matrix() - if not np.allclose(left_matrix, right_matrix): - if show_error: - from .dialogs import show_error_msg - if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)): - show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.") - else: - show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") - return False + if not rule.lhs_graph.variable_types and not rule.rhs_graph.variable_types: + left_matrix, right_matrix = rule.lhs_graph.to_matrix(), rule.rhs_graph.to_matrix() + if not np.allclose(left_matrix, right_matrix): + if show_error: + from .dialogs import show_error_msg + if np.allclose(left_matrix / np.linalg.norm(left_matrix), right_matrix / np.linalg.norm(right_matrix)): + show_error_msg("Warning!", "The left-hand side and right-hand side of the rule differ by a scalar.") + else: + show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") + return False return True From ea0e3ec6c0036ece65acda22f5f51a06a708a5b6 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:26:39 +0100 Subject: [PATCH 25/37] rewrite rule matching with symbolic parameters --- zxlive/custom_rule.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 515e5288..ea0e232e 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -11,6 +11,8 @@ from pyzx.utils import EdgeType, VertexType from shapely import Polygon +from pyzx.symbolic import Poly + from .common import ET, VT, GraphT if TYPE_CHECKING: @@ -75,8 +77,10 @@ def matcher(self, graph: GraphT, in_selection: Callable[[VT], bool]) -> list[VT] vertices = [v for v in graph.vertices() if in_selection(v)] subgraph_nx, _ = create_subgraph(graph, vertices) graph_matcher = GraphMatcher(self.lhs_graph_nx, subgraph_nx, - node_match=categorical_node_match(['type', 'phase'], default=[1, 0])) - if graph_matcher.is_isomorphic(): + node_match=categorical_node_match('type', 1)) + matchings = list(graph_matcher.match()) + matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx) + if len(matchings) > 0: return vertices return [] @@ -102,6 +106,32 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) +def match_symbolic_parameters(match, left, right): + params = {} + left_phase = left.nodes.data('phase', default=0) + right_phase = right.nodes.data('phase', default=0) + for v in left.nodes(): + if isinstance(left_phase[v], Poly): + if str(left_phase[v]) in params: + if params[str(left_phase)] != right_phase[match[v]]: + raise ValueError("Symbolic parameters do not match") + else: + params[str(left_phase[v])] = right_phase[match[v]] + elif left_phase[v] != right_phase[match[v]]: + raise ValueError("Parameters do not match") + return params + +def filter_matchings_if_symbolic_compatible(matchings, left, right): + new_matchings = [] + for matching in matchings: + try: + match_symbolic_parameters(matching, left, right) + new_matchings.append(matching) + except ValueError: + pass + return new_matchings + + def to_networkx(graph: GraphT) -> nx.Graph: G = nx.Graph() v_data = {v: {"type": graph.type(v), From 4eaa8ca5182cb5d70d9acbfd4985f0a1185af15e Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:28:14 +0100 Subject: [PATCH 26/37] applying custom rule with symbolic parameters --- zxlive/custom_rule.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index ea0e232e..ca549d66 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -33,8 +33,11 @@ def __init__(self, lhs_graph: GraphT, rhs_graph: GraphT, name: str, description: def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutputType[ET,VT]: subgraph_nx, boundary_mapping = create_subgraph(graph, vertices) graph_matcher = GraphMatcher(self.lhs_graph_nx, subgraph_nx, - node_match=categorical_node_match(['type', 'phase'], default=[1, 0])) - matching = list(graph_matcher.match())[0] + node_match=categorical_node_match('type', 1)) + matchings = graph_matcher.match() + matchings = filter_matchings_if_symbolic_compatible(matchings, self.lhs_graph_nx, subgraph_nx) + matching = matchings[0] + symbolic_params_map = match_symbolic_parameters(matching, self.lhs_graph_nx, subgraph_nx) vertices_to_remove = [] for v in matching: @@ -55,10 +58,15 @@ def __call__(self, graph: GraphT, vertices: list[VT]) -> pyzx.rules.RewriteOutpu vertex_map = boundary_vertex_map for v in self.rhs_graph_nx.nodes(): if self.rhs_graph_nx.nodes()[v]['type'] != VertexType.BOUNDARY: + phase = self.rhs_graph_nx.nodes()[v]['phase'] + if isinstance(phase, Poly): + phase = phase.substitute(symbolic_params_map) + if phase.free_vars() == set(): + phase = phase.terms[0][0] vertex_map[v] = graph.add_vertex(ty = self.rhs_graph_nx.nodes()[v]['type'], row = vertex_positions[v][0], qubit = vertex_positions[v][1], - phase = self.rhs_graph_nx.nodes()[v]['phase'],) + phase = phase,) # create etab to add edges etab = {} From 56b1ee5ad652153686b0a7722fe788f9cabb70ca Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 00:39:35 +0100 Subject: [PATCH 27/37] typo in match_symbolic_parameters --- zxlive/custom_rule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index ca549d66..65a131eb 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -121,7 +121,7 @@ def match_symbolic_parameters(match, left, right): for v in left.nodes(): if isinstance(left_phase[v], Poly): if str(left_phase[v]) in params: - if params[str(left_phase)] != right_phase[match[v]]: + if params[str(left_phase[v])] != right_phase[match[v]]: raise ValueError("Symbolic parameters do not match") else: params[str(left_phase[v])] = right_phase[match[v]] From f4264a14326852d86de84570d989734ae12d906c Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 01:33:39 +0100 Subject: [PATCH 28/37] small bug fix --- zxlive/custom_rule.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 65a131eb..44e26af8 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -132,6 +132,8 @@ def match_symbolic_parameters(match, left, right): def filter_matchings_if_symbolic_compatible(matchings, left, right): new_matchings = [] for matching in matchings: + if len(matching) != len(left): + continue try: match_symbolic_parameters(matching, left, right) new_matchings.append(matching) From 26bad0ce919ea280ff55755757c9889928003e46 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 01:34:02 +0100 Subject: [PATCH 29/37] get var method for symbolic parameters --- zxlive/custom_rule.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 44e26af8..c2533c85 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -114,17 +114,26 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) +def get_var(v): + if not isinstance(v, Poly): + raise ValueError("Not a symbolic parameter") + if len(v.terms) != 1: + raise ValueError("Only single-term symbolic parameters are supported") + if len(v.terms[0][1].vars) != 1: + raise ValueError("Only single-variable symbolic parameters are supported") + return v.terms[0][1].vars[0][0] + def match_symbolic_parameters(match, left, right): params = {} left_phase = left.nodes.data('phase', default=0) right_phase = right.nodes.data('phase', default=0) for v in left.nodes(): if isinstance(left_phase[v], Poly): - if str(left_phase[v]) in params: - if params[str(left_phase[v])] != right_phase[match[v]]: + if get_var(left_phase[v]) in params: + if params[get_var(left_phase[v])] != right_phase[match[v]]: raise ValueError("Symbolic parameters do not match") else: - params[str(left_phase[v])] = right_phase[match[v]] + params[get_var(left_phase[v])] = right_phase[match[v]] elif left_phase[v] != right_phase[match[v]]: raise ValueError("Parameters do not match") return params From 1053ac1eece6bfdcf178ff8f65bf107db5845099 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 17:35:43 +0100 Subject: [PATCH 30/37] add warnings for custom rules with symbolic parameters --- zxlive/custom_rule.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index c2533c85..accfe13d 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -213,4 +213,17 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: else: show_error_msg("Warning!", "The left-hand side and right-hand side of the rule have different semantics.") return False + else: + if not (rule.rhs_graph.variable_types.items() <= rule.lhs_graph.variable_types.items()): + if show_error: + from .dialogs import show_error_msg + show_error_msg("Warning!", "The right-hand side has more free variables than the left-hand side.") + return False + for vertex in rule.lhs_graph.vertices(): + if isinstance(rule.lhs_graph.phase(vertex), Poly): + if len(rule.lhs_graph.phase(vertex).free_vars()) > 1: + if show_error: + from .dialogs import show_error_msg + show_error_msg("Warning!", "Only one symbolic parameter per vertex is supported on the left-hand side.") + return False return True From 40f1fa4a46cea4b9bd0988905e6efa47b4d960e1 Mon Sep 17 00:00:00 2001 From: Razin Shaikh Date: Tue, 14 Nov 2023 18:47:17 +0100 Subject: [PATCH 31/37] symbolic rewrites support linear terms Co-authored-by: Tuomas Laakkonen --- zxlive/custom_rule.py | 64 +++++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 15 deletions(-) diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index accfe13d..349a9463 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -114,30 +114,62 @@ def to_proof_action(self) -> "ProofAction": return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) -def get_var(v): +def get_linear(v): if not isinstance(v, Poly): raise ValueError("Not a symbolic parameter") - if len(v.terms) != 1: - raise ValueError("Only single-term symbolic parameters are supported") - if len(v.terms[0][1].vars) != 1: - raise ValueError("Only single-variable symbolic parameters are supported") - return v.terms[0][1].vars[0][0] + if len(v.terms) > 2 or len(v.free_vars()) > 1: + raise ValueError("Only linear symbolic parameters are supported") + if len(v.terms) == 0: + return 1, None, 0 + elif len(v.terms) == 1: + if len(v.terms[0][1].vars) > 0: + var_term = v.terms[0] + const = 0 + else: + const = v.terms[0][0] + return 1, None, const + else: + if len(v.terms[0][1].vars) > 0: + var_term = v.terms[0] + const = v.terms[1][0] + else: + var_term = v.terms[1] + const = v.terms[0][0] + coeff = var_term[0] + var, power = var_term[1].vars[0] + if power != 1: + raise ValueError("Only linear symbolic parameters are supported") + return coeff, var, const + def match_symbolic_parameters(match, left, right): params = {} left_phase = left.nodes.data('phase', default=0) right_phase = right.nodes.data('phase', default=0) + + def check_phase_equality(v): + if left_phase[v] != right_phase[match[v]]: + raise ValueError("Parameters do not match") + + def update_params(v, var, coeff, const): + var_value = (right_phase[match[v]] - const) / coeff + if var in params and params[var] != var_value: + raise ValueError("Symbolic parameters do not match") + params[var] = var_value + for v in left.nodes(): if isinstance(left_phase[v], Poly): - if get_var(left_phase[v]) in params: - if params[get_var(left_phase[v])] != right_phase[match[v]]: - raise ValueError("Symbolic parameters do not match") - else: - params[get_var(left_phase[v])] = right_phase[match[v]] - elif left_phase[v] != right_phase[match[v]]: - raise ValueError("Parameters do not match") + coeff, var, const = get_linear(left_phase[v]) + if var is None: + check_phase_equality(v) + continue + update_params(v, var, coeff, const) + else: + check_phase_equality(v) + return params + def filter_matchings_if_symbolic_compatible(matchings, left, right): new_matchings = [] for matching in matchings: @@ -221,9 +253,11 @@ def check_rule(rule: CustomRule, show_error: bool = True) -> bool: return False for vertex in rule.lhs_graph.vertices(): if isinstance(rule.lhs_graph.phase(vertex), Poly): - if len(rule.lhs_graph.phase(vertex).free_vars()) > 1: + try: + get_linear(rule.lhs_graph.phase(vertex)) + except ValueError as e: if show_error: from .dialogs import show_error_msg - show_error_msg("Warning!", "Only one symbolic parameter per vertex is supported on the left-hand side.") + show_error_msg("Warning!", str(e)) return False return True From b58cc302db905475e1ed9ea1100ae47cc3c54426 Mon Sep 17 00:00:00 2001 From: John van de Wetering Date: Wed, 15 Nov 2023 11:26:22 +0000 Subject: [PATCH 32/37] Comment out the using of local copy of pyzx --- zxlive/app.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/zxlive/app.py b/zxlive/app.py index 56a2991c..7171eacd 100644 --- a/zxlive/app.py +++ b/zxlive/app.py @@ -18,8 +18,9 @@ from PySide6.QtWidgets import QApplication from PySide6.QtCore import QCommandLineParser from PySide6.QtGui import QIcon -import sys -sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx + +#import sys +#sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx from .mainwindow import MainWindow from .common import get_data From 75f7666621c742a204bca01b00b6ee266a6c2af9 Mon Sep 17 00:00:00 2001 From: John van de Wetering Date: Wed, 15 Nov 2023 11:31:43 +0000 Subject: [PATCH 33/37] Remove lark from requirements as it is now only used in pyzx itself --- pyproject.toml | 3 +-- requirements.txt | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c07058cd..9a352e57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,8 +33,7 @@ dependencies = [ "networkx", "numpy", "shapely", - "lark>=1.1.5", - "pyperclip>=1.8.1" + "pyperclip" ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 1fc881f4..8c7f1726 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ pyzx @ git+https://github.com/Quantomatic/pyzx -lark~=1.1.7 networkx~=3.1 numpy~=1.25.2 pytest-qt~=4.2.0 From 7f25048d25860d3b42a9957354d0dd03d5a5f59c Mon Sep 17 00:00:00 2001 From: Boldi Date: Wed, 15 Nov 2023 12:08:38 +0000 Subject: [PATCH 34/37] fix missing import --- zxlive/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/zxlive/app.py b/zxlive/app.py index 1f6763fd..6fc133ce 100644 --- a/zxlive/app.py +++ b/zxlive/app.py @@ -19,8 +19,8 @@ from PySide6.QtCore import QCommandLineParser from PySide6.QtGui import QIcon -#import sys -#sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx +import sys +# sys.path.insert(0, '../pyzx') # So that it can find a local copy of pyzx from .mainwindow import MainWindow from .common import get_data, GraphT From de4a72533ea5a541a02a0858d4b132b49c25f09a Mon Sep 17 00:00:00 2001 From: Boldi Date: Wed, 15 Nov 2023 12:15:20 +0000 Subject: [PATCH 35/37] Remove string_to_fraction test as it is in pyzx now --- test/test_editor_base_panel.py | 53 +--------------------------------- 1 file changed, 1 insertion(+), 52 deletions(-) diff --git a/test/test_editor_base_panel.py b/test/test_editor_base_panel.py index 02043d88..6b7cfb12 100644 --- a/test/test_editor_base_panel.py +++ b/test/test_editor_base_panel.py @@ -18,60 +18,9 @@ import pytest -from zxlive.editor_base_panel import string_to_fraction, string_to_complex -from zxlive.poly import Poly, Term, Var, new_var +from zxlive.editor_base_panel import string_to_complex -def test_string_to_fraction() -> None: - types_dict = {'a': False, 'b': False} - - def _new_var(name: str) -> Poly: - return new_var(name, types_dict) - - # Test empty input clears the phase. - assert string_to_fraction('', _new_var) == Fraction(0) - - # Test different ways of specifying integer multiples of pi. - assert string_to_fraction('3', _new_var) == Fraction(3) - assert string_to_fraction('3pi', _new_var) == Fraction(3) - assert string_to_fraction('3*pi', _new_var) == Fraction(3) - assert string_to_fraction('pi*3', _new_var) == Fraction(3) - - # Test different ways of specifying fractions. - assert string_to_fraction('pi/2', _new_var) == Fraction(1, 2) - assert string_to_fraction('-pi/2', _new_var) == Fraction(-1, 2) - assert string_to_fraction('5/2', _new_var) == Fraction(5, 2) - assert string_to_fraction('5pi/2', _new_var) == Fraction(5, 2) - assert string_to_fraction('5*pi/2', _new_var) == Fraction(5, 2) - assert string_to_fraction('pi*5/2', _new_var) == Fraction(5, 2) - assert string_to_fraction('5/2pi', _new_var) == Fraction(5, 2) - assert string_to_fraction('5/2*pi', _new_var) == Fraction(5, 2) - assert string_to_fraction('5/pi*2', _new_var) == Fraction(5, 2) - - # Test different ways of specifying floats. - assert string_to_fraction('5.5', _new_var) == Fraction(11, 2) - assert string_to_fraction('5.5pi', _new_var) == Fraction(11, 2) - assert string_to_fraction('25e-1', _new_var) == Fraction(5, 2) - assert string_to_fraction('5.5*pi', _new_var) == Fraction(11, 2) - assert string_to_fraction('pi*5.5', _new_var) == Fraction(11, 2) - - # Test a fractional phase specified with variables. - assert (string_to_fraction('a*b', _new_var) == - Poly([(1, Term([(Var('a', types_dict), 1), (Var('b', types_dict), 1)]))])) - assert (string_to_fraction('2*a', _new_var) == - Poly([(2, Term([(Var('a', types_dict), 1)]))])) - assert (string_to_fraction('2a', _new_var) == - Poly([(2, Term([(Var('a', types_dict), 1)]))])) - assert (string_to_fraction('3/2a', _new_var) == - Poly([(3/2, Term([(Var('a', types_dict), 1)]))])) - assert (string_to_fraction('3a+2b', _new_var) == - Poly([(3, Term([(Var('a', types_dict), 1)])), (2, Term([(Var('b', types_dict), 1)]))])) - - - # Test bad input. - with pytest.raises(ValueError): - string_to_fraction('bad input', _new_var) - def test_string_to_complex() -> None: # Test empty input clears the phase. assert string_to_complex('') == 0 From 8bcc98284f1449a57c98ff8aff923ba3ce32bd1e Mon Sep 17 00:00:00 2001 From: Boldi Date: Wed, 15 Nov 2023 11:51:21 +0000 Subject: [PATCH 36/37] dev: Moving from tab based rewrite buttons to tree based side panel --- zxlive/animations.py | 64 ++++++- zxlive/custom_rule.py | 9 +- zxlive/proof_actions.py | 382 --------------------------------------- zxlive/proof_panel.py | 110 ++++------- zxlive/rewrite_action.py | 220 ++++++++++++++++++++++ zxlive/rewrite_data.py | 253 ++++++++++++++++++++++++++ 6 files changed, 580 insertions(+), 458 deletions(-) delete mode 100644 zxlive/proof_actions.py create mode 100644 zxlive/rewrite_action.py create mode 100644 zxlive/rewrite_data.py diff --git a/zxlive/animations.py b/zxlive/animations.py index 6c3b776d..e11da23d 100644 --- a/zxlive/animations.py +++ b/zxlive/animations.py @@ -1,16 +1,24 @@ +from __future__ import annotations + import itertools import random -from typing import Optional, Callable +from typing import Optional, Callable, TYPE_CHECKING from PySide6.QtCore import QEasingCurve, QPointF, QAbstractAnimation, \ QParallelAnimationGroup from PySide6.QtGui import QUndoStack, QUndoCommand from pyzx.utils import vertex_is_w -from .common import VT, GraphT, pos_to_view +from .custom_rule import CustomRule +from .rewrite_data import operations +from .common import VT, GraphT, pos_to_view, ANIMATION_DURATION from .graphscene import GraphScene from .vitem import VItem, VItemAnimation, VITEM_UNSELECTED_Z, VITEM_SELECTED_Z, get_w_partner_vitem +if TYPE_CHECKING: + from .proof_panel import ProofPanel + from .rewrite_action import RewriteAction + class AnimatedUndoStack(QUndoStack): """An undo stack that can play animations between actions.""" @@ -256,3 +264,55 @@ def unfuse(before: GraphT, after: GraphT, src: VT, scene: GraphScene) -> QAbstra return morph_graph(before, after, scene, to_start=lambda _: src, to_end=lambda _: None, duration=700, ease=QEasingCurve(QEasingCurve.Type.OutElastic)) + +def make_animation(self: RewriteAction, panel: ProofPanel, g, matches, rem_verts) -> tuple: + anim_before = None + anim_after = None + if self.name == operations['spider']['text'] or self.name == operations['fuse_w']['text']: + anim_before = QParallelAnimationGroup() + for v1, v2 in matches: + if v1 in rem_verts: + v1, v2 = v2, v1 + anim_before.addAnimation(fuse(panel.graph_scene.vertex_map[v2], panel.graph_scene.vertex_map[v1])) + elif self.name == operations['to_z']['text']: + print('To do: animate ' + self.name) + elif self.name == operations['to_x']['text']: + print('To do: animate ' + self.name) + elif self.name == operations['rem_id']['text']: + anim_before = QParallelAnimationGroup() + for m in matches: + anim_before.addAnimation(remove_id(panel.graph_scene.vertex_map[m[0]])) + elif self.name == operations['copy']['text']: + anim_before = QParallelAnimationGroup() + for m in matches: + anim_before.addAnimation(fuse(panel.graph_scene.vertex_map[m[0]], + panel.graph_scene.vertex_map[m[1]])) + anim_after = QParallelAnimationGroup() + for m in matches: + anim_after.addAnimation(strong_comp(panel.graph, g, m[1], panel.graph_scene)) + elif self.name == operations['pauli']['text']: + print('To do: animate ' + self.name) + elif self.name == operations['bialgebra']['text']: + anim_before = QParallelAnimationGroup() + for v1, v2 in matches: + anim_before.addAnimation(fuse(panel.graph_scene.vertex_map[v1], + panel.graph_scene.vertex_map[v2], meet_halfway=True)) + anim_after = QParallelAnimationGroup() + for v1, v2 in matches: + v2_row, v2_qubit = panel.graph.row(v2), panel.graph.qubit(v2) + panel.graph.set_row(v2, (panel.graph.row(v1) + v2_row) / 2) + panel.graph.set_qubit(v2, (panel.graph.qubit(v1) + v2_qubit) / 2) + anim_after.addAnimation(strong_comp(panel.graph, g, v2, panel.graph_scene)) + panel.graph.set_row(v2, v2_row) + panel.graph.set_qubit(v2, v2_qubit) + elif isinstance(self.rule, CustomRule) and self.rule.last_rewrite_center is not None: + center = self.rule.last_rewrite_center + duration = ANIMATION_DURATION / 2 + anim_before = morph_graph_to_center(panel.graph, lambda v: v not in g.graph, + panel.graph_scene, center, duration, + QEasingCurve(QEasingCurve.Type.InQuad)) + anim_after = morph_graph_from_center(g, lambda v: v not in panel.graph.graph, + panel.graph_scene, center, duration, + QEasingCurve(QEasingCurve.Type.OutQuad)) + + return anim_before, anim_after diff --git a/zxlive/custom_rule.py b/zxlive/custom_rule.py index 349a9463..daff45cd 100644 --- a/zxlive/custom_rule.py +++ b/zxlive/custom_rule.py @@ -16,7 +16,7 @@ from .common import ET, VT, GraphT if TYPE_CHECKING: - from .proof_actions import ProofAction + from .rewrite_data import RewriteData class CustomRule: def __init__(self, lhs_graph: GraphT, rhs_graph: GraphT, name: str, description: str) -> None: @@ -109,9 +109,10 @@ def from_json(cls, json_str: str) -> "CustomRule": assert (isinstance(lhs_graph, GraphT) and isinstance(rhs_graph, GraphT)) # type: ignore return cls(lhs_graph, rhs_graph, d['name'], d['description']) - def to_proof_action(self) -> "ProofAction": - from .proof_actions import MATCHES_VERTICES, ProofAction - return ProofAction(self.name, self.matcher, self, MATCHES_VERTICES, self.description) + def to_rewrite_data(self) -> "RewriteData": + from .rewrite_data import MATCHES_VERTICES + return {"text": self.name, "matcher": self.matcher, "rule": self, "type": MATCHES_VERTICES, + "tooltip": self.description, 'copy_first': False, 'returns_new_graph': False} def get_linear(v): diff --git a/zxlive/proof_actions.py b/zxlive/proof_actions.py deleted file mode 100644 index 46df31bc..00000000 --- a/zxlive/proof_actions.py +++ /dev/null @@ -1,382 +0,0 @@ -import copy -from dataclasses import dataclass, field, replace -from typing import Callable, Literal, Optional, TYPE_CHECKING - -import pyzx -from pyzx import simplify, extract_circuit - -from PySide6.QtWidgets import QPushButton, QButtonGroup -from PySide6.QtCore import QParallelAnimationGroup, QEasingCurve - -from . import animations as anims -from .commands import AddRewriteStep -from .common import ANIMATION_DURATION, ET, GraphT, VT -from .custom_rule import CustomRule -from .dialogs import show_error_msg - -if TYPE_CHECKING: - from .proof_panel import ProofPanel - -operations = copy.deepcopy(pyzx.editor.operations) - -MatchType = Literal[1, 2] - -# Copied from pyzx.editor_actions -MATCHES_VERTICES: MatchType = 1 -MATCHES_EDGES: MatchType = 2 - - -@dataclass -class ProofAction(object): - name: str - matcher: Callable[[GraphT, Callable], list] - rule: Callable[[GraphT, list], pyzx.rules.RewriteOutputType[ET,VT]] - match_type: MatchType - tooltip: str - copy_first: bool = field(default=False) # Whether the graph should be copied before trying to test whether it matches. Needed if the matcher changes the graph. - returns_new_graph: bool = field(default=False) # Whether the rule returns a new graph instead of returning the rewrite changes. - button: Optional[QPushButton] = field(default=None, init=False) - - @classmethod - def from_dict(cls, d: dict) -> "ProofAction": - if 'copy_first' not in d: - d['copy_first'] = False - if 'returns_new_graph' not in d: - d['returns_new_graph'] = False - return cls(d['text'], d['matcher'], d['rule'], d['type'], d['tooltip'], d['copy_first'], d['returns_new_graph']) - - def do_rewrite(self, panel: "ProofPanel") -> None: - verts, edges = panel.parse_selection() - g = copy.deepcopy(panel.graph_scene.g) - - if self.match_type == MATCHES_VERTICES: - matches = self.matcher(g, lambda v: v in verts) - else: - matches = self.matcher(g, lambda e: e in edges) - - try: - if self.returns_new_graph: - g = self.rule(g, matches) - else: - etab, rem_verts, rem_edges, check_isolated_vertices = self.rule(g, matches) - g.remove_edges(rem_edges) - g.remove_vertices(rem_verts) - g.add_edge_table(etab) - except Exception as e: - show_error_msg('Error while applying rewrite rule', str(e)) - return - - cmd = AddRewriteStep(panel.graph_view, g, panel.step_view, self.name) - anim_before = None - anim_after = None - if self.name == operations['spider']['text'] or self.name == operations['fuse_w']['text']: - anim_before = QParallelAnimationGroup() - for v1, v2 in matches: - if v1 in rem_verts: - v1, v2 = v2, v1 - anim_before.addAnimation(anims.fuse(panel.graph_scene.vertex_map[v2], panel.graph_scene.vertex_map[v1])) - elif self.name == operations['to_z']['text']: - print('To do: animate ' + self.name) - elif self.name == operations['to_x']['text']: - print('To do: animate ' + self.name) - elif self.name == operations['rem_id']['text']: - anim_before = QParallelAnimationGroup() - for m in matches: - anim_before.addAnimation(anims.remove_id(panel.graph_scene.vertex_map[m[0]])) - elif self.name == operations['copy']['text']: - anim_before = QParallelAnimationGroup() - for m in matches: - anim_before.addAnimation(anims.fuse(panel.graph_scene.vertex_map[m[0]], - panel.graph_scene.vertex_map[m[1]])) - anim_after = QParallelAnimationGroup() - for m in matches: - anim_after.addAnimation(anims.strong_comp(panel.graph, g, m[1], panel.graph_scene)) - elif self.name == operations['pauli']['text']: - print('To do: animate ' + self.name) - elif self.name == operations['bialgebra']['text']: - anim_before = QParallelAnimationGroup() - for v1, v2 in matches: - anim_before.addAnimation(anims.fuse(panel.graph_scene.vertex_map[v1], - panel.graph_scene.vertex_map[v2], meet_halfway=True)) - anim_after = QParallelAnimationGroup() - for v1, v2 in matches: - v2_row, v2_qubit = panel.graph.row(v2), panel.graph.qubit(v2) - panel.graph.set_row(v2, (panel.graph.row(v1) + v2_row) / 2) - panel.graph.set_qubit(v2, (panel.graph.qubit(v1) + v2_qubit) / 2) - anim_after.addAnimation(anims.strong_comp(panel.graph, g, v2, panel.graph_scene)) - panel.graph.set_row(v2, v2_row) - panel.graph.set_qubit(v2, v2_qubit) - elif isinstance(self.rule, CustomRule) and self.rule.last_rewrite_center is not None: - center = self.rule.last_rewrite_center - duration = ANIMATION_DURATION / 2 - anim_before = anims.morph_graph_to_center(panel.graph, lambda v: v not in g.graph, - panel.graph_scene, center, duration, - QEasingCurve(QEasingCurve.Type.InQuad)) - anim_after = anims.morph_graph_from_center(g, lambda v: v not in panel.graph.graph, - panel.graph_scene, center, duration, - QEasingCurve(QEasingCurve.Type.OutQuad)) - - panel.undo_stack.push(cmd, anim_before=anim_before, anim_after=anim_after) - - def update_active(self, g: GraphT, verts: list[VT], edges: list[ET]) -> None: - if self.copy_first: - g = copy.deepcopy(g) - if self.match_type == MATCHES_VERTICES: - matches = self.matcher(g, lambda v: v in verts) - else: - matches = self.matcher(g, lambda e: e in edges) - - if self.button is None: return - if matches: - self.button.setEnabled(True) - else: - self.button.setEnabled(False) - - -class ProofActionGroup(object): - def __init__(self, name: str, *actions: ProofAction) -> None: - self.name = name - self.actions = actions - self.btn_group: Optional[QButtonGroup] = None - self.parent_panel = None - - def copy(self) -> "ProofActionGroup": - copied_actions = [] - for action in self.actions: - action_copy = replace(action) - action_copy.button = None - copied_actions.append(action_copy) - return ProofActionGroup(self.name, *copied_actions) - - def init_buttons(self, parent: "ProofPanel") -> None: - self.btn_group = QButtonGroup(parent) - self.btn_group.setExclusive(False) - def create_rewrite(action: ProofAction, parent: "ProofPanel") -> Callable[[], None]: # Needed to prevent weird bug with closures in signals - def rewriter() -> None: - action.do_rewrite(parent) - return rewriter - for action in self.actions: - if action.button is not None: continue - btn = QPushButton(action.name, parent) - btn.setMaximumWidth(150) - btn.setStatusTip(action.tooltip) - btn.setEnabled(False) - btn.clicked.connect(create_rewrite(action, parent)) - self.btn_group.addButton(btn) - action.button = btn - - def update_active(self, g: GraphT, verts: list[VT], edges: list[ET]) -> None: - for action in self.actions: - action.update_active(g, verts, edges) - -# We want additional actions that are not part of the original PyZX editor -# So we add them to operations - -operations.update({ - "pivot_boundary": {"text": "boundary pivot", - "tooltip": "Performs a pivot between a Pauli spider and a spider on the boundary.", - "matcher": pyzx.rules.match_pivot_boundary, - "rule": pyzx.rules.pivot, - "type": MATCHES_EDGES, - "copy_first": True}, - "pivot_gadget": {"text": "gadget pivot", - "tooltip": "Performs a pivot between a Pauli spider and a spider with an arbitrary phase, creating a phase gadget.", - "matcher": pyzx.rules.match_pivot_gadget, - "rule": pyzx.rules.pivot, - "type": MATCHES_EDGES, - "copy_first": True}, - "phase_gadget_fuse": {"text": "Fuse phase gadgets", - "tooltip": "Fuses two phase gadgets with the same connectivity.", - "matcher": pyzx.rules.match_phase_gadgets, - "rule": pyzx.rules.merge_phase_gadgets, - "type": MATCHES_VERTICES, - "copy_first": True}, - "supplementarity": {"text": "Supplementarity", - "tooltip": "Looks for a pair of internal spiders with the same connectivity and supplementary angles and removes them.", - "matcher": pyzx.rules.match_supplementarity, - "rule": pyzx.rules.apply_supplementarity, - "type": MATCHES_VERTICES, - "copy_first": False}, - } -) - -always_true = lambda graph, matches: matches - -def apply_simplification(simplification: Callable[[GraphT], GraphT]) -> Callable[[GraphT, list], pyzx.rules.RewriteOutputType[ET,VT]]: - def rule(g: GraphT, matches: list) -> pyzx.rules.RewriteOutputType[ET,VT]: - simplification(g) - return ({}, [], [], True) - return rule - -def _extract_circuit(graph: GraphT, matches: list) -> GraphT: - graph.auto_detect_io() - simplify.full_reduce(graph) - return extract_circuit(graph).to_graph() - -simplifications: dict = { - 'bialg_simp': { - "text": "bialgebra", - "tooltip": "bialg_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.bialg_simp), - "type": MATCHES_VERTICES, - }, - 'spider_simp': { - "text": "spider fusion", - "tooltip": "spider_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.spider_simp), - "type": MATCHES_VERTICES, - }, - 'id_simp': { - "text": "id", - "tooltip": "id_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.id_simp), - "type": MATCHES_VERTICES, - }, - 'phase_free_simp': { - "text": "phase free", - "tooltip": "phase_free_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.phase_free_simp), - "type": MATCHES_VERTICES, - }, - 'pivot_simp': { - "text": "pivot", - "tooltip": "pivot_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.pivot_simp), - "type": MATCHES_VERTICES, - }, - 'pivot_gadget_simp': { - "text": "pivot gadget", - "tooltip": "pivot_gadget_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.pivot_gadget_simp), - "type": MATCHES_VERTICES, - }, - 'pivot_boundary_simp': { - "text": "pivot boundary", - "tooltip": "pivot_boundary_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.pivot_boundary_simp), - "type": MATCHES_VERTICES, - }, - 'gadget_simp': { - "text": "gadget", - "tooltip": "gadget_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.gadget_simp), - "type": MATCHES_VERTICES, - }, - 'lcomp_simp': { - "text": "local complementation", - "tooltip": "lcomp_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.lcomp_simp), - "type": MATCHES_VERTICES, - }, - 'clifford_simp': { - "text": "clifford simplification", - "tooltip": "clifford_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.clifford_simp), - "type": MATCHES_VERTICES, - }, - 'tcount': { - "text": "tcount", - "tooltip": "tcount", - "matcher": always_true, - "rule": apply_simplification(simplify.tcount), - "type": MATCHES_VERTICES, - }, - 'to_gh': { - "text": "to green-hadamard form", - "tooltip": "to_gh", - "matcher": always_true, - "rule": apply_simplification(simplify.to_gh), - "type": MATCHES_VERTICES, - }, - 'to_rg': { - "text": "to red-green form", - "tooltip": "to_rg", - "matcher": always_true, - "rule": apply_simplification(simplify.to_rg), - "type": MATCHES_VERTICES, - }, - 'full_reduce': { - "text": "full reduce", - "tooltip": "full_reduce", - "matcher": always_true, - "rule": apply_simplification(simplify.full_reduce), - "type": MATCHES_VERTICES, - }, - 'teleport_reduce': { - "text": "teleport reduce", - "tooltip": "teleport_reduce", - "matcher": always_true, - "rule": apply_simplification(simplify.teleport_reduce), - "type": MATCHES_VERTICES, - }, - 'reduce_scalar': { - "text": "reduce scalar", - "tooltip": "reduce_scalar", - "matcher": always_true, - "rule": apply_simplification(simplify.reduce_scalar), - "type": MATCHES_VERTICES, - }, - 'supplementarity_simp': { - "text": "supplementarity", - "tooltip": "supplementarity_simp", - "matcher": always_true, - "rule": apply_simplification(simplify.supplementarity_simp), - "type": MATCHES_VERTICES, - }, - 'to_clifford_normal_form_graph': { - "text": "to clifford normal form", - "tooltip": "to_clifford_normal_form_graph", - "matcher": always_true, - "rule": apply_simplification(simplify.to_clifford_normal_form_graph), - "type": MATCHES_VERTICES, - }, - 'extract_circuit': { - "text": "circuit extraction", - "tooltip": "extract_circuit", - "matcher": always_true, - "rule": _extract_circuit, - "type": MATCHES_VERTICES, - "returns_new_graph": True, - }, -} - - -spider_fuse = ProofAction.from_dict(operations['spider']) -to_z = ProofAction.from_dict(operations['to_z']) -to_x = ProofAction.from_dict(operations['to_x']) -rem_id = ProofAction.from_dict(operations['rem_id']) -copy_action = ProofAction.from_dict(operations['copy']) -pauli = ProofAction.from_dict(operations['pauli']) -bialgebra = ProofAction.from_dict(operations['bialgebra']) -euler_rule = ProofAction.from_dict(operations['euler']) -rules_basic = ProofActionGroup("Basic rules", spider_fuse, to_z, to_x, rem_id, copy_action, pauli, bialgebra, euler_rule).copy() - -lcomp = ProofAction.from_dict(operations['lcomp']) -pivot = ProofAction.from_dict(operations['pivot']) -pivot_boundary = ProofAction.from_dict(operations['pivot_boundary']) -pivot_gadget = ProofAction.from_dict(operations['pivot_gadget']) -supplementarity = ProofAction.from_dict(operations['supplementarity']) -rules_graph_theoretic = ProofActionGroup("Graph-like rules", lcomp, pivot, pivot_boundary, pivot_gadget, supplementarity).copy() - -w_fuse = ProofAction.from_dict(operations['fuse_w']) -z_to_z_box = ProofAction.from_dict(operations['z_to_z_box']) -rules_zxw = ProofActionGroup("ZXW rules",spider_fuse, w_fuse, z_to_z_box).copy() - -hbox_to_edge = ProofAction.from_dict(operations['had2edge']) -fuse_hbox = ProofAction.from_dict(operations['fuse_hbox']) -mult_hbox = ProofAction.from_dict(operations['mult_hbox']) -rules_zh = ProofActionGroup("ZH rules", hbox_to_edge, fuse_hbox, mult_hbox).copy() - -simplification_actions = ProofActionGroup("Simplification routines", *[ProofAction.from_dict(s) for s in simplifications.values()]).copy() - -action_groups = [rules_basic, rules_graph_theoretic, rules_zxw, rules_zh, simplification_actions] diff --git a/zxlive/proof_panel.py b/zxlive/proof_panel.py index 1697257d..d502e44c 100644 --- a/zxlive/proof_panel.py +++ b/zxlive/proof_panel.py @@ -1,30 +1,26 @@ from __future__ import annotations import copy -import os -from fractions import Fraction from typing import Iterator, Union, cast import pyzx from PySide6.QtCore import (QItemSelection, QModelIndex, QPersistentModelIndex, QPointF, QRect, QSize, Qt) from PySide6.QtGui import (QAction, QColor, QFont, QFontMetrics, QIcon, - QPainter, QPen, QVector2D) -from PySide6.QtWidgets import (QAbstractItemView, QHBoxLayout, QListView, + QPainter, QPen, QVector2D, QFontInfo) +from PySide6.QtWidgets import (QAbstractItemView, QListView, QStyle, QStyledItemDelegate, - QStyleOptionViewItem, QToolButton, QWidget, - QVBoxLayout, QTabWidget, QInputDialog) + QStyleOptionViewItem, QToolButton, + QInputDialog, QTreeView) from pyzx import VertexType, basicrules from pyzx.graph.jsonparser import string_to_phase from pyzx.utils import get_z_box_label, set_z_box_label, get_w_partner, EdgeType, FractionLike from . import animations as anims -from . import proof_actions from .base_panel import BasePanel, ToolbarSection from .commands import AddRewriteStep, GoToRewriteStep, MoveNodeInStep -from .common import (get_custom_rules_path, ET, SCALE, VT, GraphT, get_data, +from .common import (ET, VT, GraphT, get_data, pos_from_view, pos_to_view, colors) -from .custom_rule import CustomRule from .dialogs import show_error_msg from .eitem import EItem from .graphscene import GraphScene @@ -32,6 +28,8 @@ from .proof import ProofModel from .vitem import DragState, VItem, W_INPUT_OFFSET, SCALE from .editor_base_panel import string_to_complex +from .rewrite_data import action_groups, refresh_custom_rules +from .rewrite_action import RewriteActionTreeModel class ProofPanel(BasePanel): @@ -41,8 +39,6 @@ def __init__(self, graph: GraphT, *actions: QAction) -> None: super().__init__(*actions) self.graph_scene = GraphScene() self.graph_scene.vertices_moved.connect(self._vert_moved) - # TODO: Right now this calls for every single vertex selected, even if we select many at the same time - self.graph_scene.selectionChanged.connect(self.update_on_selection) self.graph_scene.vertex_double_clicked.connect(self._vert_double_clicked) @@ -50,10 +46,9 @@ def __init__(self, graph: GraphT, *actions: QAction) -> None: self.splitter.addWidget(self.graph_view) self.graph_view.set_graph(graph) - self.actions_bar = QTabWidget(self) - self.layout().insertWidget(1, self.actions_bar) # type: ignore - self.init_action_groups() - self.actions_bar.currentChanged.connect(self.update_on_selection) + self.rewrites_panel = QTreeView(self) + self.splitter.insertWidget(0, self.rewrites_panel) + self.init_rewrites_bar() self.graph_view.wand_trace_finished.connect(self._wand_trace_finished) self.graph_scene.vertex_dragged.connect(self._vertex_dragged) @@ -104,35 +99,29 @@ def _toolbar_sections(self) -> Iterator[ToolbarSection]: self.refresh_rules = QToolButton(self) self.refresh_rules.setText("Refresh rules") - self.refresh_rules.clicked.connect(self._refresh_rules) + self.refresh_rules.clicked.connect(self._refresh_rewrites_model) yield ToolbarSection(*self.identity_choice, exclusive=True) yield ToolbarSection(*self.actions()) yield ToolbarSection(self.refresh_rules) - def init_action_groups(self) -> None: - self.action_groups = [group.copy() for group in proof_actions.action_groups] - custom_rules = [] - for root, dirs, files in os.walk(get_custom_rules_path()): - for file in files: - if file.endswith(".zxr"): - zxr_file = os.path.join(root, file) - with open(zxr_file, "r") as f: - rule = CustomRule.from_json(f.read()).to_proof_action() - custom_rules.append(rule) - self.action_groups.append(proof_actions.ProofActionGroup("Custom rules", *custom_rules).copy()) - for group in self.action_groups: - hlayout = QHBoxLayout() - group.init_buttons(self) - for action in group.actions: - assert action.button is not None - hlayout.addWidget(action.button) - hlayout.addStretch() - - widget = QWidget() - widget.setLayout(hlayout) - setattr(widget, "action_group", group) - self.actions_bar.addTab(widget, group.name) + def init_rewrites_bar(self) -> None: + self.rewrites_panel.setUniformRowHeights(True) + self.rewrites_panel.setSelectionMode(QAbstractItemView.SelectionMode.NoSelection) + fi = QFontInfo(self.font()) + + self.rewrites_panel.setStyleSheet( + f''' + QTreeView::Item:hover {{ + background-color: #e2f4ff; + }} + QTreeView::Item{{ + height:{fi.pixelSize() * 2}px; + }} + ''') + + # Set the models + self._refresh_rewrites_model() def parse_selection(self) -> tuple[list[VT], list[ET]]: selection = list(self.graph_scene.selected_vertices) @@ -145,12 +134,6 @@ def parse_selection(self) -> tuple[list[VT], list[ET]]: return selection, edges - def update_on_selection(self) -> None: - selection, edges = self.parse_selection() - g = self.graph_scene.g - action_group = getattr(self.actions_bar.currentWidget(), "action_group") - action_group.update_active(g, selection, edges) - def _vert_moved(self, vs: list[tuple[VT, float, float]]) -> None: cmd = MoveNodeInStep(self.graph_view, vs, self.step_view) self.undo_stack.push(cmd) @@ -232,7 +215,7 @@ def cross(a: QPointF, b: QPointF) -> float: if not trace.shift and basicrules.check_remove_id(self.graph, vertex): self._remove_id(vertex) return True - + if trace.shift and self.graph.type(vertex) != VertexType.W_OUTPUT: phase_is_complex = (self.graph.type(vertex) == VertexType.Z_BOX) if phase_is_complex: @@ -254,7 +237,7 @@ def cross(a: QPointF, b: QPointF) -> float: phase = get_z_box_label(self.graph, vertex) else: phase = self.graph.phase(vertex) - + start = trace.hit[item][0] end = trace.hit[item][-1] if start.y() > end.y(): @@ -272,7 +255,7 @@ def cross(a: QPointF, b: QPointF) -> float: else: right.append(neighbor) mouse_dir = ((start + end) * (1/2)) - pos - + if self.graph.type(vertex) == VertexType.W_OUTPUT: self._unfuse_w(vertex, left, mouse_dir) else: @@ -291,7 +274,7 @@ def _unfuse_w(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF) -> Non vi = get_w_partner(self.graph, v) par_dir = QVector2D( - self.graph.row(v) - self.graph.row(vi), + self.graph.row(v) - self.graph.row(vi), self.graph.qubit(v) - self.graph.qubit(vi) ).normalized() @@ -408,27 +391,14 @@ def _proof_step_selected(self, selected: QItemSelection, deselected: QItemSelect cmd = GoToRewriteStep(self.graph_view, self.step_view, deselected.first().topLeft().row(), selected.first().topLeft().row()) self.undo_stack.push(cmd) - def _refresh_rules(self) -> None: - self.actions_bar.removeTab(self.actions_bar.count() - 1) - custom_rules = [] - for root, dirs, files in os.walk(get_custom_rules_path()): - for file in files: - if file.endswith(".zxr"): - zxr_file = os.path.join(root, file) - with open(zxr_file, "r") as f: - rule = CustomRule.from_json(f.read()).to_proof_action() - custom_rules.append(rule) - group = proof_actions.ProofActionGroup("Custom rules", *custom_rules).copy() - hlayout = QHBoxLayout() - group.init_buttons(self) - for action in group.actions: - assert action.button is not None - hlayout.addWidget(action.button) - hlayout.addStretch() - widget = QWidget() - widget.setLayout(hlayout) - setattr(widget, "action_group", group) - self.actions_bar.addTab(widget, group.name) + def _refresh_rewrites_model(self) -> None: + refresh_custom_rules() + model = RewriteActionTreeModel.from_dict(action_groups, self) + self.rewrites_panel.setModel(model) + self.rewrites_panel.clicked.connect(model.do_rewrite) + # TODO: Right now this calls for every single vertex selected, even if we select many at the same time + self.graph_scene.selectionChanged.connect(model.update_on_selection) + self.rewrites_panel.expandAll() class ProofStepItemDelegate(QStyledItemDelegate): diff --git a/zxlive/rewrite_action.py b/zxlive/rewrite_action.py new file mode 100644 index 00000000..1d2b01b4 --- /dev/null +++ b/zxlive/rewrite_action.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import copy +from dataclasses import dataclass, field +from typing import Callable, TYPE_CHECKING + +import pyzx +from PySide6.QtCore import Qt, QAbstractItemModel, QModelIndex + +from .animations import make_animation +from .commands import AddRewriteStep +from .common import ET, GraphT, VT +from .dialogs import show_error_msg +from .rewrite_data import is_rewrite_data, RewriteData, MatchType, MATCHES_VERTICES + +if TYPE_CHECKING: + from .proof_panel import ProofPanel + +operations = copy.deepcopy(pyzx.editor.operations) + + +@dataclass +class RewriteAction: + name: str + matcher: Callable[[GraphT, Callable], list] + rule: Callable[[GraphT, list], pyzx.rules.RewriteOutputType[ET, VT]] + match_type: MatchType + tooltip: str + # Whether the graph should be copied before trying to test whether it matches. + # Needed if the matcher changes the graph. + copy_first: bool = field(default=False) + # Whether the rule returns a new graph instead of returning the rewrite changes. + returns_new_graph: bool = field(default=False) + enabled: bool = field(default=False) + + @classmethod + def from_rewrite_data(cls, d: RewriteData) -> RewriteAction: + return cls( + name=d['text'], + matcher=d['matcher'], + rule=d['rule'], + match_type=d['type'], + tooltip=d['tooltip'], + copy_first=d.get('copy_first', False), + returns_new_graph=d.get('returns_new_graph', False), + ) + + def do_rewrite(self, panel: ProofPanel) -> None: + if not self.enabled: + return + + g = copy.deepcopy(panel.graph_scene.g) + verts, edges = panel.parse_selection() + + matches = self.matcher(g, lambda v: v in verts) \ + if self.match_type == MATCHES_VERTICES \ + else self.matcher(g, lambda e: e in edges) + + try: + g, rem_verts = self.apply_rewrite(g, matches) + except Exception as e: + show_error_msg('Error while applying rewrite rule', str(e)) + return + + cmd = AddRewriteStep(panel.graph_view, g, panel.step_view, self.name) + anim_before, anim_after = make_animation(self, panel, g, matches, rem_verts) + panel.undo_stack.push(cmd, anim_before=anim_before, anim_after=anim_after) + + def apply_rewrite(self, g: GraphT, matches: list): + if self.returns_new_graph: + return self.rule(g, matches), None + + etab, rem_verts, rem_edges, check_isolated_vertices = self.rule(g, matches) + g.remove_edges(rem_edges) + g.remove_vertices(rem_verts) + g.add_edge_table(etab) + return g, rem_verts + + def update_active(self, g: GraphT, verts: list[VT], edges: list[ET]) -> None: + if self.copy_first: + g = copy.deepcopy(g) + self.enabled = bool( + self.matcher(g, lambda v: v in verts) + if self.match_type == MATCHES_VERTICES + else self.matcher(g, lambda e: e in edges) + ) + + +@dataclass +class RewriteActionTree: + id: str + rewrite: RewriteAction | None + child_items: list[RewriteActionTree] + parent: RewriteActionTree | None + + @property + def is_rewrite(self) -> bool: + return self.rewrite is not None + + @property + def rewrite_action(self) -> RewriteAction: + assert self.rewrite is not None + return self.rewrite + + def append_child(self, child: RewriteActionTree) -> None: + self.child_items.append(child) + + def child(self, row: int) -> RewriteActionTree: + assert -len(self.child_items) <= row < len(self.child_items) + return self.child_items[row] + + def child_count(self) -> int: + return len(self.child_items) + + def row(self) -> int | None: + return self.parent.child_items.index(self) if self.parent else None + + def header(self) -> str: + return self.id if self.rewrite is None else self.rewrite.name + + def tooltip(self) -> str: + return "" if self.rewrite is None else self.rewrite.tooltip + + def enabled(self) -> bool: + return self.rewrite is None or self.rewrite.enabled + + @classmethod + def from_dict(cls, d: dict, header: str = "", parent: RewriteActionTree | None = None) -> RewriteActionTree: + if is_rewrite_data(d): + return RewriteActionTree( + header, RewriteAction.from_rewrite_data(d), [], parent + ) + ret = RewriteActionTree(header, None, [], parent) + for group, actions in d.items(): + ret.append_child(cls.from_dict(actions, group, ret)) + return ret + + def update_on_selection(self, g, selection, edges): + for child in self.child_items: + child.update_on_selection(g, selection, edges) + if self.rewrite is not None: + self.rewrite.update_active(g, selection, edges) + + +class RewriteActionTreeModel(QAbstractItemModel): + root_item: RewriteActionTree + + def __init__(self, data: RewriteActionTree, proof_panel: ProofPanel): + super().__init__(proof_panel) + self.proof_panel = proof_panel + self.root_item = data + + @classmethod + def from_dict(cls, d: dict, proof_panel: ProofPanel): + return RewriteActionTreeModel( + RewriteActionTree.from_dict(d), + proof_panel + ) + + def index(self, row: int, column: int, parent: QModelIndex = None) -> QModelIndex: + if not self.hasIndex(row, column, parent): + return QModelIndex() + + parentItem = parent.internalPointer() if parent.isValid() else self.root_item + + if childItem := parentItem.child(row): + return self.createIndex(row, column, childItem) + return QModelIndex() + + def parent(self, index: QModelIndex = None) -> QModelIndex: + if not index.isValid(): + return QModelIndex() + + parentItem = index.internalPointer().parent + + if parentItem == self.root_item: + return QModelIndex() + + return self.createIndex(parentItem.row(), 0, parentItem) + + def rowCount(self, parent: QModelIndex = None) -> int: + if parent.column() > 0: + return 0 + parentItem = parent.internalPointer() if parent.isValid() else self.root_item + return parentItem.child_count() + + def columnCount(self, parent: QModelIndex = None) -> int: + return 1 + + def flags(self, index: QModelIndex) -> Qt.ItemFlag: + if index.isValid(): + return Qt.ItemFlag.ItemIsEnabled if index.internalPointer().enabled() else Qt.ItemFlag.NoItemFlags + return Qt.ItemFlag.ItemIsEnabled + + def data(self, index: QModelIndex, role: Qt.ItemDataRole = Qt.ItemDataRole.DisplayRole) -> str: + if index.isValid() and role == Qt.ItemDataRole.DisplayRole: + return index.internalPointer().header() + if index.isValid() and role == Qt.ItemDataRole.ToolTipRole: + return index.internalPointer().tooltip() + elif not index.isValid(): + return self.root_item.header() + + def headerData(self, section: int, orientation: Qt.Orientation, + role: Qt.ItemDataRole = Qt.ItemDataRole.DisplayRole) -> str: + if orientation == Qt.Orientation.Horizontal and role == Qt.ItemDataRole.DisplayRole: + return self.root_item.header() + return "" + + def do_rewrite(self, index: QModelIndex) -> None: + if not index.isValid(): + return + node = index.internalPointer() + if node.is_rewrite: + node.rewrite_action.do_rewrite(self.proof_panel) + + def update_on_selection(self) -> None: + selection, edges = self.proof_panel.parse_selection() + g = self.proof_panel.graph_scene.g + + self.root_item.update_on_selection(g, selection, edges) diff --git a/zxlive/rewrite_data.py b/zxlive/rewrite_data.py new file mode 100644 index 00000000..555c9a60 --- /dev/null +++ b/zxlive/rewrite_data.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +import copy +import os +from typing import Callable, Literal, TypedDict + +import pyzx +from pyzx import simplify, extract_circuit + +from .common import ET, GraphT, VT, get_custom_rules_path +from .custom_rule import CustomRule + +operations = copy.deepcopy(pyzx.editor.operations) + +MatchType = Literal[1, 2] + +# Copied from pyzx.editor_actions +MATCHES_VERTICES: MatchType = 1 +MATCHES_EDGES: MatchType = 2 + + +class RewriteData(TypedDict): + text: str + matcher: Callable[[GraphT, Callable], list] + rule: Callable[[GraphT, list], pyzx.rules.RewriteOutputType[ET, VT]] + type: MatchType + tooltip: str + copy_first: bool | None + returns_new_graph: bool | None + + +def is_rewrite_data(d: dict) -> bool: + proof_action_keys = {"text", "tooltip", "matcher", "rule", "type"} + return proof_action_keys.issubset(set(d.keys())) + + +def read_custom_rules() -> list[RewriteData]: + custom_rules = [] + for root, dirs, files in os.walk(get_custom_rules_path()): + for file in files: + if file.endswith(".zxr"): + zxr_file = os.path.join(root, file) + with open(zxr_file, "r") as f: + rule = CustomRule.from_json(f.read()).to_rewrite_data() + custom_rules.append(rule) + return custom_rules + + +# We want additional actions that are not part of the original PyZX editor +# So we add them to operations + +rewrites_graph_theoretic: dict[str, RewriteData] = { + "pivot_boundary": {"text": "boundary pivot", + "tooltip": "Performs a pivot between a Pauli spider and a spider on the boundary.", + "matcher": pyzx.rules.match_pivot_boundary, + "rule": pyzx.rules.pivot, + "type": MATCHES_EDGES, + "copy_first": True}, + "pivot_gadget": {"text": "gadget pivot", + "tooltip": "Performs a pivot between a Pauli spider and a spider with an arbitrary phase, creating a phase gadget.", + "matcher": pyzx.rules.match_pivot_gadget, + "rule": pyzx.rules.pivot, + "type": MATCHES_EDGES, + "copy_first": True}, + "phase_gadget_fuse": {"text": "Fuse phase gadgets", + "tooltip": "Fuses two phase gadgets with the same connectivity.", + "matcher": pyzx.rules.match_phase_gadgets, + "rule": pyzx.rules.merge_phase_gadgets, + "type": MATCHES_VERTICES, + "copy_first": True}, + "supplementarity": {"text": "Supplementarity", + "tooltip": "Looks for a pair of internal spiders with the same connectivity and supplementary angles and removes them.", + "matcher": pyzx.rules.match_supplementarity, + "rule": pyzx.rules.apply_supplementarity, + "type": MATCHES_VERTICES, + "copy_first": False}, +} + +const_true = lambda graph, matches: matches + + +def apply_simplification(simplification: Callable[[GraphT], GraphT]) -> Callable[ + [GraphT, list], pyzx.rules.RewriteOutputType[ET, VT]]: + def rule(g: GraphT, matches: list) -> pyzx.rules.RewriteOutputType[ET, VT]: + simplification(g) + return ({}, [], [], True) + + return rule + + +def _extract_circuit(graph: GraphT, matches: list) -> GraphT: + graph.auto_detect_io() + simplify.full_reduce(graph) + return extract_circuit(graph).to_graph() + + +simplifications: dict[str, RewriteData] = { + 'bialg_simp': { + "text": "bialgebra", + "tooltip": "bialg_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.bialg_simp), + "type": MATCHES_VERTICES, + }, + 'spider_simp': { + "text": "spider fusion", + "tooltip": "spider_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.spider_simp), + "type": MATCHES_VERTICES, + }, + 'id_simp': { + "text": "id", + "tooltip": "id_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.id_simp), + "type": MATCHES_VERTICES, + }, + 'phase_free_simp': { + "text": "phase free", + "tooltip": "phase_free_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.phase_free_simp), + "type": MATCHES_VERTICES, + }, + 'pivot_simp': { + "text": "pivot", + "tooltip": "pivot_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.pivot_simp), + "type": MATCHES_VERTICES, + }, + 'pivot_gadget_simp': { + "text": "pivot gadget", + "tooltip": "pivot_gadget_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.pivot_gadget_simp), + "type": MATCHES_VERTICES, + }, + 'pivot_boundary_simp': { + "text": "pivot boundary", + "tooltip": "pivot_boundary_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.pivot_boundary_simp), + "type": MATCHES_VERTICES, + }, + 'gadget_simp': { + "text": "gadget", + "tooltip": "gadget_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.gadget_simp), + "type": MATCHES_VERTICES, + }, + 'lcomp_simp': { + "text": "local complementation", + "tooltip": "lcomp_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.lcomp_simp), + "type": MATCHES_VERTICES, + }, + 'clifford_simp': { + "text": "clifford simplification", + "tooltip": "clifford_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.clifford_simp), + "type": MATCHES_VERTICES, + }, + 'tcount': { + "text": "tcount", + "tooltip": "tcount", + "matcher": const_true, + "rule": apply_simplification(simplify.tcount), + "type": MATCHES_VERTICES, + }, + 'to_gh': { + "text": "to green-hadamard form", + "tooltip": "to_gh", + "matcher": const_true, + "rule": apply_simplification(simplify.to_gh), + "type": MATCHES_VERTICES, + }, + 'to_rg': { + "text": "to red-green form", + "tooltip": "to_rg", + "matcher": const_true, + "rule": apply_simplification(simplify.to_rg), + "type": MATCHES_VERTICES, + }, + 'full_reduce': { + "text": "full reduce", + "tooltip": "full_reduce", + "matcher": const_true, + "rule": apply_simplification(simplify.full_reduce), + "type": MATCHES_VERTICES, + }, + 'teleport_reduce': { + "text": "teleport reduce", + "tooltip": "teleport_reduce", + "matcher": const_true, + "rule": apply_simplification(simplify.teleport_reduce), + "type": MATCHES_VERTICES, + }, + 'reduce_scalar': { + "text": "reduce scalar", + "tooltip": "reduce_scalar", + "matcher": const_true, + "rule": apply_simplification(simplify.reduce_scalar), + "type": MATCHES_VERTICES, + }, + 'supplementarity_simp': { + "text": "supplementarity", + "tooltip": "supplementarity_simp", + "matcher": const_true, + "rule": apply_simplification(simplify.supplementarity_simp), + "type": MATCHES_VERTICES, + }, + 'to_clifford_normal_form_graph': { + "text": "to clifford normal form", + "tooltip": "to_clifford_normal_form_graph", + "matcher": const_true, + "rule": apply_simplification(simplify.to_clifford_normal_form_graph), + "type": MATCHES_VERTICES, + }, + 'extract_circuit': { + "text": "circuit extraction", + "tooltip": "extract_circuit", + "matcher": const_true, + "rule": _extract_circuit, + "type": MATCHES_VERTICES, + "returns_new_graph": True, + }, +} + +rules_basic = {"spider", "to_z", "to_x", "rem_id", "copy", "pauli", "bialgebra", "euler"} + +rules_zxw = {"spider", "fuse_w", "z_to_z_box"} + +rules_zh = {"had2edge", "fuse_hbox", "mult_hbox"} + +action_groups = { + "Basic rules": {key: operations[key] for key in rules_basic}, + "Graph-like rules": rewrites_graph_theoretic, + "ZXW rules": {key: operations[key] for key in rules_zxw}, + "ZH rules": {key: operations[key] for key in rules_zh}, + "Simplification routines": simplifications, +} + + +def refresh_custom_rules() -> None: + action_groups["Custom rules"] = {rule["text"]: rule for rule in read_custom_rules()} + + +refresh_custom_rules() From bb898d66fb5b2addf9a73b9e3101f42161d1604c Mon Sep 17 00:00:00 2001 From: Boldi Date: Wed, 15 Nov 2023 14:47:51 +0000 Subject: [PATCH 37/37] Fix disabled button issue on Windows --- zxlive/proof_panel.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/zxlive/proof_panel.py b/zxlive/proof_panel.py index d502e44c..92526dc5 100644 --- a/zxlive/proof_panel.py +++ b/zxlive/proof_panel.py @@ -118,6 +118,9 @@ def init_rewrites_bar(self) -> None: QTreeView::Item{{ height:{fi.pixelSize() * 2}px; }} + QTreeView::Item:!enabled {{ + color: #c0c0c0; + }} ''') # Set the models