Skip to content

Commit

Permalink
allow unfusing W nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
tuomas56 committed Nov 13, 2023
1 parent e7a2c69 commit 9d87e77
Showing 1 changed file with 48 additions and 6 deletions.
54 changes: 48 additions & 6 deletions zxlive/proof_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
QStyleOptionViewItem, QToolButton, QWidget,
QVBoxLayout, QTabWidget, QInputDialog)
from pyzx import VertexType, basicrules
from pyzx.utils import get_z_box_label, set_z_box_label
from pyzx.utils import get_z_box_label, set_z_box_label, get_w_partner, EdgeType

from . import animations as anims
from . import proof_actions
Expand All @@ -29,7 +29,7 @@
from .graphscene import GraphScene
from .graphview import GraphTool, GraphView, WandTrace
from .proof import ProofModel
from .vitem import DragState, VItem
from .vitem import DragState, VItem, get_w_partner_vitem, W_INPUT_OFFSET, SCALE
from .editor_base_panel import string_to_complex, string_to_fraction
from .poly import Poly

Expand Down Expand Up @@ -220,14 +220,14 @@ def cross(a: QPointF, b: QPointF) -> float:
return False
item = filtered[0]
vertex = item.v
if self.graph.type(vertex) not in (VertexType.Z, VertexType.X, VertexType.Z_BOX):
if self.graph.type(vertex) not in (VertexType.Z, VertexType.X, VertexType.Z_BOX, VertexType.W_OUTPUT):
return False

if not trace.shift and basicrules.check_remove_id(self.graph, vertex):
self._remove_id(vertex)
return True

if trace.shift:
if trace.shift and self.graph.type(vertex) != VertexType.W_OUTPUT:
phase_is_complex = (self.graph.type(vertex) == VertexType.Z_BOX)
if phase_is_complex:
prompt = "Enter desired phase value (complex value):"
Expand All @@ -245,7 +245,7 @@ def new_var(_):
except ValueError:
show_error_msg("Invalid Input", error_msg)
return False
else:
elif self.graph.type(vertex) != VertexType.W_OUTPUT:
if self.graph.type(vertex) == VertexType.Z_BOX:
phase = get_z_box_label(self.graph, vertex)
else:
Expand All @@ -268,7 +268,11 @@ def new_var(_):
else:
right.append(neighbor)
mouse_dir = ((start + end) * (1/2)) - pos
self._unfuse(vertex, left, mouse_dir, phase)

if self.graph.type(vertex) == VertexType.W_OUTPUT:
self._unfuse_w(vertex, left, mouse_dir)
else:
self._unfuse(vertex, left, mouse_dir, phase)
return True

def _remove_id(self, v: VT) -> None:
Expand All @@ -278,6 +282,44 @@ def _remove_id(self, v: VT) -> None:
cmd = AddRewriteStep(self.graph_view, new_g, self.step_view, "id")
self.undo_stack.push(cmd, anim_before=anim)

def _unfuse_w(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF) -> None:
new_g = copy.deepcopy(self.graph)

vi = get_w_partner(self.graph, v)
par_dir = QVector2D(
self.graph.row(v) - self.graph.row(vi),
self.graph.qubit(v) - self.graph.qubit(vi)
).normalized()

perp_dir = QVector2D(mouse_dir - QPointF(self.graph.row(v)/SCALE, self.graph.qubit(v)/SCALE)).normalized()
perp_dir -= QVector2D.dotProduct(perp_dir, par_dir) * par_dir
perp_dir.normalize()

out_offset_x = par_dir.x() * 0.5 + perp_dir.x() * 0.5
out_offset_y = par_dir.y() * 0.5 + perp_dir.y() * 0.5

in_offset_x = out_offset_x - par_dir.x()*W_INPUT_OFFSET
in_offset_y = out_offset_y - par_dir.y()*W_INPUT_OFFSET

left_vert = new_g.add_vertex(VertexType.W_OUTPUT,
qubit=self.graph.qubit(v) + out_offset_y,
row=self.graph.row(v) + out_offset_x)
left_vert_i = new_g.add_vertex(VertexType.W_INPUT,
qubit=self.graph.qubit(v) + in_offset_y,
row=self.graph.row(v) + in_offset_x)
new_g.add_edge((left_vert_i, left_vert), EdgeType.W_IO)
new_g.add_edge((v, left_vert_i))
new_g.set_row(v, self.graph.row(v))
new_g.set_qubit(v, self.graph.qubit(v))
for neighbor in left_neighbours:
new_g.add_edge((neighbor, left_vert),
self.graph.edge_type((v, neighbor)))
new_g.remove_edge((v, neighbor))

anim = anims.unfuse(self.graph, new_g, v, self.graph_scene)
cmd = AddRewriteStep(self.graph_view, new_g, self.step_view, "unfuse")
self.undo_stack.push(cmd, anim_after=anim)

def _unfuse(self, v: VT, left_neighbours: list[VT], mouse_dir: QPointF, phase: Poly | complex | Fraction) -> None:
def snap_vector(v: QVector2D) -> None:
if abs(v.x()) > abs(v.y()):
Expand Down

0 comments on commit 9d87e77

Please sign in to comment.