-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #137 from Julia-Tempering/aaps
AAPS implementation
- Loading branch information
Showing
13 changed files
with
383 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[deps] | ||
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
############################################################################### | ||
# The Pigeons implementation of AAPS is based on code by | ||
# Naitong Chen and Trevor Campbell (2023). Reused with their permission. | ||
############################################################################### | ||
|
||
""" | ||
$SIGNATURES | ||
The Apogee to Apogee Path Sampler (AAPS) by Sherlock et al. (2022). | ||
AAPS is a simple alternative to the No U-Turn Sampler (NUTS). | ||
It serves a similar purpose as NUTS: the method should be robust to its choice | ||
of tuning parameters when compared to standard HMC. | ||
For a given starting position and momentum (x, v), AAPS explores both forward and | ||
backward trajectories. The trajectories are divided into segments, with | ||
segments being separated by apogees (local maxima) in the energy landscape | ||
of -log pi(x). The tuning parameter `K` defines the number of segments to explore. | ||
""" | ||
Base.@kwdef struct AAPS{T,TPrec <: Preconditioner} | ||
""" | ||
Leapfrog step size. | ||
""" | ||
step_size::Float64 = 1.0 | ||
|
||
""" | ||
Maximum number of segments (regions between apogees) to explore. | ||
""" | ||
K::Int = 5 | ||
|
||
""" | ||
See details in AutoMALA. | ||
""" | ||
default_autodiff_backend::Symbol = :ForwardDiff | ||
|
||
""" | ||
A strategy for building a preconditioner. | ||
""" | ||
preconditioner::TPrec = MixDiagonalPreconditioner() | ||
|
||
""" | ||
This gets updated after first iteration; initially `nothing` in | ||
which case an identity mass matrix is used. | ||
""" | ||
estimated_target_std_deviations::T = nothing | ||
end | ||
|
||
function adapt_explorer(explorer::AAPS, reduced_recorders, current_pt, new_tempering) | ||
estimated_target_std_deviations = adapt_preconditioner(explorer.preconditioner, reduced_recorders) | ||
# TODO: adapt step_size and K | ||
return AAPS( | ||
explorer.step_size, explorer.K, explorer.default_autodiff_backend, | ||
explorer.preconditioner, estimated_target_std_deviations | ||
) | ||
end | ||
|
||
#= | ||
Extract info common to all types of target and perform a step!() | ||
=# | ||
function _extract_commons_and_run!(explorer::AAPS, replica, shared, log_potential, state::AbstractVector) | ||
log_potential_autodiff = ADgradient( | ||
explorer.default_autodiff_backend, log_potential, replica.recorders.buffers | ||
) | ||
aaps!( | ||
replica.rng, | ||
explorer, | ||
log_potential_autodiff, | ||
state, | ||
replica.recorders, | ||
replica.chain | ||
) | ||
end | ||
|
||
struct AAPSState{TV<:Vector{<:Real}} | ||
position::TV | ||
velocity::TV | ||
max_position::TV | ||
end | ||
|
||
function get_fwd_bwd_states(buffers, dim) | ||
fwd_state = AAPSState( | ||
get_buffer(buffers, :aaps_fwd_position_buffer, dim), | ||
get_buffer(buffers, :aaps_fwd_velocity_buffer, dim), | ||
get_buffer(buffers, :aaps_fwd_max_position_buffer, dim) | ||
) | ||
bwd_state = AAPSState( | ||
get_buffer(buffers, :aaps_bwd_position_buffer, dim), | ||
get_buffer(buffers, :aaps_bwd_velocity_buffer, dim), | ||
get_buffer(buffers, :aaps_bwd_max_position_buffer, dim) | ||
) | ||
fwd_state, bwd_state | ||
end | ||
|
||
""" | ||
Main function for AAPS. Note that this implementation uses scheme (1) | ||
from the AAPS paper, which results in an acceptance probability of one. | ||
""" | ||
function aaps!( | ||
rng::AbstractRNG, | ||
explorer::AAPS, | ||
target_log_potential, | ||
position::Vector, | ||
recorders, | ||
chain) | ||
|
||
# get buffers | ||
dim = length(position) | ||
diag_precond = get_buffer(recorders.buffers, :am_ones_buffer, dim) | ||
fwd_state, bwd_state = get_fwd_bwd_states(recorders.buffers, dim) | ||
|
||
# initialize | ||
build_preconditioner!( | ||
diag_precond, explorer.preconditioner, rng, explorer.estimated_target_std_deviations | ||
) | ||
copyto!(fwd_state.position, position) | ||
copyto!(bwd_state.position, position) # start bwd at same position -> requires skipping | ||
randn!(rng, fwd_state.velocity) # sample velocity ~ N(0,I) <=> sample momentum ~ N(0,diag_precond^2) | ||
bwd_state.velocity .= -1 .* fwd_state.velocity | ||
|
||
# find the initial segment by moving forward and backward | ||
fwd_wmax = sample_segment!(explorer, fwd_state, target_log_potential, rng, diag_precond) | ||
bwd_wmax = sample_segment!(explorer, bwd_state, target_log_potential, rng, diag_precond, skip_first=true) # avoids double counting initial state | ||
|
||
# update the Gumbel-max-trick decision | ||
if fwd_wmax > bwd_wmax | ||
wmax = fwd_wmax | ||
copyto!(position, fwd_state.max_position) | ||
else | ||
wmax = bwd_wmax | ||
copyto!(position, bwd_state.max_position) | ||
end | ||
|
||
# sample segments by continuing from the previous endpoints | ||
# note that K+1 segments are sampled in total, as in the original AAPS implementation | ||
# see https://github.com/ChrisGSherlock/AAPS/blob/c48c59d81031745cf08b6b3d3d9ad53287bf3b34/AAPS.cpp#L311 | ||
for _ in 1:explorer.K | ||
if rand(rng, Bool) # extend forward trajectory. avoids specifying in advance how many times we move forward/backward | ||
fwd_wmax = sample_segment!(explorer, fwd_state, target_log_potential, rng, diag_precond) | ||
if fwd_wmax > wmax | ||
wmax = fwd_wmax | ||
copyto!(position, fwd_state.max_position) | ||
end | ||
else | ||
bwd_wmax = sample_segment!(explorer, bwd_state, target_log_potential, rng, diag_precond) | ||
if bwd_wmax > wmax | ||
wmax = bwd_wmax | ||
copyto!(position, bwd_state.max_position) | ||
end | ||
end | ||
end | ||
# w(z,z') = exp(log_joint) => proposal always accepted | ||
# no need to update position, we work in place | ||
# TODO: accept/reject if other proposal is used | ||
end | ||
|
||
""" | ||
Sample a segment of the trajectory until an apogee is reached. | ||
""" | ||
function sample_segment!( | ||
explorer::AAPS, | ||
state::AAPSState, | ||
target_log_potential, | ||
rng::AbstractRNG, | ||
diag_precond::Vector; | ||
skip_first::Bool = false # avoid double counting starting state. same as try0 in https://github.com/ChrisGSherlock/AAPS/blob/c48c59d81031745cf08b6b3d3d9ad53287bf3b34/AAPS.cpp#L268 | ||
) | ||
logp, cgrad = conditioned_target_gradient(target_log_potential, state.position, diag_precond) | ||
copyto!(state.max_position, state.position) # reset max to the current position | ||
if skip_first | ||
ljoint = wmax = -typeof(logp)(Inf) | ||
else | ||
ljoint = log_joint(logp, state.velocity) | ||
wmax = ljoint + rand(rng, Gumbel()) | ||
end | ||
|
||
# propagate forward, checking for apogee, tracking stats, keeping track of next state using gumbel-max trick | ||
# note: since M is sym ⟹ p^T M^{-1} gradU = (M^{1/2}v)^T M^{-1} gradU = v^T M^{-1/2} gradU = -v^T cgrad | ||
# hence, p^T M^{-1} gradU > 0 ⟺ v^T cgrad < 0, and viceversa | ||
old_sign = sign(dot(state.velocity, cgrad)) | ||
while true | ||
leap_frog!( | ||
target_log_potential, diag_precond, state.position, state.velocity, | ||
explorer.step_size) | ||
logp, cgrad = conditioned_target_gradient(target_log_potential, state.position, diag_precond) | ||
new_sign = sign(dot(state.velocity, cgrad)) | ||
old_sign < 0 && new_sign > 0 && return wmax | ||
old_sign = new_sign | ||
ljoint = log_joint(logp, state.velocity) | ||
w = ljoint + rand(rng, Gumbel()) | ||
if w > wmax | ||
wmax = w | ||
copyto!(state.max_position, state.position) | ||
end | ||
end | ||
end | ||
|
||
function explorer_recorder_builders(explorer::AAPS) | ||
result = [explorer_acceptance_pr, explorer_n_steps, buffers] | ||
add_precond_recorder_if_needed!(result, explorer) | ||
return result | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
""" | ||
$SIGNATURES | ||
The Metropolis-Adjusted Langevin Algorithm (MALA). | ||
MALA is based on an approximation to overdamped Langevin dynamics followed by a | ||
Metropolis-Hastings correction to ensure that we target the correct distribution. | ||
This round-based version of MALA allows for the use of a preconditioner, | ||
which is updated after every PT tuning round. | ||
This setting can also be turned off by specifying the type of preconditioner to use. | ||
However, MALA will not automatically adjust the step size. | ||
For such functionality, use autoMALA. | ||
As for autoMALA, the number of steps per exploration is | ||
`base_n_refresh * ceil(Int, dim^exponent_n_refresh)`. | ||
""" | ||
@kwdef struct MALA{TPrec <: Preconditioner, T} | ||
""" | ||
The base number of steps (equivalently, momentum refreshments) between swaps. | ||
This base number gets multiplied by `ceil(Int, dim^(exponent_n_refresh))`. | ||
""" | ||
base_n_refresh::Int = 3 | ||
|
||
""" | ||
Used to scale the increase in number of refreshments with dimensionality. | ||
""" | ||
exponent_n_refresh::Float64 = 0.35 | ||
|
||
""" | ||
The default backend to use for autodiff. | ||
See https://github.com/tpapp/LogDensityProblemsAD.jl#backends | ||
Certain targets may ignore it, e.g. if a manual differential is | ||
offered or when calling an external program such as Stan. | ||
""" | ||
default_autodiff_backend::Symbol = :ForwardDiff | ||
|
||
""" | ||
Step size to use when approximating the Langevin dynamics. | ||
This is an important tuning parameter of MALA. This implementation of | ||
MALA does not automatically choose the step size so the user should | ||
select it carefully. | ||
""" | ||
step_size::Float64 = 1.0 | ||
|
||
""" | ||
A strategy for building a preconditioner. | ||
""" | ||
preconditioner::TPrec = MixDiagonalPreconditioner() | ||
|
||
""" | ||
This gets updated after the first tuning round; initially it is `nothing`, in | ||
which case an identity mass matrix is used for the preconditioner. | ||
""" | ||
estimated_target_std_deviations::T = nothing | ||
end | ||
|
||
function adapt_explorer(explorer::MALA, reduced_recorders, current_pt, new_tempering) | ||
estimated_target_std_deviations = adapt_preconditioner(explorer.preconditioner, reduced_recorders) | ||
return MALA( | ||
explorer.base_n_refresh, explorer.exponent_n_refresh, | ||
explorer.default_autodiff_backend, explorer.step_size, | ||
explorer.preconditioner, estimated_target_std_deviations) | ||
end | ||
|
||
# Extract info common to all types of target and perform a step!() | ||
function _extract_commons_and_run!(explorer::MALA, replica, shared, log_potential, state::AbstractVector) | ||
log_potential_autodiff = ADgradient(explorer.default_autodiff_backend, log_potential, replica.recorders.buffers) | ||
mala!(replica.rng, explorer, log_potential_autodiff, state, replica.recorders, replica.chain) | ||
end | ||
|
||
# The main exploration function for MALA | ||
function mala!(rng::AbstractRNG, explorer::MALA, target_log_potential, state::Vector, recorders, chain) | ||
dim = length(state) | ||
momentum = get_buffer(recorders.buffers, :mala_momentum_buffer, dim) | ||
diag_precond = get_buffer(recorders.buffers, :mala_ones_buffer, dim) | ||
build_preconditioner!(diag_precond, explorer.preconditioner, rng, explorer.estimated_target_std_deviations) | ||
start_state = get_buffer(recorders.buffers, :mala_state_buffer, dim) | ||
n_refresh = explorer.base_n_refresh * ceil(Int, dim^explorer.exponent_n_refresh) | ||
for i in 1:n_refresh | ||
start_state .= state | ||
randn!(rng, momentum) | ||
init_joint_log = log_joint(target_log_potential, state, momentum) | ||
@assert isfinite(init_joint_log) "MALA can only be called on a configuration of positive density." | ||
leap_frog!(target_log_potential, diag_precond, state, momentum, explorer.step_size) | ||
momentum .*= -1.0 # flip momentum (involution) | ||
final_joint_log = log_joint(target_log_potential, state, momentum) | ||
probability = min(1.0, exp(final_joint_log - init_joint_log)) | ||
@record_if_requested!(recorders, :explorer_acceptance_pr, (chain, probability)) | ||
if rand(rng) < probability # accept: nothing to do, we work in-place | ||
else # reject: go back to start state | ||
state .= start_state # momentum gets resampled at next iteration anyway | ||
end | ||
end | ||
end | ||
|
||
function explorer_recorder_builders(explorer::MALA) | ||
result = [explorer_acceptance_pr, explorer_n_steps, buffers] | ||
add_precond_recorder_if_needed!(result, explorer) | ||
return result | ||
end |
Oops, something went wrong.