Skip to content

Commit

Permalink
Merge pull request #137 from Julia-Tempering/aaps
Browse files Browse the repository at this point in the history
AAPS implementation
  • Loading branch information
miguelbiron authored Sep 28, 2023
2 parents acac53b + d40e696 commit 0607ecb
Show file tree
Hide file tree
Showing 13 changed files with 383 additions and 34 deletions.
2 changes: 2 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
2 changes: 1 addition & 1 deletion examples/pluto-demo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ PlutoUI = "~0.7.52"
PLUTO_MANIFEST_TOML_CONTENTS = """
# This file is machine-generated - editing it directly is not advised
julia_version = "1.9.2"
julia_version = "1.9.3"
manifest_format = "2.0"
project_hash = "e0ad5955baad6e9bbdcd9af95b539ba026f0b9d1"
Expand Down
2 changes: 1 addition & 1 deletion src/Pigeons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ export pigeons, Inputs, PT,
# variational references:
GaussianReference,
# samplers
SliceSampler, AutoMALA, Compose
SliceSampler, AutoMALA, Compose, AAPS, MALA
end # End module


200 changes: 200 additions & 0 deletions src/explorers/AAPS.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
###############################################################################
# The Pigeons implementation of AAPS is based on code by
# Naitong Chen and Trevor Campbell (2023). Reused with their permission.
###############################################################################

"""
$SIGNATURES
The Apogee to Apogee Path Sampler (AAPS) by Sherlock et al. (2022).
AAPS is a simple alternative to the No U-Turn Sampler (NUTS).
It serves a similar purpose as NUTS: the method should be robust to its choice
of tuning parameters when compared to standard HMC.
For a given starting position and momentum (x, v), AAPS explores both forward and
backward trajectories. The trajectories are divided into segments, with
segments being separated by apogees (local maxima) in the energy landscape
of -log pi(x). The tuning parameter `K` defines the number of segments to explore.
"""
Base.@kwdef struct AAPS{T,TPrec <: Preconditioner}
"""
Leapfrog step size.
"""
step_size::Float64 = 1.0

"""
Maximum number of segments (regions between apogees) to explore.
"""
K::Int = 5

"""
See details in AutoMALA.
"""
default_autodiff_backend::Symbol = :ForwardDiff

"""
A strategy for building a preconditioner.
"""
preconditioner::TPrec = MixDiagonalPreconditioner()

"""
This gets updated after first iteration; initially `nothing` in
which case an identity mass matrix is used.
"""
estimated_target_std_deviations::T = nothing
end

function adapt_explorer(explorer::AAPS, reduced_recorders, current_pt, new_tempering)
estimated_target_std_deviations = adapt_preconditioner(explorer.preconditioner, reduced_recorders)
# TODO: adapt step_size and K
return AAPS(
explorer.step_size, explorer.K, explorer.default_autodiff_backend,
explorer.preconditioner, estimated_target_std_deviations
)
end

#=
Extract info common to all types of target and perform a step!()
=#
function _extract_commons_and_run!(explorer::AAPS, replica, shared, log_potential, state::AbstractVector)
log_potential_autodiff = ADgradient(
explorer.default_autodiff_backend, log_potential, replica.recorders.buffers
)
aaps!(
replica.rng,
explorer,
log_potential_autodiff,
state,
replica.recorders,
replica.chain
)
end

struct AAPSState{TV<:Vector{<:Real}}
position::TV
velocity::TV
max_position::TV
end

function get_fwd_bwd_states(buffers, dim)
fwd_state = AAPSState(
get_buffer(buffers, :aaps_fwd_position_buffer, dim),
get_buffer(buffers, :aaps_fwd_velocity_buffer, dim),
get_buffer(buffers, :aaps_fwd_max_position_buffer, dim)
)
bwd_state = AAPSState(
get_buffer(buffers, :aaps_bwd_position_buffer, dim),
get_buffer(buffers, :aaps_bwd_velocity_buffer, dim),
get_buffer(buffers, :aaps_bwd_max_position_buffer, dim)
)
fwd_state, bwd_state
end

"""
Main function for AAPS. Note that this implementation uses scheme (1)
from the AAPS paper, which results in an acceptance probability of one.
"""
function aaps!(
rng::AbstractRNG,
explorer::AAPS,
target_log_potential,
position::Vector,
recorders,
chain)

# get buffers
dim = length(position)
diag_precond = get_buffer(recorders.buffers, :am_ones_buffer, dim)
fwd_state, bwd_state = get_fwd_bwd_states(recorders.buffers, dim)

# initialize
build_preconditioner!(
diag_precond, explorer.preconditioner, rng, explorer.estimated_target_std_deviations
)
copyto!(fwd_state.position, position)
copyto!(bwd_state.position, position) # start bwd at same position -> requires skipping
randn!(rng, fwd_state.velocity) # sample velocity ~ N(0,I) <=> sample momentum ~ N(0,diag_precond^2)
bwd_state.velocity .= -1 .* fwd_state.velocity

# find the initial segment by moving forward and backward
fwd_wmax = sample_segment!(explorer, fwd_state, target_log_potential, rng, diag_precond)
bwd_wmax = sample_segment!(explorer, bwd_state, target_log_potential, rng, diag_precond, skip_first=true) # avoids double counting initial state

# update the Gumbel-max-trick decision
if fwd_wmax > bwd_wmax
wmax = fwd_wmax
copyto!(position, fwd_state.max_position)
else
wmax = bwd_wmax
copyto!(position, bwd_state.max_position)
end

# sample segments by continuing from the previous endpoints
# note that K+1 segments are sampled in total, as in the original AAPS implementation
# see https://github.com/ChrisGSherlock/AAPS/blob/c48c59d81031745cf08b6b3d3d9ad53287bf3b34/AAPS.cpp#L311
for _ in 1:explorer.K
if rand(rng, Bool) # extend forward trajectory. avoids specifying in advance how many times we move forward/backward
fwd_wmax = sample_segment!(explorer, fwd_state, target_log_potential, rng, diag_precond)
if fwd_wmax > wmax
wmax = fwd_wmax
copyto!(position, fwd_state.max_position)
end
else
bwd_wmax = sample_segment!(explorer, bwd_state, target_log_potential, rng, diag_precond)
if bwd_wmax > wmax
wmax = bwd_wmax
copyto!(position, bwd_state.max_position)
end
end
end
# w(z,z') = exp(log_joint) => proposal always accepted
# no need to update position, we work in place
# TODO: accept/reject if other proposal is used
end

"""
Sample a segment of the trajectory until an apogee is reached.
"""
function sample_segment!(
explorer::AAPS,
state::AAPSState,
target_log_potential,
rng::AbstractRNG,
diag_precond::Vector;
skip_first::Bool = false # avoid double counting starting state. same as try0 in https://github.com/ChrisGSherlock/AAPS/blob/c48c59d81031745cf08b6b3d3d9ad53287bf3b34/AAPS.cpp#L268
)
logp, cgrad = conditioned_target_gradient(target_log_potential, state.position, diag_precond)
copyto!(state.max_position, state.position) # reset max to the current position
if skip_first
ljoint = wmax = -typeof(logp)(Inf)
else
ljoint = log_joint(logp, state.velocity)
wmax = ljoint + rand(rng, Gumbel())
end

# propagate forward, checking for apogee, tracking stats, keeping track of next state using gumbel-max trick
# note: since M is sym ⟹ p^T M^{-1} gradU = (M^{1/2}v)^T M^{-1} gradU = v^T M^{-1/2} gradU = -v^T cgrad
# hence, p^T M^{-1} gradU > 0 ⟺ v^T cgrad < 0, and viceversa
old_sign = sign(dot(state.velocity, cgrad))
while true
leap_frog!(
target_log_potential, diag_precond, state.position, state.velocity,
explorer.step_size)
logp, cgrad = conditioned_target_gradient(target_log_potential, state.position, diag_precond)
new_sign = sign(dot(state.velocity, cgrad))
old_sign < 0 && new_sign > 0 && return wmax
old_sign = new_sign
ljoint = log_joint(logp, state.velocity)
w = ljoint + rand(rng, Gumbel())
if w > wmax
wmax = w
copyto!(state.max_position, state.position)
end
end
end

function explorer_recorder_builders(explorer::AAPS)
result = [explorer_acceptance_pr, explorer_n_steps, buffers]
add_precond_recorder_if_needed!(result, explorer)
return result
end
30 changes: 3 additions & 27 deletions src/explorers/AutoMALA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,32 +73,10 @@ function adapt_explorer(explorer::AutoMALA, reduced_recorders, current_pt, new_t
estimated_target_std_deviations)
end

function step!(explorer::AutoMALA, replica, shared)
step!(explorer, replica, shared, replica.state)
end

### Dispatch on state for the behaviours for the different targets ###

step!(explorer::AutoMALA, replica, shared, state::StanState) =
step!(explorer, replica, shared, state.unconstrained_parameters)


function step!(explorer::AutoMALA, replica, shared, state::AbstractVector)
log_potential = find_log_potential(replica, shared.tempering, shared)
_extract_commons_and_run_auto_mala!(explorer, replica, shared, log_potential, state)
end

function step!(explorer::AutoMALA, replica, shared, vi::DynamicPPL.TypedVarInfo)
log_potential = find_log_potential(replica, shared.tempering, shared)
state = DynamicPPL.getall(vi)
_extract_commons_and_run_auto_mala!(explorer, replica, shared, log_potential, state)
DynamicPPL.setall!(replica.state, state)
end

#=
Extract info common to all types of target and perform a step!()
=#
function _extract_commons_and_run_auto_mala!(explorer::AutoMALA, replica, shared, log_potential, state::AbstractVector)
function _extract_commons_and_run!(explorer::AutoMALA, replica, shared, log_potential, state::AbstractVector)

log_potential_autodiff = ADgradient(explorer.default_autodiff_backend, log_potential, replica.recorders.buffers)
is_first_scan_of_round = shared.iterators.scan == 1
Expand Down Expand Up @@ -298,8 +276,6 @@ function explorer_recorder_builders(explorer::AutoMALA)
am_factors,
buffers
]
if explorer.preconditioner isa AdaptedDiagonalPreconditioner
push!(result, _transformed_online) # for mass matrix adaptation
end
add_precond_recorder_if_needed!(result, explorer)
return result
end
end
102 changes: 102 additions & 0 deletions src/explorers/MALA.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
$SIGNATURES
The Metropolis-Adjusted Langevin Algorithm (MALA).
MALA is based on an approximation to overdamped Langevin dynamics followed by a
Metropolis-Hastings correction to ensure that we target the correct distribution.
This round-based version of MALA allows for the use of a preconditioner,
which is updated after every PT tuning round.
This setting can also be turned off by specifying the type of preconditioner to use.
However, MALA will not automatically adjust the step size.
For such functionality, use autoMALA.
As for autoMALA, the number of steps per exploration is
`base_n_refresh * ceil(Int, dim^exponent_n_refresh)`.
"""
@kwdef struct MALA{TPrec <: Preconditioner, T}
"""
The base number of steps (equivalently, momentum refreshments) between swaps.
This base number gets multiplied by `ceil(Int, dim^(exponent_n_refresh))`.
"""
base_n_refresh::Int = 3

