diff --git a/Project.toml b/Project.toml index 76cc29b9c..a85d0e734 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Pigeons" uuid = "0eb8d820-af6a-4919-95ae-11206f830c31" authors = ["Alexandre Bouchard-Côté , Nikola Surjanovic , Paul Tiede , Trevor Campbell, Miguel Biron-Lattes, Saifuddin Syed"] -version = "0.2.4" +version = "0.2.5" [deps] BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a" diff --git a/docs/make.jl b/docs/make.jl index c9954f8f4..529330351 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -57,6 +57,7 @@ makedocs(; "Off-memory" => "output-off-memory.md", "PT diagnostics" => "output-pt.md", "Custom types" => "output-custom-types.md", + "Extended output" => "output-extended.md", "MPI output" => "output-mpi-postprocessing.md" ], "Checkpoints" => "checkpoints.md", diff --git a/docs/src/output-extended.md b/docs/src/output-extended.md new file mode 100644 index 000000000..764648ce4 --- /dev/null +++ b/docs/src/output-extended.md @@ -0,0 +1,101 @@ +```@meta +CurrentModule = Pigeons +``` + +# [Extended output (i.e., for all chains)](@id output-extended) + +So far when outputting traces (either to memory via [`traces`](@ref) or to disk via [`disk`](@ref)), +we have been storing only the target distribution's samples. +This is the most common scenario and the default. +Here we show how to instead store the samples from all chains. + +This can be useful in scenarios where all distributions $$\pi_i$$ are of interest, e.g. +in certain statistical mechanics applications and for Bayesian inference under model +mis-specification. + +The key argument to add is `extended_traces = true`, which we demonstrate for +various common scenarios below. + + +## Posterior densities and trace plots for all chains + +Make sure to have the third party `MCMCChains` and `StatsPlots` +packages installed via + +``` +using Pkg; Pkg.add("MCMCChains", "StatsPlots") +``` + +Then use the following: + +```@example +using Pigeons +using MCMCChains +using StatsPlots +plotlyjs() + +# example target: Binomial likelihood with parameter p = p1 * p2 +an_unidentifiable_model = Pigeons.toy_turing_unid_target(100, 50) + +pt = pigeons(target = an_unidentifiable_model, + n_rounds = 12, + extended_traces = true, + # make sure to record the trace: + record = [traces; round_trip; record_default()]) + +# collect the statistics and convert to MCMCChains' Chains +# to have axes labels matching variable names in Turing and Stan +samples = Chains(sample_array(pt), variable_names(pt)) + +# create the trace plots +my_plot = StatsPlots.plot(samples) +StatsPlots.savefig(my_plot, "posterior_densities_and_traces_extended.html"); +nothing # hide +``` + +Here the ten different colours correspond to the 10 chains interpolating between +the posterior and the prior (here a uniform distribution). + +```@raw html + +``` + + +## Off-memory processing for all chains + +The same option, `extended_traces = true` can +be used in the same fashion to save to disk +samples from all chains: + +```@example +using Pigeons + +# example target: a 1000 dimensional target +high_d_target = Pigeons.toy_mvn_target(1000) + +pt = pigeons(target = high_d_target, + checkpoint = true, + extended_traces = true, + record = [disk]) + +first_dim_of_each = zeros(10, 1024) +process_sample(pt) do chain, scan, sample # ordered as if we had an inner loop over scans + # each sample here is a Vector{Float64} of length 1000 + # in general, it will is produced by extract_sample() + first_dim_of_each[chain, scan] = sample[1] +end +``` + +## Accessing the annealing parameters + +To obtain the annealing parameter used to define each intermediate distribution, use: + +```@example schedule +using Pigeons + +an_unidentifiable_model = Pigeons.toy_turing_unid_target(100, 50) + +pt = pigeons(target = an_unidentifiable_model) + +pt.shared.tempering.schedule +``` diff --git a/docs/src/output-overview.md b/docs/src/output-overview.md index b5cf2ac59..7eca36663 100644 --- a/docs/src/output-overview.md +++ b/docs/src/output-overview.md @@ -20,4 +20,5 @@ methods using either [the disk](@ref output-off-memory) or - [PT-specific diagnostics.](@ref output-pt) - [Post-processing for MPI runs.](@ref output-mpi-postprocessing) - [Output for custom types.](@ref output-custom-types) +- [Extended output, i.e. including non-target chains](@ref output-extended) - [Further customization using "recorders".](@ref collecting-statistics) \ No newline at end of file diff --git a/src/pt/Inputs.jl b/src/pt/Inputs.jl index 400a74da8..8bfe04163 100644 --- a/src/pt/Inputs.jl +++ b/src/pt/Inputs.jl @@ -78,12 +78,20 @@ $FIELDS show_report::Bool = true """ - Type of traces to collect: + Type of traces that the [`traces`](@ref) recorder will collect: - `:samples` - `extract_sample()` is called on the state, or - `:log_potential` - `log_potential()` is called on the state """ trace_type::Symbol = :samples + + """ + Whether the [`traces`](@ref) and [`disk`](@ref) recorders will + store samples for all + the chains (extended = true) or just for the target(s) + (extended = false). + """ + extended_traces::Bool = false end """ diff --git a/src/pt/pigeons.jl b/src/pt/pigeons.jl index 74c45cbd9..3e3a30ddc 100644 --- a/src/pt/pigeons.jl +++ b/src/pt/pigeons.jl @@ -108,8 +108,12 @@ function explore!(pt, replica, explorer) end process_ac!(log_potential, replica, before) if is_target(pt.shared.tempering.swap_graphs, replica.chain) + # for the online stats, we ignore pt.inputs.extended_traces + # because the recorders do not support grouping by chains @record_if_requested!(replica.recorders, :online, extract_sample(replica.state, log_potential)) @record_if_requested!(replica.recorders, :_transformed_online, replica.state) + end + if pt.inputs.extended_traces || is_target(pt.shared.tempering.swap_graphs, replica.chain) @record_if_requested!( replica.recorders, :traces, diff --git a/src/pt/process_sample.jl b/src/pt/process_sample.jl index dcc4afe30..0aacc61b1 100644 --- a/src/pt/process_sample.jl +++ b/src/pt/process_sample.jl @@ -17,22 +17,24 @@ which can then be used to obtain summary statistics, diagnostics, create trace p and pair plots (via [PairPlots](https://sefffal.github.io/PairPlots.jl/dev/chains/)). """ function sample_array(pt::PT) - targets = target_chains(pt) - dim, size = sample_dim_size(pt, targets) - result = zeros(size, dim, length(targets)) - for t_index in eachindex(targets) - t = targets[t_index] + chains = chains_with_samples(pt) + dim, size = sample_dim_size(pt, chains) + result = zeros(size, dim, length(chains)) + for chain_index in eachindex(chains) + t = chains[chain_index] sample = get_sample(pt, t) for i in 1:size vector = sample[i] - result[i, :, t_index] .= vector + result[i, :, chain_index] .= vector end end return result end -function sample_dim_size(pt::PT, targets = target_chains(pt)) - sample = get_sample(pt, targets[1]) +chains_with_samples(pt) = pt.inputs.extended_traces ? (1:n_chains(pt.inputs)) : target_chains(pt) + +function sample_dim_size(pt::PT, chains) + sample = get_sample(pt, chains[1]) return length(sample[1]), length(sample) end diff --git a/test/test_traces.jl b/test/test_traces.jl index 58fa48ff2..6dca5d6ea 100644 --- a/test/test_traces.jl +++ b/test/test_traces.jl @@ -1,56 +1,64 @@ using MCMCChains @testset "Sample matrix" begin + for extended_traces in [true, false] + for use_two_chains in [true, false] + targets = Any[Pigeons.toy_turing_target(3)] + use_two_chains || push!(targets, toy_mvn_target(3)) + is_windows_in_CI() || push!(targets, Pigeons.toy_stan_target(3)) - for use_two_chains in [true, false] - targets = Any[Pigeons.toy_turing_target(3)] - use_two_chains || push!(targets, toy_mvn_target(3)) - is_windows_in_CI() || push!(targets, Pigeons.toy_stan_target(3)) + for target in targets + pt = pigeons(; + target, + extended_traces, + record = [traces], + n_rounds = 2, + n_chains_variational = use_two_chains ? 10 : 0, + variational = use_two_chains ? GaussianReference() : nothing + ) - - for target in targets - pt = pigeons(; - target, - record = [traces], - n_rounds = 2, - n_chains_variational = use_two_chains ? 10 : 0, - variational = use_two_chains ? GaussianReference() : nothing - ) - - mtx = sample_array(pt) - @test size(mtx) == (4, 3, use_two_chains ? 2 : 1) - @test length(variable_names(pt)) == 3 - chain = Chains(sample_array(pt), variable_names(pt)) - end + mtx = sample_array(pt) + @test size(mtx) == (4, 3, (use_two_chains ? 2 : 1) * (extended_traces ? 10 : 1)) + @test length(variable_names(pt)) == 3 + chain = Chains(sample_array(pt), variable_names(pt)) + end + end end - end @testset "Traces" begin targets = Any[toy_mvn_target(10), Pigeons.toy_turing_target(10)] is_windows_in_CI() || push!(targets, toy_stan_target(10)) - for target in targets - r = pigeons(; - target, - record = [traces, disk, online], - multithreaded = false, # setting to true puts too much pressure on CI instances? https://github.com/Julia-Tempering/Pigeons.jl/actions/runs/5627897144/job/15251121621?pr=90 - checkpoint = true, - on = ChildProcess(n_local_mpi_processes = 2, n_threads = 1)) # setting to more than 1 puts too much pressure on CI instances? - pt = load(r) - @test length(pt.reduced_recorders.traces) == 1024 - marginal = [get_sample(pt, 10, i)[1] for i in 1:1024] - s = get_sample(pt, 10) - @test marginal == first.(s) - @test abs(mean(marginal) - 0.0) < 0.05 - @test isapprox(mean(marginal), mean(pt)[1], atol = 1e-10) - @test mean(marginal) ≈ mean(s)[1] - @test s[1] == get_sample(pt, 10, 1) - @test size(s)[1] == length(marginal) - @test_throws "You cannot" setindex!(s, s[2], 1) - # check that the disk serialization gives the same result - process_sample(pt) do chain, scan, sample - @test sample == get_sample(pt, chain, scan) + for extended_traces in [true, false] + for target in targets + r = pigeons(; + target, + record = [traces, disk, online], + extended_traces, + multithreaded = false, # setting to true puts too much pressure on CI instances? https://github.com/Julia-Tempering/Pigeons.jl/actions/runs/5627897144/job/15251121621?pr=90 + checkpoint = true, + on = ChildProcess(n_local_mpi_processes = 2, n_threads = 1)) # setting to more than 1 puts too much pressure on CI instances? + pt = load(r) + @test length(pt.reduced_recorders.traces) == 1024 * (extended_traces ? 10 : 1) + for chain in Pigeons.chains_with_samples(pt) + marginal = [get_sample(pt, chain, i)[1] for i in 1:1024] + s = get_sample(pt, chain) + @test marginal == first.(s) + @test abs(mean(marginal) - 0.0) < 0.1 + if !extended_traces + @test isapprox(mean(marginal), mean(pt)[1], atol = 1e-10) + end + @test mean(marginal) ≈ mean(s)[1] + @test s[1] == get_sample(pt, chain, 1) + @test size(s)[1] == length(marginal) + @test_throws "You cannot" setindex!(s, s[2], 1) + # check that the disk serialization gives the same result + process_sample(pt) do chain, scan, sample + @test sample == get_sample(pt, chain, scan) + end + end end end end +