Skip to content

Commit

Permalink
flesh out evaluate_decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Krastanov committed Jan 18, 2024
1 parent e834c31 commit 8437690
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 119 deletions.
2 changes: 2 additions & 0 deletions src/affectedqubits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ affectedqubits(p::PauliOperator) = 1:length(p)
affectedqubits(m::Union{AbstractMeasurement,sMRX,sMRY,sMRZ}) = (m.qubit,)
affectedqubits(v::VerifyOp) = v.indices
affectedqubits(c::CliffordOperator) = 1:nqubits(c)
affectedqubits(c::ClassicalXOR) = ()

affectedbits(o) = ()
affectedbits(m::sMRZ) = (m.bit,)
affectedbits(m::sMZ) = (m.bit,)
affectedbits(c::ClassicalXOR) = (c.bits..., c.store)
3 changes: 2 additions & 1 deletion src/ecc/ECC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ export parity_checks, code_n, code_s, code_k, rate, distance,
isdegenerate, faults_matrix,
naive_syndrome_circuit, shor_syndrome_circuit, naive_encoding_circuit,
CSS, Unicycle, Bicycle,
Shor9, Steane7, Cleve8, Perfect5, Bitflip3
Shor9, Steane7, Cleve8, Perfect5, Bitflip3,
evaluate_decoder, TableDecoder

"""Parity check tableau of a code."""
function parity_checks end
Expand Down
4 changes: 2 additions & 2 deletions src/ecc/circuits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ Use the `ancillary_index` and `bit_index` arguments to offset where the correspo
Ancillary qubits
Returns:
- The cat state preparation circuit.
- The ancillary cat state preparation circuit.
- The Shor syndrome measurement circuit.
- The number of ancillary qubits that were added.
- The list of bit indices that store the final measurement results.
Expand All @@ -165,7 +165,7 @@ and stores the measurement result into classical bits starting at `bit_index`.
The final measurement result is the XOR of all the bits.
Returns:
- The cat state preparation circuit.
- The ancillary cat state preparation circuit.
- The Shor syndrome measurement circuit.
- One more than the index of the last added ancillary qubit.
- One more than the index of the last added classical bit.
Expand Down
223 changes: 108 additions & 115 deletions src/ecc/decoder_pipeline.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,90 @@
"""An abstract type for QECC syndrome decoding algorithms.
All `AbstractSyndromeDecoder` types are expected to:
- have a `parity_checks` method giving the parity checks for the code under study
- have a `decode` method that guesses error which caused the syndrome
- have an `evaluate_decoder` method which runs a full simulation but it supports only a small number of ECC protocols"""
abstract type AbstractSyndromeDecoder end

function evaluate_decoder(d::AbstractSyndromeDecoder, nsamples, init_error, gate_error, syndrome_circuit_func, encoding_circuit_func)
pre_X = [sHadamard(i) for i in n-k+1:n]
X_error = evaluate_classical_decoder(d, nsamples, init_error, gate_error, syndrome_circuit_func, encoding_circuit_func, logicalxview, 1, d.k, pre_X)
Z_error = evaluate_classical_decoder(d, nsamples, init_error, gate_error, syndrome_circuit_func, encoding_circuit_func, logicalzview, d.k + 1, 2 * d.k)
return (X_error, Z_error)
"""An abstract type mostly used by [`evaluate_decoder`](@ref) to specify in what context to evaluate an ECC."""
abstract type AbstractECCSetup end

"""A helper function that takes a parity check tableau and an `AbstractECCSetup` type and provides the circuit that needs to be simulated."""
function physical_ECC_circuit end # XXX Do not export! This might need to be refactored as we add more interesting setups!

"""Configuration for ECC evaluators that simulate the Shor-style syndrome measurement (without a flag qubit).
The simulated circuit includes:
- perfect noiseless encoding (encoding and its fault tolerance are not being studied here)
- one round of "memory noise" after the encoding but before the syndrome measurement
- perfect preparation of entangled ancillary qubits
- noisy Shor-style syndrome measurement (only two-qubit gate noise)
- noiseless "logical state measurement" (providing the comparison data when evaluating the decoder)
"""
struct ShorSyndromeECCSetup <: AbstractECCSetup
mem_noise::Float64
two_qubit_gate_noise::Float64
function ShorSyndromeECCSetup(mem_noise, two_qubit_gate_noise)
0<=mem_noise<=1 || throw(DomainError(mem_noise, "The memory noise in `ShorSyndromeECCSetup` should be between 0 and 1."))
0<=two_qubit_gate_noise<=1 || throw(DomainError(two_qubit_gate_noise, "The two-qubit gate noise in `ShorSyndromeECCSetup` should be between 0 and 1."))
new(mem_noise, two_qubit_gate_noise)
end
end

