Skip to content

Commit

Permalink
Fix self loops, add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
boldar99 committed Jul 16, 2024
1 parent 0947f34 commit e49c74c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
17 changes: 14 additions & 3 deletions pyzx/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
'adjoint', 'is_unitary','tensor_to_matrix',
'find_scalar_correction']

import itertools
from math import pi, sqrt

from typing import Optional
Expand Down Expand Up @@ -85,7 +86,7 @@ def W_to_tensor(arity: int) -> np.ndarray:
return m

def pop_and_shift(verts, indices):
res = [indices[v].pop() for v in verts]
res = [indices[v].pop() for v in verts if v in indices]
for i in sorted(res,reverse=True):
for w,l in indices.items():
l2 = []
Expand Down Expand Up @@ -129,8 +130,11 @@ def tensorfy(g: 'BaseGraph[VT,ET]', preserve_scalar:bool=True) -> np.ndarray:

for i,r in enumerate(sorted(verts_row.keys())):
for v in sorted(verts_row[r]):
neigh = [list(set(g.edge_st(e)) - {v})[0] for e in g.incident_edges(v)]
d = len(neigh)
neigh = list(itertools.chain.from_iterable(
set(g.edge_st(e)) - {v} for e in g.incident_edges(v)
))
self_loops = [e for e in g.incident_edges(v) if g.edge_s(e) == g.edge_t(e)]
d = len(neigh) + len(self_loops) * 2
if v in inputs:
if types[v] != VertexType.BOUNDARY: raise ValueError("Wrong type for input:", v, types[v])
continue # inputs already taken care of
Expand Down Expand Up @@ -159,6 +163,13 @@ def tensorfy(g: 'BaseGraph[VT,ET]', preserve_scalar:bool=True) -> np.ndarray:
t = Z_box_to_tensor(d, label)
else:
raise ValueError("Vertex %s has non-ZXH type but is not an input or output" % str(v))
for sl in self_loops:
if g.edge_type(sl) == EdgeType.HADAMARD:
t = np.tensordot(t,had)
elif g.edge_type(sl) == EdgeType.SIMPLE:
t = np.trace(t)
else:
raise NotImplementedError(f"Tensor contraction for {repr(sl)} self loops are not implemented.")
nn = list(filter(lambda n: rows[n]<r or (rows[n]==r and n<v), neigh)) # TODO: allow ordering on vertex indices?
ety = {n:g.edge_type(g.edge(v,n)) for n in nn}
nn.sort(key=lambda n: ety[n])
Expand Down
29 changes: 26 additions & 3 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from types import ModuleType
from typing import Optional


if __name__ == '__main__':
sys.path.append('..')
sys.path.append('.')
Expand Down Expand Up @@ -136,7 +135,7 @@ def test_adjoint(self):

def test_multiedge_scalar(self):
g = Multigraph()
g._auto_simplify = False
g.set_auto_simplify(False)
i1 = g.add_vertex(1,0,0)
i2 = g.add_vertex(2,1,0)
g.add_edges([(i1, i2)] * 3)
Expand All @@ -147,9 +146,33 @@ def test_self_loop_scalar(self):
g.set_auto_simplify(False)
i1 = g.add_vertex(1,0,0)
g.add_edge((i1, i1))
self.assertTrue(compare_tensors(g, np.array([2]), preserve_scalar=True))
g.add_edge((i1, i1), 2)
self.assertTrue(compare_tensors(g, np.array([0]), preserve_scalar=True))


def test_self_loop_state(self):
g = Multigraph()
g.set_auto_simplify(False)
i0 = g.add_vertex(0,0,0)
i1 = g.add_vertex(2,0,1)
g.set_inputs((i0,))
g.add_edge((i0, i1))
self.assertTrue(compare_tensors(g, np.array([1,0])))
g.add_edge((i1, i1), 2)
self.assertTrue(compare_tensors(g, np.array([0,1])))

def test_self_loop_and_parallel_edge_map(self):
g = Multigraph()
g.set_auto_simplify(False)
i0 = g.add_vertex(0,0,0)
i1 = g.add_vertex(2,0,1)
i2 = g.add_vertex(1,0,2)
i3 = g.add_vertex(0,0,3)
g.set_inputs((i0,))
g.set_outputs((i3,))
g.add_edges([(i0, i1), (i1, i1)] + [(i1, i2)] * 2)
g.add_edges([(i2, i2), (i2, i3)], 2)
self.assertTrue(compare_tensors(g, np.array([[0,0],[1,0]])))

if __name__ == '__main__':
unittest.main()

0 comments on commit e49c74c

Please sign in to comment.