Skip to content

Commit

Permalink
Use minimal cycle basis for ARCs (#408)
Browse files Browse the repository at this point in the history
* use minimal cycle basis

* improve use of cycles

* fix bravyi-haah

* remove leaflessness
  • Loading branch information
quantumjim authored Nov 16, 2023
1 parent aebaa93 commit 5624e3e
Showing 1 changed file with 160 additions and 134 deletions.
294 changes: 160 additions & 134 deletions src/qiskit_qec/circuits/repetition_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import numpy as np
import rustworkx as rx
import networkx as nx

from qiskit import ClassicalRegister, QuantumCircuit, QuantumRegister, transpile
from qiskit.circuit.library import XGate, RZGate
from qiskit.transpiler import PassManager, InstructionDurations
Expand Down Expand Up @@ -625,23 +627,22 @@ def _get_cycles(self):
set of qubits around adjacent cycles is found.
"""

link_graph = self._get_link_graph()
lg_edges = set(link_graph.edge_list())
lg_nodes = link_graph.nodes()
cycles = rx.cycle_basis(link_graph)
cycle_dict = {(lg_nodes[edge[0]], lg_nodes[edge[1]]): list(edge) for edge in lg_edges}
for cycle in cycles:
edges = []
cl = len(cycle)
for j in range(cl):
for edge in [(cycle[j], cycle[(j + 1) % cl]), (cycle[(j + 1) % cl], cycle[j])]:
if edge in lg_edges:
edges.append((lg_nodes[edge[0]], lg_nodes[edge[1]]))
for edge in edges:
cycle_dict[edge] += cycle
for edge, ns in cycle_dict.items():
cycle_dict[edge] = set(ns)
self.cycle_dict = cycle_dict
self.link_graph = self._get_link_graph()
lg_edges = set(self.link_graph.edge_list())
lg_nodes = self.link_graph.nodes()
ng = nx.Graph()
for n0, n1 in self.link_graph.edge_list():
ng.add_edge(n0, n1)
# express the cycles in terms of the ns of the link graph
self.cycles = nx.minimum_cycle_basis(ng)
# and for each pair of data qubits, list the cycles it is a part of
self.cycle_dict = {(lg_nodes[edge[0]], lg_nodes[edge[1]]): set() for edge in lg_edges}
for c, cycle in enumerate(self.cycles):
for n0 in cycle:
for n1 in cycle:
for edge in [(n0, n1), (n1, n0)]:
if edge in lg_edges:
self.cycle_dict[lg_nodes[edge[0]], lg_nodes[edge[1]]].add(c)

def _coloring(self):
"""
Expand Down Expand Up @@ -812,7 +813,7 @@ def _preparation(self):
if graph.degree(n) == 1:
z_logicals.append(node)
# if there are none, just use the first
if not z_logicals: # z_logicals == []
if not z_logicals:
z_logicals = [min(self.code_index.keys())]
self.z_logicals = z_logicals

Expand Down Expand Up @@ -1182,126 +1183,151 @@ def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False):

