Skip to content

Commit

Permalink
Merge pull request #135 from Julia-Tempering/fix-leapfrog
Browse files Browse the repository at this point in the history
fix leapfrog
  • Loading branch information
miguelbiron authored Sep 23, 2023
2 parents d97ba35 + 1a59ec8 commit cf76ee7
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 41 deletions.
12 changes: 8 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ jobs:
arch:
- x64
steps:
- uses: actions/checkout@v3
- name: Enable long paths on windows
if: ${{ startsWith(matrix.os, 'windows') }}
run: |
REG ADD "HKLM\SYSTEM\CurrentControlSet\Control\FileSystem" /v LongPathsEnabled /t REG_DWORD /d 1 /f
- uses: actions/checkout@v4
- uses: actions/setup-java@v3
with:
distribution: 'temurin'
Expand Down Expand Up @@ -70,7 +74,7 @@ jobs:
JULIA_MPI_TEST_ABI: OpenMPI
steps:
- name: Checkout
uses: actions/checkout@v3
uses: actions/checkout@v4

- uses: julia-actions/setup-julia@latest
with:
Expand Down Expand Up @@ -110,7 +114,7 @@ jobs:
JULIA_MPI_TEST_BINARY: system
ZES_ENABLE_SYSMAN: 1 # https://github.com/open-mpi/ompi/issues/10142
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Install MPI via homebrew
run: brew install $MPI
Expand Down Expand Up @@ -148,7 +152,7 @@ jobs:
permissions:
contents: write
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: actions/setup-java@v3
with:
distribution: 'temurin'
Expand Down
33 changes: 24 additions & 9 deletions src/explorers/hamiltonian_dynamics.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
###############################################################################
# Utilities for Hamiltonian Dynamics
# Note on working on transformed momentum space: when the momentum follows
# p ~ N(0,M)
# for some positive definite M, the Hamiltonian is
# H(x,p) = -log(pi(x)) + (1/2)p^T M^{-1} p
# The corresponding leapfrog update is
# p*(x,p) = p + (eps/2)grad(log pi)(x)
# x'(x,p*) = x + eps M^{-1}p*
# p'(x',p*) = p* + (eps/2)grad(log pi)(x')
# We work instead with the transformed momentum
# y = M^{-1/2}p => y ~ N(0,I)
# Then, replacing p by M^{1/2}y above gives the modified Leapfrog
# y*(x,y) = y + (eps/2)M^{-1/2}grad(log pi)(x)
# x'(x,y*) = x + eps M^{-1/2}y*
# y'(x',y*) = y* + (eps/2)M^{-1/2}grad(log pi)(x')
# The function `conditioned_target_gradient` returns M^{-1/2}grad(log pi)(x)
###############################################################################

log_joint(target, state, momentum) = log_joint(LogDensityProblems.logdensity(target, state), momentum)
log_joint(logp, momentum) = logp - 0.5 * sqr_norm(momentum)

# We use an implicit linear transformation rescaling
# component i with 1/estimated_target_std_dev[i]
# and use an isotropic normal momentum.
# This is equivalent to having a "mass matrix" in HMC jargon.
function conditioned_target_gradient(target_log_potential, state, estimated_target_std_dev)
logdens, grad = LogDensityProblems.logdensity_and_gradient(target_log_potential, state)
grad .= grad .* estimated_target_std_dev
grad ./= estimated_target_std_dev # M^{-1/2}grad(log pi)(x)
return logdens, grad
end

