-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #143 from Julia-Tempering/extended-trace
Add extended traces
- Loading branch information
Showing
8 changed files
with
176 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
name = "Pigeons" | ||
uuid = "0eb8d820-af6a-4919-95ae-11206f830c31" | ||
authors = ["Alexandre Bouchard-Côté <[email protected]>, Nikola Surjanovic <[email protected]>, Paul Tiede <[email protected]>, Trevor Campbell, Miguel Biron-Lattes, Saifuddin Syed"] | ||
version = "0.2.4" | ||
version = "0.2.5" | ||
|
||
[deps] | ||
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
```@meta | ||
CurrentModule = Pigeons | ||
``` | ||
|
||
# [Extended output (i.e., for all chains)](@id output-extended) | ||
|
||
So far when outputting traces (either to memory via [`traces`](@ref) or to disk via [`disk`](@ref)), | ||
we have been storing only the target distribution's samples. | ||
This is the most common scenario and the default. | ||
Here we show how to instead store the samples from all chains. | ||
|
||
This can be useful in scenarios where all distributions $$\pi_i$$ are of interest, e.g. | ||
in certain statistical mechanics applications and for Bayesian inference under model | ||
mis-specification. | ||
|
||
The key argument to add is `extended_traces = true`, which we demonstrate for | ||
various common scenarios below. | ||
|
||
|
||
## Posterior densities and trace plots for all chains | ||
|
||
Make sure to have the third party `MCMCChains` and `StatsPlots` | ||
packages installed via | ||
|
||
``` | ||
using Pkg; Pkg.add("MCMCChains", "StatsPlots") | ||
``` | ||
|
||
Then use the following: | ||
|
||
```@example | ||
using Pigeons | ||
using MCMCChains | ||
using StatsPlots | ||
plotlyjs() | ||
# example target: Binomial likelihood with parameter p = p1 * p2 | ||
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100, 50) | ||
pt = pigeons(target = an_unidentifiable_model, | ||
n_rounds = 12, | ||
extended_traces = true, | ||
# make sure to record the trace: | ||
record = [traces; round_trip; record_default()]) | ||
# collect the statistics and convert to MCMCChains' Chains | ||
# to have axes labels matching variable names in Turing and Stan | ||
samples = Chains(sample_array(pt), variable_names(pt)) | ||
# create the trace plots | ||
my_plot = StatsPlots.plot(samples) | ||
StatsPlots.savefig(my_plot, "posterior_densities_and_traces_extended.html"); | ||
nothing # hide | ||
``` | ||
|
||
Here the ten different colours correspond to the 10 chains interpolating between | ||
the posterior and the prior (here a uniform distribution). | ||
|
||
```@raw html | ||
<iframe src="../posterior_densities_and_traces_extended.html" style="height:500px;width:100%;"></iframe> | ||
``` | ||
|
||
|
||
## Off-memory processing for all chains | ||
|
||
The same option, `extended_traces = true` can | ||
be used in the same fashion to save to disk | ||
samples from all chains: | ||
|
||
```@example | ||
using Pigeons | ||
# example target: a 1000 dimensional target | ||
high_d_target = Pigeons.toy_mvn_target(1000) | ||
pt = pigeons(target = high_d_target, | ||
checkpoint = true, | ||
extended_traces = true, | ||
record = [disk]) | ||
first_dim_of_each = zeros(10, 1024) | ||
process_sample(pt) do chain, scan, sample # ordered as if we had an inner loop over scans | ||
# each sample here is a Vector{Float64} of length 1000 | ||
# in general, it will is produced by extract_sample() | ||
first_dim_of_each[chain, scan] = sample[1] | ||
end | ||
``` | ||
|
||
## Accessing the annealing parameters | ||
|
||
To obtain the annealing parameter used to define each intermediate distribution, use: | ||
|
||
```@example schedule | ||
using Pigeons | ||
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100, 50) | ||
pt = pigeons(target = an_unidentifiable_model) | ||
pt.shared.tempering.schedule | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,64 @@ | ||
using MCMCChains | ||
|
||
@testset "Sample matrix" begin | ||
for extended_traces in [true, false] | ||
for use_two_chains in [true, false] | ||
targets = Any[Pigeons.toy_turing_target(3)] | ||
use_two_chains || push!(targets, toy_mvn_target(3)) | ||
is_windows_in_CI() || push!(targets, Pigeons.toy_stan_target(3)) | ||
|
||
for use_two_chains in [true, false] | ||
targets = Any[Pigeons.toy_turing_target(3)] | ||
use_two_chains || push!(targets, toy_mvn_target(3)) | ||
is_windows_in_CI() || push!(targets, Pigeons.toy_stan_target(3)) | ||
for target in targets | ||
pt = pigeons(; | ||
target, | ||
extended_traces, | ||
record = [traces], | ||
n_rounds = 2, | ||
n_chains_variational = use_two_chains ? 10 : 0, | ||
variational = use_two_chains ? GaussianReference() : nothing | ||
) | ||
|
||
|
||
for target in targets | ||
pt = pigeons(; | ||
target, | ||
record = [traces], | ||
n_rounds = 2, | ||
n_chains_variational = use_two_chains ? 10 : 0, | ||
variational = use_two_chains ? GaussianReference() : nothing | ||
) | ||
|
||
mtx = sample_array(pt) | ||
@test size(mtx) == (4, 3, use_two_chains ? 2 : 1) | ||
@test length(variable_names(pt)) == 3 | ||
chain = Chains(sample_array(pt), variable_names(pt)) | ||
end | ||
mtx = sample_array(pt) | ||
@test size(mtx) == (4, 3, (use_two_chains ? 2 : 1) * (extended_traces ? 10 : 1)) | ||
@test length(variable_names(pt)) == 3 | ||
chain = Chains(sample_array(pt), variable_names(pt)) | ||
end | ||
end | ||
end | ||
|
||
end | ||
|
||
@testset "Traces" begin | ||
targets = Any[toy_mvn_target(10), Pigeons.toy_turing_target(10)] | ||
is_windows_in_CI() || push!(targets, toy_stan_target(10)) | ||
for target in targets | ||
r = pigeons(; | ||
target, | ||
record = [traces, disk, online], | ||
multithreaded = false, # setting to true puts too much pressure on CI instances? https://github.com/Julia-Tempering/Pigeons.jl/actions/runs/5627897144/job/15251121621?pr=90 | ||
checkpoint = true, | ||
on = ChildProcess(n_local_mpi_processes = 2, n_threads = 1)) # setting to more than 1 puts too much pressure on CI instances? | ||
pt = load(r) | ||
@test length(pt.reduced_recorders.traces) == 1024 | ||
marginal = [get_sample(pt, 10, i)[1] for i in 1:1024] | ||
s = get_sample(pt, 10) | ||
@test marginal == first.(s) | ||
@test abs(mean(marginal) - 0.0) < 0.05 | ||
@test isapprox(mean(marginal), mean(pt)[1], atol = 1e-10) | ||
@test mean(marginal) ≈ mean(s)[1] | ||
@test s[1] == get_sample(pt, 10, 1) | ||
@test size(s)[1] == length(marginal) | ||
@test_throws "You cannot" setindex!(s, s[2], 1) | ||
# check that the disk serialization gives the same result | ||
process_sample(pt) do chain, scan, sample | ||
@test sample == get_sample(pt, chain, scan) | ||
for extended_traces in [true, false] | ||
for target in targets | ||
r = pigeons(; | ||
target, | ||
record = [traces, disk, online], | ||
extended_traces, | ||
multithreaded = false, # setting to true puts too much pressure on CI instances? https://github.com/Julia-Tempering/Pigeons.jl/actions/runs/5627897144/job/15251121621?pr=90 | ||
checkpoint = true, | ||
on = ChildProcess(n_local_mpi_processes = 2, n_threads = 1)) # setting to more than 1 puts too much pressure on CI instances? | ||
pt = load(r) | ||
@test length(pt.reduced_recorders.traces) == 1024 * (extended_traces ? 10 : 1) | ||
for chain in Pigeons.chains_with_samples(pt) | ||
marginal = [get_sample(pt, chain, i)[1] for i in 1:1024] | ||
s = get_sample(pt, chain) | ||
@test marginal == first.(s) | ||
@test abs(mean(marginal) - 0.0) < 0.1 | ||
if !extended_traces | ||
@test isapprox(mean(marginal), mean(pt)[1], atol = 1e-10) | ||
end | ||
@test mean(marginal) ≈ mean(s)[1] | ||
@test s[1] == get_sample(pt, chain, 1) | ||
@test size(s)[1] == length(marginal) | ||
@test_throws "You cannot" setindex!(s, s[2], 1) | ||
# check that the disk serialization gives the same result | ||
process_sample(pt) do chain, scan, sample | ||
@test sample == get_sample(pt, chain, scan) | ||
end | ||
end | ||
end | ||
end | ||
end | ||
|
||
|
3a30ebf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
3a30ebf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/92489
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: