Skip to content

Commit

Permalink
remove hcat and vcat from decoders
Browse files Browse the repository at this point in the history
  • Loading branch information
Fe-r-oz committed Sep 26, 2024
1 parent 05491d9 commit 1ac7557
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,25 @@ parity_checks(d::BeliefPropDecoder) = d.H
parity_checks(d::BitFlipDecoder) = d.H

function decode(d::BeliefPropDecoder, syndrome_sample)
row_x = syndrome_sample[1:d.cx]
row_z = syndrome_sample[d.cx+1:d.cx+d.cz]
row_x = @view syndrome_sample[1:d.cx]
row_z = @view syndrome_sample[d.cx+1:d.cx+d.cz]
guess_z, success = LDPCDecoders.decode!(d.bpdecoderx, row_x)
guess_x, success = LDPCDecoders.decode!(d.bpdecoderz, row_z)
return vcat(guess_x, guess_z)
result = Matrix{Int}(undef, 2, length(guess_x))
@inbounds result[1, 1:length(guess_x)] .= guess_x
@inbounds result[2, 1:length(guess_z)] .= guess_z
return result
end

function decode(d::BitFlipDecoder, syndrome_sample)
row_x = syndrome_sample[1:d.cx]
row_z = syndrome_sample[d.cx+1:d.cx+d.cz]
row_x = @view syndrome_sample[1:d.cx]
row_z = @view syndrome_sample[d.cx+1:d.cx+d.cz]
guess_z, success = LDPCDecoders.decode!(d.bfdecoderx, row_x)
guess_x, success = LDPCDecoders.decode!(d.bfdecoderz, row_z)
return vcat(guess_x, guess_z)
result = Matrix{Int}(undef, 2, length(guess_x))
@inbounds result[1, 1:length(guess_x)] .= guess_x
@inbounds result[2, 1:length(guess_z)] .= guess_z
return result
end

end
28 changes: 19 additions & 9 deletions ext/QuantumCliffordPyQDecodersExt/QuantumCliffordPyQDecodersExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@ end
parity_checks(d::PyBP) = d.H

function decode(d::PyBP, syndrome_sample)
row_x = syndrome_sample[1:d.nx] # TODO These copies and indirections might be costly!
row_z = syndrome_sample[d.nx+1:end]
row_x = @view syndrome_sample[1:d.nx]
row_z = @view syndrome_sample[d.nx+1:end]
guess_z_errors = PythonCall.PyArray(d.pyx.decode(np.array(row_x)))
guess_x_errors = PythonCall.PyArray(d.pyz.decode(np.array(row_z)))
vcat(guess_x_errors, guess_z_errors)
result = Matrix{Int}(undef, 2, length(guess_x_errors))
@inbounds result[1, 1:length(guess_x_errors)] .= guess_x_errors
@inbounds result[2, 1:length(guess_z_errors)] .= guess_z_errors
return result
end

struct PyMatchingDecoder <: AbstractSyndromeDecoder # TODO all these decoders have the same fields, maybe we can factor out a common type
Expand Down Expand Up @@ -106,19 +109,26 @@ end
parity_checks(d::PyMatchingDecoder) = d.H

function decode(d::PyMatchingDecoder, syndrome_sample)
row_x = syndrome_sample[1:d.nx] # TODO This copy is costly!
row_z = syndrome_sample[d.nx+1:end]
row_x = @view syndrome_sample[1:d.nx]
row_z = @view syndrome_sample[d.nx+1:end]
guess_z_errors = PythonCall.PyArray(d.pyx.decode(row_x))
guess_x_errors = PythonCall.PyArray(d.pyz.decode(row_z))
vcat(guess_x_errors, guess_z_errors)
result = Matrix{Int}(undef, 2, length(guess_x_errors))
@inbounds result[1, 1:length(guess_x_errors)] .= guess_x_errors
@inbounds result[2, 1:length(guess_z_errors)] .= guess_z_errors
return result
end

function batchdecode(d::PyMatchingDecoder, syndrome_samples)
row_x = syndrome_samples[:,1:d.nx] # TODO This copy is costly!
row_z = syndrome_samples[:,d.nx+1:end]
row_x = @view syndrome_samples[:,1:d.nx]
row_z = @view syndrome_samples[:,d.nx+1:end]
guess_z_errors = PythonCall.PyArray(d.pyx.decode_batch(row_x))
guess_x_errors = PythonCall.PyArray(d.pyz.decode_batch(row_z))
hcat(guess_x_errors, guess_z_errors)
n_cols_x = size(guess_x_errors, 2)
result = Matrix{Int}(undef, size(guess_x_errors, 1), n_cols_x + size(guess_z_errors, 2))
@inbounds result[:,1:n_cols_x] .= guess_x_errors
@inbounds result[:,n_cols_x+1:end] .= guess_z_errors
return result
end

end

0 comments on commit 1ac7557

Please sign in to comment.