"""
Used to scale the increase in number of refreshments with dimensionality.
"""
exponent_n_refresh::Float64 = 0.35

"""
The default backend to use for autodiff.
See https://github.com/tpapp/LogDensityProblemsAD.jl#backends
Certain targets may ignore it, e.g. if a manual differential is
offered or when calling an external program such as Stan.
"""
default_autodiff_backend::Symbol = :ForwardDiff

"""
Step size to use when approximating the Langevin dynamics.
This is an important tuning parameter of MALA. This implementation of
MALA does not automatically choose the step size so the user should
select it carefully.
"""
step_size::Float64 = 1.0

"""
A strategy for building a preconditioner.
"""
preconditioner::TPrec = MixDiagonalPreconditioner()

"""
This gets updated after the first tuning round; initially it is `nothing`, in
which case an identity mass matrix is used for the preconditioner.
"""
estimated_target_std_deviations::T = nothing
end

function adapt_explorer(explorer::MALA, reduced_recorders, current_pt, new_tempering)
estimated_target_std_deviations = adapt_preconditioner(explorer.preconditioner, reduced_recorders)
return MALA(
explorer.base_n_refresh, explorer.exponent_n_refresh,
explorer.default_autodiff_backend, explorer.step_size,
explorer.preconditioner, estimated_target_std_deviations)
end

