Skip to content

Commit

Permalink
some performance optimizations in decoders (#369)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Stefan Krastanov <[email protected]>
  • Loading branch information
Fe-r-oz and Krastanov authored Sep 27, 2024
1 parent 6431659 commit e5b7bd5
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ parity_checks(d::BeliefPropDecoder) = d.H
parity_checks(d::BitFlipDecoder) = d.H

function decode(d::BeliefPropDecoder, syndrome_sample)
row_x = syndrome_sample[1:d.cx]
row_z = syndrome_sample[d.cx+1:d.cx+d.cz]
row_x = @view syndrome_sample[1:d.cx]
row_z = @view syndrome_sample[d.cx+1:d.cx+d.cz]
guess_z, success = LDPCDecoders.decode!(d.bpdecoderx, row_x)
guess_x, success = LDPCDecoders.decode!(d.bpdecoderz, row_z)
return vcat(guess_x, guess_z)
end

function decode(d::BitFlipDecoder, syndrome_sample)
row_x = syndrome_sample[1:d.cx]
row_z = syndrome_sample[d.cx+1:d.cx+d.cz]
row_x = @view syndrome_sample[1:d.cx]
row_z = @view syndrome_sample[d.cx+1:d.cx+d.cz]
guess_z, success = LDPCDecoders.decode!(d.bfdecoderx, row_x)
guess_x, success = LDPCDecoders.decode!(d.bfdecoderz, row_z)
return vcat(guess_x, guess_z)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ end
parity_checks(d::PyBP) = d.H

function decode(d::PyBP, syndrome_sample)
row_x = syndrome_sample[1:d.nx] # TODO These copies and indirections might be costly!
row_z = syndrome_sample[d.nx+1:end]
row_x = @view syndrome_sample[1:d.nx]
row_z = @view syndrome_sample[d.nx+1:end]
guess_z_errors = PythonCall.PyArray(d.pyx.decode(np.array(row_x)))
guess_x_errors = PythonCall.PyArray(d.pyz.decode(np.array(row_z)))
vcat(guess_x_errors, guess_z_errors)
Expand Down Expand Up @@ -106,18 +106,19 @@ end
parity_checks(d::PyMatchingDecoder) = d.H

function decode(d::PyMatchingDecoder, syndrome_sample)
row_x = syndrome_sample[1:d.nx] # TODO This copy is costly!
row_z = syndrome_sample[d.nx+1:end]
row_x = @view syndrome_sample[1:d.nx]
row_z = @view syndrome_sample[d.nx+1:end]
guess_z_errors = PythonCall.PyArray(d.pyx.decode(row_x))
guess_x_errors = PythonCall.PyArray(d.pyz.decode(row_z))
vcat(guess_x_errors, guess_z_errors)
end

function batchdecode(d::PyMatchingDecoder, syndrome_samples)
row_x = syndrome_samples[:,1:d.nx] # TODO This copy is costly!
row_z = syndrome_samples[:,d.nx+1:end]
row_x = @view syndrome_samples[:,1:d.nx]
row_z = @view syndrome_samples[:,d.nx+1:end]
guess_z_errors = PythonCall.PyArray(d.pyx.decode_batch(row_x))
guess_x_errors = PythonCall.PyArray(d.pyz.decode_batch(row_z))
n_cols_x = size(guess_x_errors, 2)
hcat(guess_x_errors, guess_z_errors)
end

Expand Down
19 changes: 14 additions & 5 deletions src/ecc/decoder_pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,21 @@ end

function evaluate_guesses(measured_faults, guesses, faults_matrix)
nsamples = size(guesses, 1)
guess_faults = (faults_matrix * guesses') .% 2 # TODO this can be faster and non-allocating by turning it into a loop
decoded = 0
for i in 1:nsamples # TODO this can be faster and non-allocating by having the loop and the matrix multiplication on the line above work together and not store anything
(@view guess_faults[:,i]) == (@view measured_faults[i,:]) && (decoded += 1)
fails = 0
for i in 1:nsamples
for j in 1:size(faults_matrix, 1)
sum_mod = 0
@inbounds @simd for k in 1:size(faults_matrix, 2)
sum_mod += faults_matrix[j, k] * guesses[i, k]
end
sum_mod %= 2
if sum_mod != measured_faults[i, j]
fails += 1
break
end
end
end
return (nsamples - decoded) / nsamples
return fails / nsamples
end

function evaluate_decoder(d::AbstractSyndromeDecoder, setup::CommutationCheckECCSetup, nsamples::Int)
Expand Down

0 comments on commit e5b7bd5

Please sign in to comment.