From 0880e37e23e0635b2ede3b734144ec2e998585fe Mon Sep 17 00:00:00 2001 From: Tom Peham Date: Wed, 20 Nov 2024 17:25:41 +0100 Subject: [PATCH] Extend functionality for working with codes. (#288) ## Description This PR extends QECCs capabilities for working with stabilizer codes. The following features are planned: - [x] Refactor Pauli operator and symplectic vector functionality - [ ] Add more codes and unify their creation - [x] Quantum Hamming Codes - [ ] Bring Code - [x] Many-Hypercube Code - [x] Add support for code concatenation ## Checklist: - [x] The pull request only contains commits that are related to it. - [x] I have added appropriate tests and documentation. - [x] I have made sure that all CI jobs on GitHub pass. - [x] The pull request introduces no new warnings and follows the project's style guidelines. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/mqt/qecc/codes/__init__.py | 7 + src/mqt/qecc/codes/concatenation.py | 152 ++++++++++ src/mqt/qecc/codes/constructions.py | 56 ++++ src/mqt/qecc/codes/css_code.py | 49 ++-- .../qecc/codes/{hamming => hamming_15}/hx.npy | Bin .../qecc/codes/{hamming => hamming_15}/hz.npy | Bin src/mqt/qecc/codes/pauli.py | 221 +++++++++++++++ src/mqt/qecc/codes/stabilizer_code.py | 263 +++++++----------- src/mqt/qecc/codes/symplectic.py | 161 +++++++++++ test/python/test_code.py | 257 +++++++++++++++-- 10 files changed, 958 insertions(+), 208 deletions(-) create mode 100644 src/mqt/qecc/codes/concatenation.py create mode 100644 src/mqt/qecc/codes/constructions.py rename src/mqt/qecc/codes/{hamming => hamming_15}/hx.npy (100%) rename src/mqt/qecc/codes/{hamming => hamming_15}/hz.npy (100%) create mode 100644 src/mqt/qecc/codes/pauli.py create mode 100644 src/mqt/qecc/codes/symplectic.py diff --git a/src/mqt/qecc/codes/__init__.py b/src/mqt/qecc/codes/__init__.py index dc4c6768..24a30451 100644 --- a/src/mqt/qecc/codes/__init__.py +++ b/src/mqt/qecc/codes/__init__.py @@ -4,6 +4,8 @@ from .bb_codes import construct_bb_code from .color_code import ColorCode, LatticeType +from .concatenation import ConcatenatedCode, ConcatenatedCSSCode +from .constructions import construct_iceberg_code, construct_many_hypercube_code, construct_quantum_hamming_code from .css_code import CSSCode, InvalidCSSCodeError from .hexagonal_color_code import HexagonalColorCode from .square_octagon_color_code import SquareOctagonColorCode @@ -12,6 +14,8 @@ __all__ = [ "CSSCode", "ColorCode", + "ConcatenatedCSSCode", + "ConcatenatedCode", "HexagonalColorCode", "InvalidCSSCodeError", "InvalidStabilizerCodeError", @@ -19,4 +23,7 @@ "SquareOctagonColorCode", "StabilizerCode", "construct_bb_code", + "construct_iceberg_code", + "construct_many_hypercube_code", + "construct_quantum_hamming_code", ] diff --git a/src/mqt/qecc/codes/concatenation.py b/src/mqt/qecc/codes/concatenation.py new file mode 100644 index 00000000..8e310115 --- /dev/null +++ b/src/mqt/qecc/codes/concatenation.py @@ -0,0 +1,152 @@ +"""Concatenated quantum codes.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from .css_code import CSSCode +from .pauli import Pauli +from .stabilizer_code import InvalidStabilizerCodeError, StabilizerCode +from .symplectic import SymplecticVector + +if TYPE_CHECKING: + from collections.abc import Sequence + + import numpy.typing as npt + + +class ConcatenatedCode(StabilizerCode): + """A concatenated quantum code.""" + + def __init__(self, outer_code: StabilizerCode, inner_code: StabilizerCode | Sequence[StabilizerCode]) -> None: + """Initialize a concatenated quantum code. + + Args: + outer_code: The outer code. + inner_code: The inner code. If a list of codes is provided, the qubits of the outer code are encoded by the different inner codes in the list. + """ + self.outer_code = outer_code + if isinstance(inner_code, list): + self.inner_codes = inner_code + else: + self.inner_codes = [inner_code] * outer_code.n + if not all(code.k == 1 for code in self.inner_codes): + msg = "The inner codes must be stabilizer codes with a single logical qubit." + raise InvalidStabilizerCodeError(msg) + + self.n = sum(code.n for code in self.inner_codes) + generators = [self._outer_pauli_to_physical(p) for p in outer_code.generators] + + x_logicals = None + z_logicals = None + if outer_code.x_logicals is not None: + x_logicals = [self._outer_pauli_to_physical(p) for p in outer_code.x_logicals] + if outer_code.z_logicals is not None: + z_logicals = [self._outer_pauli_to_physical(p) for p in outer_code.z_logicals] + + d = min(code.distance * outer_code.distance for code in self.inner_codes) + StabilizerCode.__init__(self, generators, d, x_logicals, z_logicals) + + def __eq__(self, other: object) -> bool: + """Check if two concatenated codes are equal.""" + if not isinstance(other, ConcatenatedCode): + return NotImplemented + return self.outer_code == other.outer_code and all( + c1 == c2 for c1, c2 in zip(self.inner_codes, other.inner_codes) + ) + + def __hash__(self) -> int: + """Compute the hash of the concatenated code.""" + return hash((self.outer_code, tuple(self.inner_codes))) + + def _outer_pauli_to_physical(self, p: Pauli) -> Pauli: + """Convert a Pauli operator on the outer code to the operator on the concatenated code. + + Args: + p: The Pauli operator. + + Returns: + The Pauli operator on the physical qubits. + """ + if len(p) != self.outer_code.n: + msg = "The Pauli operator must have the same number of qubits as the outer code." + raise InvalidStabilizerCodeError(msg) + concatenated = SymplecticVector.zeros(self.n) + phase = 0 + offset = 0 + for i in range(self.outer_code.n): + c = self.inner_codes[i] + new_offset = offset + c.n + assert c.x_logicals is not None + assert c.z_logicals is not None + if p[i] == "X": + concatenated[offset:new_offset] = c.x_logicals[0].x_part() + concatenated[offset + self.n : new_offset + self.n] = c.x_logicals[0].z_part() + phase += c.x_logicals[0].phase + elif p[i] == "Z": + concatenated[offset:new_offset] = c.z_logicals[0].x_part() + concatenated[offset + self.n : new_offset + self.n] = c.z_logicals[0].z_part() + phase += c.z_logicals[0].phase + + elif p[i] == "Y": + concatenated[offset:new_offset] = c.x_logicals[0].x_part ^ c.z_logicals[0].x_part() + concatenated[offset + self.n : new_offset + self.n] = c.x_logicals[0].z_part ^ c.z_logicals[0].z_part() + phase += c.x_logicals[0].phase + c.z_logicals[0].phase + + offset = new_offset + return Pauli(concatenated, phase) + + +# def _valid_logicals(lst: list[StabilizerTableau | None]) -> TypeGuard[list[StabilizerTableau]]: +# return None not in lst + + +class ConcatenatedCSSCode(ConcatenatedCode, CSSCode): + """A concatenated CSS code.""" + + def __init__(self, outer_code: CSSCode, inner_codes: CSSCode | Sequence[CSSCode]) -> None: + """Initialize a concatenated CSS code. + + Args: + outer_code: The outer code. + inner_codes: The inner code. If a list of codes is provided, the qubits of the outer code are encoded by the different inner codes in the list. + """ + # self.outer_code = outer_code + if isinstance(inner_codes, CSSCode): + inner_codes = [inner_codes] * outer_code.n + + if not all(code.k == 1 for code in inner_codes): + msg = "The inner codes must be CSS codes with a single logical qubit." + raise InvalidStabilizerCodeError(msg) + + ConcatenatedCode.__init__(self, outer_code, inner_codes) + hx = np.array([self._outer_checks_to_physical(check, "X") for check in outer_code.Hx], dtype=np.int8) + hz = np.array([self._outer_checks_to_physical(check, "Z") for check in outer_code.Hz], dtype=np.int8) + d = min(code.distance * outer_code.distance for code in inner_codes) + CSSCode.__init__(self, d, hx, hz) + + def _outer_checks_to_physical(self, check: npt.NDArray[np.int8], operator: str) -> npt.NDArray[np.int8]: + """Convert a check operator on the outer code to the operator on the concatenated code. + + Args: + check: The check operator. + operator: The type of operator to be converted. Either 'X' or 'Z'. + + Returns: + The check operator on the physical qubits. + """ + if check.shape[0] != self.outer_code.n: + msg = "The check operator must have the same number of qubits as the outer code." + raise InvalidStabilizerCodeError(msg) + concatenated = np.zeros((self.n), dtype=np.int8) + offset = 0 + for i in range(self.outer_code.n): + c = self.inner_codes[i] + new_offset = offset + c.n + if check[i] == 1: + logical = c.Lx if operator == "X" else c.Lz + concatenated[offset:new_offset] = logical + offset = new_offset + return concatenated diff --git a/src/mqt/qecc/codes/constructions.py b/src/mqt/qecc/codes/constructions.py new file mode 100644 index 00000000..036ca633 --- /dev/null +++ b/src/mqt/qecc/codes/constructions.py @@ -0,0 +1,56 @@ +"""Constructions of various known stabilizer codes.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from .css_code import CSSCode + +if TYPE_CHECKING: + import numpy.typing as npt + + +def construct_quantum_hamming_code(r: int) -> CSSCode: + """Return the [[2^r, 2^r-r-1, 3]] quantum Hamming code.""" + h = _hamming_code_checks(r) + return CSSCode(3, h, h) + + +def construct_iceberg_code(m: int) -> CSSCode: + """Return the [[2m, 2m-2, 2]] Iceberg code. + + The Iceberg code is a CSS code with stabilizer generators X^2m and Z^2m. + https://errorcorrectionzoo.org/c/iceberg + """ + n = 2 * m + h = np.array([[1] * n], dtype=np.int8) + return CSSCode(2, h, h) + + +def construct_many_hypercube_code(level: int) -> CSSCode: + """Return the [[6^l, 4^l, 2^l]] level l many-hypercube code (https://arxiv.org/abs/2403.16054). + + This code is obtained by (l-1)-fold concatenation of the [[6,4,2]] iceberg code with itself. + """ + code = construct_iceberg_code(3) + + for _ in range(1, level): + sx = np.hstack([code.Lx] * 6, dtype=np.int8) + sx_rem = np.kron(np.eye(6, dtype=np.int8), code.Hx) + sx = np.vstack((sx, sx_rem), dtype=np.int8) + sz = sx + code = CSSCode(code.distance * 2, sx, sz) + return code + + +def _hamming_code_checks(r: int) -> npt.NDArray[np.int8]: + """Return the check matrix for the [2^r-1, 2^r-r-1, 3] Hamming code.""" + n = 2**r - 1 + h = np.zeros((r, n), dtype=int) + # columns are all binary strings up to 2^r + for i in range(1, n + 1): + h[:, i - 1] = np.array([int(x) for x in f"{i:b}".zfill(r)]) + + return h diff --git a/src/mqt/qecc/codes/css_code.py b/src/mqt/qecc/codes/css_code.py index 5373ba0b..707b15fc 100644 --- a/src/mqt/qecc/codes/css_code.py +++ b/src/mqt/qecc/codes/css_code.py @@ -8,6 +8,7 @@ import numpy as np from ldpc import mod2 +from .pauli import StabilizerTableau from .stabilizer_code import StabilizerCode if TYPE_CHECKING: # pragma: no cover @@ -24,8 +25,21 @@ def __init__( Hz: npt.NDArray[np.int8] | None = None, # noqa: N803 x_distance: int | None = None, z_distance: int | None = None, + n: int | None = None, ) -> None: """Initialize the code.""" + if Hx is None and Hz is None: + if n is None: + msg = "If no check matrices are provided, the code size must be specified." + raise InvalidCSSCodeError(msg) + self.Hx = np.zeros((0, n), dtype=np.int8) + self.Hz = np.zeros((0, n), dtype=np.int8) + self.Lx = np.eye(n, dtype=np.int8) + self.Lz = np.eye(n, dtype=np.int8) + triv = StabilizerCode.get_trivial_code(n) + super().__init__(triv.generators, triv.distance, triv.x_logicals, triv.z_logicals) + return + self._check_valid_check_matrices(Hx, Hz) if Hx is None: @@ -46,8 +60,8 @@ def __init__( x_padded = np.hstack([self.Hx, z_padding]) z_padded = np.hstack([x_padding, self.Hz]) - phases = np.zeros((x_padded.shape[0] + z_padded.shape[0], 1), dtype=np.int8) - super().__init__(np.hstack((np.vstack((x_padded, z_padded)), phases)), distance) + phases = np.zeros((x_padded.shape[0] + z_padded.shape[0]), dtype=np.int8) + super().__init__(StabilizerTableau(np.vstack((x_padded, z_padded)), phases), distance) self.distance = distance self.x_distance = x_distance if x_distance is not None else distance @@ -88,14 +102,10 @@ def _compute_logical(m1: npt.NDArray[np.int8], m2: npt.NDArray[np.int8]) -> npt. def get_x_syndrome(self, error: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]: """Compute the x syndrome of the error.""" - if self.Hx is None: - return np.empty((0, error.shape[0]), dtype=np.int8) return self.Hx @ error % 2 def get_z_syndrome(self, error: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]: """Compute the z syndrome of the error.""" - if self.Hz is None: - return np.empty((0, error.shape[0]), dtype=np.int8) return self.Hz @ error % 2 def check_if_logical_x_error(self, residual: npt.NDArray[np.int8]) -> bool: @@ -104,21 +114,19 @@ def check_if_logical_x_error(self, residual: npt.NDArray[np.int8]) -> bool: def check_if_x_stabilizer(self, pauli: npt.NDArray[np.int8]) -> bool: """Check if the Pauli is a stabilizer.""" - assert self.Hx is not None return bool(mod2.rank(np.vstack((self.Hx, pauli))) == mod2.rank(self.Hx)) def check_if_logical_z_error(self, residual: npt.NDArray[np.int8]) -> bool: """Check if the residual is a logical error.""" - return bool((self.Lx @ residual % 2 == 1).any()) + return (self.Hx.shape[0] != 0) and bool((self.Lx @ residual % 2 == 1).any()) def check_if_z_stabilizer(self, pauli: npt.NDArray[np.int8]) -> bool: """Check if the Pauli is a stabilizer.""" - assert self.Hz is not None - return bool(mod2.rank(np.vstack((self.Hz, pauli))) == mod2.rank(self.Hz)) + return (self.Hz.shape[0] != 0) and bool(mod2.rank(np.vstack((self.Hz, pauli))) == mod2.rank(self.Hz)) def stabilizer_eq_x_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDArray[np.int8]) -> bool: """Check if two X errors are in the same coset.""" - if self.Hx is None: + if self.Hx.shape[0] == 0: return bool(np.array_equal(error_1, error_2)) m1 = np.vstack([self.Hx, error_1]) m2 = np.vstack([self.Hx, error_2]) @@ -127,7 +135,7 @@ def stabilizer_eq_x_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDAr def stabilizer_eq_z_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDArray[np.int8]) -> bool: """Check if two Z errors are in the same coset.""" - if self.Hz is None: + if self.Hz.shape[0] == 0: return bool(np.array_equal(error_1, error_2)) m1 = np.vstack([self.Hz, error_1]) m2 = np.vstack([self.Hz, error_2]) @@ -136,8 +144,6 @@ def stabilizer_eq_z_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDAr def is_self_dual(self) -> bool: """Check if the code is self-dual.""" - if self.Hx is None or self.Hz is None: - return False return bool( self.Hx.shape[0] == self.Hz.shape[0] and mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, self.Hz])) ) @@ -145,10 +151,6 @@ def is_self_dual(self) -> bool: @staticmethod def _check_valid_check_matrices(Hx: npt.NDArray[np.int8] | None, Hz: npt.NDArray[np.int8] | None) -> None: # noqa: N803 """Check if the code is a valid CSS code.""" - if Hx is None and Hz is None: - msg = "At least one of the check matrices must be provided" - raise InvalidCSSCodeError(msg) - if Hx is not None and Hz is not None: if Hx.shape[1] != Hz.shape[1]: msg = "Check matrices must have the same number of columns" @@ -157,6 +159,11 @@ def _check_valid_check_matrices(Hx: npt.NDArray[np.int8] | None, Hz: npt.NDArray msg = "The check matrices must be orthogonal" raise InvalidCSSCodeError(msg) + @classmethod + def get_trivial_code(cls, n: int) -> CSSCode: + """Return the trivial code.""" + return CSSCode(1, None, None, n=n) + @staticmethod def from_code_name(code_name: str, distance: int | None = None) -> CSSCode: r"""Return CSSCode object for a known code. @@ -164,11 +171,11 @@ def from_code_name(code_name: str, distance: int | None = None) -> CSSCode: The following codes are supported: - [[7, 1, 3]] Steane (\"Steane\") - [[15, 1, 3]] tetrahedral code (\"Tetrahedral\") - - [[15, 7, 3]] Hamming code (\"Hamming\") - [[9, 1, 3]] Shore code (\"Shor\") - [[12, 2, 4]] Carbon Code (\"Carbon\") - [[9, 1, 3]] rotated surface code (\"Surface, 3\"), also default when no distance is given - [[25, 1, 5]] rotated surface code (\"Surface, 5\") + - [[15, 7, 3]] Hamming code (\"Hamming\") - [[23, 1, 7]] golay code (\"Golay\") Args: @@ -179,23 +186,23 @@ def from_code_name(code_name: str, distance: int | None = None) -> CSSCode: paths = { "steane": prefix / "steane/", "tetrahedral": prefix / "tetrahedral/", - "hamming": prefix / "hamming/", "shor": prefix / "shor/", "surface_3": prefix / "rotated_surface_d3/", "surface_5": prefix / "rotated_surface_d5/", "golay": prefix / "golay/", "carbon": prefix / "carbon/", + "hamming": prefix / "hamming_15/", } distances = { "steane": (3, 3), "tetrahedral": (7, 3), - "hamming": (3, 3), "shor": (3, 3), "golay": (7, 7), "surface_3": (3, 3), "surface_5": (5, 5), "carbon": (4, 4), + "hamming": (3, 3), } # X, Z distances code_name = code_name.lower() diff --git a/src/mqt/qecc/codes/hamming/hx.npy b/src/mqt/qecc/codes/hamming_15/hx.npy similarity index 100% rename from src/mqt/qecc/codes/hamming/hx.npy rename to src/mqt/qecc/codes/hamming_15/hx.npy diff --git a/src/mqt/qecc/codes/hamming/hz.npy b/src/mqt/qecc/codes/hamming_15/hz.npy similarity index 100% rename from src/mqt/qecc/codes/hamming/hz.npy rename to src/mqt/qecc/codes/hamming_15/hz.npy diff --git a/src/mqt/qecc/codes/pauli.py b/src/mqt/qecc/codes/pauli.py new file mode 100644 index 00000000..7d94843f --- /dev/null +++ b/src/mqt/qecc/codes/pauli.py @@ -0,0 +1,221 @@ +"""Class for working with representations of Pauli operators.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from .symplectic import SymplecticMatrix, SymplecticVector + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + import numpy.typing as npt + + +class Pauli: + """Class representing an n-qubit Pauli operator.""" + + def __init__(self, symplectic: SymplecticVector, phase: int = 0) -> None: + """Create a new Pauli operator. + + Args: + symplectic: A 2n x n binary matrix representing the symplectic form of the Pauli operator. The first n entries correspond to X operators, and the second n entries correspond to Z operators. + phase: An integer 0 or 1 representing the phase of the Pauli operator (0 for +, 1 for -). + """ + self.n = symplectic.n + self.symplectic = symplectic + self.phase = phase + + @classmethod + def from_pauli_string(cls, p: str) -> Pauli: + """Create a new Pauli operator from a Pauli string.""" + if not is_pauli_string(p): + msg = f"Invalid Pauli string: {p}" + raise InvalidPauliError(msg) + pauli_start_index = 1 if p[0] in "+-" else 0 + x_part = np.array([c in "XY" for c in p[pauli_start_index:]]).astype(np.int8) + z_part = np.array([c in "ZY" for c in p[pauli_start_index:]]).astype(np.int8) + phase = int(p[0] == "-") + return cls(SymplecticVector(np.concatenate((x_part, z_part))), phase) + + def commute(self, other: Pauli) -> bool: + """Check if this Pauli operator commutes with another Pauli operator.""" + return self.symplectic @ other.symplectic == 0 + + def anticommute(self, other: Pauli) -> bool: + """Check if this Pauli operator anticommutes with another Pauli operator.""" + return not self.commute(other) + + def __mul__(self, other: Pauli) -> Pauli: + """Multiply this Pauli operator by another Pauli operator.""" + if self.n != other.n: + msg = "Pauli operators must have the same number of qubits." + raise InvalidPauliError(msg) + return Pauli(self.symplectic + other.symplectic, (self.phase + other.phase) % 2) + + def __repr__(self) -> str: + """Return a string representation of the Pauli operator.""" + x_part = self.symplectic[: self.n] + z_part = self.symplectic[self.n :] + pauli = [ + "X" if x and not z else "Z" if z and not x else "Y" if x and z else "I" for x, z in zip(x_part, z_part) + ] + return f"{'' if self.phase == 0 else '-'}" + "".join(pauli) + + def as_vector(self) -> npt.NDArray[np.int8]: + """Convert the Pauli operator to a binary vector.""" + return np.concatenate((self.symplectic.vector, np.array([self.phase]))) + + def __len__(self) -> int: + """Return the number of qubits in the Pauli operator.""" + return int(self.n) + + def __getitem__(self, key: int) -> str: + """Return the Pauli operator for a single qubit.""" + if key < 0 or key >= self.n: + msg = "Index out of range." + raise IndexError(msg) + x = self.symplectic[key] + z = self.symplectic[key + self.n] + return "X" if x and not z else "Z" if z and not x else "Y" if x and z else "I" + + def x_part(self) -> npt.NDArray[np.int8]: + """Return the X part of the Pauli operator.""" + return self.symplectic[: self.n] + + def z_part(self) -> npt.NDArray[np.int8]: + """Return the Z part of the Pauli operator.""" + return self.symplectic[self.n :] + + def __eq__(self, other: object) -> bool: + """Check if this Pauli operator is equal to another Pauli operator.""" + if not isinstance(other, Pauli): + return False + return self.symplectic == other.symplectic and self.phase == other.phase + + def __ne__(self, other: object) -> bool: + """Check if this Pauli operator is not equal to another Pauli operator.""" + return not self == other + + def __neg__(self) -> Pauli: + """Return the negation of this Pauli operator.""" + return Pauli(self.symplectic, 1 - self.phase) + + def __hash__(self) -> int: + """Return a hash of the Pauli operator.""" + return hash((self.symplectic, self.phase)) + + +class StabilizerTableau: + """Class representing a stabilizer tableau.""" + + def __init__(self, tableau: SymplecticMatrix | npt.NDArray[np.int8], phase: npt.NDArray[np.int8]) -> None: + """Create a new stabilizer tableau. + + Args: + tableau: Symplectic matrix representing the stabilizer tableau. + phase: An n x 1 binary vector representing the phase of the stabilizer tableau. + """ + if isinstance(tableau, np.ndarray): + self.tableau = SymplecticMatrix(tableau) + else: + self.tableau = tableau + if self.tableau.shape[0] != phase.shape[0]: + msg = "The number of rows in the tableau must match the number of phases." + raise InvalidPauliError(msg) + self.n = self.tableau.n + self.n_rows = self.tableau.shape[0] + self.phase = phase + self.shape = (self.n_rows, self.n) + + @classmethod + def from_paulis(cls, paulis: Sequence[Pauli]) -> StabilizerTableau: + """Create a new stabilizer tableau from a list of Pauli operators.""" + if len(paulis) == 0: + msg = "At least one Pauli operator is required." + raise InvalidPauliError(msg) + n = paulis[0].n + if not all(p.n == n for p in paulis): + msg = "All Pauli operators must have the same number of qubits." + raise InvalidPauliError(msg) + mat = SymplecticMatrix.zeros(len(paulis), n) + phase = np.zeros((len(paulis)), dtype=np.int8) + for i, p in enumerate(paulis): + mat[i] = p.symplectic.vector + phase[i] = p.phase + return cls(mat, phase) + + @classmethod + def from_pauli_strings(cls, pauli_strings: Sequence[str]) -> StabilizerTableau: + """Create a new stabilizer tableau from a list of Pauli strings.""" + if len(pauli_strings) == 0: + msg = "At least one Pauli string is required." + raise InvalidPauliError(msg) + + paulis = [Pauli.from_pauli_string(p) for p in pauli_strings] + return cls.from_paulis(paulis) + + @classmethod + def empty(cls, n: int) -> StabilizerTableau: + """Create a new empty stabilizer tableau.""" + return cls(SymplecticMatrix.empty(n), np.zeros(0, dtype=np.int8)) + + def __eq__(self, other: object) -> bool: + """Check if two stabilizer tableaus are equal.""" + if isinstance(other, list): + if len(other) != self.n_rows: + return False + if isinstance(other[0], Pauli): + other = StabilizerTableau.from_paulis(other) + elif isinstance(other[0], str): + other = StabilizerTableau.from_pauli_strings(other) + else: + return False + + if not isinstance(other, StabilizerTableau): + return False + return bool(self.tableau == other.tableau and np.all(self.phase == other.phase)) + + def __ne__(self, other: object) -> bool: + """Check if two stabilizer tableaus are not equal.""" + return not self == other + + def __len__(self) -> int: + """Return the number of Paulis in the tableau.""" + return int(len(self.tableau)) + + def all_commute(self, other: StabilizerTableau) -> bool: + """Check if all Pauli operators in this stabilizer tableau commute with all Pauli operators in another stabilizer tableau.""" + return bool(np.all((self.tableau @ other.tableau).matrix == 0)) + + def __getitem__(self, key: int) -> Pauli: + """Get a Pauli operator from the stabilizer tableau.""" + return Pauli(SymplecticVector(self.tableau[key]), self.phase[key]) + + def __hash__(self) -> int: + """Compute the hash of the stabilizer tableau.""" + return hash((self.tableau, self.phase)) + + def __iter__(self) -> Iterator[Pauli]: + """Iterate over the Pauli operators in the stabilizer tableau.""" + for i in range(self.n_rows): + yield self[i] + + def as_matrix(self) -> npt.NDArray[np.int8]: + """Convert the stabilizer tableau to a binary matrix.""" + return np.hstack((self.tableau.matrix, self.phase[..., np.newaxis])) + + +def is_pauli_string(p: str) -> bool: + """Check if a string is a valid Pauli string.""" + return len(p) > 0 and all(c in {"I", "X", "Y", "Z"} for c in p[1:]) and p[0] in {"+", "-", "I", "X", "Y", "Z"} + + +class InvalidPauliError(ValueError): + """Exception raised when an invalid Pauli operator is encountered.""" + + def __init__(self, message: str) -> None: + """Create a new InvalidPauliError.""" + super().__init__(message) diff --git a/src/mqt/qecc/codes/stabilizer_code.py b/src/mqt/qecc/codes/stabilizer_code.py index 88832ea8..6a9744c0 100644 --- a/src/mqt/qecc/codes/stabilizer_code.py +++ b/src/mqt/qecc/codes/stabilizer_code.py @@ -2,27 +2,15 @@ from __future__ import annotations -import sys from typing import TYPE_CHECKING import numpy as np -import numpy.typing as npt - -if sys.version_info >= (3, 10): - from typing import TypeAlias - - Pauli: TypeAlias = npt.NDArray[np.int8] | list[str] -else: - from typing import Union - - from typing_extensions import TypeAlias - - Pauli: TypeAlias = Union[npt.NDArray[np.int8], list[str]] - from ldpc import mod2 +from .pauli import Pauli, StabilizerTableau + if TYPE_CHECKING: - from collections.abc import Iterable + import numpy.typing as npt class StabilizerCode: @@ -30,25 +18,48 @@ class StabilizerCode: def __init__( self, - generators: npt.NDArray | list[str], + generators: StabilizerTableau | list[Pauli] | list[str], distance: int | None = None, - Lz: Pauli | None = None, # noqa: N803 - Lx: Pauli | None = None, # noqa: N803 + z_logicals: StabilizerTableau | list[Pauli] | list[str] | None = None, + x_logicals: StabilizerTableau | list[Pauli] | list[str] | None = None, + n: int | None = None, ) -> None: """Initialize the code. Args: - generators: The stabilizer generators of the code. Qiskit has a reverse order of qubits in PauliList. We assume that stabilizers are ordered from left to right in ascending order of qubits. + generators: The stabilizer generators of the code. We assume that stabilizers are ordered from left to right in ascending order of qubits. distance: The distance of the code. - Lz: The logical Z-operators. - Lx: The logical X-operators. + z_logicals: The logical Z-operators. + x_logicals: The logical X-operators. + n: The number of qubits in the code. If not given, it is inferred from the stabilizer generators. """ - self._check_stabilizer_generators(generators) - self.n = get_n_qubits_from_pauli(generators[0]) - self.generators = paulis_to_binary(generators) - self.symplectic_matrix = self.generators[:, :-1] # discard the phase - self.phases = self.generators[:, -1] - self.k = self.n - mod2.rank(self.generators) + # if len(generators) == 0: + # if n is None: + # raise ValueError("Number of qubits must be given if no stabilizer generators are given.") + # if z_logicals is None or x_logicals is None: + # t = StabilizerCode.get_trivial_code(n) + # print(type(self)) + # self.n = n + # self.x_logicals = t.x_logicals + # self.z_logicals = t.z_logicals + # self.k = t.k + # self.generators = t.generators + # self.symplectic = t.symplectic + # self.distance = t.distance + # self._check_code_correct() + + self.generators = self.get_generators(generators, n) + self.symplectic = self.generators.tableau.matrix + + if n is None: + self.n = self.generators.n + else: + self.n = n + + if self.generators.n_rows != 0: + self.k = self.n - mod2.rank(self.generators.as_matrix()) + else: + self.k = self.n if distance is not None and distance <= 0: msg = "Distance must be a positive integer." @@ -56,195 +67,127 @@ def __init__( self.distance = 1 if distance is None else distance # default distance is 1 - if Lz is not None: - self.Lz = paulis_to_binary(Lz) - self.Lz_symplectic = self.Lz[:, :-1] - else: - self.Lz = None - self.Lz_symplectic = None + self.z_logicals = None + self.x_logicals = None - if Lx is not None: - self.Lx = paulis_to_binary(Lx) - self.Lx_symplectic = self.Lx[:, :-1] - else: - self.Lx = None - self.Lx_symplectic = None + if z_logicals is not None: + self.z_logicals = self.get_generators(z_logicals) + if x_logicals is not None: + self.x_logicals = self.get_generators(x_logicals) self._check_code_correct() def __hash__(self) -> int: """Compute a hash for the stabilizer code.""" - return hash(int.from_bytes(self.generators.tobytes(), sys.byteorder)) + return hash(self.generators) def __eq__(self, other: object) -> bool: """Check if two stabilizer codes are equal.""" if not isinstance(other, StabilizerCode): return NotImplemented - rnk = mod2.rank(self.generators) + rnk = mod2.rank(self.generators.as_matrix()) return bool( - rnk == mod2.rank(other.generators) and rnk == mod2.rank(np.vstack((self.generators, other.generators))) + rnk == mod2.rank(other.generators.as_matrix()) + and rnk == mod2.rank(np.vstack((self.generators.as_matrix(), other.generators.as_matrix()))) ) - def get_syndrome(self, error: Pauli) -> npt.NDArray: + def get_syndrome(self, error: Pauli | str) -> npt.NDArray: """Compute the syndrome of the error. Args: error: The error as a pauli string or binary vector. """ - return symplectic_matrix_mul(self.symplectic_matrix, pauli_to_symplectic_vec(error)) + if isinstance(error, str): + error = Pauli.from_pauli_string(error) + return (self.generators.tableau @ error.symplectic).vector def stabs_as_pauli_strings(self) -> list[str]: """Return the stabilizers as Pauli strings.""" - return [binary_to_pauli_string(s) for s in self.generators] + return [str(p) for p in self.generators] - def stabilizer_equivalent(self, p1: Pauli, p2: Pauli) -> bool: + def stabilizer_equivalent(self, p1: Pauli | str, p2: Pauli | str) -> bool: """Check if two Pauli strings are equivalent up to stabilizers of the code.""" - v1 = pauli_to_binary(p1) - v2 = pauli_to_binary(p2) - return bool(mod2.rank(np.vstack((self.generators, v1, v2))) == mod2.rank(np.vstack((self.generators, v1)))) - - @staticmethod - def _check_stabilizer_generators(generators: npt.NDArray[np.int8] | list[str]) -> None: - """Check if the stabilizer generators are valid. Throws an exception if not.""" - if len(generators) == 0: - msg = "Stabilizer code must have at least one generator." - raise InvalidStabilizerCodeError(msg) - if not all(len(generators[0]) == len(g) for g in generators): - msg = "All stabilizer generators must have the same length." - raise InvalidStabilizerCodeError(msg) - - if not isinstance(generators[0], str): - return - - if not all(is_pauli_string(g) for g in generators): - msg = "When providing stabilizer generators as strings, they must be valid Pauli strings." - raise InvalidStabilizerCodeError(msg) + if isinstance(p1, str): + p1 = Pauli.from_pauli_string(p1) + if isinstance(p2, str): + p2 = Pauli.from_pauli_string(p2) + return bool( + mod2.rank(np.vstack((self.generators.as_matrix(), p1.as_vector(), p2.as_vector()))) + == mod2.rank(np.vstack((self.generators.as_matrix(), p1.as_vector()))) + ) def _check_code_correct(self) -> None: """Check if the code is correct. Throws an exception if not.""" - if self.Lz is not None or self.Lx is not None: - if self.Lz is None: + if self.z_logicals is not None or self.x_logicals is not None: + if self.z_logicals is None: msg = "If logical X-operators are given, logical Z-operators must also be given." raise InvalidStabilizerCodeError(msg) - if self.Lx is None: + if self.x_logicals is None: msg = "If logical Z-operators are given, logical X-operators must also be given." raise InvalidStabilizerCodeError(msg) - if self.Lz is None: + if self.z_logicals is None: + return + + if self.x_logicals is None: return - if get_n_qubits_from_pauli(self.Lz[0]) != self.n: + if self.z_logicals.n != self.n: msg = "Logical operators must have the same number of qubits as the stabilizer generators." raise InvalidStabilizerCodeError(msg) - if self.Lz.shape[0] > self.k: + if self.z_logicals.shape[0] > self.k: msg = "Number of logical Z-operators must be at most the number of logical qubits." raise InvalidStabilizerCodeError(msg) - if get_n_qubits_from_pauli(self.Lx[0]) != self.n: + if self.x_logicals.n != self.n: msg = "Logical operators must have the same number of qubits as the stabilizer generators." raise InvalidStabilizerCodeError(msg) - if self.Lx.shape[0] > self.k: + if self.x_logicals.shape[0] > self.k: msg = "Number of logical X-operators must be at most the number of logical qubits." raise InvalidStabilizerCodeError(msg) - if not all_commute(self.Lz_symplectic, self.symplectic_matrix): + if not self.z_logicals.all_commute(self.generators): msg = "Logical Z-operators must anti-commute with the stabilizer generators." raise InvalidStabilizerCodeError(msg) - if not all_commute(self.Lx_symplectic, self.symplectic_matrix): + if not self.x_logicals.all_commute(self.generators): msg = "Logical X-operators must commute with the stabilizer generators." raise InvalidStabilizerCodeError(msg) - commutations = symplectic_matrix_product(self.Lz_symplectic, self.Lx_symplectic) + commutations = (self.z_logicals.tableau @ self.x_logicals.tableau).matrix if not np.all(np.sum(commutations, axis=1) == 1): msg = "Every logical X-operator must anti-commute with exactly one logical Z-operator." raise InvalidStabilizerCodeError(msg) + @staticmethod + def get_generators( + generators: StabilizerTableau | list[Pauli] | list[str], n: int | None = None + ) -> StabilizerTableau: + """Get the stabilizer generators as a StabilizerTableau object. -def pauli_to_binary(p: Pauli) -> npt.NDArray: - """Convert a Pauli string to a binary array.""" - if isinstance(p, np.ndarray): - return p - - # check if there is a sign - phase = 0 - if p[0] in {"+", "-"}: - phase = 0 if p[0] == "+" else 1 - p = p[1:] - x_part = np.array([int(p == "X") for p in p]) - z_part = np.array([int(p == "Z") for p in p]) - y_part = np.array([int(p == "Y") for p in p]) - x_part += y_part - z_part += y_part - return np.hstack((x_part, z_part, np.array([phase]))) - - -def paulis_to_binary(ps: Iterable[Pauli]) -> npt.NDArray: - """Convert a list of Pauli strings to a 2d binary array.""" - return np.array([pauli_to_binary(p) for p in ps]) - - -def binary_to_pauli_string(b: npt.NDArray) -> str: - """Convert a binary array to a Pauli string.""" - x_part = b[: len(b) // 2] - z_part = b[len(b) // 2 : -1] - phase = b[-1] - - pauli = ["X" if x and not z else "Z" if z and not x else "Y" if x and z else "I" for x, z in zip(x_part, z_part)] - return f"{'' if phase == 0 else '-'}" + "".join(pauli) - - -def is_pauli_string(p: str) -> bool: - """Check if a string is a valid Pauli string.""" - return len(p) > 0 and all(c in {"I", "X", "Y", "Z"} for c in p[1:]) and p[0] in {"+", "-", "I", "X", "Y", "Z"} - - -def get_n_qubits_from_pauli(p: Pauli) -> int: - """Get the number of qubits from a Pauli string.""" - if isinstance(p, np.ndarray): - return int(p.shape[0] // 2) - if p[0] in {"+", "-"}: - return len(p) - 1 - return len(p) - - -def commute(p1: npt.NDArray[np.int8], p2: npt.NDArray[np.int8]) -> bool: - """Check if two Paulistrings in binary representation commute.""" - return bool(symplectic_inner_product(p1, p2) == 0) - - -def anti_commute(p1: npt.NDArray[np.int8], p2: npt.NDArray[np.int8]) -> bool: - """Check if two Paulistrings in binary representation anti-commute.""" - return not commute(p1, p2) - - -def all_commute(ps1: npt.NDArray[np.int8], ps2: npt.NDArray[np.int8]) -> bool: - """Check if all Paulistrings in binary representation commute.""" - return bool((symplectic_matrix_product(ps1, ps2) == 0).all()) - - -def symplectic_inner_product(p1: npt.NDArray[np.int8], p2: npt.NDArray[np.int8]) -> int: - """Compute the symplectic inner product of two symplectic vectors.""" - n = p1.shape[0] // 2 - return int((p1[:n] @ p2[n:] + p1[n:] @ p2[:n]) % 2) - - -def symplectic_matrix_product(m1: npt.NDArray[np.int8], m2: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]: - """Compute the symplectic matrix product of two symplectic matrices.""" - n = m1.shape[1] // 2 - return ((m1[:, :n] @ m2[:, n:].T) + (m1[:, n:] @ m2[:, :n].T)) % 2 - - -def symplectic_matrix_mul(m: npt.NDArray[np.int8], v: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]: - """Compute the symplectic matrix product of symplectic matrix with symplectic vector.""" - n = m.shape[1] // 2 - return (m[:, :n] @ v[n:] + m[:, n:] @ v[:n]) % 2 - - -def pauli_to_symplectic_vec(p: Pauli) -> npt.NDArray: - """Convert a Pauli string to a symplectic vector.""" - return pauli_to_binary(p)[:-1] + Args: + generators: The stabilizer generators as a StabilizerTableau object, a list of Pauli objects, or a list of Pauli strings. + n: The number of qubits in the code. Required if generators is an empty list. + """ + if isinstance(generators, list): + if len(generators) == 0: + if n is None: + msg = "Number of qubits must be given if no generators are provided." + raise ValueError(msg) + return StabilizerTableau.empty(n) + if isinstance(generators[0], str): + return StabilizerTableau.from_pauli_strings(generators) # type: ignore[arg-type] + if isinstance(generators[0], Pauli): + return StabilizerTableau.from_paulis(generators) # type: ignore[arg-type] + return generators + + @classmethod + def get_trivial_code(cls, n: int) -> StabilizerCode: + """Get the trivial stabilizer code.""" + z_logicals = ["I" * i + "Z" + "I" * (n - i - 1) for i in range(n)] + x_logicals = ["I" * i + "X" + "I" * (n - i - 1) for i in range(n)] + return StabilizerCode([], distance=1, z_logicals=z_logicals, x_logicals=x_logicals, n=n) class InvalidStabilizerCodeError(ValueError): diff --git a/src/mqt/qecc/codes/symplectic.py b/src/mqt/qecc/codes/symplectic.py new file mode 100644 index 00000000..dfb37573 --- /dev/null +++ b/src/mqt/qecc/codes/symplectic.py @@ -0,0 +1,161 @@ +"""Classes and Methods for working with symplectic vector spaces.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +# from ldpc import mod2 + +if TYPE_CHECKING: + from typing import Any + + import numpy.typing as npt + + +class SymplecticVector: + """Symplectic Vector Class.""" + + def __init__(self, vector: npt.NDArray[np.int8]) -> None: + """Initialize the Symplectic Vector.""" + self.vector = vector + self.n = vector.shape[0] // 2 + + @classmethod + def zeros(cls, n: int) -> SymplecticVector: + """Create a zero vector of length n.""" + return cls(np.zeros(2 * n, dtype=np.int8)) + + @classmethod + def ones(cls, n: int) -> SymplecticVector: + """Create a ones vector of length n.""" + return cls(np.ones(2 * n, dtype=np.int8)) + + def __add__(self, other: SymplecticVector) -> SymplecticVector: + """Add two symplectic vectors.""" + return SymplecticVector((self.vector + other.vector) % 2) + + def __sub__(self, other: SymplecticVector) -> SymplecticVector: + """Subtract two symplectic vectors.""" + return SymplecticVector((self.vector - other.vector) % 2) + + def __neg__(self) -> SymplecticVector: + """Negate the vector.""" + return SymplecticVector(-self.vector) + + def __matmul__(self, other: SymplecticVector) -> int: + """Compute the symplectic inner product.""" + assert self.n == other.n, "Vectors must be of the same length." + return int( + (self.vector[: self.n] @ other.vector[self.n :] - self.vector[self.n :] @ other.vector[: self.n]) % 2 + ) + + def __getitem__(self, key: int | slice) -> Any: # noqa: ANN401 + """Get the value of the vector at index key.""" + return self.vector[key] + + def __setitem__(self, key: int | slice, value: int) -> None: + """Set the value of the vector at index key.""" + self.vector[key] = value + + def __eq__(self, other: object) -> bool: + """Check if two vectors are equal.""" + if not isinstance(other, SymplecticVector): + return False + return np.array_equal(self.vector, other.vector) + + def __ne__(self, other: object) -> bool: + """Check if two vectors are not equal.""" + return not self == other + + def __hash__(self) -> int: + """Return the hash of the vector.""" + return hash(self.vector.to_bytes()) + + +class SymplecticMatrix: + """Symplectic Matrix Class.""" + + def __init__(self, matrix: npt.NDArray[np.int8]) -> None: + """Initialize the Symplectic Matrix.""" + assert matrix.ndim == 2, "Matrix must be 2D." + self.matrix = matrix + self.n = matrix.shape[1] // 2 + self.shape = matrix.shape + + def transpose(self) -> SymplecticMatrix: + """Return the transpose of the matrix.""" + return SymplecticMatrix(self.matrix.T) + + @classmethod + def zeros(cls, n_rows: int, n: int) -> SymplecticMatrix: + """Create a zero matrix of size n.""" + return cls(np.zeros((n_rows, 2 * n), dtype=np.int8)) + + @classmethod + def identity(cls, n: int) -> SymplecticMatrix: + """Create the identity matrix of size n.""" + return cls( + np.block([ + [np.zeros((n, n), dtype=np.int8), np.eye(n, dtype=np.int8)], + [np.eye(n, dtype=np.int8), np.zeros((n, n), dtype=np.int8)], + ]) + ) + + @classmethod + def empty(cls, n: int) -> SymplecticMatrix: + """Create an empty matrix of size n.""" + return cls(np.empty((0, 2 * n), dtype=np.int8)) + + def __add__(self, other: SymplecticMatrix) -> SymplecticMatrix: + """Add two symplectic matrices.""" + return SymplecticMatrix((self.matrix + other.matrix) % 2) + + def __sub__(self, other: SymplecticMatrix) -> SymplecticMatrix: + """Subtract two symplectic matrices.""" + return SymplecticMatrix((self.matrix - other.matrix) % 2) + + def __matmul__(self, other: SymplecticMatrix | SymplecticVector) -> Any: # noqa: ANN401 + """Compute the symplectic product of two matrices.""" + assert self.n == other.n, "Matrices must be of the same size." + n = self.n + if isinstance(other, SymplecticVector): + return SymplecticVector((self.matrix[:, :n] @ other[n:] + self.matrix[:, n:] @ other[:n]) % 2) + m1 = self.matrix + m2 = other.matrix + return SymplecticMatrix(((m1[:, :n] @ m2[:, n:].T) + (m1[:, n:] @ m2[:, :n].T)) % 2) + + def __getitem__(self, key: tuple[int, int] | int | slice) -> Any: # noqa: ANN401 + """Get the value of the matrix at index key.""" + return self.matrix[key] + + def __setitem__(self, key: tuple[int, int] | int | slice, value: npt.NDArray[np.int8]) -> None: + """Set the value of the matrix at index key.""" + self.matrix[key] = value + + def __repr__(self) -> str: + """Return the string representation of the matrix.""" + return str(self.matrix.__repr__()) + + def __iter__(self) -> npt.NDArray[np.int8]: + """Iterate over the rows of the matrix.""" + return self.matrix.__iter__() + + def __eq__(self, other: object) -> bool: + """Check if two matrices are equal.""" + if not isinstance(other, SymplecticMatrix): + return False + return np.array_equal(self.matrix, other.matrix) + + def __ne__(self, other: object) -> bool: + """Check if two matrices are not equal.""" + return not self == other + + def __hash__(self) -> int: + """Return the hash of the matrix.""" + return hash(self.matrix.tobytes()) + + def __len__(self) -> int: + """Return the number of rows in the matrix.""" + return len(self.matrix) diff --git a/test/python/test_code.py b/test/python/test_code.py index ad6a43a2..41a74fc1 100644 --- a/test/python/test_code.py +++ b/test/python/test_code.py @@ -8,16 +8,123 @@ import pytest from mqt.qecc import CSSCode, StabilizerCode -from mqt.qecc.codes import InvalidCSSCodeError, InvalidStabilizerCodeError, construct_bb_code +from mqt.qecc.codes import ( + ConcatenatedCode, + ConcatenatedCSSCode, + InvalidCSSCodeError, + InvalidStabilizerCodeError, + construct_bb_code, + construct_iceberg_code, + construct_many_hypercube_code, + construct_quantum_hamming_code, +) +from mqt.qecc.codes.pauli import InvalidPauliError, Pauli, StabilizerTableau +from mqt.qecc.codes.symplectic import SymplecticMatrix, SymplecticVector if TYPE_CHECKING: # pragma: no cover import numpy.typing as npt +def test_pauli() -> None: + """Test the Pauli class.""" + p1 = Pauli.from_pauli_string("XIZ") + p2 = Pauli(SymplecticVector(np.array([1, 0, 0, 0, 0, 1]))) + assert p1 == p2 + p3 = p1 * p2 + assert p3 == Pauli.from_pauli_string("III") + p4 = Pauli.from_pauli_string("-X") + p5 = Pauli.from_pauli_string("+Z") + p6 = Pauli.from_pauli_string("Y") + assert p4 * p5 != p6 + assert p4 * p5 == -p6 + + assert np.array_equal(p1.x_part(), np.array([1, 0, 0])) + assert np.array_equal(p1.z_part(), np.array([0, 0, 1])) + assert np.array_equal(p6.x_part(), np.array([1])) + assert np.array_equal(p6.z_part(), np.array([1])) + assert len(p1) == 3 + assert len(p6) == 1 + + assert p4.anticommute(p5) + p7 = Pauli.from_pauli_string("XI") + p8 = Pauli.from_pauli_string("IZ") + assert p8.commute(p7) + + with pytest.raises(IndexError): + p1[3] + + +def test_symplectic() -> None: + """Test the SymplecticMatrix and SymplecticVector classes.""" + ones = SymplecticVector.ones(3) + zeros = SymplecticVector.zeros(3) + assert ones - ones == zeros + assert ones + ones == zeros + + v = SymplecticVector(np.array([1, 0, 0, 0, 0, 1])) + w = SymplecticVector(np.array([0, 1, 0, 0, 0, 1])) + assert w + v == v + w + assert w - v == -v + w + + obj = "abc" + assert v != obj + + assert v @ w == 0 + u = SymplecticVector(np.array([0, 0, 1, 0, 0, 0])) + assert v @ u == 1 + + eye = SymplecticMatrix.identity(3) + zero_mat = SymplecticMatrix.zeros(6, 3) + assert eye + eye == zero_mat + assert eye - eye == zero_mat + + vs = [v.vector, w.vector, u.vector, ones.vector, zeros.vector, v.vector] + m = SymplecticMatrix(np.array(vs)) + assert eye @ m.transpose() == m + assert m @ eye == m + + for i, row in enumerate(m): + assert np.array_equal(row, vs[i]) + + assert m != obj + assert len(m) == 6 + assert m.shape == (6, 6) + assert m.n == 3 + + +def test_stabilizer_tableau() -> None: + """Test the StabilizerTableau class.""" + with pytest.raises(InvalidPauliError): + StabilizerTableau.from_pauli_strings([]) + + with pytest.raises(InvalidPauliError): + StabilizerTableau.from_paulis([]) + + m = SymplecticMatrix(np.array([[1, 0], [0, 1]])) + with pytest.raises(InvalidPauliError): + StabilizerTableau(m, np.array([1])) + + p1 = Pauli.from_pauli_string("XIZ") + p2 = Pauli.from_pauli_string("ZIX") + p3 = Pauli.from_pauli_string("IZX") + t1 = StabilizerTableau.from_paulis([p1, p2, p3]) + t2 = StabilizerTableau(np.array([[1, 0, 0, 0, 0, 1], [0, 0, 1, 1, 0, 0], [0, 0, 1, 0, 1, 0]]), np.array([0, 0, 0])) + assert t1 == t2 + + t3 = StabilizerTableau.from_pauli_strings(["ZII", "IZI", "IIZ"]) + assert t1 != t3 + + t4 = StabilizerTableau.from_pauli_strings(["ZII"]) + assert t1 != t4 + + assert t1 == [Pauli.from_pauli_string("XIZ"), Pauli.from_pauli_string("ZIX"), Pauli.from_pauli_string("IZX")] + assert len(t1) == 3 + + @pytest.fixture def rep_code_checks() -> tuple[npt.NDArray[np.int8] | None, npt.NDArray[np.int8] | None]: """Return the parity check matrices for the repetition code.""" - hx = np.array([[1, 1, 0], [0, 0, 1]]) + hx = np.array([[1, 1, 0], [0, 1, 1]]) hz = None return hx, hz @@ -38,12 +145,26 @@ def steane_code_checks() -> tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]: return hx, hz +@pytest.fixture +def steane_code() -> CSSCode: + """Return the Steane code.""" + hx = np.array([[1, 1, 1, 1, 0, 0, 0], [1, 0, 1, 0, 1, 0, 1], [0, 1, 1, 0, 1, 1, 0]]) + hz = hx + return CSSCode(distance=3, Hx=hx, Hz=hz) + + @pytest.fixture def five_qubit_code_stabs() -> list[str]: """Return the five qubit code.""" return ["XZZXI", "IXZZX", "XIXZZ", "ZXIXZ"] +@pytest.fixture +def five_qubit_code() -> StabilizerCode: + """Return the five qubit code.""" + return StabilizerCode(["XZZXI", "IXZZX", "XIXZZ", "ZXIXZ"], 3, z_logicals=["ZZZZZ"], x_logicals=["XXXXX"]) + + def test_invalid_css_codes() -> None: """Test that an invalid CSS code raises an error.""" # Violates CSS condition @@ -121,6 +242,35 @@ def test_errors(steane_code_checks: tuple[npt.NDArray[np.int8], npt.NDArray[np.i assert code.stabilizer_eq_z_error(e1, e4) +def test_rep_code(rep_code_checks: tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]) -> None: + """Test utility functions and correctness of the repetition code.""" + hx, hz = rep_code_checks + code = CSSCode(distance=1, Hx=hx, Hz=hz) + assert code.n == 3 + assert code.k == 1 + assert code.distance == 1 + assert not code.is_self_dual() + + e1 = np.array([1, 0, 0], dtype=np.int8) + e2 = np.array([0, 1, 0], dtype=np.int8) + e3 = np.array([0, 0, 1], dtype=np.int8) + assert np.array_equal(code.get_x_syndrome(e1), np.array([1, 0])) + assert np.array_equal(code.get_x_syndrome(e2), np.array([1, 1])) + assert np.array_equal(code.get_x_syndrome(e3), np.array([0, 1])) + + assert code.get_z_syndrome(e1).size == 0 + + assert code.check_if_logical_z_error((e1 + e2 + e3) % 2) + assert not code.check_if_x_stabilizer((e1 + e2 + e3) % 2) + assert code.check_if_x_stabilizer((e1 + e2) % 2) + assert not code.check_if_z_stabilizer((e1 + e2 + e3) % 2) + assert not code.check_if_z_stabilizer((e1 + e3) % 2) + + assert code.stabilizer_eq_x_error(e1, (e1 + e2 + e3) % 2) + assert not code.stabilizer_eq_z_error(e1, (e1 + e2 + e3) % 2) + assert code.stabilizer_eq_z_error(e1, e1) + + def test_steane(steane_code_checks: tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]) -> None: """Test utility functions and correctness of the Steane code.""" hx, hz = steane_code_checks @@ -138,13 +288,6 @@ def test_steane(steane_code_checks: tuple[npt.NDArray[np.int8], npt.NDArray[np.i assert x_paulis == ["XXXXIII", "XIXIXIX", "IXXIXXI"] assert z_paulis == ["ZZZZIII", "ZIZIZIZ", "IZZIZZI"] - x_log = code.x_logicals_as_pauli_strings()[0] - z_log = code.z_logicals_as_pauli_strings()[0] - assert x_log.count("X") == 3 - assert x_log.count("I") == 4 - assert z_log.count("Z") == 3 - assert z_log.count("I") == 4 - hx_reordered = hx[::-1, :] code_reordered = CSSCode(distance=3, Hx=hx_reordered, Hz=hz) assert code == code_reordered @@ -162,11 +305,11 @@ def test_bb_codes(n: int) -> None: def test_five_qubit_code(five_qubit_code_stabs: list[str]) -> None: """Test that the five qubit code is constructed as a valid stabilizer code.""" - Lz = ["ZZZZZ"] # noqa: N806 - Lx = ["XXXXX"] # noqa: N806 + z_logicals = ["ZZZZZ"] + x_logicals = ["XXXXX"] # Many assertions are already made in the constructor - code = StabilizerCode(five_qubit_code_stabs, distance=3, Lx=Lx, Lz=Lz) + code = StabilizerCode(five_qubit_code_stabs, distance=3, x_logicals=x_logicals, z_logicals=z_logicals) assert code.n == 5 assert code.k == 1 assert code.distance == 3 @@ -197,10 +340,14 @@ def test_stabilizer_sign() -> None: assert np.array_equal(syndrome, np.array([1, 0])) -def test_no_stabilizers() -> None: - """Test that an error is raised if no stabilizers are provided.""" - with pytest.raises(InvalidStabilizerCodeError): - StabilizerCode([]) +def test_trivial_code() -> None: + """Test code with no stabilizers.""" + code = StabilizerCode.get_trivial_code(3) + assert code.n == 3 + assert code.k == 3 + assert code.x_logicals == ["XII", "IXI", "IIX"] + assert code.z_logicals == ["ZII", "IZI", "IIZ"] + assert code.generators.n_rows == 0 def test_negative_distance() -> None: @@ -211,53 +358,109 @@ def test_negative_distance() -> None: def test_different_length_stabilizers() -> None: """Test that an error is raised if stabilizers have different lengths.""" - with pytest.raises(InvalidStabilizerCodeError): + with pytest.raises(InvalidPauliError): StabilizerCode(["ZZZZ", "X", "Y"]) def test_invalid_pauli_strings() -> None: """Test that invalid Pauli strings raise an error.""" - with pytest.raises(InvalidStabilizerCodeError): + with pytest.raises(InvalidPauliError): StabilizerCode(["ABCD", "XIXI", "YIYI"]) def test_no_x_logical() -> None: """Test that an error is raised if no X logical is provided when a Z logical is provided.""" with pytest.raises(InvalidStabilizerCodeError): - StabilizerCode(["ZZZZ", "XXXX"], Lx=["XXII"]) + StabilizerCode(["ZZZZ", "XXXX"], x_logicals=["XXII"]) def test_no_z_logical() -> None: """Test that an error is raised if no Z logical is provided when an X logical is provided.""" with pytest.raises(InvalidStabilizerCodeError): - StabilizerCode(["ZZZZ", "XXXX"], Lz=["ZZII"]) + StabilizerCode(["ZZZZ", "XXXX"], z_logicals=["ZZII"]) def test_logicals_wrong_length() -> None: """Test that an error is raised if the logicals have the wrong length.""" with pytest.raises(InvalidStabilizerCodeError): - StabilizerCode(["ZZZZ", "XXXX"], Lx=["XX"], Lz=["IZZI"]) + StabilizerCode(["ZZZZ", "XXXX"], x_logicals=["XX"], z_logicals=["IZZI"]) with pytest.raises(InvalidStabilizerCodeError): - StabilizerCode(["ZZZZ", "XXXX"], Lx=["IXXI"], Lz=["ZZ"]) + StabilizerCode(["ZZZZ", "XXXX"], x_logicals=["IXXI"], z_logicals=["ZZ"]) def test_commuting_logicals() -> None: """Test that an error is raised if the logicals commute.""" with pytest.raises(InvalidStabilizerCodeError): - StabilizerCode(["ZZZZ", "XXXX"], Lz=["ZZII"], Lx=["XXII"]) + StabilizerCode(["ZZZZ", "XXXX"], z_logicals=["ZZII"], x_logicals=["XXII"]) def test_anticommuting_logicals() -> None: """Test that an error is raised if the logicals anticommute with the stabilizer generators.""" with pytest.raises(InvalidStabilizerCodeError): - StabilizerCode(["ZZZZ", "XXXX"], Lz=["ZIII"], Lx=["IXXI"]) + StabilizerCode(["ZZZZ", "XXXX"], z_logicals=["ZIII"], x_logicals=["IXXI"]) with pytest.raises(InvalidStabilizerCodeError): - StabilizerCode(["ZZZZ", "XXXX"], Lz=["IZZI"], Lx=["XIII"]) + StabilizerCode(["ZZZZ", "XXXX"], z_logicals=["IZZI"], x_logicals=["XIII"]) def test_too_many_logicals() -> None: """Test that an error is raised if too many logicals are provided.""" with pytest.raises(InvalidStabilizerCodeError): - StabilizerCode(["ZZZZ", "XXXX"], Lz=["ZZII", "ZZII", "ZZII"], Lx=["IXXI"]) + StabilizerCode(["ZZZZ", "XXXX"], z_logicals=["ZZII", "ZZII", "ZZII"], x_logicals=["IXXI"]) with pytest.raises(InvalidStabilizerCodeError): - StabilizerCode(["ZZZZ", "XXXX"], Lz=["IZZI"], Lx=["XXII", "XXII", "XXII"]) + StabilizerCode(["ZZZZ", "XXXX"], z_logicals=["IZZI"], x_logicals=["XXII", "XXII", "XXII"]) + + +def test_trivial_concatenation(five_qubit_code: StabilizerCode) -> None: + """Test that the trivial concatenation of a code is the code itself.""" + inner_code = StabilizerCode.get_trivial_code(1) + concatenated = ConcatenatedCode(five_qubit_code, inner_code) + + assert concatenated.n == 5 + assert concatenated.k == 1 + assert concatenated.distance == 3 + assert concatenated == five_qubit_code + + +def test_trivial_css_concatenation(steane_code: CSSCode) -> None: + """Test that the trivial concatenation of a CSS code is the code itself.""" + inner_code = CSSCode.get_trivial_code(1) + concatenated = ConcatenatedCSSCode(steane_code, inner_code) + + assert concatenated.n == 7 + assert concatenated.k == 1 + assert concatenated.distance == 3 + assert concatenated == steane_code + + +def test_hamming_code() -> None: + """Test that the Hamming code is constructed as a valid CSS code.""" + code = construct_quantum_hamming_code(3) + assert code.n == 7 + assert code.k == 1 + assert code.distance == 3 + + +def test_many_hypercube_code_level_1() -> None: + """Test that the many-hypercube code.""" + code = construct_many_hypercube_code(1) + assert code.n == 6 + assert code.k == 4 + assert code.distance == 2 + iceberg = construct_iceberg_code(3) + assert code == iceberg + + +def test_many_hypercube_code_level_2() -> None: + """Test that the many-hypercube code.""" + code = construct_many_hypercube_code(2) + assert code.n == 36 + assert code.k == 16 + assert code.distance == 4 + + +def test_many_hypercube_code_level_3() -> None: + """Test that the many-hypercube code.""" + code = construct_many_hypercube_code(3) + assert code.n == 6**3 + assert code.k == 4**3 + assert code.distance == 2**3