# Extract info common to all types of target and perform a step!()
function _extract_commons_and_run!(explorer::MALA, replica, shared, log_potential, state::AbstractVector)
log_potential_autodiff = ADgradient(explorer.default_autodiff_backend, log_potential, replica.recorders.buffers)
mala!(replica.rng, explorer, log_potential_autodiff, state, replica.recorders, replica.chain)
end

# The main exploration function for MALA
function mala!(rng::AbstractRNG, explorer::MALA, target_log_potential, state::Vector, recorders, chain)
dim = length(state)
momentum = get_buffer(recorders.buffers, :mala_momentum_buffer, dim)
diag_precond = get_buffer(recorders.buffers, :mala_ones_buffer, dim)
build_preconditioner!(diag_precond, explorer.preconditioner, rng, explorer.estimated_target_std_deviations)
start_state = get_buffer(recorders.buffers, :mala_state_buffer, dim)
n_refresh = explorer.base_n_refresh * ceil(Int, dim^explorer.exponent_n_refresh)
for i in 1:n_refresh
start_state .= state
randn!(rng, momentum)
init_joint_log = log_joint(target_log_potential, state, momentum)
@assert isfinite(init_joint_log) "MALA can only be called on a configuration of positive density."
leap_frog!(target_log_potential, diag_precond, state, momentum, explorer.step_size)
momentum .*= -1.0 # flip momentum (involution)
final_joint_log = log_joint(target_log_potential, state, momentum)
probability = min(1.0, exp(final_joint_log - init_joint_log))
@record_if_requested!(recorders, :explorer_acceptance_pr, (chain, probability))
if rand(rng) < probability # accept: nothing to do, we work in-place
else # reject: go back to start state
state .= start_state # momentum gets resampled at next iteration anyway
end
end
end

function explorer_recorder_builders(explorer::MALA)
result = [explorer_acceptance_pr, explorer_n_steps, buffers]
add_precond_recorder_if_needed!(result, explorer)
return result
end
Loading

0 comments on commit 0607ecb

Please sign in to comment.