Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multigraph self loop #242

Merged
merged 10 commits into from
Jun 27, 2024
11 changes: 10 additions & 1 deletion pyzx/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,13 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> 'BaseGraph':
graph did not.
"""
from .graph import Graph # imported here to prevent circularity
from .multigraph import Multigraph
if (backend is None):
backend = type(self).backend
g = Graph(backend = backend)
if isinstance(self, Multigraph) and isinstance(g, Multigraph):
g.set_auto_simplify(self._auto_simplify) # type: ignore
# mypy issue https://github.com/python/mypy/issues/16413
g.track_phases = self.track_phases
g.scalar = self.scalar.copy(conjugate=adjoint)
g.merge_vdata = self.merge_vdata
Expand Down Expand Up @@ -390,14 +394,19 @@ def merge(self, other: 'BaseGraph') -> Tuple[List[VT],List[ET]]:
def subgraph_from_vertices(self,verts: List[VT]) -> 'BaseGraph':
"""Returns the subgraph consisting of the specified vertices."""
from .graph import Graph # imported here to prevent circularity
from .multigraph import Multigraph
g = Graph(backend=type(self).backend)
if isinstance(self, Multigraph) and isinstance(g, Multigraph):
g.set_auto_simplify(self._auto_simplify) # type: ignore
# mypy issue https://github.com/python/mypy/issues/16413
ty = self.types()
rs = self.rows()
qs = self.qubits()
phase = self.phases()
grounds = self.grounds()

edges = [self.edge(v,w) for v in verts for w in verts if self.connected(v,w)]
edges = [e for e in self.edges() \
if self.edge_st(e)[0] in verts and self.edge_st(e)[1] in verts]

vert_map = dict()
for v in verts:
Expand Down
10 changes: 5 additions & 5 deletions pyzx/graph/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import json
from collections import Counter
from typing import Any, Callable, Generic, Optional, List, Dict, Tuple
import copy

from ..utils import VertexType, EdgeType, FractionLike, FloatInt, phase_to_s
from .base import BaseGraph, VT, ET
Expand Down Expand Up @@ -56,12 +57,11 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None:
self.new_edges = []
self.removed_edges = []

for e in (new_edges - old_edges):
for e in Counter(new_edges - old_edges).elements():
self.new_edges.append((g2.edge_st(e), g2.edge_type(e)))

for e in (old_edges - new_edges):
for e in Counter(old_edges - new_edges).elements():
s,t = g1.edge_st(e)
if s in self.removed_verts or t in self.removed_verts: continue
self.removed_edges.append(e)

for v in new_verts:
Expand Down Expand Up @@ -94,8 +94,8 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None:

def apply_diff(self,g: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]:
g = copy.deepcopy(g)
g.remove_vertices(self.removed_verts)
g.remove_edges(self.removed_edges)
g.remove_vertices(self.removed_verts)
for v in self.new_verts:
g.add_vertex_indexed(v)
g.set_position(v,*self.changed_pos[v])
Expand Down
10 changes: 5 additions & 5 deletions pyzx/graph/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def clone(self) -> 'Multigraph':
cpy.phase_mult = self.phase_mult.copy()
cpy.max_phase_index = self.max_phase_index
return cpy

def set_auto_simplify(self, s: bool):
"""Automatically remove parallel edges as edges are added"""
self._auto_simplify = s

def multigraph(self):
return False

Expand Down Expand Up @@ -209,7 +209,7 @@ def remove_vertices(self, vertices):
e = self.graph[v][v1]
self.nedges -= e.s + e.h
del self.graph[v][v1]
del self.graph[v1][v]
if v != v1: del self.graph[v1][v]
# remove the vertex
del self.graph[v]
del self.ty[v]
Expand Down Expand Up @@ -244,7 +244,7 @@ def remove_edge(self, edge):

if e.is_empty():
del self.graph[s][t]
del self.graph[t][s]
if s != t: del self.graph[t][s]

self.nedges -= 1

Expand All @@ -270,7 +270,7 @@ def edges(self, s=None, t=None):
if s == None:
for v0,adj in self.graph.items():
for v1, e in adj.items():
if v1 > v0:
if v1 >= v0:
for _ in range(e.s): yield (v0, v1, EdgeType.SIMPLE)
for _ in range(e.h): yield (v0, v1, EdgeType.HADAMARD)
for _ in range(e.w_io): yield (v0, v1, EdgeType.W_IO)
Expand Down
Loading