From e5b7bd55972799640877d846b34e607cc8cf5bfe Mon Sep 17 00:00:00 2001 From: Feroz Ahmad Date: Fri, 27 Sep 2024 08:23:27 +0500 Subject: [PATCH] some performance optimizations in decoders (#369) --------- Co-authored-by: Stefan Krastanov --- .../QuantumCliffordLDPCDecodersExt.jl | 8 ++++---- .../QuantumCliffordPyQDecodersExt.jl | 13 +++++++------ src/ecc/decoder_pipeline.jl | 19 ++++++++++++++----- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/ext/QuantumCliffordLDPCDecodersExt/QuantumCliffordLDPCDecodersExt.jl b/ext/QuantumCliffordLDPCDecodersExt/QuantumCliffordLDPCDecodersExt.jl index 8c369cf82..e80c13af6 100644 --- a/ext/QuantumCliffordLDPCDecodersExt/QuantumCliffordLDPCDecodersExt.jl +++ b/ext/QuantumCliffordLDPCDecodersExt/QuantumCliffordLDPCDecodersExt.jl @@ -74,16 +74,16 @@ parity_checks(d::BeliefPropDecoder) = d.H parity_checks(d::BitFlipDecoder) = d.H function decode(d::BeliefPropDecoder, syndrome_sample) - row_x = syndrome_sample[1:d.cx] - row_z = syndrome_sample[d.cx+1:d.cx+d.cz] + row_x = @view syndrome_sample[1:d.cx] + row_z = @view syndrome_sample[d.cx+1:d.cx+d.cz] guess_z, success = LDPCDecoders.decode!(d.bpdecoderx, row_x) guess_x, success = LDPCDecoders.decode!(d.bpdecoderz, row_z) return vcat(guess_x, guess_z) end function decode(d::BitFlipDecoder, syndrome_sample) - row_x = syndrome_sample[1:d.cx] - row_z = syndrome_sample[d.cx+1:d.cx+d.cz] + row_x = @view syndrome_sample[1:d.cx] + row_z = @view syndrome_sample[d.cx+1:d.cx+d.cz] guess_z, success = LDPCDecoders.decode!(d.bfdecoderx, row_x) guess_x, success = LDPCDecoders.decode!(d.bfdecoderz, row_z) return vcat(guess_x, guess_z) diff --git a/ext/QuantumCliffordPyQDecodersExt/QuantumCliffordPyQDecodersExt.jl b/ext/QuantumCliffordPyQDecodersExt/QuantumCliffordPyQDecodersExt.jl index bfa557d05..e3496a733 100644 --- a/ext/QuantumCliffordPyQDecodersExt/QuantumCliffordPyQDecodersExt.jl +++ b/ext/QuantumCliffordPyQDecodersExt/QuantumCliffordPyQDecodersExt.jl @@ -69,8 +69,8 @@ end parity_checks(d::PyBP) = d.H function decode(d::PyBP, syndrome_sample) - row_x = syndrome_sample[1:d.nx] # TODO These copies and indirections might be costly! - row_z = syndrome_sample[d.nx+1:end] + row_x = @view syndrome_sample[1:d.nx] + row_z = @view syndrome_sample[d.nx+1:end] guess_z_errors = PythonCall.PyArray(d.pyx.decode(np.array(row_x))) guess_x_errors = PythonCall.PyArray(d.pyz.decode(np.array(row_z))) vcat(guess_x_errors, guess_z_errors) @@ -106,18 +106,19 @@ end parity_checks(d::PyMatchingDecoder) = d.H function decode(d::PyMatchingDecoder, syndrome_sample) - row_x = syndrome_sample[1:d.nx] # TODO This copy is costly! - row_z = syndrome_sample[d.nx+1:end] + row_x = @view syndrome_sample[1:d.nx] + row_z = @view syndrome_sample[d.nx+1:end] guess_z_errors = PythonCall.PyArray(d.pyx.decode(row_x)) guess_x_errors = PythonCall.PyArray(d.pyz.decode(row_z)) vcat(guess_x_errors, guess_z_errors) end function batchdecode(d::PyMatchingDecoder, syndrome_samples) - row_x = syndrome_samples[:,1:d.nx] # TODO This copy is costly! - row_z = syndrome_samples[:,d.nx+1:end] + row_x = @view syndrome_samples[:,1:d.nx] + row_z = @view syndrome_samples[:,d.nx+1:end] guess_z_errors = PythonCall.PyArray(d.pyx.decode_batch(row_x)) guess_x_errors = PythonCall.PyArray(d.pyz.decode_batch(row_z)) + n_cols_x = size(guess_x_errors, 2) hcat(guess_x_errors, guess_z_errors) end diff --git a/src/ecc/decoder_pipeline.jl b/src/ecc/decoder_pipeline.jl index 0488dd28a..3ac2680cf 100644 --- a/src/ecc/decoder_pipeline.jl +++ b/src/ecc/decoder_pipeline.jl @@ -171,12 +171,21 @@ end function evaluate_guesses(measured_faults, guesses, faults_matrix) nsamples = size(guesses, 1) - guess_faults = (faults_matrix * guesses') .% 2 # TODO this can be faster and non-allocating by turning it into a loop - decoded = 0 - for i in 1:nsamples # TODO this can be faster and non-allocating by having the loop and the matrix multiplication on the line above work together and not store anything - (@view guess_faults[:,i]) == (@view measured_faults[i,:]) && (decoded += 1) + fails = 0 + for i in 1:nsamples + for j in 1:size(faults_matrix, 1) + sum_mod = 0 + @inbounds @simd for k in 1:size(faults_matrix, 2) + sum_mod += faults_matrix[j, k] * guesses[i, k] + end + sum_mod %= 2 + if sum_mod != measured_faults[i, j] + fails += 1 + break + end + end end - return (nsamples - decoded) / nsamples + return fails / nsamples end function evaluate_decoder(d::AbstractSyndromeDecoder, setup::CommutationCheckECCSetup, nsamples::Int)