Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SliceSampler: simplify code + more detailed errors #183

Merged
merged 2 commits into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions ext/PigeonsBridgeStanExt/state.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

Pigeons.continuous_variables(state::Pigeons.StanState) = Pigeons.SINGLETON_VAR # all Stan variables should be continuous
Pigeons.discrete_variables(state::Pigeons.StanState) = []
Pigeons.continuous_variables(::Pigeons.StanState) = Pigeons.SINGLETON_VAR # all Stan variables should be continuous
Pigeons.discrete_variables(::Pigeons.StanState) = []

Pigeons.extract_sample(state::Pigeons.StanState, log_potential) =
[
Expand All @@ -22,30 +22,27 @@ function Pigeons.variable(state::Pigeons.StanState, name::Symbol)
end
end

Pigeons.step!(explorer::Pigeons.HamiltonianSampler, replica, shared, state::Pigeons.StanState) =
Pigeons.step!(explorer, replica, shared, state.unconstrained_parameters)

Pigeons.sample_names(::Pigeons.StanState, log_potential) =
[
BridgeStan.param_names(Pigeons.stan_model(log_potential); include_tp = true, include_gq = true);
:log_density
]


function Pigeons.slice_sample!(h::SliceSampler, state::Pigeons.StanState, log_potential, cached_lp, replica)
cached_lp = Pigeons.cached_log_potential(log_potential, state, cached_lp)
for i in eachindex(state.unconstrained_parameters)
pointer = Ref(state.unconstrained_parameters, i)
cached_lp = Pigeons.slice_sample_coord!(h, replica, pointer, log_potential, cached_lp)
end
return cached_lp
end


#=
specialized equality checks
=#
Pigeons.recursive_equal(a::StanLogPotential, b::StanLogPotential) =
a.data == b.data && BridgeStan.name(a.model) == BridgeStan.name(b.model)
Pigeons.recursive_equal(a::StanRNG, b::StanRNG) = Pigeons._recursive_equal(a, b)

#=
explorer implementations
=#
Pigeons.slice_sample!(h::SliceSampler, state::Pigeons.StanState, args...) =
Pigeons.slice_sample!(h, state.unconstrained_parameters, args...)
Pigeons.step!(explorer::Pigeons.HamiltonianSampler, replica, shared, state::Pigeons.StanState) =
Pigeons.step!(explorer, replica, shared, state.unconstrained_parameters)

(log_potential::Pigeons.ScaledPrecisionNormalLogPotential)(x::Pigeons.StanState) = log_potential(x.unconstrained_parameters)
Random.rand!(rng::AbstractRNG, state::Pigeons.StanState{Vector{Float64}}, log_potential::Pigeons.ScaledPrecisionNormalLogPotential) =
rand!(rng, state.unconstrained_parameters, log_potential)
3 changes: 2 additions & 1 deletion ext/PigeonsDynamicPPLExt/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
DynamicPPL.logjoint(log_potential.model, vi)
end
catch e
(isa(e, DomainError) | isa(e, BoundsError)) ? -Inf : error("Unknown error in evaluation of the Turing log_potential.")
(isa(e, DomainError) || isa(e, BoundsError)) && return -Inf
rethrow(e)
end

"""
Expand Down
21 changes: 10 additions & 11 deletions ext/PigeonsDynamicPPLExt/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,24 @@ function Pigeons.sample_names(state::DynamicPPL.TypedVarInfo, _)
return result
end

function Pigeons.slice_sample!(h::SliceSampler, state::DynamicPPL.TypedVarInfo, log_potential, cached_lp, replica)
cached_lp = Pigeons.cached_log_potential(log_potential, state, cached_lp)
for i in 1:length(state.metadata)
for c in 1:length(state.metadata[i].vals)
pointer = Ref(state.metadata[i].vals, c)
cached_lp = Pigeons.slice_sample_coord!(h, replica, pointer, log_potential, cached_lp)
end
#=
explorer implementations
=#
function Pigeons.slice_sample!(h::SliceSampler, vi::DynamicPPL.TypedVarInfo, log_potential, cached_lp, replica)
for meta in vi.metadata
cached_lp = Pigeons.slice_sample!(h, meta.vals, log_potential, cached_lp, replica)
end
return cached_lp
end

function Pigeons.step!(explorer::Pigeons.HamiltonianSampler, replica, shared, vi::DynamicPPL.TypedVarInfo)
log_potential = find_log_potential(replica, shared.tempering, shared)
state = DynamicPPL.getall(vi)
_extract_commons_and_run!(explorer, replica, shared, log_potential, state)
Pigeons.step!(explorer, replica, shared, state)
DynamicPPL.setall!(replica.state, state)
end


#=
specialized equality checks
=#
Pigeons.recursive_equal(a::DynamicPPL.TypedVarInfo, b::DynamicPPL.TypedVarInfo) =
# as of Nov 2023, DynamicPPL does not supply == for TypedVarInfo
length(a.metadata) == length(b.metadata) &&
Expand Down
96 changes: 52 additions & 44 deletions src/explorers/SliceSampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,82 +16,89 @@ $FIELDS
n_passes::Int = 3

""" Maximum number of interations inside shrink_slice! before erroring out """
max_iter::Int = 4_096
max_iter::Int = 1_024 # == log2(prevfloat(Inf))
end

explorer_recorder_builders(::SliceSampler) = [explorer_acceptance_pr, explorer_n_steps]

function step!(explorer::SliceSampler, replica, shared)
log_potential = find_log_potential(replica, shared.tempering, shared)
cached_lp = -Inf
for i in 1:explorer.n_passes
for _ in 1:explorer.n_passes
cached_lp = slice_sample!(explorer, replica.state, log_potential, cached_lp, replica)
end
end

function cached_log_potential(log_potential, state, cached_lp)
cached_log_potential(log_potential, state, cached_lp) =
return if cached_lp == -Inf
result = log_potential(state)
if result == -Inf
error("SliceSampler supports contrained target, but the sampler should be initialized in the support: $state")
end
return result
result
else
cached_lp
end
end

function slice_sample!(h::SliceSampler, state::AbstractVector, log_potential, cached_lp, replica)
cached_lp = cached_log_potential(log_potential, state, cached_lp)
cached_lp = cached_log_potential(log_potential, replica.state, cached_lp) # note: we pass `replica.state` instead of `state` in case the latter is the vector version of a non-vector state (e.g. stan and dppl models)

# iterate over coordinates
for c in 1:length(state)
for c in eachindex(state)
pointer = Ref(state, c)
cached_lp = slice_sample_coord!(h, replica, pointer, log_potential, cached_lp)
end
return cached_lp
end



function slice_sample_coord!(h, replica, pointer, log_potential, cached_lp)
rng = replica.rng
if pointer[] isa Bool
cached_lp = Bernoulli_sample_coord!(replica, pointer, log_potential, cached_lp) # don't slice sample for {0,1} variables
else
z = cached_lp - rand(rng, Exponential(1.0)) # log(vertical draw)
L, R, lp_L, lp_R = slice_double(h, replica, z, pointer, log_potential)
cached_lp = slice_shrink!(h, replica, z, L, R, lp_L, lp_R, pointer, log_potential)
cached_lp = slice_sample_coord!(h, replica, pointer, log_potential, cached_lp, typeof(pointer[])) # note: when state is mixed, pointer is RefArray{generic common type} for all coordinates, so can't use it to dispatch

# check we still have a healthy state
if !isfinite(cached_lp)
error("""Got an invalid log density after updating state at index $c:
- log density = $cached_lp
- state[$c] = $(pointer[])
Dumping full replica state:
$(replica.state)
""")
end
end
return cached_lp
end

function Bernoulli_sample_coord!(replica, pointer, log_potential, cached_lp)
# handle Bools separately: sample from the full conditional (requires 1 density eval)
function slice_sample_coord!(h, replica, pointer, log_potential, cached_lp, ::Type{Bool})
state = replica.state
rng = replica.rng
if pointer[] == Bool(0)
lp0 = cached_lp
pointer[] = Bool(1)
lp1 = log_potential(state)
else
if pointer[] # currently true => already have lp1
lp1 = cached_lp
pointer[] = Bool(0)
pointer[] = false
lp0 = log_potential(state)
else # currently false => already have lp0
lp0 = cached_lp
pointer[] = true
lp1 = log_potential(state)
end

if rand(rng) < exp(lp0-lp1)/(1.0 + exp(lp0-lp1))
pointer[] = Bool(0)
prob_ratio = exp(lp1-lp0)
prob_zero = inv(1 + prob_ratio) # r = p1/p0 => p1 = p0r and p0 + p1=1 => p0(1+r) = 1 => p0=1/(1+r)
if rand(rng) < prob_zero
pointer[] = false
return lp0
else
pointer[] = Bool(1)
pointer[] = true
return lp1
end
end

# generic case: use slicing
function slice_sample_coord!(h, replica, pointer, log_potential, cached_lp, ::Type)
rng = replica.rng
z = cached_lp - randexp(rng) # log(vertical draw)
L, R, lp_L, lp_R = slice_double(h, replica, z, pointer, log_potential)
cached_lp = slice_shrink!(h, replica, z, L, R, lp_L, lp_R, pointer, log_potential)
return cached_lp
end

function slice_double(h::SliceSampler, replica, z, pointer, log_potential)
rng = replica.rng
state = replica.state
old_position = pointer[] # store old position while avoiding memory allocation
L, R = initialize_slice_endpoints(h.w, pointer, rng, typeof(pointer[])) # dispatch on either float or int
L, R = initialize_slice_endpoints(pointer[], h.w, rng)
K = h.p

pointer[] = L
Expand All @@ -110,29 +117,30 @@ function slice_double(h::SliceSampler, replica, z, pointer, log_potential)
pointer[] = R
potent_R = log_potential(state)
end
K = K - 1
K -= 1
end
@record_if_requested!(replica.recorders, :explorer_n_steps, (replica.chain, h.p - K))

pointer[] = old_position # return the state back to where it was before
return (L, R, potent_L, potent_R)
end

function initialize_slice_endpoints(width, pointer, rng, ::Type{T}) where T <: AbstractFloat
L = pointer[] - width * rand(rng)
# generic case
function initialize_slice_endpoints(current, width, rng)
L = current - width * rand(rng)
R = L + width
return (L, R)
end

function initialize_slice_endpoints(width, pointer, rng, ::Type{T}) where T <: Integer
width = convert(T, ceil(width))
L = pointer[] - rand(rng, 0:width)
# handle integers separately
function initialize_slice_endpoints(current::T, width, rng) where {T<:Integer}
width = ceil(T, width)
L = current - rand(rng, 0:width)
R = L + width
return (L, R)
end

function slice_shrink!(h::SliceSampler, replica, z, L, R, lp_L, lp_R, pointer, log_potential)
@assert isfinite(z)
rng = replica.rng
state = replica.state
old_position = pointer[]
Expand All @@ -142,7 +150,7 @@ function slice_shrink!(h::SliceSampler, replica, z, L, R, lp_L, lp_R, pointer, l
n = 1

while n <= h.max_iter
new_position = draw_new_position(Lbar, Rbar, rng, typeof(pointer[]))
new_position = draw_new_position(Lbar, Rbar, rng)
pointer[] = new_position
new_lp = log_potential(state)
consider = z < new_lp
Expand Down Expand Up @@ -170,8 +178,8 @@ function slice_shrink!(h::SliceSampler, replica, z, L, R, lp_L, lp_R, pointer, l
return 0.0
end

draw_new_position(L, R, rng, ::Type{T}) where T <: AbstractFloat = L + rand(rng) * (R-L)
draw_new_position(L, R, rng, ::Type{T}) where T <: Integer = rand(rng, L:R)
draw_new_position(L, R, rng) = L + rand(rng) * (R-L)
draw_new_position(L::Integer, R::Integer, rng) = rand(rng, L:R)


function slice_accept(h::SliceSampler, replica, new_position, z, L, R, lp_L, lp_R, pointer, log_potential)
Expand Down
29 changes: 25 additions & 4 deletions test/test_slice_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,23 @@ Pigeons.initialization(log_potential::UnitInterval, ::AbstractRNG, ::Int) = [0.5
pigeons(target = UnitInterval(true))
end

@testset "Check inf potential throws" begin
log_potential(x::AbstractVector) = log_potential(first(x))
log_potential(x) = iszero(x) ? x : Inf
state = [0.0]
cached_lp = -Inf
replica = Replica(state, 1, SplittableRandom(1), (;), 1)
@test_throws ErrorException slice_sample!(SliceSampler(), state, log_potential, cached_lp, replica)
end

@testset "Check slice_shrink! throws on unattainable z level" begin
log_potential(x) = zero(eltype(x))
state = [0.0]
cached_lp = prevfloat(Inf)
replica = Replica(state, 1, SplittableRandom(1), (;), 1)
@test_throws ErrorException slice_sample!(SliceSampler(), state, log_potential, cached_lp, replica)
end

include("supporting/turing_models.jl")

function test_slice_sampler_logprob_counts()
Expand Down Expand Up @@ -42,9 +59,13 @@ end

function test_slice_sampler_vector()
rng = SplittableRandom(1)
log_potential = (x) -> logpdf(Bernoulli(0.5), x[1]) + logpdf(Normal(0.0, 1.0), x[2])
log_potential(x) = begin
logpdf(Bernoulli(0.5), first(x)) +
logpdf(Binomial(10), x[2]) +
logpdf(Normal(0.0, 1.0), last(x))
end
h = SliceSampler()
state = Number[0, 0.0]
state = Number[false, 0, 0.0]
n = 1000
states = Vector{typeof(state)}(undef, n)
cached_lp = -Inf
Expand All @@ -53,8 +74,8 @@ function test_slice_sampler_vector()
cached_lp = slice_sample!(h, state, log_potential, cached_lp, replica)
states[i] = copy(state)
end
@test all(abs.(mean(states) - [0.5, 0.0]) .≤ 0.2)
@test all(abs.(std(states) - [0.5, 1.0]) .≤ 0.2)
@test all(abs.(mean(states) - [0.5, 5.0, 0.0]) .≤ 0.2)
@test all(abs.(std(states) - [0.5, std(Binomial(10)), 1.0]) .≤ 0.2)
end

function test_slice_sampler_Turing()
Expand Down
Loading