Skip to content

Commit

Permalink
Merge pull request #143 from Julia-Tempering/extended-trace
Browse files Browse the repository at this point in the history
Add extended traces
  • Loading branch information
alexandrebouchard authored Sep 29, 2023
2 parents 0607ecb + 2a1f0ff commit 3a30ebf
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Pigeons"
uuid = "0eb8d820-af6a-4919-95ae-11206f830c31"
authors = ["Alexandre Bouchard-Côté <[email protected]>, Nikola Surjanovic <[email protected]>, Paul Tiede <[email protected]>, Trevor Campbell, Miguel Biron-Lattes, Saifuddin Syed"]
version = "0.2.4"
version = "0.2.5"

[deps]
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ makedocs(;
"Off-memory" => "output-off-memory.md",
"PT diagnostics" => "output-pt.md",
"Custom types" => "output-custom-types.md",
"Extended output" => "output-extended.md",
"MPI output" => "output-mpi-postprocessing.md"
],
"Checkpoints" => "checkpoints.md",
Expand Down
101 changes: 101 additions & 0 deletions docs/src/output-extended.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
```@meta
CurrentModule = Pigeons
```

# [Extended output (i.e., for all chains)](@id output-extended)

So far when outputting traces (either to memory via [`traces`](@ref) or to disk via [`disk`](@ref)),
we have been storing only the target distribution's samples.
This is the most common scenario and the default.
Here we show how to instead store the samples from all chains.

This can be useful in scenarios where all distributions $$\pi_i$$ are of interest, e.g.
in certain statistical mechanics applications and for Bayesian inference under model
mis-specification.

The key argument to add is `extended_traces = true`, which we demonstrate for
various common scenarios below.


## Posterior densities and trace plots for all chains

Make sure to have the third party `MCMCChains` and `StatsPlots`
packages installed via

```
using Pkg; Pkg.add("MCMCChains", "StatsPlots")
```

Then use the following:

```@example
using Pigeons
using MCMCChains
using StatsPlots
plotlyjs()
# example target: Binomial likelihood with parameter p = p1 * p2
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100, 50)
pt = pigeons(target = an_unidentifiable_model,
n_rounds = 12,
extended_traces = true,
# make sure to record the trace:
record = [traces; round_trip; record_default()])
# collect the statistics and convert to MCMCChains' Chains
# to have axes labels matching variable names in Turing and Stan
samples = Chains(sample_array(pt), variable_names(pt))
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "posterior_densities_and_traces_extended.html");
nothing # hide
```

Here the ten different colours correspond to the 10 chains interpolating between
the posterior and the prior (here a uniform distribution).

```@raw html
<iframe src="../posterior_densities_and_traces_extended.html" style="height:500px;width:100%;"></iframe>
```


## Off-memory processing for all chains

The same option, `extended_traces = true` can
be used in the same fashion to save to disk
samples from all chains:

```@example
using Pigeons
# example target: a 1000 dimensional target
high_d_target = Pigeons.toy_mvn_target(1000)
pt = pigeons(target = high_d_target,
checkpoint = true,
extended_traces = true,
record = [disk])
first_dim_of_each = zeros(10, 1024)
process_sample(pt) do chain, scan, sample # ordered as if we had an inner loop over scans
# each sample here is a Vector{Float64} of length 1000
# in general, it will is produced by extract_sample()
first_dim_of_each[chain, scan] = sample[1]
end
```

## Accessing the annealing parameters

To obtain the annealing parameter used to define each intermediate distribution, use:

```@example schedule
using Pigeons
an_unidentifiable_model = Pigeons.toy_turing_unid_target(100, 50)
pt = pigeons(target = an_unidentifiable_model)
pt.shared.tempering.schedule
```
1 change: 1 addition & 0 deletions docs/src/output-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ methods using either [the disk](@ref output-off-memory) or
- [PT-specific diagnostics.](@ref output-pt)
- [Post-processing for MPI runs.](@ref output-mpi-postprocessing)
- [Output for custom types.](@ref output-custom-types)
- [Extended output, i.e. including non-target chains](@ref output-extended)
- [Further customization using "recorders".](@ref collecting-statistics)
10 changes: 9 additions & 1 deletion src/pt/Inputs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,20 @@ $FIELDS
show_report::Bool = true

"""
Type of traces to collect:
Type of traces that the [`traces`](@ref) recorder will collect:
- `:samples` - `extract_sample()` is called on the state, or
- `:log_potential` - `log_potential()` is called on the state
"""
trace_type::Symbol = :samples

"""
Whether the [`traces`](@ref) and [`disk`](@ref) recorders will
store samples for all
the chains (extended = true) or just for the target(s)
(extended = false).
"""
extended_traces::Bool = false
end

"""
Expand Down
4 changes: 4 additions & 0 deletions src/pt/pigeons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,12 @@ function explore!(pt, replica, explorer)
end
process_ac!(log_potential, replica, before)
if is_target(pt.shared.tempering.swap_graphs, replica.chain)
# for the online stats, we ignore pt.inputs.extended_traces
# because the recorders do not support grouping by chains
@record_if_requested!(replica.recorders, :online, extract_sample(replica.state, log_potential))
@record_if_requested!(replica.recorders, :_transformed_online, replica.state)
end
if pt.inputs.extended_traces || is_target(pt.shared.tempering.swap_graphs, replica.chain)
@record_if_requested!(
replica.recorders,
:traces,
Expand Down
18 changes: 10 additions & 8 deletions src/pt/process_sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,24 @@ which can then be used to obtain summary statistics, diagnostics, create trace p
and pair plots (via [PairPlots](https://sefffal.github.io/PairPlots.jl/dev/chains/)).
"""
function sample_array(pt::PT)
targets = target_chains(pt)
dim, size = sample_dim_size(pt, targets)
result = zeros(size, dim, length(targets))
for t_index in eachindex(targets)
t = targets[t_index]
chains = chains_with_samples(pt)
dim, size = sample_dim_size(pt, chains)
result = zeros(size, dim, length(chains))
for chain_index in eachindex(chains)
t = chains[chain_index]
sample = get_sample(pt, t)
for i in 1:size
vector = sample[i]
result[i, :, t_index] .= vector
result[i, :, chain_index] .= vector
end
end
return result
end

