Skip to content

Commit

Permalink
Extend functionality for working with codes. (#288)
Browse files Browse the repository at this point in the history
## Description

This PR extends QECCs capabilities for working with stabilizer codes.
The following features are planned:

- [x] Refactor Pauli operator and symplectic vector functionality
- [ ] Add more codes and unify their creation
  - [x] Quantum Hamming Codes
  - [ ] Bring Code
  - [x] Many-Hypercube Code
- [x] Add support for code concatenation


## Checklist:

<!---
This checklist serves as a reminder of a couple of things that ensure
your pull request will be merged swiftly.
-->

- [x] The pull request only contains commits that are related to it.
- [x] I have added appropriate tests and documentation.
- [x] I have made sure that all CI jobs on GitHub pass.
- [x] The pull request introduces no new warnings and follows the
project's style guidelines.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
pehamTom and pre-commit-ci[bot] authored Nov 20, 2024
1 parent ef25257 commit 0880e37
Show file tree
Hide file tree
Showing 10 changed files with 958 additions and 208 deletions.
7 changes: 7 additions & 0 deletions src/mqt/qecc/codes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from .bb_codes import construct_bb_code
from .color_code import ColorCode, LatticeType
from .concatenation import ConcatenatedCode, ConcatenatedCSSCode
from .constructions import construct_iceberg_code, construct_many_hypercube_code, construct_quantum_hamming_code
from .css_code import CSSCode, InvalidCSSCodeError
from .hexagonal_color_code import HexagonalColorCode
from .square_octagon_color_code import SquareOctagonColorCode
Expand All @@ -12,11 +14,16 @@
__all__ = [
"CSSCode",
"ColorCode",
"ConcatenatedCSSCode",
"ConcatenatedCode",
"HexagonalColorCode",
"InvalidCSSCodeError",
"InvalidStabilizerCodeError",
"LatticeType",
"SquareOctagonColorCode",
"StabilizerCode",
"construct_bb_code",
"construct_iceberg_code",
"construct_many_hypercube_code",
"construct_quantum_hamming_code",
]
152 changes: 152 additions & 0 deletions src/mqt/qecc/codes/concatenation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Concatenated quantum codes."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from .css_code import CSSCode
from .pauli import Pauli
from .stabilizer_code import InvalidStabilizerCodeError, StabilizerCode
from .symplectic import SymplecticVector

if TYPE_CHECKING:
from collections.abc import Sequence

import numpy.typing as npt


class ConcatenatedCode(StabilizerCode):
"""A concatenated quantum code."""

def __init__(self, outer_code: StabilizerCode, inner_code: StabilizerCode | Sequence[StabilizerCode]) -> None:
"""Initialize a concatenated quantum code.
Args:
outer_code: The outer code.
inner_code: The inner code. If a list of codes is provided, the qubits of the outer code are encoded by the different inner codes in the list.
"""
self.outer_code = outer_code
if isinstance(inner_code, list):
self.inner_codes = inner_code
else:
self.inner_codes = [inner_code] * outer_code.n
if not all(code.k == 1 for code in self.inner_codes):
msg = "The inner codes must be stabilizer codes with a single logical qubit."
raise InvalidStabilizerCodeError(msg)

self.n = sum(code.n for code in self.inner_codes)
generators = [self._outer_pauli_to_physical(p) for p in outer_code.generators]

x_logicals = None
z_logicals = None
if outer_code.x_logicals is not None:
x_logicals = [self._outer_pauli_to_physical(p) for p in outer_code.x_logicals]
if outer_code.z_logicals is not None:
z_logicals = [self._outer_pauli_to_physical(p) for p in outer_code.z_logicals]

d = min(code.distance * outer_code.distance for code in self.inner_codes)
StabilizerCode.__init__(self, generators, d, x_logicals, z_logicals)

def __eq__(self, other: object) -> bool:
"""Check if two concatenated codes are equal."""
if not isinstance(other, ConcatenatedCode):
return NotImplemented
return self.outer_code == other.outer_code and all(
c1 == c2 for c1, c2 in zip(self.inner_codes, other.inner_codes)
)

def __hash__(self) -> int:
"""Compute the hash of the concatenated code."""
return hash((self.outer_code, tuple(self.inner_codes)))

def _outer_pauli_to_physical(self, p: Pauli) -> Pauli:
"""Convert a Pauli operator on the outer code to the operator on the concatenated code.
Args:
p: The Pauli operator.
Returns:
The Pauli operator on the physical qubits.
"""
if len(p) != self.outer_code.n:
msg = "The Pauli operator must have the same number of qubits as the outer code."
raise InvalidStabilizerCodeError(msg)
concatenated = SymplecticVector.zeros(self.n)
phase = 0
offset = 0
for i in range(self.outer_code.n):
c = self.inner_codes[i]
new_offset = offset + c.n
assert c.x_logicals is not None
assert c.z_logicals is not None
if p[i] == "X":
concatenated[offset:new_offset] = c.x_logicals[0].x_part()
concatenated[offset + self.n : new_offset + self.n] = c.x_logicals[0].z_part()
phase += c.x_logicals[0].phase
elif p[i] == "Z":
concatenated[offset:new_offset] = c.z_logicals[0].x_part()
concatenated[offset + self.n : new_offset + self.n] = c.z_logicals[0].z_part()
phase += c.z_logicals[0].phase

elif p[i] == "Y":
concatenated[offset:new_offset] = c.x_logicals[0].x_part ^ c.z_logicals[0].x_part()
concatenated[offset + self.n : new_offset + self.n] = c.x_logicals[0].z_part ^ c.z_logicals[0].z_part()
phase += c.x_logicals[0].phase + c.z_logicals[0].phase

offset = new_offset
return Pauli(concatenated, phase)


# def _valid_logicals(lst: list[StabilizerTableau | None]) -> TypeGuard[list[StabilizerTableau]]:
# return None not in lst


class ConcatenatedCSSCode(ConcatenatedCode, CSSCode):
"""A concatenated CSS code."""

def __init__(self, outer_code: CSSCode, inner_codes: CSSCode | Sequence[CSSCode]) -> None:
"""Initialize a concatenated CSS code.
Args:
outer_code: The outer code.
inner_codes: The inner code. If a list of codes is provided, the qubits of the outer code are encoded by the different inner codes in the list.
"""
# self.outer_code = outer_code
if isinstance(inner_codes, CSSCode):
inner_codes = [inner_codes] * outer_code.n

if not all(code.k == 1 for code in inner_codes):
msg = "The inner codes must be CSS codes with a single logical qubit."
raise InvalidStabilizerCodeError(msg)

ConcatenatedCode.__init__(self, outer_code, inner_codes)
hx = np.array([self._outer_checks_to_physical(check, "X") for check in outer_code.Hx], dtype=np.int8)
hz = np.array([self._outer_checks_to_physical(check, "Z") for check in outer_code.Hz], dtype=np.int8)
d = min(code.distance * outer_code.distance for code in inner_codes)
CSSCode.__init__(self, d, hx, hz)

def _outer_checks_to_physical(self, check: npt.NDArray[np.int8], operator: str) -> npt.NDArray[np.int8]:
"""Convert a check operator on the outer code to the operator on the concatenated code.
Args:
check: The check operator.
operator: The type of operator to be converted. Either 'X' or 'Z'.
Returns:
The check operator on the physical qubits.
"""
if check.shape[0] != self.outer_code.n:
msg = "The check operator must have the same number of qubits as the outer code."
raise InvalidStabilizerCodeError(msg)
concatenated = np.zeros((self.n), dtype=np.int8)
offset = 0
for i in range(self.outer_code.n):
c = self.inner_codes[i]
new_offset = offset + c.n
if check[i] == 1:
logical = c.Lx if operator == "X" else c.Lz
concatenated[offset:new_offset] = logical
offset = new_offset
return concatenated
56 changes: 56 additions & 0 deletions src/mqt/qecc/codes/constructions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Constructions of various known stabilizer codes."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from .css_code import CSSCode

if TYPE_CHECKING:
import numpy.typing as npt


def construct_quantum_hamming_code(r: int) -> CSSCode:
"""Return the [[2^r, 2^r-r-1, 3]] quantum Hamming code."""
h = _hamming_code_checks(r)
return CSSCode(3, h, h)


def construct_iceberg_code(m: int) -> CSSCode:
"""Return the [[2m, 2m-2, 2]] Iceberg code.
The Iceberg code is a CSS code with stabilizer generators X^2m and Z^2m.
https://errorcorrectionzoo.org/c/iceberg
"""
n = 2 * m
h = np.array([[1] * n], dtype=np.int8)
return CSSCode(2, h, h)


def construct_many_hypercube_code(level: int) -> CSSCode:
"""Return the [[6^l, 4^l, 2^l]] level l many-hypercube code (https://arxiv.org/abs/2403.16054).
This code is obtained by (l-1)-fold concatenation of the [[6,4,2]] iceberg code with itself.
"""
code = construct_iceberg_code(3)

for _ in range(1, level):
sx = np.hstack([code.Lx] * 6, dtype=np.int8)
sx_rem = np.kron(np.eye(6, dtype=np.int8), code.Hx)
sx = np.vstack((sx, sx_rem), dtype=np.int8)
sz = sx
code = CSSCode(code.distance * 2, sx, sz)
return code


def _hamming_code_checks(r: int) -> npt.NDArray[np.int8]:
"""Return the check matrix for the [2^r-1, 2^r-r-1, 3] Hamming code."""
n = 2**r - 1
h = np.zeros((r, n), dtype=int)
# columns are all binary strings up to 2^r
for i in range(1, n + 1):
h[:, i - 1] = np.array([int(x) for x in f"{i:b}".zfill(r)])

return h
49 changes: 28 additions & 21 deletions src/mqt/qecc/codes/css_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
from ldpc import mod2

from .pauli import StabilizerTableau
from .stabilizer_code import StabilizerCode

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -24,8 +25,21 @@ def __init__(
Hz: npt.NDArray[np.int8] | None = None, # noqa: N803
x_distance: int | None = None,
z_distance: int | None = None,
n: int | None = None,
) -> None:
"""Initialize the code."""
if Hx is None and Hz is None:
if n is None:
msg = "If no check matrices are provided, the code size must be specified."
raise InvalidCSSCodeError(msg)
self.Hx = np.zeros((0, n), dtype=np.int8)
self.Hz = np.zeros((0, n), dtype=np.int8)
self.Lx = np.eye(n, dtype=np.int8)
self.Lz = np.eye(n, dtype=np.int8)
triv = StabilizerCode.get_trivial_code(n)
super().__init__(triv.generators, triv.distance, triv.x_logicals, triv.z_logicals)
return

self._check_valid_check_matrices(Hx, Hz)

if Hx is None:
Expand All @@ -46,8 +60,8 @@ def __init__(

x_padded = np.hstack([self.Hx, z_padding])
z_padded = np.hstack([x_padding, self.Hz])
phases = np.zeros((x_padded.shape[0] + z_padded.shape[0], 1), dtype=np.int8)
super().__init__(np.hstack((np.vstack((x_padded, z_padded)), phases)), distance)
phases = np.zeros((x_padded.shape[0] + z_padded.shape[0]), dtype=np.int8)
super().__init__(StabilizerTableau(np.vstack((x_padded, z_padded)), phases), distance)

self.distance = distance
self.x_distance = x_distance if x_distance is not None else distance
Expand Down Expand Up @@ -88,14 +102,10 @@ def _compute_logical(m1: npt.NDArray[np.int8], m2: npt.NDArray[np.int8]) -> npt.

def get_x_syndrome(self, error: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
"""Compute the x syndrome of the error."""
if self.Hx is None:
return np.empty((0, error.shape[0]), dtype=np.int8)
return self.Hx @ error % 2

def get_z_syndrome(self, error: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
"""Compute the z syndrome of the error."""
if self.Hz is None:
return np.empty((0, error.shape[0]), dtype=np.int8)
return self.Hz @ error % 2

def check_if_logical_x_error(self, residual: npt.NDArray[np.int8]) -> bool:
Expand All @@ -104,21 +114,19 @@ def check_if_logical_x_error(self, residual: npt.NDArray[np.int8]) -> bool:

def check_if_x_stabilizer(self, pauli: npt.NDArray[np.int8]) -> bool:
"""Check if the Pauli is a stabilizer."""
assert self.Hx is not None
return bool(mod2.rank(np.vstack((self.Hx, pauli))) == mod2.rank(self.Hx))

def check_if_logical_z_error(self, residual: npt.NDArray[np.int8]) -> bool:
"""Check if the residual is a logical error."""
return bool((self.Lx @ residual % 2 == 1).any())
return (self.Hx.shape[0] != 0) and bool((self.Lx @ residual % 2 == 1).any())

def check_if_z_stabilizer(self, pauli: npt.NDArray[np.int8]) -> bool:
"""Check if the Pauli is a stabilizer."""
assert self.Hz is not None
return bool(mod2.rank(np.vstack((self.Hz, pauli))) == mod2.rank(self.Hz))
return (self.Hz.shape[0] != 0) and bool(mod2.rank(np.vstack((self.Hz, pauli))) == mod2.rank(self.Hz))

def stabilizer_eq_x_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDArray[np.int8]) -> bool:
"""Check if two X errors are in the same coset."""
if self.Hx is None:
if self.Hx.shape[0] == 0:
return bool(np.array_equal(error_1, error_2))
m1 = np.vstack([self.Hx, error_1])
m2 = np.vstack([self.Hx, error_2])
Expand All @@ -127,7 +135,7 @@ def stabilizer_eq_x_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDAr

def stabilizer_eq_z_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDArray[np.int8]) -> bool:
"""Check if two Z errors are in the same coset."""
if self.Hz is None:
if self.Hz.shape[0] == 0:
return bool(np.array_equal(error_1, error_2))
m1 = np.vstack([self.Hz, error_1])
m2 = np.vstack([self.Hz, error_2])
Expand All @@ -136,19 +144,13 @@ def stabilizer_eq_z_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDAr

def is_self_dual(self) -> bool:
"""Check if the code is self-dual."""
if self.Hx is None or self.Hz is None:
return False
return bool(
self.Hx.shape[0] == self.Hz.shape[0] and mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, self.Hz]))
)

