Skip to content

Commit

Permalink
Merge pull request #258 from boldar99/fix-paralell-self-loops-tensor
Browse files Browse the repository at this point in the history
Fix tensorfy for parallel and loop edges
  • Loading branch information
jvdwetering authored Jul 16, 2024
2 parents f8954b5 + 8ba8451 commit 1c2da27
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
19 changes: 14 additions & 5 deletions pyzx/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
'adjoint', 'is_unitary','tensor_to_matrix',
'find_scalar_correction']

import itertools
from math import pi, sqrt

from typing import Optional
Expand Down Expand Up @@ -85,9 +86,7 @@ def W_to_tensor(arity: int) -> np.ndarray:
return m

def pop_and_shift(verts, indices):
res = []
for v in verts:
res.append(indices[v].pop())
res = [indices[v].pop() for v in verts if v in indices]
for i in sorted(res,reverse=True):
for w,l in indices.items():
l2 = []
Expand Down Expand Up @@ -131,8 +130,11 @@ def tensorfy(g: 'BaseGraph[VT,ET]', preserve_scalar:bool=True) -> np.ndarray:

for i,r in enumerate(sorted(verts_row.keys())):
for v in sorted(verts_row[r]):
neigh = list(g.neighbors(v))
d = len(neigh)
neigh = list(itertools.chain.from_iterable(
set(g.edge_st(e)) - {v} for e in g.incident_edges(v)
))
self_loops = [e for e in g.incident_edges(v) if g.edge_s(e) == g.edge_t(e)]
d = len(neigh) + len(self_loops) * 2
if v in inputs:
if types[v] != VertexType.BOUNDARY: raise ValueError("Wrong type for input:", v, types[v])
continue # inputs already taken care of
Expand Down Expand Up @@ -161,6 +163,13 @@ def tensorfy(g: 'BaseGraph[VT,ET]', preserve_scalar:bool=True) -> np.ndarray:
t = Z_box_to_tensor(d, label)
else:
raise ValueError("Vertex %s has non-ZXH type but is not an input or output" % str(v))
for sl in self_loops:
if g.edge_type(sl) == EdgeType.HADAMARD:
t = np.tensordot(t,had)
elif g.edge_type(sl) == EdgeType.SIMPLE:
t = np.trace(t)
else:
raise NotImplementedError(f"Tensor contraction with {repr(sl)} self-loops is not implemented.")
nn = list(filter(lambda n: rows[n]<r or (rows[n]==r and n<v), neigh)) # TODO: allow ordering on vertex indices?
ety = {n:g.edge_type(g.edge(v,n)) for n in nn}
nn.sort(key=lambda n: ety[n])
Expand Down
41 changes: 41 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
sys.path.append('..')
sys.path.append('.')
from pyzx.graph import Graph
from pyzx.graph.multigraph import Multigraph
from pyzx.generate import cliffords
from pyzx.circuit import Circuit

Expand Down Expand Up @@ -132,6 +133,46 @@ def test_adjoint(self):
circ_adj = tensorfy(circ.adjoint())
self.assertTrue(compare_tensors(t_adj,circ_adj))

def test_multiedge_scalar(self):
g = Multigraph()
g.set_auto_simplify(False)
i1 = g.add_vertex(1,0,0)
i2 = g.add_vertex(2,1,0)
g.add_edges([(i1, i2)] * 3)
self.assertTrue(compare_tensors(g, np.array([np.sqrt(2)**(-1)]), preserve_scalar=True))

def test_self_loop_scalar(self):
g = Multigraph()
g.set_auto_simplify(False)
i1 = g.add_vertex(1,0,0)
g.add_edge((i1, i1))
self.assertTrue(compare_tensors(g, np.array([2]), preserve_scalar=True))
g.add_edge((i1, i1), 2)
self.assertTrue(compare_tensors(g, np.array([0]), preserve_scalar=True))

def test_self_loop_state(self):
g = Multigraph()
g.set_auto_simplify(False)
i0 = g.add_vertex(0,0,0)
i1 = g.add_vertex(2,0,1)
g.set_inputs((i0,))
g.add_edge((i0, i1))
self.assertTrue(compare_tensors(g, np.array([1,0])))
g.add_edge((i1, i1), 2)
self.assertTrue(compare_tensors(g, np.array([0,1])))

def test_self_loop_and_parallel_edge_map(self):
g = Multigraph()
g.set_auto_simplify(False)
i0 = g.add_vertex(0,0,0)
i1 = g.add_vertex(2,0,1)
i2 = g.add_vertex(1,0,2)
i3 = g.add_vertex(0,0,3)
g.set_inputs((i0,))
g.set_outputs((i3,))
g.add_edges([(i0, i1), (i1, i1)] + [(i1, i2)] * 2)
g.add_edges([(i2, i2), (i2, i3)], 2)
self.assertTrue(compare_tensors(g, np.array([[0,0],[1,0]])))

if __name__ == '__main__':
unittest.main()

0 comments on commit 1c2da27

Please sign in to comment.