Skip to content

Commit

Permalink
352 uf for arc (#362)
Browse files Browse the repository at this point in the history
* Small fixes to make UnionFind run with real data (#355)

* some quick hacks to try to make the UnionFind Decoder run with real data. Reversal of string in process() is not needed, and qubit-index fix for the correction

* fixing unittests for unionfind

* prevent infinite while loop

* hacky avoidance of infinite loop

* fix flatten nodes

* improve cluster checking for ARCs

* allow check_nodes to declare non-optimal clusters as neutral

* remove bug test on while loop

* make everything work!

* improve printing of nodes

* add minimal kwarg to check_nodes

* change test to linear ARC

* change test to linear ARC

* use rep codes to test UF

* put uf into standard form

* don't test things that don't work

* lint and black

* move cluster processing to base

* make all cluster methods use standard form

* restore uf test file (but streamlined)

* fix inconsistencies with logical strings

---------

Co-authored-by: Milan Liepelt <[email protected]>
  • Loading branch information
quantumjim and milanliepelt authored May 5, 2023
1 parent 603c677 commit ada699d
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 299 deletions.
4 changes: 3 additions & 1 deletion src/qiskit_qec/circuits/code_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def string2nodes(self, string, **kwargs):
pass

@abstractmethod
def check_nodes(self, nodes, ignore_extra_boundary=False):
def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False):
"""
Determines whether a given set of nodes are neutral. If so, also
determines any additional logical readout qubits that would be
Expand All @@ -65,6 +65,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
nodes (list): List of nodes, of the type produced by `string2nodes`.
ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are
ignored.
minimal (bool): Whether output should only reflect the minimal error
case.
Returns:
neutral (bool): Whether the nodes independently correspond to a valid
set of errors.
Expand Down
74 changes: 52 additions & 22 deletions src/qiskit_qec/circuits/repetition_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,12 @@ def flatten_nodes(nodes: List[DecodingGraphNode]):
for node in nodes:
if nodes_per_link[node.properties["link qubit"]] % 2:
flat_node = copy(node)
# FIXME: Seems unsafe.
flat_node.time = None
flat_nodes.append(flat_node)
if flat_node not in flat_nodes:
flat_nodes.append(flat_node)
return flat_nodes

def check_nodes(self, nodes, ignore_extra_boundary=False):
def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False):
"""
Determines whether a given set of nodes are neutral. If so, also
determines any additional logical readout qubits that would be
Expand All @@ -385,6 +385,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
nodes (list): List of nodes, of the type produced by `string2nodes`.
ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are
ignored.
minimal (bool): Whether output should only reflect the minimal error
case.
Returns:
neutral (bool): Whether the nodes independently correspond to a valid
set of errors.
Expand Down Expand Up @@ -422,9 +424,17 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
# and majority
error_c_max = str((int(error_c_min) + 1) % 2)

# calculate all required info for the max to see if that is fully neutral
# if not, calculate and output for the min case
for error_c in [error_c_max, error_c_min]:
# list the colours with the max error one first
# (unless we do min only)
error_cs = []
if minimal:
error_cs.append(error_c_max)
error_cs.append(error_c_min)

# see what happens for both colours
# if neutral for maximal, it's neutral
# otherwise, it is whatever it is for the minimal
for error_c in error_cs:
num_errors = colors.count(error_c)

# determine the corresponding flipped logicals
Expand Down Expand Up @@ -1036,6 +1046,16 @@ def _process_string(self, string):

return new_string

def string2raw_logicals(self, string):
"""
Extracts raw logicals from output string.
Args:
string (string): Results string from which to extract logicals
Returns:
list: Raw values for logical operators that correspond to nodes.
"""
return _separate_string(self._process_string(string))[0]

def string2nodes(self, string, **kwargs) -> List[DecodingGraphNode]:
"""
Convert output string from circuits into a set of nodes.
Expand Down Expand Up @@ -1114,10 +1134,11 @@ def flatten_nodes(nodes: List[DecodingGraphNode]):
if nodes_per_link[node.properties["link qubit"]] % 2:
flat_node = deepcopy(node)
flat_node.time = None
flat_nodes.append(flat_node)
if flat_node not in flat_nodes:
flat_nodes.append(flat_node)
return flat_nodes

def check_nodes(self, nodes, ignore_extra_boundary=False):
def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False):
"""
Determines whether a given set of nodes are neutral. If so, also
determines any additional logical readout qubits that would be
Expand All @@ -1127,6 +1148,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
nodes (list): List of nodes, of the type produced by `string2nodes`.
ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are
ignored.
minimal (bool): Whether output should only reflect the minimal error
case.
Returns:
neutral (bool): Whether the nodes independently correspond to a valid
set of errors.
Expand All @@ -1151,10 +1174,10 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
nodes = self.flatten_nodes(nodes)
link_qubits = set(node.properties["link qubit"] for node in nodes)
node_color = {0: 0}
neutral = True
base_neutral = True
link_graph = self._get_link_graph()
ns_to_do = set(n for n in range(1, len(link_graph.nodes())))
while ns_to_do and neutral:
while ns_to_do and base_neutral:
# go through all coloured nodes
newly_colored = {}
for n, c in node_color.items():
Expand All @@ -1175,14 +1198,14 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
newly_colored[nn] = (c + dc) % 2
# if it is coloured, check the colour is correct
else:
neutral = neutral and (node_color[nn] == (c + dc) % 2)
base_neutral = base_neutral and (node_color[nn] == (c + dc) % 2)
for nn, c in newly_colored.items():
node_color[nn] = c
ns_to_do.remove(nn)

# see which qubits for logical zs are needed
flipped_logicals_all = [[], []]
if neutral:
if base_neutral:
for inside_c in range(2):
for n, c in node_color.items():
qubit = link_graph.nodes()[n]
Expand All @@ -1196,22 +1219,28 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
for n, c in node_color.items():
num_nodes[c] += 1

if num_nodes[0] == num_nodes[1]:
min_cs = [0, 1]
else:
min_cs = [int(sum(node_color.values()) < len(node_color) / 2)]
# list the colours with the max error one first
# (unless we do min only)
min_color = int(sum(node_color.values()) < len(node_color) / 2)
cs = []
if not minimal:
cs.append((min_color + 1) % 2)
cs.append(min_color)

# see what happens for both colours
# once full neutrality us found, go for it!
for c in min_cs:
this_neutral = neutral
# if neutral for maximal, it's neutral
# otherwise, it is whatever it is for the minimal
for c in cs:

neutral = base_neutral
num_errors = num_nodes[c]
flipped_logicals = flipped_logicals_all[c]

# if unneeded logical zs are given, cluster is not neutral
# (unless this is ignored)
if (not ignore_extra_boundary) and given_logicals.difference(flipped_logicals):
this_neutral = False
neutral = False
flipped_logicals = set()
# otherwise, report only needed logicals that aren't given
else:
flipped_logicals = flipped_logicals.difference(given_logicals)
Expand All @@ -1225,8 +1254,7 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
)
flipped_logical_nodes.append(node)