@staticmethod
def _check_valid_check_matrices(Hx: npt.NDArray[np.int8] | None, Hz: npt.NDArray[np.int8] | None) -> None: # noqa: N803
"""Check if the code is a valid CSS code."""
if Hx is None and Hz is None:
msg = "At least one of the check matrices must be provided"
raise InvalidCSSCodeError(msg)

if Hx is not None and Hz is not None:
if Hx.shape[1] != Hz.shape[1]:
msg = "Check matrices must have the same number of columns"
Expand All @@ -157,18 +159,23 @@ def _check_valid_check_matrices(Hx: npt.NDArray[np.int8] | None, Hz: npt.NDArray
msg = "The check matrices must be orthogonal"
raise InvalidCSSCodeError(msg)

@classmethod
def get_trivial_code(cls, n: int) -> CSSCode:
"""Return the trivial code."""
return CSSCode(1, None, None, n=n)

@staticmethod
def from_code_name(code_name: str, distance: int | None = None) -> CSSCode:
r"""Return CSSCode object for a known code.
The following codes are supported:
- [[7, 1, 3]] Steane (\"Steane\")
- [[15, 1, 3]] tetrahedral code (\"Tetrahedral\")
- [[15, 7, 3]] Hamming code (\"Hamming\")
- [[9, 1, 3]] Shore code (\"Shor\")
- [[12, 2, 4]] Carbon Code (\"Carbon\")
- [[9, 1, 3]] rotated surface code (\"Surface, 3\"), also default when no distance is given
- [[25, 1, 5]] rotated surface code (\"Surface, 5\")
- [[15, 7, 3]] Hamming code (\"Hamming\")
- [[23, 1, 7]] golay code (\"Golay\")
Args:
Expand All @@ -179,23 +186,23 @@ def from_code_name(code_name: str, distance: int | None = None) -> CSSCode:
paths = {
"steane": prefix / "steane/",
"tetrahedral": prefix / "tetrahedral/",
"hamming": prefix / "hamming/",
"shor": prefix / "shor/",
"surface_3": prefix / "rotated_surface_d3/",
"surface_5": prefix / "rotated_surface_d5/",
"golay": prefix / "golay/",
"carbon": prefix / "carbon/",
"hamming": prefix / "hamming_15/",
}

distances = {
"steane": (3, 3),
"tetrahedral": (7, 3),
"hamming": (3, 3),
"shor": (3, 3),
"golay": (7, 7),
"surface_3": (3, 3),
"surface_5": (5, 5),
"carbon": (4, 4),
"hamming": (3, 3),
} # X, Z distances

code_name = code_name.lower()
Expand Down
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 0880e37

Please sign in to comment.