Skip to content

Commit

Permalink
Merge pull request #216 from mjsutcliffe99/master
Browse files Browse the repository at this point in the history
Vertex/edge cutting
  • Loading branch information
jvdwetering authored Aug 6, 2024
2 parents a9a5a65 + 2aecd70 commit ffff4f8
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 2 deletions.
94 changes: 93 additions & 1 deletion pyzx/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import numpy as np

from .utils import EdgeType, VertexType, toggle_edge
from .utils import EdgeType, VertexType, toggle_vertex, toggle_edge, ave_pos
from . import simplify
from .circuit import Circuit
from .graph.base import BaseGraph,VT,ET
Expand Down Expand Up @@ -482,3 +482,95 @@ def replace_1_1(g: BaseGraph[VT,ET], verts: List[VT]) -> BaseGraph[VT,ET]:
g.add_edge((verts[0],w),EdgeType.HADAMARD)
for v in verts: g.add_to_phase(v,Fraction(-1,4))
return g

def cut_vertex(g,v):
"""Applies the ``cutting'' decomposition to a vertex, as used in, for example: https://arxiv.org/pdf/2403.10964."""
g = g.clone()
g0 = g.clone()
g1 = g.clone()
g0.remove_vertex(v)
g1.remove_vertex(v)

n = len(g.neighbors(v))
g0.scalar.add_power(-n)
g1.scalar.add_power(-n)
g1.scalar.add_phase(g.phase(v)) # account for e^(i*pi*alpha) on right branch

vtype = toggle_vertex(g.type(v))

for i in g.neighbors(v):
etype = g.edge_type((v,i)) # maintain edge type
qubit = ave_pos(g.qubit(v),g.qubit(i),1/2)
row = ave_pos(g.row(v),g.row(i),1/2)

newV = g0.add_vertex(vtype,qubit,row,0) # add and connect the new vertices
g0.add_edge((newV,i),etype)
newV = g1.add_vertex(vtype,qubit,row,1)
g1.add_edge((newV,i),etype)

return (g0,g1)

def cut_edge(g,e,ty=1):
"""Applies the ``cutting'' decomposition to an edge, as used in, for example: https://arxiv.org/pdf/2403.10964. The type ty decides whether to cut with Z- branches or X- branches."""
g = g.clone()
g0 = g.clone()
g1 = g.clone()
g0.remove_edge(e)
g1.remove_edge(e)

etype = g.edge_type(e)

g0.scalar.add_power(-2)
g1.scalar.add_power(-2)

x0,x1 = g.row(e[0]), g.row(e[1])
y0,y1 = g.qubit(e[0]), g.qubit(e[1])

qubit1 = ave_pos(y0,y1,1/3)
row1 = ave_pos(x0,x1,1/3)
qubit2 = ave_pos(y0,y1,2/3)
row2 = ave_pos(x0,x1,2/3)

v = g0.add_vertex(ty=ty,qubit=qubit1,row=row1,phase=0)
g0.add_edge((v,e[0]),1)
v = g0.add_vertex(ty=ty,qubit=qubit2,row=row2,phase=0)
g0.add_edge((v,e[1]),etype)

v = g1.add_vertex(ty=ty,qubit=qubit1,row=row1,phase=1)
g1.add_edge((v,e[0]),1)
v = g1.add_vertex(ty=ty,qubit=qubit2,row=row2,phase=1)
g1.add_edge((v,e[1]),etype)

return (g0,g1)

def cut_wishbone(g,v,neighs,ph):
"""Applies the ``wishbone cut'' (or ``separator cut'') decomposition to vertex v of graph g, pulling out the neighbours ``neighs'' and a phase ``ph'', as used in: [PAPER UPCOMING]."""
g = g.clone()

for i in neighs:
if not i in g.neighbors(v):
raise ValueError("Attempted illegal wishbone cut. Vertex " + str(i) + " is not a neighbor of target vertex " + str(v) + ".")