function evaluate_classical_decoder(d::AbstractSyndromeDecoder, nsamples, init_error, gate_error, syndrome_circuit_func, encoding_circuit_func, logical_view_function, guess_start, guess_stop, pre_circuit = nothing)
H = d.H
O = d.faults_matrix
syndrome_circuit = syndrome_circuit_func(H)

n = d.n
s = d.s
k = d.k

errors = [PauliError(i, init_error) for i in 1:n];

md = MixedDestabilizer(H)

full_circuit = []

logview = logical_view_function(md)
logcirc, _ = syndrome_circuit_func(logview)
function physical_ECC_circuit(H, setup::ShorSyndromeECCSetup)
prep_anc, syndrome_circ, n_anc, syndrome_bits = shor_syndrome_circuit(H)
noisy_syndrome_circ = syndrome_circ # add_two_qubit_gate_noise(syndrome_circ, gate_error)
mem_error_circ = [PauliError(i, setup.mem_noise) for i in 1:nqubits(H)];
circ = [prep_anc..., mem_error_circ..., noisy_syndrome_circ...]
circ, syndrome_bits, n_anc
end

noisy_syndrome_circuit = add_two_qubit_gate_noise(syndrome_circuit, gate_error);
"""Evaluate the performance of a given decoder (e.g. [`TableDecoder`](@ref)) and a given style of running an ECC code (e.g. [`ShorSyndromeECCSetup`](@ref))"""
function evaluate_decoder(d::AbstractSyndromeDecoder, setup::AbstractECCSetup, nsamples::Int)
H = parity_checks(d)
n = code_n(H)
k = code_k(H)
O = faults_matrix(H)

physical_noisy_circ, syndrome_bits, n_anc = physical_ECC_circuit(H, setup)
encoding_circ = naive_encoding_circuit(H)
preX = [sHadamard(i) for i in n-k+1:n]

mdH = MixedDestabilizer(H)
logX_circ, _, logX_bits = naive_syndrome_circuit(logicalxview(mdH), n_anc+1, last(syndrome_bits)+1)
logZ_circ, _, logZ_bits = naive_syndrome_circuit(logicalzview(mdH), n_anc+1, last(syndrome_bits)+1)

X_error = evaluate_decoder(
d, nsamples,
[encoding_circ..., physical_noisy_circ..., logZ_circ...],
syndrome_bits, logZ_bits, O[length(logZ_bits)+1:end,:])
Z_error = evaluate_decoder(
d, nsamples,
[preX..., encoding_circ..., physical_noisy_circ..., logX_circ...],
syndrome_bits, logX_bits, O[1:length(logZ_bits),:])
return (X_error, Z_error)
end

for gate in logcirc
type = typeof(gate)
if type == sMRZ
push!(syndrome_circuit, sMRZ(gate.qubit+s, gate.bit+s))
else
push!(syndrome_circuit, type(gate.q1, gate.q2+s))
end
end
"""Evaluate the performance of an error-correcting circuit.
ecirc = encoding_circuit_func(syndrome_circuit)
if isnothing(pre_circuit)
full_circuit = vcat(pre_circuits, ecirc, errors, noisy_syndrome_circuit)
else
full_circuit = vcat(ecirc, errors, noisy_syndrome_circuit)
end
This method requires you give the circuit that performs both syndrome measurements and (probably noiseless) logical state measurements.
The faults matrix that translates an error vector into corresponding logical errors is necessary as well.
frames = PauliFrame(nframes, n+s+k, s+k)
pftrajectories(frames, full_circuit)
syndromes = pfmeasurements(frames)[:, 1:s]
logical_syndromes = pfmeasurements(frames)[:, s+1: s+k]
This is a relatively barebones method that assumes the user prepares necessary circuits, etc.
It is a method that is used internally by more user-frienly methods providing automatic conversion of codes and noise models
to the necessary noisy circuits.
"""
function evaluate_decoder(d::AbstractSyndromeDecoder, nsamples, circuit, syndrome_bits, logical_bits, faults_submatrix)
frames = pftrajectories(circuit;trajectories=nsamples,threads=true)

syndromes = @view pfmeasurements(frames)[:, syndrome_bits]
measured_faults = @view pfmeasurements(frames)[:, logical_bits]
decoded = 0
for i in 1:nsamples
guess = decode(d, syndromes[i])

# result should be concatinated guess of the X and Z checks
result = (O * (guess))[guess_start:guess_stop]

if result == logical_syndromes[i]
guess = decode(d, @view syndromes[i,:])
isnothing(guess) && continue
guess_faults = faults_submatrix * guess
if guess_faults == @view measured_faults[i,:]
decoded += 1
end
end
Expand All @@ -75,19 +105,42 @@ struct TableDecoder <: AbstractSyndromeDecoder
k
"""The lookup table corresponding to the code, slow to create"""
lookup_table
"""The time taken to create the lookup table + decode the code a specified number of time"""
time
end

