From 84376907811a2ac938157d56e6772fc2b5e4fb58 Mon Sep 17 00:00:00 2001 From: Stefan Krastanov Date: Wed, 17 Jan 2024 22:13:45 -0500 Subject: [PATCH] flesh out `evaluate_decoder` --- src/affectedqubits.jl | 2 + src/ecc/ECC.jl | 3 +- src/ecc/circuits.jl | 4 +- src/ecc/decoder_pipeline.jl | 223 +++++++++++++++++------------------- src/pauli_frames.jl | 2 +- 5 files changed, 115 insertions(+), 119 deletions(-) diff --git a/src/affectedqubits.jl b/src/affectedqubits.jl index 09c9217e6..da6314a65 100644 --- a/src/affectedqubits.jl +++ b/src/affectedqubits.jl @@ -12,7 +12,9 @@ affectedqubits(p::PauliOperator) = 1:length(p) affectedqubits(m::Union{AbstractMeasurement,sMRX,sMRY,sMRZ}) = (m.qubit,) affectedqubits(v::VerifyOp) = v.indices affectedqubits(c::CliffordOperator) = 1:nqubits(c) +affectedqubits(c::ClassicalXOR) = () affectedbits(o) = () affectedbits(m::sMRZ) = (m.bit,) affectedbits(m::sMZ) = (m.bit,) +affectedbits(c::ClassicalXOR) = (c.bits..., c.store) diff --git a/src/ecc/ECC.jl b/src/ecc/ECC.jl index 5854076d3..f9878d98e 100644 --- a/src/ecc/ECC.jl +++ b/src/ecc/ECC.jl @@ -16,7 +16,8 @@ export parity_checks, code_n, code_s, code_k, rate, distance, isdegenerate, faults_matrix, naive_syndrome_circuit, shor_syndrome_circuit, naive_encoding_circuit, CSS, Unicycle, Bicycle, - Shor9, Steane7, Cleve8, Perfect5, Bitflip3 + Shor9, Steane7, Cleve8, Perfect5, Bitflip3, + evaluate_decoder, TableDecoder """Parity check tableau of a code.""" function parity_checks end diff --git a/src/ecc/circuits.jl b/src/ecc/circuits.jl index 0f62a7057..f71cc6b99 100644 --- a/src/ecc/circuits.jl +++ b/src/ecc/circuits.jl @@ -142,7 +142,7 @@ Use the `ancillary_index` and `bit_index` arguments to offset where the correspo Ancillary qubits Returns: - - The cat state preparation circuit. + - The ancillary cat state preparation circuit. - The Shor syndrome measurement circuit. - The number of ancillary qubits that were added. - The list of bit indices that store the final measurement results. @@ -165,7 +165,7 @@ and stores the measurement result into classical bits starting at `bit_index`. The final measurement result is the XOR of all the bits. Returns: - - The cat state preparation circuit. + - The ancillary cat state preparation circuit. - The Shor syndrome measurement circuit. - One more than the index of the last added ancillary qubit. - One more than the index of the last added classical bit. diff --git a/src/ecc/decoder_pipeline.jl b/src/ecc/decoder_pipeline.jl index af3538b80..211f4b48a 100644 --- a/src/ecc/decoder_pipeline.jl +++ b/src/ecc/decoder_pipeline.jl @@ -1,60 +1,90 @@ +"""An abstract type for QECC syndrome decoding algorithms. + +All `AbstractSyndromeDecoder` types are expected to: +- have a `parity_checks` method giving the parity checks for the code under study +- have a `decode` method that guesses error which caused the syndrome +- have an `evaluate_decoder` method which runs a full simulation but it supports only a small number of ECC protocols""" abstract type AbstractSyndromeDecoder end -function evaluate_decoder(d::AbstractSyndromeDecoder, nsamples, init_error, gate_error, syndrome_circuit_func, encoding_circuit_func) - pre_X = [sHadamard(i) for i in n-k+1:n] - X_error = evaluate_classical_decoder(d, nsamples, init_error, gate_error, syndrome_circuit_func, encoding_circuit_func, logicalxview, 1, d.k, pre_X) - Z_error = evaluate_classical_decoder(d, nsamples, init_error, gate_error, syndrome_circuit_func, encoding_circuit_func, logicalzview, d.k + 1, 2 * d.k) - return (X_error, Z_error) +"""An abstract type mostly used by [`evaluate_decoder`](@ref) to specify in what context to evaluate an ECC.""" +abstract type AbstractECCSetup end + +"""A helper function that takes a parity check tableau and an `AbstractECCSetup` type and provides the circuit that needs to be simulated.""" +function physical_ECC_circuit end # XXX Do not export! This might need to be refactored as we add more interesting setups! + +"""Configuration for ECC evaluators that simulate the Shor-style syndrome measurement (without a flag qubit). + +The simulated circuit includes: +- perfect noiseless encoding (encoding and its fault tolerance are not being studied here) +- one round of "memory noise" after the encoding but before the syndrome measurement +- perfect preparation of entangled ancillary qubits +- noisy Shor-style syndrome measurement (only two-qubit gate noise) +- noiseless "logical state measurement" (providing the comparison data when evaluating the decoder) +""" +struct ShorSyndromeECCSetup <: AbstractECCSetup + mem_noise::Float64 + two_qubit_gate_noise::Float64 + function ShorSyndromeECCSetup(mem_noise, two_qubit_gate_noise) + 0<=mem_noise<=1 || throw(DomainError(mem_noise, "The memory noise in `ShorSyndromeECCSetup` should be between 0 and 1.")) + 0<=two_qubit_gate_noise<=1 || throw(DomainError(two_qubit_gate_noise, "The two-qubit gate noise in `ShorSyndromeECCSetup` should be between 0 and 1.")) + new(mem_noise, two_qubit_gate_noise) + end end -function evaluate_classical_decoder(d::AbstractSyndromeDecoder, nsamples, init_error, gate_error, syndrome_circuit_func, encoding_circuit_func, logical_view_function, guess_start, guess_stop, pre_circuit = nothing) - H = d.H - O = d.faults_matrix - syndrome_circuit = syndrome_circuit_func(H) - - n = d.n - s = d.s - k = d.k - - errors = [PauliError(i, init_error) for i in 1:n]; - - md = MixedDestabilizer(H) - - full_circuit = [] - - logview = logical_view_function(md) - logcirc, _ = syndrome_circuit_func(logview) +function physical_ECC_circuit(H, setup::ShorSyndromeECCSetup) + prep_anc, syndrome_circ, n_anc, syndrome_bits = shor_syndrome_circuit(H) + noisy_syndrome_circ = syndrome_circ # add_two_qubit_gate_noise(syndrome_circ, gate_error) + mem_error_circ = [PauliError(i, setup.mem_noise) for i in 1:nqubits(H)]; + circ = [prep_anc..., mem_error_circ..., noisy_syndrome_circ...] + circ, syndrome_bits, n_anc +end - noisy_syndrome_circuit = add_two_qubit_gate_noise(syndrome_circuit, gate_error); +"""Evaluate the performance of a given decoder (e.g. [`TableDecoder`](@ref)) and a given style of running an ECC code (e.g. [`ShorSyndromeECCSetup`](@ref))""" +function evaluate_decoder(d::AbstractSyndromeDecoder, setup::AbstractECCSetup, nsamples::Int) + H = parity_checks(d) + n = code_n(H) + k = code_k(H) + O = faults_matrix(H) + + physical_noisy_circ, syndrome_bits, n_anc = physical_ECC_circuit(H, setup) + encoding_circ = naive_encoding_circuit(H) + preX = [sHadamard(i) for i in n-k+1:n] + + mdH = MixedDestabilizer(H) + logX_circ, _, logX_bits = naive_syndrome_circuit(logicalxview(mdH), n_anc+1, last(syndrome_bits)+1) + logZ_circ, _, logZ_bits = naive_syndrome_circuit(logicalzview(mdH), n_anc+1, last(syndrome_bits)+1) + + X_error = evaluate_decoder( + d, nsamples, + [encoding_circ..., physical_noisy_circ..., logZ_circ...], + syndrome_bits, logZ_bits, O[length(logZ_bits)+1:end,:]) + Z_error = evaluate_decoder( + d, nsamples, + [preX..., encoding_circ..., physical_noisy_circ..., logX_circ...], + syndrome_bits, logX_bits, O[1:length(logZ_bits),:]) + return (X_error, Z_error) +end - for gate in logcirc - type = typeof(gate) - if type == sMRZ - push!(syndrome_circuit, sMRZ(gate.qubit+s, gate.bit+s)) - else - push!(syndrome_circuit, type(gate.q1, gate.q2+s)) - end - end +"""Evaluate the performance of an error-correcting circuit. - ecirc = encoding_circuit_func(syndrome_circuit) - if isnothing(pre_circuit) - full_circuit = vcat(pre_circuits, ecirc, errors, noisy_syndrome_circuit) - else - full_circuit = vcat(ecirc, errors, noisy_syndrome_circuit) - end +This method requires you give the circuit that performs both syndrome measurements and (probably noiseless) logical state measurements. +The faults matrix that translates an error vector into corresponding logical errors is necessary as well. - frames = PauliFrame(nframes, n+s+k, s+k) - pftrajectories(frames, full_circuit) - syndromes = pfmeasurements(frames)[:, 1:s] - logical_syndromes = pfmeasurements(frames)[:, s+1: s+k] +This is a relatively barebones method that assumes the user prepares necessary circuits, etc. +It is a method that is used internally by more user-frienly methods providing automatic conversion of codes and noise models +to the necessary noisy circuits. +""" +function evaluate_decoder(d::AbstractSyndromeDecoder, nsamples, circuit, syndrome_bits, logical_bits, faults_submatrix) + frames = pftrajectories(circuit;trajectories=nsamples,threads=true) + syndromes = @view pfmeasurements(frames)[:, syndrome_bits] + measured_faults = @view pfmeasurements(frames)[:, logical_bits] + decoded = 0 for i in 1:nsamples - guess = decode(d, syndromes[i]) - - # result should be concatinated guess of the X and Z checks - result = (O * (guess))[guess_start:guess_stop] - - if result == logical_syndromes[i] + guess = decode(d, @view syndromes[i,:]) + isnothing(guess) && continue + guess_faults = faults_submatrix * guess + if guess_faults == @view measured_faults[i,:] decoded += 1 end end @@ -75,19 +105,42 @@ struct TableDecoder <: AbstractSyndromeDecoder k """The lookup table corresponding to the code, slow to create""" lookup_table - """The time taken to create the lookup table + decode the code a specified number of time""" - time end -function TableDecoder(Hx, Hz) - c = CSS(Hx, Hz) +function TableDecoder(c) H = parity_checks(c) s, n = size(H) _, _, r = canonicalize!(Base.copy(H), ranks=true) k = n - r - lookup_table, time, _ = @timed create_lookup_table(H) - faults_matrix = faults_matrix(H) - return TableDecoder(H, n, s, k, faults_matrix, lookup_table, time) + lookup_table = create_lookup_table(H) + fm = faults_matrix(H) + return TableDecoder(H, n, s, k, fm, lookup_table) +end + +parity_checks(d::TableDecoder) = d.H + +function create_lookup_table(code::Stabilizer) + lookup_table = Dict() + constraints, qubits = size(code) + # In the case of no errors + lookup_table[ zeros(UInt8, constraints) ] = stab_to_gf2(zero(PauliOperator, qubits)) + # In the case of single bit errors + for bit_to_be_flipped in 1:qubits + for error_type in [single_x, single_y, single_z] + # Generate e⃗ + error = error_type(qubits, bit_to_be_flipped) + # Calculate s⃗ + # (check which stabilizer rows do not commute with the Pauli error) + syndrome = comm(error, code) + # Store s⃗ → e⃗ + lookup_table[syndrome] = stab_to_gf2(error) + end + end + lookup_table +end; + +function decode(d::TableDecoder, syndrome_sample) + return get(d.lookup_table, syndrome_sample, nothing) end struct BeliefPropDecoder <: AbstractSyndromeDecoder @@ -156,9 +209,7 @@ function BeliefPropDecoder(Hx, Hz) return BeliefPropDecoder(H, faults_matrix, n, s, k, log_probabs, channel_probs, numchecks_X, b2c_X, c2b_X, numchecks_Z, b2c_Z, c2b_Z, err, sparse_Cx, sparse_CxT, sparse_Cz, sparse_CzT) end -function decode(d::TableDecoder, syndrome_sample) - return get(d.lookup_table, syndrome_sample, nothing) -end +parity_checks(d::BeliefPropDecoder) = d.H function decode(d::BeliefPropDecoder, syndrome_sample) row_x = syndrome_sample[1:d.numchecks_X] @@ -168,61 +219,3 @@ function decode(d::BeliefPropDecoder, syndrome_sample) KguessZ, success = syndrome_decode(d.sparse_Cz, d.sparse_CzT, d.row_z, d.max_iters, d.channel_probs, d.b2c_Z, d.c2b_Z, d.log_probabs, Base.copy(d.err)) guess = vcat(KguessZ, KguessX) end - - -## NOT WORKING -function evaluate_classical_decoder(H, nsamples, init_error, gate_error, syndrome_circuit_func, encoding_circuit_func, logical_view_func, decoder_func, pre_circuit = nothing) - decoded = 0 - - H_stab = Stabilizer(fill(0x0, size(Hx, 2)), H, zeros(Bool, size(H))) - - O = faults_matrix(H_stab) - syndrome_circuit = syndrome_circuit_func(H_stab) - - s, n = size(H) - k = n - s - - errors = [PauliError(i, init_error) for i in 1:n]; - - md = MixedDestabilizer(H_stab) - - full_circuit = [] - - logview = logical_view_func(md) - logcirc, _ = syndrome_circuit_func(logview) - - noisy_syndrome_circuit = add_two_qubit_gate_noise(syndrome_circuit, gate_error); - - for gate in logcirc - type = typeof(gate) - if type == sMRZ - push!(circuit, sMRZ(gate.qubit+s, gate.bit+s)) - else - push!(circuit, type(gate.q1, gate.q2+s)) - end - end - - ecirc = encoding_circuit_func(syndrome_circuit) - if isnothing(pre_circuit) - full_circuit = vcat(pre_circuits, ecirc, errors, noisy_syndrome_circuit) - else - full_circuit = vcat(ecirc, errors, noisy_syndrome_circuit) - end - - frames = PauliFrame(nframes, n+s+k, s+k) - pftrajectories(frames, full_circuit) - syndromes = pfmeasurements(frames)[:, 1:s] - logical_syndromes = pfmeasurements(frames)[:, s+1: s+k] - - for i in 1:nsamples - guess = decode(decoder_obj, syndromes[i]) # TODO: replace 'decoder_obj' with proper object - - result = (O * (guess)) - - if result == logical_syndromes[i] - decoded += 1 - end - end - - return (nsamples - decoded) / nsamples -end diff --git a/src/pauli_frames.jl b/src/pauli_frames.jl index 71d5dc10b..461978032 100644 --- a/src/pauli_frames.jl +++ b/src/pauli_frames.jl @@ -164,7 +164,7 @@ function pftrajectories(circuit;trajectories=5000,threads=true) end function _create_pauliframe(ccircuit; trajectories=5000) - qmax=maximum((maximum(affectedqubits(g)) for g in ccircuit)) + qmax=maximum((maximum(affectedqubits(g),init=1) for g in ccircuit)) bmax=maximum((maximum(affectedbits(g),init=1) for g in ccircuit)) return PauliFrame(trajectories, qmax, bmax) end