if this_neutral and flipped_logical_nodes == []:
neutral = this_neutral
if neutral and flipped_logical_nodes == []:
break

else:
Expand All @@ -1250,6 +1278,8 @@ def is_cluster_neutral(self, atypical_nodes):
to the method.
Args:
atypical_nodes (dictionary in the form of the return value of string2nodes)
ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are
ignored.
"""
neutral, logicals, _ = self.check_nodes(atypical_nodes)
return neutral and not logicals
Expand Down
16 changes: 9 additions & 7 deletions src/qiskit_qec/circuits/surface_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ def __init__(self, d: int, T: int, basis: str = "z", resets=True):
# set info needed for css codes
self.css_x_gauge_ops = [[q for q in plaq if q is not None] for plaq in self.xplaqs]
self.css_x_stabilizer_ops = self.css_x_gauge_ops
self.css_x_logical = self._logicals["x"][0]
self.css_x_boundary = self._logicals["x"][0] + self._logicals["x"][1]
self.css_x_logical = [self._logicals["x"][0]]
self.css_x_boundary = [self._logicals["x"][0] + self._logicals["x"][1]]
self.css_z_gauge_ops = [[q for q in plaq if q is not None] for plaq in self.zplaqs]
self.css_z_stabilizer_ops = self.css_z_gauge_ops
self.css_z_logical = self._logicals["z"][0]
self.css_z_boundary = self._logicals["z"][0] + self._logicals["z"][1]
self.css_z_logical = [self._logicals["z"][0]]
self.css_z_boundary = [self._logicals["z"][0] + self._logicals["z"][1]]
self.round_schedule = self.basis
self.blocks = T

Expand Down Expand Up @@ -342,7 +342,7 @@ def string2raw_logicals(self, string):
Z[0] = (Z[0] + int(final_readout[j * self.d])) % 2
# evaluated using right side
Z[1] = (Z[1] + int(final_readout[(j + 1) * self.d - 1])) % 2
return str(Z[0]) + " " + str(Z[1])
return [str(Z[0]), str(Z[1])]

def _process_string(self, string):
# get logical readout
Expand All @@ -353,7 +353,7 @@ def _process_string(self, string):

# the space separated string of syndrome changes then gets a
# double space separated logical value on the end
new_string = measured_Z + " " + syndrome_changes
new_string = " ".join(measured_Z) + " " + syndrome_changes

return new_string

Expand Down Expand Up @@ -416,7 +416,7 @@ def string2nodes(self, string, **kwargs):
nodes.append(node)
return nodes

def check_nodes(self, nodes, ignore_extra_boundary=False):
def check_nodes(self, nodes, ignore_extra_boundary=False, minimal=False):
"""
Determines whether a given set of nodes are neutral. If so, also
determines any additional logical readout qubits that would be
Expand All @@ -426,6 +426,8 @@ def check_nodes(self, nodes, ignore_extra_boundary=False):
nodes (list): List of nodes, of the type produced by `string2nodes`.
ignore_extra_boundary (bool): If `True`, undeeded boundary nodes are
ignored.
minimal (bool): Whether output should only reflect the minimal error
case.
Returns:
neutral (bool): Whether the nodes independently correspond to a valid
set of errors.
Expand Down
Loading

0 comments on commit ada699d

Please sign in to comment.