# see whether the bulk nodes are neutral
if bulk_nodes:
# bicolor the nodes of the link graph, such that node edges connect unlike edges
link_qubits = set(node.properties["link qubit"] for node in nodes)
link_graph = self._get_link_graph()
# all the qubits around cycles of the node edges have to be covered
ns_to_do = set()
for edge in [tuple(node.qubits) for node in bulk_nodes]:
ns_to_do = ns_to_do.union(self.cycle_dict[edge])
# start with one of these
if ns_to_do:
n = ns_to_do.pop()
else:
n = 0
node_color = {n: 0}
recently_colored = node_color.copy()
base_neutral = True
# count the number of nodes for each colour throughout
num_nodes = [1, 0]
last_num = [None, None]
fully_converged = False
last_converged = False
while base_neutral and not fully_converged:
# go through all nodes coloured in the last pass
newly_colored = {}
for n, c in recently_colored.items():
# look at all the code qubits that are neighbours
incident_es = link_graph.incident_edges(n)
for e in incident_es:
edge = link_graph.edges()[e]
n0, n1 = link_graph.edge_list()[e]
if n0 == n:
nn = n1
else:
nn = n0
# see if the edge corresponds to one of the given nodes
dc = edge["link qubit"] in link_qubits
# if the neighbour is not yet coloured, colour it
# different color if edge is given node, same otherwise
if nn not in node_color:
new_c = (c + dc) % 2
if nn not in newly_colored:
newly_colored[nn] = new_c
num_nodes[new_c] += 1
# if it is coloured, check the colour is correct
else:
base_neutral = base_neutral and (node_color[nn] == (c + dc) % 2)
for nn, c in newly_colored.items():
node_color[nn] = c
if nn in ns_to_do:
ns_to_do.remove(nn)
recently_colored = newly_colored.copy()
# process is converged once one colour has stoppped growing
# once ns_to_do is empty
converged = (not ns_to_do) and (
(num_nodes[0] == last_num[0] != 0) or (num_nodes[1] == last_num[1] != 0)
)
fully_converged = converged and last_converged
if not fully_converged:
last_num = num_nodes.copy()
last_converged = converged
# see how many qubits are in the converged colour, and determine the min colour
for c in range(2):
if num_nodes[c] == last_num[c]:
conv_color = c
if num_nodes[conv_color] <= self.d / 2:
min_color = conv_color
else:
min_color = (conv_color + 1) % 2
# calculate the number of nodes for the other
num_nodes[(conv_color + 1) % 2] = link_graph.num_nodes() - num_nodes[conv_color]
# get the set of min nodes
min_ns = set()
for n, c in node_color.items():
if c == min_color:
min_ns.add(n)

# see which qubits for logical zs are needed
flipped_logicals_all = [[], []]
if base_neutral:
for qubit in self.z_logicals:
n = link_graph.nodes().index(qubit)
dc = not n in min_ns
flipped_logicals_all[(min_color + dc) % 2].append(qubit)
for j in range(2):
flipped_logicals_all[j] = set(flipped_logicals_all[j])

# list the colours with the max error one first
# (unless we do min only)
cs = []
if not minimal:
cs.append((min_color + 1) % 2)
cs.append(min_color)

# see what happens for both colours
# if neutral for maximal, it's neutral
# otherwise, it is whatever it is for the minimal
for c in cs:
neutral = base_neutral
num_errors = num_nodes[c]
flipped_logicals = flipped_logicals_all[c]

# if unneeded logical zs are given, cluster is not neutral
# (unless this is ignored)
if (not ignore_extra_boundary) and given_logicals.difference(flipped_logicals):
neutral = False
flipped_logicals = set()
# otherwise, report only needed logicals that aren't given
else:
flipped_logicals = flipped_logicals.difference(given_logicals)

# check whether there are frustrated plaquettes
parities = {}
for node in bulk_nodes:
for c in self.cycle_dict[tuple(node.qubits)]:
if c in parities:
parities[c] += 1
else:
parities[c] = 1
for c, parity in parities.items():
parities[c] = parity % 2
frust = any(parities.values())