Expand All @@ -20,12 +35,12 @@ function hamiltonian_dynamics!(

# first half-step
_, grad = conditioned_target_gradient(target_log_potential, state, estimated_target_std_dev)
momentum .= momentum .+ (step_size/2) .* grad
momentum .+= (step_size/2) .* grad

for i in 1:n_steps

# full step on position
state .= state .+ step_size .* momentum .* estimated_target_std_dev
state .+= step_size .* (momentum ./ estimated_target_std_dev) # eps M^{-1/2}y*

logp, grad = conditioned_target_gradient(target_log_potential, state, estimated_target_std_dev)

Expand All @@ -36,12 +51,12 @@ function hamiltonian_dynamics!(

# Neal's trick to merge successive half-steps
if i != n_steps
momentum .= momentum .+ step_size .* grad
momentum .+= step_size .* grad
end
end

# last half-step
momentum .= momentum .+ (step_size/2) .* grad
momentum .+= (step_size/2) .* grad

if !isfinite(sqr_norm(momentum))
return false
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
include("supporting/setup.jl")

is_windows_in_CI() = Sys.iswindows() && (get(ENV, "CI", "false") == "true")

# check we are testing the checked-out version of the repo, not e.g. latest released version
test_dir = @__DIR__
@assert basename(test_dir) == "test"
Expand Down
26 changes: 14 additions & 12 deletions test/test_allocs.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
@testset "Allocs-Stan" begin
#=
Despite best effort, just can't get automala + stan bridge
all the way to zero allocs in the inner loop. See e.g.
9d85645b3422260043a59b89d049f54d782f76bc
However what allocs are left are now dimension-independent.
This checks that a 100-fold increase in dim only increases allocs
by a small factor.
=#
allocs_1d = Pigeons.last_round_max_allocation(pigeons(variational = GaussianReference(), n_chains = 1, n_rounds = 10, target = Pigeons.toy_stan_target(1), explorer = AutoMALA(exponent_n_refresh = 0.0)))
allocs_100d = Pigeons.last_round_max_allocation(pigeons(variational = GaussianReference(), n_chains = 1, n_rounds = 10, target = Pigeons.toy_stan_target(100), explorer = AutoMALA(exponent_n_refresh = 0.0)))
@static if !is_windows_in_CI()
@testset "Allocs-Stan" begin
#=
Despite best effort, just can't get automala + stan bridge
all the way to zero allocs in the inner loop. See e.g.
9d85645b3422260043a59b89d049f54d782f76bc
However what allocs are left are now dimension-independent.
This checks that a 100-fold increase in dim only increases allocs
by a small factor.
=#
allocs_1d = Pigeons.last_round_max_allocation(pigeons(variational = GaussianReference(), n_chains = 1, n_rounds = 10, target = Pigeons.toy_stan_target(1), explorer = AutoMALA(exponent_n_refresh = 0.0)))
allocs_100d = Pigeons.last_round_max_allocation(pigeons(variational = GaussianReference(), n_chains = 1, n_rounds = 10, target = Pigeons.toy_stan_target(100), explorer = AutoMALA(exponent_n_refresh = 0.0)))

@test abs(allocs_1d - allocs_100d)/allocs_1d < 3
@test abs(allocs_1d - allocs_100d)/allocs_1d < 3
end
end

@testset "Allocs-SliceSampler" begin
Expand Down
4 changes: 3 additions & 1 deletion test/test_auto_mala.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ automala(target) =
end

@testset "Step size convergence" begin
for t in [toy_mvn_target(1), toy_stan_target(1)]
targets = Any[toy_mvn_target(1)]
is_windows_in_CI() || push!(targets, toy_stan_target(1))
for t in targets
step10rounds = pigeons(target = t, explorer = AutoMALA(), n_chains = 1, n_rounds = 10).shared.explorer.step_size
step15rounds = pigeons(target = t, explorer = AutoMALA(), n_chains = 1, n_rounds = 15).shared.explorer.step_size
@test isapprox(step10rounds, step15rounds, rtol = 0.1)
Expand Down
5 changes: 4 additions & 1 deletion test/test_moments.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
@testset "Moments" begin
targets = Any[toy_mvn_target(2)]
is_windows_in_CI() || push!(targets, toy_stan_target(2))
for variational in [nothing, GaussianReference()]
for target in [toy_mvn_target(2), toy_stan_target(2)]
for target in targets
@show variational, target
if !(variational isa GaussianReference) || !(target isa Pigeons.ScaledPrecisionNormalPath)
pt = pigeons(;
Expand All @@ -20,6 +22,7 @@
end
end
end

end
end
end
7 changes: 5 additions & 2 deletions test/test_parallelism_invariance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@ include("supporting/mpi_test_utils.jl")
@testset "Parallelism Invariance" begin
n_mpis = set_n_mpis_to_one_on_windows(2)
record = [swap_acceptance_pr, index_process, log_sum_ratio, round_trip, energy_ac1]
targets = Any[toy_mvn_target(1)]
is_windows_in_CI() || push!(targets, toy_stan_target(1))

# various explorers on a Julia function and on a Stan model
for explorer in [SliceSampler(), AutoMALA(), Compose(SliceSampler(), AutoMALA())]
for target in [toy_mvn_target(1), toy_stan_target(1)]
for target in targets
@show explorer, target
@show is_stan = target isa Pigeons.StanLogPotential

# setting to true puts too much pressure on CI instances? https://github.com/Julia-Tempering/Pigeons.jl/actions/runs/5627897144/job/15251121621?pr=90
# setting to true puts too much pressure on CI instances? https://github.com/Julia-Tempering/Pigeons.jl/actions/runs/5627897144/job/15251121621?pr=90
multithreaded = is_stan ? false : true

pigeons(;
Expand All @@ -26,6 +28,7 @@ include("supporting/mpi_test_utils.jl")
n_local_mpi_processes = n_mpis,
n_threads = multithreaded ? 2 : 1,
mpiexec_args = extra_mpi_args()))

end
end

Expand Down
4 changes: 3 additions & 1 deletion test/test_recorders.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
@testset "Default recorder groups" begin
for target in [toy_mvn_target(10), toy_stan_target(10), Pigeons.toy_turing_target(10)]
targets = Any[toy_mvn_target(10), Pigeons.toy_turing_target(10)]
is_windows_in_CI() || push!(targets, toy_stan_target(10))
for target in targets
for record in [record_online(), record_default(), []]
pigeons(; target, record)
end
Expand Down
4 changes: 3 additions & 1 deletion test/test_stan.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
if !is_windows_in_CI()
@testset "Stan examples" begin
pigeons(target = Pigeons.stan_eight_schools(true), n_rounds = 2, n_chains = 2)
pigeons(target = Pigeons.stan_eight_schools(false), n_rounds = 2, n_chains = 2)
Expand Down Expand Up @@ -29,4 +30,5 @@ end
=#
@test n_restarts > 40 # 100
end
end
end
end
16 changes: 7 additions & 9 deletions test/test_traces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@ using MCMCChains
@testset "Sample matrix" begin

for use_two_chains in [true, false]
targets = [
Pigeons.toy_stan_target(3),
Pigeons.toy_turing_target(3)
]
if !use_two_chains
push!(targets, toy_mvn_target(3))
end
targets = Any[Pigeons.toy_turing_target(3)]
use_two_chains || push!(targets, toy_mvn_target(3))
is_windows_in_CI() || push!(targets, Pigeons.toy_stan_target(3))


for target in targets

pt = pigeons(;
target,
record = [traces],
Expand All @@ -31,7 +27,9 @@ using MCMCChains
end

@testset "Traces" begin
for target in [toy_mvn_target(10), toy_stan_target(10), Pigeons.toy_turing_target(10)]
targets = Any[toy_mvn_target(10), Pigeons.toy_turing_target(10)]
is_windows_in_CI() || push!(targets, toy_stan_target(10))
for target in targets
r = pigeons(;
target,
record = [traces, disk, online],
Expand Down
3 changes: 2 additions & 1 deletion test/test_turing_stan_agree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ function logdensity_and_gradient(target, x)
g = LogDensityProblemsAD.ADgradient(:ForwardDiff, target, Pigeons.buffers())
return LogDensityProblems.logdensity_and_gradient(g, x)
end

if !is_windows_in_CI()
@testset "Gradient agreement" begin
turing_target = Pigeons.toy_turing_unid_target(10)
stan_target = Pigeons.toy_stan_unid_target(10)
Expand All @@ -25,3 +25,4 @@ end
end
end

end

0 comments on commit cf76ee7

Please sign in to comment.