function TableDecoder(Hx, Hz)
c = CSS(Hx, Hz)
function TableDecoder(c)
H = parity_checks(c)
s, n = size(H)
_, _, r = canonicalize!(Base.copy(H), ranks=true)
k = n - r
lookup_table, time, _ = @timed create_lookup_table(H)
faults_matrix = faults_matrix(H)
return TableDecoder(H, n, s, k, faults_matrix, lookup_table, time)
lookup_table = create_lookup_table(H)
fm = faults_matrix(H)
return TableDecoder(H, n, s, k, fm, lookup_table)
end

parity_checks(d::TableDecoder) = d.H

function create_lookup_table(code::Stabilizer)
lookup_table = Dict()
constraints, qubits = size(code)
# In the case of no errors
lookup_table[ zeros(UInt8, constraints) ] = stab_to_gf2(zero(PauliOperator, qubits))
# In the case of single bit errors
for bit_to_be_flipped in 1:qubits
for error_type in [single_x, single_y, single_z]
# Generate e⃗
error = error_type(qubits, bit_to_be_flipped)
# Calculate s⃗
# (check which stabilizer rows do not commute with the Pauli error)
syndrome = comm(error, code)
# Store s⃗ → e⃗
lookup_table[syndrome] = stab_to_gf2(error)
end
end
lookup_table
end;

function decode(d::TableDecoder, syndrome_sample)
return get(d.lookup_table, syndrome_sample, nothing)
end

struct BeliefPropDecoder <: AbstractSyndromeDecoder
Expand Down Expand Up @@ -156,9 +209,7 @@ function BeliefPropDecoder(Hx, Hz)
return BeliefPropDecoder(H, faults_matrix, n, s, k, log_probabs, channel_probs, numchecks_X, b2c_X, c2b_X, numchecks_Z, b2c_Z, c2b_Z, err, sparse_Cx, sparse_CxT, sparse_Cz, sparse_CzT)
end

function decode(d::TableDecoder, syndrome_sample)
return get(d.lookup_table, syndrome_sample, nothing)
end
parity_checks(d::BeliefPropDecoder) = d.H

function decode(d::BeliefPropDecoder, syndrome_sample)
row_x = syndrome_sample[1:d.numchecks_X]
Expand All @@ -168,61 +219,3 @@ function decode(d::BeliefPropDecoder, syndrome_sample)
KguessZ, success = syndrome_decode(d.sparse_Cz, d.sparse_CzT, d.row_z, d.max_iters, d.channel_probs, d.b2c_Z, d.c2b_Z, d.log_probabs, Base.copy(d.err))
guess = vcat(KguessZ, KguessX)
end


## NOT WORKING
function evaluate_classical_decoder(H, nsamples, init_error, gate_error, syndrome_circuit_func, encoding_circuit_func, logical_view_func, decoder_func, pre_circuit = nothing)
decoded = 0

H_stab = Stabilizer(fill(0x0, size(Hx, 2)), H, zeros(Bool, size(H)))

O = faults_matrix(H_stab)
syndrome_circuit = syndrome_circuit_func(H_stab)

s, n = size(H)
k = n - s

errors = [PauliError(i, init_error) for i in 1:n];

md = MixedDestabilizer(H_stab)

full_circuit = []

logview = logical_view_func(md)
logcirc, _ = syndrome_circuit_func(logview)

noisy_syndrome_circuit = add_two_qubit_gate_noise(syndrome_circuit, gate_error);

for gate in logcirc
type = typeof(gate)
if type == sMRZ
push!(circuit, sMRZ(gate.qubit+s, gate.bit+s))
else
push!(circuit, type(gate.q1, gate.q2+s))
end
end

ecirc = encoding_circuit_func(syndrome_circuit)
if isnothing(pre_circuit)
full_circuit = vcat(pre_circuits, ecirc, errors, noisy_syndrome_circuit)
else
full_circuit = vcat(ecirc, errors, noisy_syndrome_circuit)
end

frames = PauliFrame(nframes, n+s+k, s+k)
pftrajectories(frames, full_circuit)
syndromes = pfmeasurements(frames)[:, 1:s]
logical_syndromes = pfmeasurements(frames)[:, s+1: s+k]

for i in 1:nsamples
guess = decode(decoder_obj, syndromes[i]) # TODO: replace 'decoder_obj' with proper object

result = (O * (guess))

if result == logical_syndromes[i]
decoded += 1
end
end

return (nsamples - decoded) / nsamples
end
2 changes: 1 addition & 1 deletion src/pauli_frames.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ function pftrajectories(circuit;trajectories=5000,threads=true)
end

function _create_pauliframe(ccircuit; trajectories=5000)
qmax=maximum((maximum(affectedqubits(g)) for g in ccircuit))
qmax=maximum((maximum(affectedqubits(g),init=1) for g in ccircuit))
bmax=maximum((maximum(affectedbits(g),init=1) for g in ccircuit))
return PauliFrame(trajectories, qmax, bmax)
end
Expand Down

0 comments on commit 8437690

Please sign in to comment.