neighs_left = set(g.neighbors(v)).symmetric_difference(neighs)
neighs_right = neighs

phase_left = g.phase(v) - ph
phase_right = ph

v_left = g.add_vertex(qubit=g.qubit(v),row=g.row(v)-0.5,ty=g.type(v),phase=phase_left)
v_right = g.add_vertex(qubit=g.qubit(v),row=g.row(v)+0.5,ty=g.type(v),phase=phase_right)

for i in neighs_left: g.add_edge((v_left,i),g.edge_type((v,i)))
for i in neighs_right: g.add_edge((v_right,i),g.edge_type((v,i)))

g.remove_vertex(v)

#--

gLeft = g.clone()
gRight = g.clone()

gRight.add_to_phase(v_left,1)
gRight.add_to_phase(v_right,1)

return (gLeft,gRight)
3 changes: 3 additions & 0 deletions pyzx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,6 @@ def get_z_box_label(g, v):
def set_z_box_label(g, v, label):
assert g.type(v) == VertexType.Z_BOX
g.set_vdata(v, 'label', label)

# Return position 'perc'%-distance between 2 points:
def ave_pos(a,b,perc=1/2): return (abs(a-b))*(perc) + min(a,b)
39 changes: 38 additions & 1 deletion tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,32 @@

import unittest
import sys
import random
from types import ModuleType
from typing import Optional

if __name__ == '__main__':
sys.path.append('..')
sys.path.append('.')
from pyzx.circuit import Circuit
from pyzx.simulate import replace_magic_states
from pyzx.simulate import replace_magic_states, cut_vertex, cut_edge
from pyzx.generate import cliffords
from pyzx.simplify import full_reduce

np: Optional[ModuleType]
try:
import numpy as np
except ImportError:
np = None

def rand_graph(qubits=5,depth=10):
g = cliffords(qubits,depth)
g.apply_state('0'*qubits)
g.apply_effect('0'*qubits)
return g

def round_complex(scalar,decimal_places):
return round(scalar.real,decimal_places) + round(scalar.imag,decimal_places)*1j

@unittest.skipUnless(np, "numpy needs to be installed for this to run")
class TestSimulate(unittest.TestCase):
Expand All @@ -43,6 +54,32 @@ def test_magic_state_decomposition_is_correct(self):
g = c.to_graph()
gsum = replace_magic_states(g)
self.assertTrue(np.allclose(g.to_tensor(), gsum.to_tensor()))

def test_vertex_cut(self,repeats=20):
for i in range(1,repeats):
g = rand_graph() # generate random Clifford graph
v_cut = random.randrange(len(g.vertices()))
g0,g1 = cut_vertex(g,v_cut) # apply random vertex cut

for g_i in (g,g0,g1): full_reduce(g_i)

scal = round_complex(g.scalar.to_number(),3) # the scalar from fully reducing g
scalCut = round_complex(g0.scalar.to_number()+g1.scalar.to_number(),3) # the sum of scalars from the cut graph
assert(scal == scalCut)

def test_edge_cut(self,repeats=20):
for i in range(1,repeats):
g = rand_graph() # generate random Clifford graph
rand_v = random.randrange(len(g.vertices()))
rand_neigh = list(g.neighbors(rand_v))[random.randrange(len(g.neighbors(rand_v)))]
e_cut = (rand_v,rand_neigh) # apply random edge cut

g0,g1 = cut_edge(g,e_cut)
for g_i in (g,g0,g1): full_reduce(g_i)

scal = round_complex(g.scalar.to_number(),3) # the scalar from fully reducing g
scalCut = round_complex(g0.scalar.to_number()+g1.scalar.to_number(),3) # the sum of scalars from the cut graph
assert(scal == scalCut)


if __name__ == '__main__':
Expand Down

0 comments on commit ffff4f8

Please sign in to comment.