function sample_dim_size(pt::PT, targets = target_chains(pt))
sample = get_sample(pt, targets[1])
chains_with_samples(pt) = pt.inputs.extended_traces ? (1:n_chains(pt.inputs)) : target_chains(pt)

function sample_dim_size(pt::PT, chains)
sample = get_sample(pt, chains[1])
return length(sample[1]), length(sample)
end

Expand Down
90 changes: 49 additions & 41 deletions test/test_traces.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,64 @@
using MCMCChains

@testset "Sample matrix" begin
for extended_traces in [true, false]
for use_two_chains in [true, false]
targets = Any[Pigeons.toy_turing_target(3)]
use_two_chains || push!(targets, toy_mvn_target(3))
is_windows_in_CI() || push!(targets, Pigeons.toy_stan_target(3))

for use_two_chains in [true, false]
targets = Any[Pigeons.toy_turing_target(3)]
use_two_chains || push!(targets, toy_mvn_target(3))
is_windows_in_CI() || push!(targets, Pigeons.toy_stan_target(3))
for target in targets
pt = pigeons(;
target,
extended_traces,
record = [traces],
n_rounds = 2,
n_chains_variational = use_two_chains ? 10 : 0,
variational = use_two_chains ? GaussianReference() : nothing
)


for target in targets
pt = pigeons(;
target,
record = [traces],
n_rounds = 2,
n_chains_variational = use_two_chains ? 10 : 0,
variational = use_two_chains ? GaussianReference() : nothing
)

mtx = sample_array(pt)
@test size(mtx) == (4, 3, use_two_chains ? 2 : 1)
@test length(variable_names(pt)) == 3
chain = Chains(sample_array(pt), variable_names(pt))
end
mtx = sample_array(pt)
@test size(mtx) == (4, 3, (use_two_chains ? 2 : 1) * (extended_traces ? 10 : 1))
@test length(variable_names(pt)) == 3
chain = Chains(sample_array(pt), variable_names(pt))
end
end
end

end

@testset "Traces" begin
targets = Any[toy_mvn_target(10), Pigeons.toy_turing_target(10)]
is_windows_in_CI() || push!(targets, toy_stan_target(10))
for target in targets
r = pigeons(;
target,
record = [traces, disk, online],
multithreaded = false, # setting to true puts too much pressure on CI instances? https://github.com/Julia-Tempering/Pigeons.jl/actions/runs/5627897144/job/15251121621?pr=90
checkpoint = true,
on = ChildProcess(n_local_mpi_processes = 2, n_threads = 1)) # setting to more than 1 puts too much pressure on CI instances?
pt = load(r)
@test length(pt.reduced_recorders.traces) == 1024
marginal = [get_sample(pt, 10, i)[1] for i in 1:1024]
s = get_sample(pt, 10)
@test marginal == first.(s)
@test abs(mean(marginal) - 0.0) < 0.05
@test isapprox(mean(marginal), mean(pt)[1], atol = 1e-10)
@test mean(marginal) mean(s)[1]
@test s[1] == get_sample(pt, 10, 1)
@test size(s)[1] == length(marginal)
@test_throws "You cannot" setindex!(s, s[2], 1)
# check that the disk serialization gives the same result
process_sample(pt) do chain, scan, sample
@test sample == get_sample(pt, chain, scan)
for extended_traces in [true, false]
for target in targets
r = pigeons(;
target,
record = [traces, disk, online],
extended_traces,
multithreaded = false, # setting to true puts too much pressure on CI instances? https://github.com/Julia-Tempering/Pigeons.jl/actions/runs/5627897144/job/15251121621?pr=90
checkpoint = true,
on = ChildProcess(n_local_mpi_processes = 2, n_threads = 1)) # setting to more than 1 puts too much pressure on CI instances?
pt = load(r)
@test length(pt.reduced_recorders.traces) == 1024 * (extended_traces ? 10 : 1)
for chain in Pigeons.chains_with_samples(pt)
marginal = [get_sample(pt, chain, i)[1] for i in 1:1024]
s = get_sample(pt, chain)
@test marginal == first.(s)
@test abs(mean(marginal) - 0.0) < 0.1
if !extended_traces
@test isapprox(mean(marginal), mean(pt)[1], atol = 1e-10)
end
@test mean(marginal) mean(s)[1]
@test s[1] == get_sample(pt, chain, 1)
@test size(s)[1] == length(marginal)
@test_throws "You cannot" setindex!(s, s[2], 1)
# check that the disk serialization gives the same result
process_sample(pt) do chain, scan, sample
@test sample == get_sample(pt, chain, scan)
end
end
end
end
end


2 comments on commit 3a30ebf

@alexandrebouchard
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/92489

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.5 -m "<description of version>" 3a30ebf1805dc8e6634bbbadad98b055defdfb4c
git push origin v0.2.5

Please sign in to comment.