# if frust==True, no colouring is possible, so no need to do it
if frust:
# neutral if not frustrated
neutral = not frust
# empty because ignored
flipped_logical_nodes = []
for flipped_logical in flipped_logicals:
node = DecodingGraphNode(
is_boundary=True,
qubits=[flipped_logical],
index=self.z_logicals.index(flipped_logical),
# None because not counted
num_errors = None
else:
# bicolor the nodes of the link graph, such that node edges connect unlike edges
link_qubits = set(node.properties["link qubit"] for node in nodes)
link_graph = self.link_graph
# all the qubits around cycles of the node edges have to be covered
ns_to_do = set()
for edge in [tuple(node.qubits) for node in bulk_nodes]:
for c in self.cycle_dict[edge]:
ns_to_do = ns_to_do.union(self.cycles[c])
# if this gives us qubits to start with, we start with one
if ns_to_do:
n = ns_to_do.pop()
else:
# otherwise we commit to covering them all
ns_to_do = set(range(len(link_graph.nodes())))
n = ns_to_do.pop()
node_color = {n: 0}
recently_colored = node_color.copy()
base_neutral = True
# count the number of nodes for each colour throughout
num_nodes = [1, 0]
last_num = [None, None]
fully_converged = False
last_converged = False
while base_neutral and not fully_converged:
# go through all nodes coloured in the last pass
newly_colored = {}
for n, c in recently_colored.items():
# look at all the code qubits that are neighbours
incident_es = link_graph.incident_edges(n)
for e in incident_es:
edge = link_graph.edges()[e]
n0, n1 = link_graph.edge_list()[e]
if n0 == n:
nn = n1
else:
nn = n0
# see if the edge corresponds to one of the given nodes
dc = edge["link qubit"] in link_qubits
# if the neighbour is not yet coloured, colour it
# different color if edge is given node, same otherwise
if nn not in node_color:
new_c = (c + dc) % 2
if nn not in newly_colored:
newly_colored[nn] = new_c
num_nodes[new_c] += 1
# if it is coloured, check the colour is correct
else:
base_neutral = base_neutral and (node_color[nn] == (c + dc) % 2)
for nn, c in newly_colored.items():
node_color[nn] = c
if nn in ns_to_do:
ns_to_do.remove(nn)
recently_colored = newly_colored.copy()
# process is converged once one colour has stoppped growing
# once ns_to_do is empty
converged = (not ns_to_do) and (
(num_nodes[0] == last_num[0] != 0) or (num_nodes[1] == last_num[1] != 0)
)
flipped_logical_nodes.append(node)
fully_converged = converged and last_converged
if not fully_converged:
last_num = num_nodes.copy()
last_converged = converged
# see how many qubits are in the converged colour, and determine the min colour
for c in range(2):
if num_nodes[c] == last_num[c]:
conv_color = c
if num_nodes[conv_color] <= self.d / 2:
min_color = conv_color
else:
min_color = (conv_color + 1) % 2
# calculate the number of nodes for the other
num_nodes[(conv_color + 1) % 2] = link_graph.num_nodes() - num_nodes[conv_color]
# get the set of min nodes
min_ns = set()
for n, c in node_color.items():
if c == min_color:
min_ns.add(n)

# see which qubits for logical zs are needed
flipped_logicals_all = [[], []]
if base_neutral:
for qubit in self.z_logicals:
n = link_graph.nodes().index(qubit)
dc = not n in min_ns
flipped_logicals_all[(min_color + dc) % 2].append(qubit)
for j in range(2):
flipped_logicals_all[j] = set(flipped_logicals_all[j])

# list the colours with the max error one first
# (unless we do min only)
cs = []
if not minimal:
cs.append((min_color + 1) % 2)
cs.append(min_color)

# see what happens for both colours
# if neutral for maximal, it's neutral
# otherwise, it is whatever it is for the minimal
for c in cs:
neutral = base_neutral
num_errors = num_nodes[c]
flipped_logicals = flipped_logicals_all[c]

# if unneeded logical zs are given, cluster is not neutral
# (unless this is ignored)
if (not ignore_extra_boundary) and given_logicals.difference(flipped_logicals):
neutral = False
flipped_logicals = set()
# otherwise, report only needed logicals that aren't given
else:
flipped_logicals = flipped_logicals.difference(given_logicals)

flipped_logical_nodes = []
for flipped_logical in flipped_logicals:
node = DecodingGraphNode(
is_boundary=True,
qubits=[flipped_logical],
index=self.z_logicals.index(flipped_logical),
)
flipped_logical_nodes.append(node)

if neutral and not flipped_logical_nodes:
break
if neutral and not flipped_logical_nodes:
break

else:
# without bulk nodes, neutral only if no boundary nodes are given
Expand Down

0 comments on commit 5624e3e

Please sign in to comment.