Skip to content

Commit

Permalink
add BP-OSD (#22)
Browse files Browse the repository at this point in the history

Co-authored-by: Stefan Krastanov <[email protected]>
  • Loading branch information
royess and Krastanov authored Nov 27, 2024
1 parent d58ed11 commit 7b74dde
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 5 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# News

## v0.3.2 - 2024-11-15

- Add a (still unoptimized) implementation of a BP OSD decoder.

## Older - before 2021-10-28 unrecorded
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LDPCDecoders"
uuid = "3c486d74-64b9-4c60-8b1a-13a564e77efb"
authors = ["Krishna Praneet Gudipaty", "Stefan Krastanov", "QuantumSavory contributors"]
version = "0.3.1"
version = "0.3.2"

[deps]
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Expand Down
3 changes: 3 additions & 0 deletions src/LDPCDecoders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using RowEchelon
export
decode!, batchdecode!,
BeliefPropagationDecoder,
BeliefPropagationOSDDecoder,
BitFlipDecoder

include("generator.jl")
Expand All @@ -22,7 +23,9 @@ include("parity_generator.jl")

include("decoders/abstract_decoder.jl")
include("decoders/belief_propagation.jl")
include("decoders/belief_propagation_osd.jl")
include("decoders/iterative_bitflip.jl")

include("syndrome_bp_decoder.jl")
include("syndrome_simulator.jl")
include("syndrome_it_decoder.jl")
Expand Down
8 changes: 4 additions & 4 deletions src/decoders/belief_propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct BeliefPropagationDecoder <: AbstractDecoder
scratch::BeliefPropagationScratchSpace
end

function BeliefPropagationDecoder(H, per::Float64, max_iters::Int)
function BeliefPropagationDecoder(H::Union{SparseArrays.SparseMatrixCSC{Bool,Int}, BitMatrix}, per::Float64, max_iters::Int)
s, n = size(H)
sparse_H = sparse(H)
sparse_HT = sparse(H')
Expand Down Expand Up @@ -108,7 +108,7 @@ true
function decode!(decoder::BeliefPropagationDecoder, syndrome::AbstractVector) # TODO check if casting to bitarrays helps with performance -- if it does, set up warnings to the user for cases where they have not done the casting
reset!(decoder)
rows::Vector{Int} = rowvals(decoder.sparse_H);
rowsT::Vector{Int} = rowvals(decoder.sparse_HT);
rowsT::Vector{Int} = rowvals(decoder.sparse_HT);
setup = decoder.scratch

for j in 1:decoder.n
Expand Down Expand Up @@ -138,7 +138,7 @@ function decode!(decoder::BeliefPropagationDecoder, syndrome::AbstractVector) #

for j in 1:decoder.n
temp::Float64 = setup.channel_probs[j] / (1 - setup.channel_probs[j])

for k in nzrange(decoder.sparse_H, j)
setup.bit_2_check[rows[k],j] = temp
temp *= setup.check_2_bit[rows[k],j]
Expand Down Expand Up @@ -166,7 +166,7 @@ function decode!(decoder::BeliefPropagationDecoder, syndrome::AbstractVector) #

syndrome_decoded = (decoder.sparse_H * setup.err) .% 2
if all(syndrome_decoded .== syndrome)
converged = true
converged = true
break # Break if converged
end
end
Expand Down
113 changes: 113 additions & 0 deletions src/decoders/belief_propagation_osd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
struct BeliefPropagationOSDDecoder <: AbstractDecoder
"""A belief propagation decoder as a subroutine"""
bp_decoder::BeliefPropagationDecoder
"""Dense form of the parity check matrix"""
H::BitMatrix
"""The order of OSD; defaulted to be 0 in the constructor"""
osd_order::Int
end

function BeliefPropagationOSDDecoder(H::BitMatrix, per::Float64, max_iters::Int; osd_order::Int=0)
bp_decoder = BeliefPropagationDecoder(H, per, max_iters)
return BeliefPropagationOSDDecoder(bp_decoder, H, osd_order)
end

function rowswap!(H::BitMatrix, i, j)
@inbounds H[i, :], H[j, :] = H[j, :], H[i, :] # TODO This could be further optimized?
end

function decode!(decoder::BeliefPropagationOSDDecoder, syndrome::AbstractVector)
# use BP to get hard and soft decisions
bp_err, converged = decode!(decoder.bp_decoder, syndrome) # hard decisions
bp_log_probabs = decoder.bp_decoder.scratch.log_probabs # soft decisions
bp_probabs = exp.(bp_log_probabs)
# sort columns by reliability, less reliable columns first
sort_by_reliability = sortperm(max.(bp_probabs, 1 .- bp_probabs), rev=true)
H_sorted = decoder.H[:, sort_by_reliability]
bp_err_sorted = bp_err[sort_by_reliability]
# TODO an optimized version of OSD can be implemented when osd_order = 0, see Algorithm 2 in https://doi.org/10.22331/q-2021-11-22-585
err = osd(H_sorted, syndrome, bp_err_sorted, decoder.osd_order)
return err[invperm(sort_by_reliability)], converged # also return whether BP is converged
end

function osd(H, syndrome, bp_err, osd_order)
m, n = size(H)
# diagnolize the submatrix corresponding to independent columns via Gaussian elimination
# first obtain the row canonical form
# and find least reliable indices, i.e., the first r pivot columns (assume H is rearranged by reliability)
least_reliable_rows = [] # row indices of pivot elements
least_reliable_cols = [] # column indices of pivot elements
r = 0 # compute rank of H
i, j = 1, 1
s = copy(syndrome) # transform syndrome along with H in Gaussian elimination

while i <= m && j <= n
k = findfirst(H[i:end, j])
if isnothing(k) # not an independent column
j += 1
else
if k > 1
ii = i + k - 1 # the first row after `i` with 1 in column `j`
rowswap!(H, i, ii) # TODO For optimization: Is this swap necessary? We may just track the row index
s[i], s[ii] = s[ii], s[i]
end
for ii in i+1:m
if H[ii, j]
H[ii, :] .⊻= H[i, :]
s[ii] ⊻= s[i]
end
end
push!(least_reliable_rows, i)
push!(least_reliable_cols, j)
i += 1
j += 1
r += 1
end
end

# then obtain a diagonal submatrix on the least reliable part
for (i, j) in zip(reverse(least_reliable_rows), reverse(least_reliable_cols))
for ii in 1:i-1
if H[ii, j]
H[ii, :] .⊻= H[i, :]
s[ii] ⊻= s[i]
end
end
end

if osd_order > n - r
@warn "The order of OSD $osd_order is greater than the size of the information set $(n-r). We set osd_order = $(n-r)."
osd_order = n - r
end

best_err = copy(bp_err)
err = Bool.(copy(bp_err)) # TODO why error is in Float in BP?
most_reliable_cols = setdiff(1:n, least_reliable_cols)
min_weight = n + 1

for x in 0:2^osd_order-1
# first compute the `most_reliable_cols` part of errors
# try all possible errors on the first `osd_order` bits within `most_reliable_cols`
if x != 0
trial_err = BitArray([x >> i & 1 for i in 0:osd_order-1])
for j in 1:osd_order
err[most_reliable_cols[j]] = trial_err[j]
end
end
# then based on the `most_reliable_cols` part of errors, compute the `least_reliable_cols` part of errors
for (i, j) in zip(least_reliable_rows, least_reliable_cols)
err[j] = s[i]
for k in most_reliable_cols
err[j] ⊻= H[i, k] * err[k]
end
end
weight = sum(err) # This weight is set for depolarizing noise
# TODO More generally, it should be a function depending on the noise model
if weight < min_weight
min_weight = weight
best_err = copy(err)
end
end

return best_err
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ println("ENV[\"PYTHON\"] = \"$(get(ENV,"PYTHON",nothing))\"")

@doset "oldtests"
@doset "bp_decoder"
@doset "bposd_decoder"
@doset "bf_decoder"

VERSION >= v"1.10" && @doset "doctests"
Expand Down
53 changes: 53 additions & 0 deletions test/test_bposd_decoder.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using Test
using LDPCDecoders

@testset "test_bposd_decoder.jl" begin

"""Test for BP-OSD decoder"""
function test_bposd_decoder()
H = LDPCDecoders.parity_check_matrix(1000, 10, 9)
per = 0.01
err = rand(1000) .< per
syn = (H * err) .% 2

bposd = BeliefPropagationOSDDecoder(H, per, 100)
guess, success = decode!(bposd, syn)

return guess == err
end

"""Test high order OSD"""
function test_bposd_decoder_high_order()
H = LDPCDecoders.parity_check_matrix(1000, 10, 9)
per = 0.01
err = rand(1000) .< per
syn = (H * err) .% 2

orders = 2:5
succ = true
for osd_order in orders
bposd = BeliefPropagationOSDDecoder(H, per, 100; osd_order=osd_order)
guess, success = decode!(bposd, syn)
succ = succ & (guess == err)
end

return succ
end

"""Test for BP-OSD decoder with large error rate. Even if the decoding is not accurate, OSD will still ensure consistency between guess and syndromes."""
function test_bposd_decoder_large_error_rate()
H = LDPCDecoders.parity_check_matrix(1000, 10, 9)
per = 0.2
err = rand(1000) .< per
syn = (H * err) .% 2

bposd = BeliefPropagationOSDDecoder(H, per, 100)
guess, success = decode!(bposd, syn)

return syn == (H * guess) .% 2
end

@test test_bposd_decoder()
@test test_bposd_decoder_high_order()
@test test_bposd_decoder_large_error_rate()
end

0 comments on commit 7b74dde

Please sign in to comment.