From 3cdde3e21032b3fd8c6b8640453fcbd1c665b790 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 21:51:44 +0100 Subject: [PATCH 01/45] Update discretediag --- src/discretediag.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/discretediag.jl b/src/discretediag.jl index ddf17248..ba6661a4 100644 --- a/src/discretediag.jl +++ b/src/discretediag.jl @@ -441,9 +441,9 @@ function discretediag( ) 0 < frac < 1 || throw(ArgumentError("`frac` must be in (0,1)")) - num_iters = size(chains, 1) + num_iters = size(chains, 2) between_chain_vals, within_chain_vals, _, _ = discretediag_sub( - chains, frac, method, nsim, num_iters, num_iters + permutedims(chains, (2, 1, 3)), frac, method, nsim, num_iters, num_iters ) return between_chain_vals, within_chain_vals From 7764bb216d6f8adcd27d535f76f73acdabb490ba Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 21:57:57 +0100 Subject: [PATCH 02/45] Update ess_rhat --- src/ess.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index ee06b13c..72ad8133 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -211,8 +211,8 @@ function ess_rhat( maxlag::Int=250, ) # compute size of matrices (each chain is split!) - niter = size(chains, 1) ÷ 2 - nparams = size(chains, 2) + niter = size(chains, 2) ÷ 2 + nparams = size(chains, 1) nchains = 2 * size(chains, 3) ntotal = niter * nchains @@ -238,7 +238,7 @@ function ess_rhat( rhat = Vector{T}(undef, nparams) # for each parameter - for (i, chains_slice) in enumerate((view(chains, :, i, :) for i in axes(chains, 2))) + for (i, chains_slice) in enumerate((selectdim(chains, 1, i) for i in axes(chains, 1))) # check that no values are missing if any(x -> x === missing, chains_slice) rhat[i] = missing From f2b3f5e049afcd80e83da5fb98899515ec8ae24c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 21:58:20 +0100 Subject: [PATCH 03/45] Update gelmandiag --- src/gelmandiag.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gelmandiag.jl b/src/gelmandiag.jl index 87626672..e4981a30 100644 --- a/src/gelmandiag.jl +++ b/src/gelmandiag.jl @@ -1,14 +1,14 @@ function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05) - niters, nparams, nchains = size(psi) + nparams, niters, nchains = size(psi) nchains > 1 || error("Gelman diagnostic requires at least 2 chains") rfixed = (niters - 1) / niters rrandomscale = (nchains + 1) / (nchains * niters) - S2 = map(Statistics.cov, (view(psi, :, :, i) for i in axes(psi, 3))) + S2 = map(x -> Statistics.cov(x; dims=2), (view(psi, :, :, i) for i in axes(psi, 3))) W = Statistics.mean(S2) - psibar = dropdims(Statistics.mean(psi; dims=1); dims=1)' + psibar = dropdims(Statistics.mean(psi; dims=2); dims=2)' B = niters .* Statistics.cov(psibar) w = LinearAlgebra.diag(W) @@ -75,7 +75,7 @@ end Compute the multivariate Gelman, Rubin and Brooks diagnostics. """ function gelmandiag_multivariate(chains::AbstractArray{<:Real,3}; kwargs...) - niters, nparams, nchains = size(chains) + nparams, niters, nchains = size(chains) if nparams < 2 error( "computation of the multivariate potential scale reduction factor requires ", From 10e3814aa801753fefb28b1bab000133939ddef9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 21:58:44 +0100 Subject: [PATCH 04/45] Update rstar --- src/rstar.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rstar.jl b/src/rstar.jl index 2def08d4..2ce0e659 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -73,7 +73,7 @@ function rstar( verbosity::Int=0, ) # checks - size(x, 1) != length(y) && throw(DimensionMismatch()) + size(x, 2) != length(y) && throw(DimensionMismatch()) 0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)")) # randomly sub-select training and testing set @@ -88,11 +88,11 @@ function rstar( # train classifier on training data ycategorical = MLJModelInterface.categorical(y) fitresult, _ = MLJModelInterface.fit( - classifier, verbosity, Tables.table(x[train_ids, :]), ycategorical[train_ids] + classifier, verbosity, Tables.table(x[:, train_ids]'), ycategorical[train_ids] ) # compute predictions on test data - xtest = Tables.table(x[test_ids, :]) + xtest = Tables.table(x[:, test_ids]') predictions = _predict(classifier, fitresult, xtest) # compute statistic From 86297e9da4748af4f3b24dac96a51876a2e255ca Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 21:59:43 +0100 Subject: [PATCH 05/45] Add 3d array method for rstar --- src/rstar.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/rstar.jl b/src/rstar.jl index 2ce0e659..964031eb 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -114,6 +114,17 @@ function _predict(model::MLJModelInterface.Model, fitresult, x) end end +function rstar( + rng::Random.AbstractRNG, + classifier::MLJModelInterface.Supervised, + x::AbstractArray{<:Any,3}; + kwargs... +) + samples = reshape(x, size(x, 1), :) + chain_inds = repeat(axes(x, 3); inner=size(x, 2)) + return rstar(rng, classifier, samples, chain_inds; kwargs...) +end + function rstar( classif::MLJModelInterface.Supervised, x::AbstractMatrix, @@ -123,6 +134,14 @@ function rstar( return rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...) end +function rstar( + classif::MLJModelInterface.Supervised, + x::AbstractArray{<:Any,3}; + kwargs..., +) + return rstar(Random.GLOBAL_RNG, classif, x; kwargs...) +end + # R⋆ for deterministic predictions (algorithm 1) function _rstar(predictions::AbstractVector{T}, ytest::AbstractVector{T}) where {T} length(predictions) == length(ytest) || From d5609e2eaf7751de3d18b5581585c9e8d7556f52 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 21:59:59 +0100 Subject: [PATCH 06/45] Add 3d array method for mcse --- src/mcse.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/mcse.jl b/src/mcse.jl index 063ab8b4..5900be0d 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -23,6 +23,9 @@ function mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...) throw(ArgumentError("unsupported MCSE method $method")) end end +function mcse(x::AbstractArray{<:Real,3}; kwargs...) + return dropdims(mapslices(xi -> mcse(vec(xi); kwargs...), x; dims=(2, 3)); dims=(2, 3)) +end function mcse_bm(x::AbstractVector{<:Real}; size::Int=floor(Int, sqrt(length(x)))) n = length(x) From 79adf82302cd5cc678cb39d38640b3674789b8c5 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 22:00:39 +0100 Subject: [PATCH 07/45] Update docstrings --- src/discretediag.jl | 6 ++++-- src/ess.jl | 2 +- src/gelmandiag.jl | 10 ++++++---- src/mcse.jl | 6 ++++-- src/rstar.jl | 18 +++++++++--------- 5 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/discretediag.jl b/src/discretediag.jl index ba6661a4..7b5c7434 100644 --- a/src/discretediag.jl +++ b/src/discretediag.jl @@ -423,9 +423,11 @@ function discretediag_sub( end """ - discretediag(chains::AbstractArray{<:Real,3}; frac=0.3, method=:weiss, nsim=1_000) + discretediag(samples::AbstractArray{<:Real,3}; frac=0.3, method=:weiss, nsim=1_000) -Compute discrete diagnostic where `method` can be one of `:weiss`, `:hangartner`, +Compute discrete diagnostic on `samples` with shape `(parameters, draws, chains)`. + +`method` can be one of `:weiss`, `:hangartner`, `:DARBOOT`, `:MCBOOT`, `:billinsgley`, and `:billingsleyBOOT`. # References diff --git a/src/ess.jl b/src/ess.jl index 72ad8133..3edb667a 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -201,7 +201,7 @@ end ) Estimate the effective sample size and the potential scale reduction of the `samples` of -shape (draws, parameters, chains) with the `method` and a maximum lag of `maxlag`. +shape `(parameters, draws, chains)` with the `method` and a maximum lag of `maxlag`. See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref) """ diff --git a/src/gelmandiag.jl b/src/gelmandiag.jl index e4981a30..31a13607 100644 --- a/src/gelmandiag.jl +++ b/src/gelmandiag.jl @@ -52,9 +52,10 @@ function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05) end """ - gelmandiag(chains::AbstractArray{<:Real,3}; alpha::Real=0.95) + gelmandiag(samples::AbstractArray{<:Real,3}; alpha::Real=0.95) -Compute the Gelman, Rubin and Brooks diagnostics [^Gelman1992] [^Brooks1998]. Values of the +Compute the Gelman, Rubin and Brooks diagnostics [^Gelman1992] [^Brooks1998] on `samples` +with shape `(parameters, draws, chains)`. Values of the diagnostic’s potential scale reduction factor (PSRF) that are close to one suggest convergence. As a rule-of-thumb, convergence is rejected if the 97.5 percentile of a PSRF is greater than 1.2. @@ -70,9 +71,10 @@ function gelmandiag(chains::AbstractArray{<:Real,3}; kwargs...) end """ - gelmandiag_multivariate(chains::AbstractArray{<:Real,3}; alpha::Real=0.05) + gelmandiag_multivariate(samples::AbstractArray{<:Real,3}; alpha::Real=0.05) -Compute the multivariate Gelman, Rubin and Brooks diagnostics. +Compute the multivariate Gelman, Rubin and Brooks diagnostics on `samples` with shape +`(parameters, draws, chains)`. """ function gelmandiag_multivariate(chains::AbstractArray{<:Real,3}; kwargs...) nparams, niters, nchains = size(chains) diff --git a/src/mcse.jl b/src/mcse.jl index 5900be0d..f092345f 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -1,7 +1,9 @@ """ - mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...) + mcse(samples::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...) + mcse(samples::AbstractArray{<:Real,3}; method::Symbol=:imse, kwargs...) -Compute the Monte Carlo standard error (MCSE) of samples `x`. +Compute the Monte Carlo standard error (MCSE) of `samples` of shape `(draws,)` or +`(parameters, draws, chains)` The optional argument `method` describes how the errors are estimated. Possible options are: - `:bm` for batch means [^Glynn1991] diff --git a/src/rstar.jl b/src/rstar.jl index 964031eb..eff0ee83 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -2,14 +2,16 @@ rstar( rng=Random.GLOBAL_RNG, classifier, - samples::AbstractMatrix, - chain_indices::AbstractVector{Int}; + samples::AbstractArray, + [chain_indices::AbstractVector{Int}]; subset::Real=0.8, verbosity::Int=0, ) -Compute the ``R^*`` convergence statistic of the `samples` with shape (draws, parameters) -and corresponding chains `chain_indices` with the `classifier`. +Compute the ``R^*`` convergence statistic of the `samples` with the `classifier`. + +Either `samples` has shape `(parameters, draws, chains)`, or `samples` has shape +`(parameters, draws)` and `chain_indices` must be provided. This implementation is an adaption of algorithms 1 and 2 described by Lambert and Vehtari. @@ -32,16 +34,14 @@ is returned (algorithm 2). ```jldoctest rstar; setup = :(using Random; Random.seed!(100)) julia> using MLJBase, MLJXGBoostInterface, Statistics -julia> samples = fill(4.0, 300, 2); - -julia> chain_indices = repeat(1:3; outer=100); +julia> samples = fill(4.0, 2, 100, 3); ``` One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the probabilistic classifier. ```jldoctest rstar -julia> distribution = rstar(XGBoostClassifier(), samples, chain_indices); +julia> distribution = rstar(XGBoostClassifier(), samples); julia> isapprox(mean(distribution), 1; atol=0.1) true @@ -54,7 +54,7 @@ predicting the mode. In MLJ this corresponds to a pipeline of models. ```jldoctest rstar julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode); -julia> value = rstar(xgboost_deterministic, samples, chain_indices); +julia> value = rstar(xgboost_deterministic, samples); julia> isapprox(value, 1; atol=0.2) true From 1370c7ac71c0bf14895cc13967f811de11c31e8e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 22:01:06 +0100 Subject: [PATCH 08/45] Change dimension order in tests --- test/discretediag.jl | 2 +- test/ess.jl | 10 +++++----- test/gelmandiag.jl | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/discretediag.jl b/test/discretediag.jl index 8415614f..24d2f509 100644 --- a/test/discretediag.jl +++ b/test/discretediag.jl @@ -1,7 +1,7 @@ @testset "discretediag.jl" begin nparams = 4 nchains = 2 - samples = rand(-100:100, 100, nparams, nchains) + samples = rand(-100:100, nparams, 100, nchains) @testset "results" begin for method in diff --git a/test/ess.jl b/test/ess.jl index 4528c735..8b14309c 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -31,7 +31,7 @@ end @testset "ESS and R̂ (IID samples)" begin - rawx = randn(10_000, 40, 10) + rawx = randn(40, 10_000, 10) # Repeat tests with different scales for scale in (1, 50, 100) @@ -58,7 +58,7 @@ end @testset "ESS and R̂ (identical samples)" begin - x = ones(10_000, 40, 10) + x = ones(40, 10_000, 10) ess_standard, rhat_standard = ess_rhat(x) ess_standard2, rhat_standard2 = ess_rhat(x; method=ESSMethod()) @@ -75,15 +75,15 @@ end @testset "ESS and R̂ (single sample)" begin # check that issue #137 is fixed - x = rand(1, 5, 3) + x = rand(5, 1, 3) for method in (ESSMethod(), FFTESSMethod(), BDAESSMethod()) # analyze array ess_array, rhat_array = ess_rhat(x; method=method) - @test length(ess_array) == size(x, 2) + @test length(ess_array) == size(x, 1) @test all(ismissing, ess_array) # since min(maxlag, niter - 1) = 0 - @test length(rhat_array) == size(x, 2) + @test length(rhat_array) == size(x, 1) @test all(ismissing, rhat_array) end end diff --git a/test/gelmandiag.jl b/test/gelmandiag.jl index 045459e1..ab5b4ba0 100644 --- a/test/gelmandiag.jl +++ b/test/gelmandiag.jl @@ -1,7 +1,7 @@ @testset "gelmandiag.jl" begin nparams = 4 nchains = 2 - samples = randn(100, nparams, nchains) + samples = randn(nparams, 100, nchains) @testset "results" begin result = @inferred(gelmandiag(samples)) @@ -24,6 +24,6 @@ @testset "exceptions" begin @test_throws ErrorException gelmandiag(samples[:, :, 1:1]) - @test_throws ErrorException gelmandiag_multivariate(samples[:, 1:1, :]) + @test_throws ErrorException gelmandiag_multivariate(samples[1:1, :, :]) end end From 29cba7f04d0b7a623813ad4aeee1d22c68d5c3f3 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 22:14:45 +0100 Subject: [PATCH 09/45] Update rstar tests with permuted dims --- test/rstar/runtests.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/rstar/runtests.jl b/test/rstar/runtests.jl index 00869743..af28fdae 100644 --- a/test/rstar/runtests.jl +++ b/test/rstar/runtests.jl @@ -15,8 +15,8 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo @testset "examples (classifier = $classifier)" for classifier in classifiers # Compute R⋆ statistic for a mixed chain. - samples = randn(N, 2) - dist = rstar(classifier, randn(N, 2), rand(1:3, N)) + samples = randn(2, N) + dist = rstar(classifier, samples, rand(1:3, N)) # Mean of the statistic should be focused around 1, i.e., the classifier does not # perform better than random guessing. @@ -31,7 +31,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo @test mean(dist) ≈ 1 rtol = 0.2 # Compute R⋆ statistic for a mixed chain. - samples = randn(4 * N, 8) + samples = randn(8, 4 * N) chain_indices = repeat(1:4, N) dist = rstar(classifier, samples, chain_indices) @@ -48,10 +48,10 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo @test mean(dist) ≈ 1 rtol = 0.15 # Compute the R⋆ statistic for a non-mixed chain. - samples = [ + samples = permutedims([ sin.(1:N) cos.(1:N) 100 .* cos.(1:N) 100 .* sin.(1:N) - ] + ]) chain_indices = repeat(1:2; inner=N) dist = rstar(classifier, samples, chain_indices) @@ -69,10 +69,10 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo end @testset "exceptions (classifier = $classifier)" for classifier in classifiers - @test_throws DimensionMismatch rstar(classifier, randn(N - 1, 2), rand(1:3, N)) + @test_throws DimensionMismatch rstar(classifier, randn(2, N - 1), rand(1:3, N)) for subset in (-0.3, 0, 1 / (3 * N), 1 - 1 / (3 * N), 1, 1.9) @test_throws ArgumentError rstar( - classifier, randn(N, 2), rand(1:3, N); subset=subset + classifier, randn(2, N), rand(1:3, N); subset=subset ) end end From 201bdec06aad316d5b6120660e8c0a3391429f5f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 22:15:17 +0100 Subject: [PATCH 10/45] Test rstar with 3d array --- test/rstar/Project.toml | 1 + test/rstar/runtests.jl | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/test/rstar/Project.toml b/test/rstar/Project.toml index d68653cc..267a56f2 100644 --- a/test/rstar/Project.toml +++ b/test/rstar/Project.toml @@ -4,6 +4,7 @@ MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/rstar/runtests.jl b/test/rstar/runtests.jl index af28fdae..e91beda4 100644 --- a/test/rstar/runtests.jl +++ b/test/rstar/runtests.jl @@ -5,6 +5,7 @@ using MLJBase using MLJLIBSVMInterface using MLJXGBoostInterface +using Random using Test const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode) @@ -76,4 +77,28 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo ) end end + + @testset "matrix with chain_inds produces same result as 3d array" begin + nparams = 2 + nchains = 3 + samples = randn(nparams, N, nchains) + + # manually construct samples_mat and chain_inds for comparison + samples_mat = Matrix{Float64}(undef, nparams, N * nchains) + chain_inds = Vector{Int}(undef, N * nchains) + i = 1 + for chain in 1:nchains, draw in 1:N + samples_mat[:, i] = samples[:, draw, chain] + chain_inds[i] = chain + i += 1 + end + + @testset "classifier = $classifier" for classifier in classifiers + rng = MersenneTwister(42) + dist1 = rstar(rng, classifier, samples_mat, chain_inds) + Random.seed!(rng, 42) + dist2 = rstar(rng, classifier, samples) + @test dist1 == dist2 + end + end end From 163ea884a6622987f6828b9a79de0541a5e21b6b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 22:34:58 +0100 Subject: [PATCH 11/45] Test mcse with 3d array --- test/mcse.jl | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/test/mcse.jl b/test/mcse.jl index 3e54d447..42037e89 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -1,7 +1,6 @@ @testset "mcse.jl" begin - samples = randn(100) - - @testset "results" begin + @testset "results 1d" begin + samples = randn(100) result = @inferred(mcse(samples)) @test result isa Float64 @test result > 0 @@ -13,13 +12,32 @@ end end + @testset "results 3d" begin + nparams = 2 + nchains = 4 + samples = randn(nparams, 100, nchains) + result = mcse(samples) # mapslices is not type-inferrable + @test result isa Vector{Float64} + @test length(result) == nparams + @test all(r -> r > 0, result) + + for method in (:imse, :ipse, :bm) + result = mcse(samples) # mapslices is not type-inferrable + @test result isa Vector{Float64} + @test length(result) == nparams + @test all(r -> r > 0, result) + end + end + @testset "warning" begin + samples = randn(100) for size in (51, 75, 100, 153) @test_logs (:warn,) mcse(samples; method=:bm, size=size) end end @testset "exception" begin + samples = randn(100) @test_throws ArgumentError mcse(samples; method=:somemethod) end end From 66d2c70802238cf3e07f285039edaceba906923f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 22:35:20 +0100 Subject: [PATCH 12/45] Increment version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c6298096..54afb63f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.1.5" +version = "0.2.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From 9215e2f03ac7877aef8d88420ae7d4480b553d10 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 22:43:27 +0100 Subject: [PATCH 13/45] Bump compat --- docs/Project.toml | 2 +- test/rstar/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 8e5438fb..c447a6ea 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,7 +8,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] Documenter = "0.27" -MCMCDiagnosticTools = "0.1" +MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" MLJXGBoostInterface = "0.1, 0.2" julia = "1.3" diff --git a/test/rstar/Project.toml b/test/rstar/Project.toml index 267a56f2..24c04fb3 100644 --- a/test/rstar/Project.toml +++ b/test/rstar/Project.toml @@ -9,7 +9,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Distributions = "0.25" -MCMCDiagnosticTools = "0.1" +MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" MLJLIBSVMInterface = "0.1, 0.2" MLJXGBoostInterface = "0.1, 0.2" From b733691d5820826362cd83fd68887523ef0f74fe Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 22:44:29 +0100 Subject: [PATCH 14/45] Run formatter --- src/rafterydiag.jl | 2 +- src/rstar.jl | 8 ++------ test/gewekediag.jl | 2 +- test/mcse.jl | 2 +- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/rafterydiag.jl b/src/rafterydiag.jl index 8f34f0b8..a6960838 100644 --- a/src/rafterydiag.jl +++ b/src/rafterydiag.jl @@ -38,7 +38,7 @@ function rafterydiag( dichot = Int[(x .<= StatsBase.quantile(x, q))...] kthin = 0 bic = 1.0 - local test , ntest + local test, ntest while bic >= 0.0 kthin += 1 test = dichot[1:kthin:nx] diff --git a/src/rstar.jl b/src/rstar.jl index eff0ee83..da69bd5a 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -118,7 +118,7 @@ function rstar( rng::Random.AbstractRNG, classifier::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; - kwargs... + kwargs..., ) samples = reshape(x, size(x, 1), :) chain_inds = repeat(axes(x, 3); inner=size(x, 2)) @@ -134,11 +134,7 @@ function rstar( return rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...) end -function rstar( - classif::MLJModelInterface.Supervised, - x::AbstractArray{<:Any,3}; - kwargs..., -) +function rstar(classif::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; kwargs...) return rstar(Random.GLOBAL_RNG, classif, x; kwargs...) end diff --git a/test/gewekediag.jl b/test/gewekediag.jl index f877ccde..5cd8f211 100644 --- a/test/gewekediag.jl +++ b/test/gewekediag.jl @@ -3,7 +3,7 @@ @testset "results" begin @test @inferred(gewekediag(samples)) isa - NamedTuple{(:zscore, :pvalue),Tuple{Float64,Float64}} + NamedTuple{(:zscore, :pvalue),Tuple{Float64,Float64}} end @testset "exceptions" begin diff --git a/test/mcse.jl b/test/mcse.jl index 42037e89..9b3f4941 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -12,7 +12,7 @@ end end - @testset "results 3d" begin + @testset "results 3d" begin nparams = 2 nchains = 4 samples = randn(nparams, 100, nchains) From 06b4d9f2d48c685129216378922d073d640c932b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 19 Nov 2022 22:54:39 +0100 Subject: [PATCH 15/45] Update seed Necessary because chain_inds are not identical to those in the previous example (now repeating with inner instead of outer) --- src/rstar.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rstar.jl b/src/rstar.jl index da69bd5a..ef2497f7 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -31,7 +31,7 @@ is returned (algorithm 2). # Examples -```jldoctest rstar; setup = :(using Random; Random.seed!(100)) +```jldoctest rstar; setup = :(using Random; Random.seed!(101)) julia> using MLJBase, MLJXGBoostInterface, Statistics julia> samples = fill(4.0, 2, 100, 3); From 0917971d29cc1504633a4e3b0332d63b0eba42a0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 21 Nov 2022 16:45:17 +0100 Subject: [PATCH 16/45] Add Compat as dependency --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 54afb63f..faf11245 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.2.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -17,6 +18,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] AbstractFFTs = "0.5, 1" +Compat = "3, 4" DataAPI = "1.6" Distributions = "0.25" MLJModelInterface = "1.6" From 3230e0861ed9aec84a82b1d4d153a6a9ebfbdd89 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 21 Nov 2022 16:47:47 +0100 Subject: [PATCH 17/45] Use eachslice --- src/MCMCDiagnosticTools.jl | 1 + src/ess.jl | 2 +- src/gelmandiag.jl | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index 6aca4543..7e37082c 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -1,6 +1,7 @@ module MCMCDiagnosticTools using AbstractFFTs: AbstractFFTs +using Compat: eachslice using DataAPI: DataAPI using Distributions: Distributions using MLJModelInterface: MLJModelInterface diff --git a/src/ess.jl b/src/ess.jl index 3edb667a..cabc1f99 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -238,7 +238,7 @@ function ess_rhat( rhat = Vector{T}(undef, nparams) # for each parameter - for (i, chains_slice) in enumerate((selectdim(chains, 1, i) for i in axes(chains, 1))) + for (i, chains_slice) in enumerate(eachslice(chains; dims=1)) # check that no values are missing if any(x -> x === missing, chains_slice) rhat[i] = missing diff --git a/src/gelmandiag.jl b/src/gelmandiag.jl index 31a13607..6a82708f 100644 --- a/src/gelmandiag.jl +++ b/src/gelmandiag.jl @@ -5,7 +5,7 @@ function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05) rfixed = (niters - 1) / niters rrandomscale = (nchains + 1) / (nchains * niters) - S2 = map(x -> Statistics.cov(x; dims=2), (view(psi, :, :, i) for i in axes(psi, 3))) + S2 = map(x -> Statistics.cov(x; dims=2), eachslice(psi; dims=3)) W = Statistics.mean(S2) psibar = dropdims(Statistics.mean(psi; dims=2); dims=2)' From 61653f91883e106a94389d3c98ce9880af1eb01f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 21 Nov 2022 17:11:19 +0100 Subject: [PATCH 18/45] Replace mapslices with an explicit loop --- src/mcse.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/mcse.jl b/src/mcse.jl index f092345f..1305c99c 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -26,7 +26,14 @@ function mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...) end end function mcse(x::AbstractArray{<:Real,3}; kwargs...) - return dropdims(mapslices(xi -> mcse(vec(xi); kwargs...), x; dims=(2, 3)); dims=(2, 3)) + T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1)) + # allocate container for type-stability and for dimensional x to have dimension output + values = similar(view(x, :, 1, 1), T) + axes(values, 1) == axes(x, 1) || @error "First axis of input and output containers do not match: $(axes(x, 1)) != $(axes(values, 1))" + for i in axes(x, 1) + values[i] = mcse(vec(view(x, i, :, :)); kwargs...) + end + return values end function mcse_bm(x::AbstractVector{<:Real}; size::Int=floor(Int, sqrt(length(x)))) From 1d0a0862e6938bb3cefd55d3e98df2483763c4eb Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 21 Nov 2022 17:12:11 +0100 Subject: [PATCH 19/45] Run formatter --- src/mcse.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mcse.jl b/src/mcse.jl index 1305c99c..5b02192f 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -29,7 +29,8 @@ function mcse(x::AbstractArray{<:Real,3}; kwargs...) T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1)) # allocate container for type-stability and for dimensional x to have dimension output values = similar(view(x, :, 1, 1), T) - axes(values, 1) == axes(x, 1) || @error "First axis of input and output containers do not match: $(axes(x, 1)) != $(axes(values, 1))" + axes(values, 1) == axes(x, 1) || + @error "First axis of input and output containers do not match: $(axes(x, 1)) != $(axes(values, 1))" for i in axes(x, 1) values[i] = mcse(vec(view(x, i, :, :)); kwargs...) end From f1e05db7be1ece739a6e6698f4ae1f43684a984b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 21 Nov 2022 17:16:10 +0100 Subject: [PATCH 20/45] Avoid explicit axis check --- src/mcse.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index 5b02192f..37c47cc8 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -29,10 +29,8 @@ function mcse(x::AbstractArray{<:Real,3}; kwargs...) T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1)) # allocate container for type-stability and for dimensional x to have dimension output values = similar(view(x, :, 1, 1), T) - axes(values, 1) == axes(x, 1) || - @error "First axis of input and output containers do not match: $(axes(x, 1)) != $(axes(values, 1))" - for i in axes(x, 1) - values[i] = mcse(vec(view(x, i, :, :)); kwargs...) + for (i, xi) in zip(eachindex(values), eachslice(x; dims=1)) + values[i] = mcse(vec(xi); kwargs...) end return values end From e7bd1240e93fc09f139fdd0fe1528f9666315120 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 21 Nov 2022 20:07:35 +0100 Subject: [PATCH 21/45] Remove type-instability --- src/gelmandiag.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gelmandiag.jl b/src/gelmandiag.jl index 6a82708f..31a13607 100644 --- a/src/gelmandiag.jl +++ b/src/gelmandiag.jl @@ -5,7 +5,7 @@ function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05) rfixed = (niters - 1) / niters rrandomscale = (nchains + 1) / (nchains * niters) - S2 = map(x -> Statistics.cov(x; dims=2), eachslice(psi; dims=3)) + S2 = map(x -> Statistics.cov(x; dims=2), (view(psi, :, :, i) for i in axes(psi, 3))) W = Statistics.mean(S2) psibar = dropdims(Statistics.mean(psi; dims=2); dims=2)' From 194104464fb3cd0203605b35c483e981b45be7e7 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 21 Nov 2022 20:56:21 +0100 Subject: [PATCH 22/45] Accept any table to rstar --- Project.toml | 2 +- src/rstar.jl | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index faf11245..61f7764b 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ Distributions = "0.25" MLJModelInterface = "1.6" SpecialFunctions = "0.8, 0.9, 0.10, 1, 2" StatsBase = "0.33" -Tables = "1" +Tables = "1.9" julia = "1" [extras] diff --git a/src/rstar.jl b/src/rstar.jl index ef2497f7..f5819849 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -67,13 +67,13 @@ Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic function rstar( rng::Random.AbstractRNG, classifier::MLJModelInterface.Supervised, - x::AbstractMatrix, + x, y::AbstractVector{Int}; subset::Real=0.8, verbosity::Int=0, ) # checks - size(x, 2) != length(y) && throw(DimensionMismatch()) + MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch()) 0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)")) # randomly sub-select training and testing set @@ -87,12 +87,13 @@ function rstar( # train classifier on training data ycategorical = MLJModelInterface.categorical(y) + xtrain = Tables.subset(x, train_ids; viewhint=true) fitresult, _ = MLJModelInterface.fit( - classifier, verbosity, Tables.table(x[:, train_ids]'), ycategorical[train_ids] + classifier, verbosity, xtrain, ycategorical[train_ids] ) # compute predictions on test data - xtest = Tables.table(x[:, test_ids]') + xtest = Tables.subset(x, test_ids; viewhint=true) predictions = _predict(classifier, fitresult, xtest) # compute statistic @@ -120,9 +121,9 @@ function rstar( x::AbstractArray{<:Any,3}; kwargs..., ) - samples = reshape(x, size(x, 1), :) + table = Tables.table(reshape(x, size(x, 1), :)') chain_inds = repeat(axes(x, 3); inner=size(x, 2)) - return rstar(rng, classifier, samples, chain_inds; kwargs...) + return rstar(rng, classifier, table, chain_inds; kwargs...) end function rstar( From 1dbc4f2a3eadd4b5a089d73ed2a97d5e8c665374 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 21 Nov 2022 21:07:20 +0100 Subject: [PATCH 23/45] Update rstar documentation --- src/rstar.jl | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/rstar.jl b/src/rstar.jl index f5819849..fd8ef47a 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -1,17 +1,15 @@ """ rstar( - rng=Random.GLOBAL_RNG, - classifier, - samples::AbstractArray, - [chain_indices::AbstractVector{Int}]; + rng::Random.AbstractRNG=Random.GLOBAL_RNG, + classifier::MLJModelInterface.Supervised, + samples::AbstractArray{<:Real,3}; subset::Real=0.8, verbosity::Int=0, ) Compute the ``R^*`` convergence statistic of the `samples` with the `classifier`. -Either `samples` has shape `(parameters, draws, chains)`, or `samples` has shape -`(parameters, draws)` and `chain_indices` must be provided. +`samples` is an array of draws with the shape `(parameters, draws, chains)`.` This implementation is an adaption of algorithms 1 and 2 described by Lambert and Vehtari. @@ -63,7 +61,27 @@ true # References Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic with uncertainty using decision tree classifiers. + + + rstar( + rng::Random.AbstractRNG=Random.GLOBAL_RNG, + classifier::MLJModelInterface.Supervised, + samples, + chain_indices::AbstractVector{Int}; + subset::Real=0.8, + verbosity::Int=0, + ) + +Compute the ``R^*`` convergence statistic of the table `samples` with the `classifier`. + +`samples` must be a table (i.e. implements the Tables.jl interface) whose columns correspond +to parameters and whose rows correspond to ordered draws. `chain_indices` indicates the +chain ids of each row of `samples`. + +This method supports ragged chains, i.e. chains of nonequal lengths. """ +function rstar end + function rstar( rng::Random.AbstractRNG, classifier::MLJModelInterface.Supervised, From 87d5f884c1735d77b800bbbfab67d7319a9dd475 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 21 Nov 2022 21:40:29 +0100 Subject: [PATCH 24/45] Release type constraint --- src/rstar.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rstar.jl b/src/rstar.jl index fd8ef47a..85ce41ec 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -146,7 +146,7 @@ end function rstar( classif::MLJModelInterface.Supervised, - x::AbstractMatrix, + x, y::AbstractVector{Int}; kwargs..., ) From 78d256b06838c16ada27c3ddae019df8a666a5e5 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 21 Nov 2022 22:01:32 +0100 Subject: [PATCH 25/45] Support rstar taking matrices or vectors --- src/rstar.jl | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/rstar.jl b/src/rstar.jl index 85ce41ec..d0420e9f 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -74,9 +74,12 @@ Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic Compute the ``R^*`` convergence statistic of the table `samples` with the `classifier`. -`samples` must be a table (i.e. implements the Tables.jl interface) whose columns correspond -to parameters and whose rows correspond to ordered draws. `chain_indices` indicates the -chain ids of each row of `samples`. +`samples` must be either an `AbstractMatrix`, an `AbstractVector`, or a table +(i.e. implements the Tables.jl interface) with shape `(draws, parameters)`. Note that these +dimensions are swapped with respect to the first 2 dimensions in the method with +3-dimensional array input. + +`chain_indices` indicates the chain ids of each row of `samples`. This method supports ragged chains, i.e. chains of nonequal lengths. """ @@ -105,14 +108,14 @@ function rstar( # train classifier on training data ycategorical = MLJModelInterface.categorical(y) - xtrain = Tables.subset(x, train_ids; viewhint=true) + xtrain = MLJModelInterface.selectrows(x, train_ids) fitresult, _ = MLJModelInterface.fit( - classifier, verbosity, xtrain, ycategorical[train_ids] + classifier, verbosity, _astable(xtrain), ycategorical[train_ids] ) # compute predictions on test data - xtest = Tables.subset(x, test_ids; viewhint=true) - predictions = _predict(classifier, fitresult, xtest) + xtest = MLJModelInterface.selectrows(x, test_ids) + predictions = _predict(classifier, fitresult, _astable(xtest)) # compute statistic ytest = ycategorical[test_ids] @@ -121,6 +124,9 @@ function rstar( return result end +_astable(x::AbstractVecOrMat) = Tables.table(x) +_astable(x) = Tables.istable(x) ? x : throw(ArgumentError("Argument is not a valid table")) + # Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863 # `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information # TODO: Remove once the upstream issue is fixed @@ -139,17 +145,12 @@ function rstar( x::AbstractArray{<:Any,3}; kwargs..., ) - table = Tables.table(reshape(x, size(x, 1), :)') + samples = transpose(reshape(x, size(x, 1), :)) chain_inds = repeat(axes(x, 3); inner=size(x, 2)) - return rstar(rng, classifier, table, chain_inds; kwargs...) + return rstar(rng, classifier, samples, chain_inds; kwargs...) end -function rstar( - classif::MLJModelInterface.Supervised, - x, - y::AbstractVector{Int}; - kwargs..., -) +function rstar(classif::MLJModelInterface.Supervised, x, y::AbstractVector{Int}; kwargs...) return rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...) end From aa01a1e4a0bb9a0c108e48013ee2443ebb81142c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 21 Nov 2022 22:01:55 +0100 Subject: [PATCH 26/45] Update rstar tests --- test/rstar/Project.toml | 2 + test/rstar/runtests.jl | 122 +++++++++++++++++++++------------------- 2 files changed, 67 insertions(+), 57 deletions(-) diff --git a/test/rstar/Project.toml b/test/rstar/Project.toml index 24c04fb3..e0b6a797 100644 --- a/test/rstar/Project.toml +++ b/test/rstar/Project.toml @@ -5,6 +5,7 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] @@ -13,4 +14,5 @@ MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" MLJLIBSVMInterface = "0.1, 0.2" MLJXGBoostInterface = "0.1, 0.2" +Tables = "1" julia = "1.3" diff --git a/test/rstar/runtests.jl b/test/rstar/runtests.jl index e91beda4..8c53bfb2 100644 --- a/test/rstar/runtests.jl +++ b/test/rstar/runtests.jl @@ -4,6 +4,7 @@ using Distributions using MLJBase using MLJLIBSVMInterface using MLJXGBoostInterface +using Tables using Random using Test @@ -14,81 +15,88 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo classifiers = (XGBoostClassifier(), xgboost_deterministic, SVC()) N = 1_000 - @testset "examples (classifier = $classifier)" for classifier in classifiers - # Compute R⋆ statistic for a mixed chain. - samples = randn(2, N) - dist = rstar(classifier, samples, rand(1:3, N)) + @testset "samples input type: $wrapper" for wrapper in [Vector, Array, Tables.table] + @testset "examples (classifier = $classifier)" for classifier in classifiers + sz = wrapper === Vector ? N : (N, 2) + # Compute R⋆ statistic for a mixed chain. + samples = wrapper(randn(sz...)) + dist = rstar(classifier, samples, rand(1:3, N)) - # Mean of the statistic should be focused around 1, i.e., the classifier does not - # perform better than random guessing. - if classifier isa MLJBase.Deterministic - @test dist isa Float64 - else - @test dist isa LocationScale - @test dist.ρ isa PoissonBinomial - @test minimum(dist) == 0 - @test maximum(dist) == 3 - end - @test mean(dist) ≈ 1 rtol = 0.2 + # Mean of the statistic should be focused around 1, i.e., the classifier does not + # perform better than random guessing. + if classifier isa MLJBase.Deterministic + @test dist isa Float64 + else + @test dist isa LocationScale + @test dist.ρ isa PoissonBinomial + @test minimum(dist) == 0 + @test maximum(dist) == 3 + end + @test mean(dist) ≈ 1 rtol = 0.2 + wrapper === Vector && break - # Compute R⋆ statistic for a mixed chain. - samples = randn(8, 4 * N) - chain_indices = repeat(1:4, N) - dist = rstar(classifier, samples, chain_indices) + # Compute R⋆ statistic for a mixed chain. + samples = wrapper(randn(4 * N, 8)) + chain_indices = repeat(1:4, N) + dist = rstar(classifier, samples, chain_indices) - # Mean of the statistic should be closte to 1, i.e., the classifier does not perform - # better than random guessing. - if classifier isa MLJBase.Deterministic - @test dist isa Float64 - else - @test dist isa LocationScale - @test dist.ρ isa PoissonBinomial - @test minimum(dist) == 0 - @test maximum(dist) == 4 - end - @test mean(dist) ≈ 1 rtol = 0.15 + # Mean of the statistic should be closte to 1, i.e., the classifier does not perform + # better than random guessing. + if classifier isa MLJBase.Deterministic + @test dist isa Float64 + else + @test dist isa LocationScale + @test dist.ρ isa PoissonBinomial + @test minimum(dist) == 0 + @test maximum(dist) == 4 + end + @test mean(dist) ≈ 1 rtol = 0.15 - # Compute the R⋆ statistic for a non-mixed chain. - samples = permutedims([ - sin.(1:N) cos.(1:N) - 100 .* cos.(1:N) 100 .* sin.(1:N) - ]) - chain_indices = repeat(1:2; inner=N) - dist = rstar(classifier, samples, chain_indices) + # Compute the R⋆ statistic for a non-mixed chain. + samples = wrapper([ + sin.(1:N) cos.(1:N) + 100 .* cos.(1:N) 100 .* sin.(1:N) + ]) + chain_indices = repeat(1:2; inner=N) + dist = rstar(classifier, samples, chain_indices) - # Mean of the statistic should be close to 2, i.e., the classifier should be able to - # learn an almost perfect decision boundary between chains. - if classifier isa MLJBase.Deterministic - @test dist isa Float64 - else - @test dist isa LocationScale - @test dist.ρ isa PoissonBinomial - @test minimum(dist) == 0 - @test maximum(dist) == 2 + # Mean of the statistic should be close to 2, i.e., the classifier should be able to + # learn an almost perfect decision boundary between chains. + if classifier isa MLJBase.Deterministic + @test dist isa Float64 + else + @test dist isa LocationScale + @test dist.ρ isa PoissonBinomial + @test minimum(dist) == 0 + @test maximum(dist) == 2 + end + @test mean(dist) ≈ 2 rtol = 0.15 end - @test mean(dist) ≈ 2 rtol = 0.15 - end + wrapper === Vector && continue - @testset "exceptions (classifier = $classifier)" for classifier in classifiers - @test_throws DimensionMismatch rstar(classifier, randn(2, N - 1), rand(1:3, N)) - for subset in (-0.3, 0, 1 / (3 * N), 1 - 1 / (3 * N), 1, 1.9) - @test_throws ArgumentError rstar( - classifier, randn(2, N), rand(1:3, N); subset=subset - ) + @testset "exceptions (classifier = $classifier)" for classifier in classifiers + samples = wrapper(randn(N - 1, 2)) + @test_throws DimensionMismatch rstar(classifier, samples, rand(1:3, N)) + for subset in (-0.3, 0, 1 / (3 * N), 1 - 1 / (3 * N), 1, 1.9) + samples = wrapper(randn(N, 2)) + @test_throws ArgumentError rstar( + classifier, samples, rand(1:3, N); subset=subset + ) + end end end - @testset "matrix with chain_inds produces same result as 3d array" begin + @testset "table with chain_ids produces same result as 3d array" begin nparams = 2 nchains = 3 samples = randn(nparams, N, nchains) # manually construct samples_mat and chain_inds for comparison - samples_mat = Matrix{Float64}(undef, nparams, N * nchains) + samples_mat = Matrix{Float64}(undef, N * nchains, nparams) chain_inds = Vector{Int}(undef, N * nchains) i = 1 for chain in 1:nchains, draw in 1:N - samples_mat[:, i] = samples[:, draw, chain] + samples_mat[i, :] = samples[:, draw, chain] chain_inds[i] = chain i += 1 end From e04fd4601bb51d724d6162e4ba3b8aa616a51e78 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 21 Nov 2022 22:29:10 +0100 Subject: [PATCH 27/45] Add type consistency check --- test/rstar/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/rstar/runtests.jl b/test/rstar/runtests.jl index 8c53bfb2..f12fc11a 100644 --- a/test/rstar/runtests.jl +++ b/test/rstar/runtests.jl @@ -107,6 +107,7 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo Random.seed!(rng, 42) dist2 = rstar(rng, classifier, samples) @test dist1 == dist2 + @test typeof(rstar(classifier, samples)) === typeof(dist2) end end end From 1a561f0fdf08d85777b9d2f273699ee4be25dec7 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 23 Nov 2022 12:58:14 +0100 Subject: [PATCH 28/45] Don't permutedims in discretediag --- src/discretediag.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/discretediag.jl b/src/discretediag.jl index 7b5c7434..b8523e11 100644 --- a/src/discretediag.jl +++ b/src/discretediag.jl @@ -372,7 +372,7 @@ function discretediag_sub( start_iter::Int, step_size::Int, ) - num_iters, num_vars, num_chains = size(c) + num_vars, num_iters, num_chains = size(c) ## Between-chain diagnostic length_results = length(start_iter:step_size:num_iters) @@ -384,7 +384,7 @@ function discretediag_sub( pvalue=Vector{Float64}(undef, num_vars), ) for j in 1:num_vars - X = convert(AbstractMatrix{Int}, c[:, j, :]) + X = convert(AbstractMatrix{Int}, c[j, :, :]) result = diag_all(X, method, nsim, start_iter, step_size) plot_vals_stat[:, j] .= result.stat ./ result.df @@ -403,7 +403,7 @@ function discretediag_sub( ) for k in 1:num_chains for j in 1:num_vars - x = convert(AbstractVector{Int}, c[:, j, k]) + x = convert(AbstractVector{Int}, c[j, :, k]) idx1 = 1:round(Int, frac * num_iters) idx2 = round(Int, num_iters - frac * num_iters + 1):num_iters @@ -432,7 +432,7 @@ Compute discrete diagnostic on `samples` with shape `(parameters, draws, chains) # References -Benjamin E. Deonovic, & Brian J. Smith. (2017). Convergence diagnostics for MCMC draws of a categorical variable. +Benjamin E. Deonovic, & Brian J. Smith. (2017). Convergence diagnostics for MCMC draws of a categorical variable. """ function discretediag( chains::AbstractArray{<:Real,3}; frac::Real=0.3, method::Symbol=:weiss, nsim::Int=1000 @@ -445,7 +445,7 @@ function discretediag( num_iters = size(chains, 2) between_chain_vals, within_chain_vals, _, _ = discretediag_sub( - permutedims(chains, (2, 1, 3)), frac, method, nsim, num_iters, num_iters + chains, frac, method, nsim, num_iters, num_iters ) return between_chain_vals, within_chain_vals From e6ce9b249e775a977c632abd8ece49430cf7e4a2 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 25 Nov 2022 11:55:18 +0100 Subject: [PATCH 29/45] Revert all changes to mcse --- src/mcse.jl | 15 ++------------- test/mcse.jl | 24 +++--------------------- 2 files changed, 5 insertions(+), 34 deletions(-) diff --git a/src/mcse.jl b/src/mcse.jl index 37c47cc8..063ab8b4 100644 --- a/src/mcse.jl +++ b/src/mcse.jl @@ -1,9 +1,7 @@ """ - mcse(samples::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...) - mcse(samples::AbstractArray{<:Real,3}; method::Symbol=:imse, kwargs...) + mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...) -Compute the Monte Carlo standard error (MCSE) of `samples` of shape `(draws,)` or -`(parameters, draws, chains)` +Compute the Monte Carlo standard error (MCSE) of samples `x`. The optional argument `method` describes how the errors are estimated. Possible options are: - `:bm` for batch means [^Glynn1991] @@ -25,15 +23,6 @@ function mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...) throw(ArgumentError("unsupported MCSE method $method")) end end -function mcse(x::AbstractArray{<:Real,3}; kwargs...) - T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1)) - # allocate container for type-stability and for dimensional x to have dimension output - values = similar(view(x, :, 1, 1), T) - for (i, xi) in zip(eachindex(values), eachslice(x; dims=1)) - values[i] = mcse(vec(xi); kwargs...) - end - return values -end function mcse_bm(x::AbstractVector{<:Real}; size::Int=floor(Int, sqrt(length(x)))) n = length(x) diff --git a/test/mcse.jl b/test/mcse.jl index 9b3f4941..3e54d447 100644 --- a/test/mcse.jl +++ b/test/mcse.jl @@ -1,6 +1,7 @@ @testset "mcse.jl" begin - @testset "results 1d" begin - samples = randn(100) + samples = randn(100) + + @testset "results" begin result = @inferred(mcse(samples)) @test result isa Float64 @test result > 0 @@ -12,32 +13,13 @@ end end - @testset "results 3d" begin - nparams = 2 - nchains = 4 - samples = randn(nparams, 100, nchains) - result = mcse(samples) # mapslices is not type-inferrable - @test result isa Vector{Float64} - @test length(result) == nparams - @test all(r -> r > 0, result) - - for method in (:imse, :ipse, :bm) - result = mcse(samples) # mapslices is not type-inferrable - @test result isa Vector{Float64} - @test length(result) == nparams - @test all(r -> r > 0, result) - end - end - @testset "warning" begin - samples = randn(100) for size in (51, 75, 100, 153) @test_logs (:warn,) mcse(samples; method=:bm, size=size) end end @testset "exception" begin - samples = randn(100) @test_throws ArgumentError mcse(samples; method=:somemethod) end end From d1924cf7b237efb4e22d3a0244395dbce40b7c5c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 1 Dec 2022 20:29:28 +0100 Subject: [PATCH 30/45] Reorder dimensions to (draw, chain, params) --- src/discretediag.jl | 10 +++++----- src/ess.jl | 10 +++++----- src/gelmandiag.jl | 12 ++++++------ src/rstar.jl | 12 +++++------- test/discretediag.jl | 3 ++- test/ess.jl | 10 +++++----- test/gelmandiag.jl | 7 ++++--- test/rstar/runtests.jl | 5 ++--- 8 files changed, 34 insertions(+), 35 deletions(-) diff --git a/src/discretediag.jl b/src/discretediag.jl index b8523e11..72e9ca6e 100644 --- a/src/discretediag.jl +++ b/src/discretediag.jl @@ -372,7 +372,7 @@ function discretediag_sub( start_iter::Int, step_size::Int, ) - num_vars, num_iters, num_chains = size(c) + num_iters, num_chains, num_vars = size(c) ## Between-chain diagnostic length_results = length(start_iter:step_size:num_iters) @@ -384,7 +384,7 @@ function discretediag_sub( pvalue=Vector{Float64}(undef, num_vars), ) for j in 1:num_vars - X = convert(AbstractMatrix{Int}, c[j, :, :]) + X = convert(AbstractMatrix{Int}, c[:, :, j]) result = diag_all(X, method, nsim, start_iter, step_size) plot_vals_stat[:, j] .= result.stat ./ result.df @@ -403,7 +403,7 @@ function discretediag_sub( ) for k in 1:num_chains for j in 1:num_vars - x = convert(AbstractVector{Int}, c[j, :, k]) + x = convert(AbstractVector{Int}, c[:, k, j]) idx1 = 1:round(Int, frac * num_iters) idx2 = round(Int, num_iters - frac * num_iters + 1):num_iters @@ -425,7 +425,7 @@ end """ discretediag(samples::AbstractArray{<:Real,3}; frac=0.3, method=:weiss, nsim=1_000) -Compute discrete diagnostic on `samples` with shape `(parameters, draws, chains)`. +Compute discrete diagnostic on `samples` with shape `(draws, chains, parameters)`. `method` can be one of `:weiss`, `:hangartner`, `:DARBOOT`, `:MCBOOT`, `:billinsgley`, and `:billingsleyBOOT`. @@ -443,7 +443,7 @@ function discretediag( ) 0 < frac < 1 || throw(ArgumentError("`frac` must be in (0,1)")) - num_iters = size(chains, 2) + num_iters = size(chains, 1) between_chain_vals, within_chain_vals, _, _ = discretediag_sub( chains, frac, method, nsim, num_iters, num_iters ) diff --git a/src/ess.jl b/src/ess.jl index cabc1f99..d385ad8d 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -201,7 +201,7 @@ end ) Estimate the effective sample size and the potential scale reduction of the `samples` of -shape `(parameters, draws, chains)` with the `method` and a maximum lag of `maxlag`. +shape `(draws, chains, parameters)` with the `method` and a maximum lag of `maxlag`. See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref) """ @@ -211,9 +211,9 @@ function ess_rhat( maxlag::Int=250, ) # compute size of matrices (each chain is split!) - niter = size(chains, 2) ÷ 2 - nparams = size(chains, 1) - nchains = 2 * size(chains, 3) + niter = size(chains, 1) ÷ 2 + nparams = size(chains, 3) + nchains = 2 * size(chains, 2) ntotal = niter * nchains # do not compute estimates if there is only one sample or lag @@ -238,7 +238,7 @@ function ess_rhat( rhat = Vector{T}(undef, nparams) # for each parameter - for (i, chains_slice) in enumerate(eachslice(chains; dims=1)) + for (i, chains_slice) in enumerate(eachslice(chains; dims=3)) # check that no values are missing if any(x -> x === missing, chains_slice) rhat[i] = missing diff --git a/src/gelmandiag.jl b/src/gelmandiag.jl index 31a13607..ce11083d 100644 --- a/src/gelmandiag.jl +++ b/src/gelmandiag.jl @@ -1,14 +1,14 @@ function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05) - nparams, niters, nchains = size(psi) + niters, nchains, nparams = size(psi) nchains > 1 || error("Gelman diagnostic requires at least 2 chains") rfixed = (niters - 1) / niters rrandomscale = (nchains + 1) / (nchains * niters) - S2 = map(x -> Statistics.cov(x; dims=2), (view(psi, :, :, i) for i in axes(psi, 3))) + S2 = map(x -> Statistics.cov(x; dims=1), (view(psi, :, i, :) for i in axes(psi, 2))) W = Statistics.mean(S2) - psibar = dropdims(Statistics.mean(psi; dims=2); dims=2)' + psibar = dropdims(Statistics.mean(psi; dims=1); dims=1) B = niters .* Statistics.cov(psibar) w = LinearAlgebra.diag(W) @@ -55,7 +55,7 @@ end gelmandiag(samples::AbstractArray{<:Real,3}; alpha::Real=0.95) Compute the Gelman, Rubin and Brooks diagnostics [^Gelman1992] [^Brooks1998] on `samples` -with shape `(parameters, draws, chains)`. Values of the +with shape `(draws, chains, parameters)`. Values of the diagnostic’s potential scale reduction factor (PSRF) that are close to one suggest convergence. As a rule-of-thumb, convergence is rejected if the 97.5 percentile of a PSRF is greater than 1.2. @@ -74,10 +74,10 @@ end gelmandiag_multivariate(samples::AbstractArray{<:Real,3}; alpha::Real=0.05) Compute the multivariate Gelman, Rubin and Brooks diagnostics on `samples` with shape -`(parameters, draws, chains)`. +`(draws, chains, parameters)`. """ function gelmandiag_multivariate(chains::AbstractArray{<:Real,3}; kwargs...) - nparams, niters, nchains = size(chains) + niters, nchains, nparams = size(chains) if nparams < 2 error( "computation of the multivariate potential scale reduction factor requires ", diff --git a/src/rstar.jl b/src/rstar.jl index d0420e9f..b10a3e27 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -9,7 +9,7 @@ Compute the ``R^*`` convergence statistic of the `samples` with the `classifier`. -`samples` is an array of draws with the shape `(parameters, draws, chains)`.` +`samples` is an array of draws with the shape `(draws, chains, parameters)`.` This implementation is an adaption of algorithms 1 and 2 described by Lambert and Vehtari. @@ -32,7 +32,7 @@ is returned (algorithm 2). ```jldoctest rstar; setup = :(using Random; Random.seed!(101)) julia> using MLJBase, MLJXGBoostInterface, Statistics -julia> samples = fill(4.0, 2, 100, 3); +julia> samples = fill(4.0, 100, 3, 2); ``` One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the @@ -75,9 +75,7 @@ Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic Compute the ``R^*`` convergence statistic of the table `samples` with the `classifier`. `samples` must be either an `AbstractMatrix`, an `AbstractVector`, or a table -(i.e. implements the Tables.jl interface) with shape `(draws, parameters)`. Note that these -dimensions are swapped with respect to the first 2 dimensions in the method with -3-dimensional array input. +(i.e. implements the Tables.jl interface) with shape `(draws, parameters)`. `chain_indices` indicates the chain ids of each row of `samples`. @@ -145,8 +143,8 @@ function rstar( x::AbstractArray{<:Any,3}; kwargs..., ) - samples = transpose(reshape(x, size(x, 1), :)) - chain_inds = repeat(axes(x, 3); inner=size(x, 2)) + samples = reshape(x, :, size(x, 3)) + chain_inds = repeat(axes(x, 2); inner=size(x, 1)) return rstar(rng, classifier, samples, chain_inds; kwargs...) end diff --git a/test/discretediag.jl b/test/discretediag.jl index 24d2f509..2796aa96 100644 --- a/test/discretediag.jl +++ b/test/discretediag.jl @@ -1,7 +1,8 @@ @testset "discretediag.jl" begin nparams = 4 + ndraws = 100 nchains = 2 - samples = rand(-100:100, nparams, 100, nchains) + samples = rand(-100:100, ndraws, nchains, nparams) @testset "results" begin for method in diff --git a/test/ess.jl b/test/ess.jl index 8b14309c..c58c00b7 100644 --- a/test/ess.jl +++ b/test/ess.jl @@ -31,7 +31,7 @@ end @testset "ESS and R̂ (IID samples)" begin - rawx = randn(40, 10_000, 10) + rawx = randn(10_000, 10, 40) # Repeat tests with different scales for scale in (1, 50, 100) @@ -58,7 +58,7 @@ end @testset "ESS and R̂ (identical samples)" begin - x = ones(40, 10_000, 10) + x = ones(10_000, 10, 40) ess_standard, rhat_standard = ess_rhat(x) ess_standard2, rhat_standard2 = ess_rhat(x; method=ESSMethod()) @@ -75,15 +75,15 @@ end @testset "ESS and R̂ (single sample)" begin # check that issue #137 is fixed - x = rand(5, 1, 3) + x = rand(1, 3, 5) for method in (ESSMethod(), FFTESSMethod(), BDAESSMethod()) # analyze array ess_array, rhat_array = ess_rhat(x; method=method) - @test length(ess_array) == size(x, 1) + @test length(ess_array) == size(x, 3) @test all(ismissing, ess_array) # since min(maxlag, niter - 1) = 0 - @test length(rhat_array) == size(x, 1) + @test length(rhat_array) == size(x, 3) @test all(ismissing, rhat_array) end end diff --git a/test/gelmandiag.jl b/test/gelmandiag.jl index ab5b4ba0..23e1a3dd 100644 --- a/test/gelmandiag.jl +++ b/test/gelmandiag.jl @@ -1,7 +1,8 @@ @testset "gelmandiag.jl" begin nparams = 4 + ndraws = 100 nchains = 2 - samples = randn(nparams, 100, nchains) + samples = randn(ndraws, nchains, nparams) @testset "results" begin result = @inferred(gelmandiag(samples)) @@ -23,7 +24,7 @@ end @testset "exceptions" begin - @test_throws ErrorException gelmandiag(samples[:, :, 1:1]) - @test_throws ErrorException gelmandiag_multivariate(samples[1:1, :, :]) + @test_throws ErrorException gelmandiag(samples[:, 1:1, :]) + @test_throws ErrorException gelmandiag_multivariate(samples[:, :, 1:1]) end end diff --git a/test/rstar/runtests.jl b/test/rstar/runtests.jl index f12fc11a..928e50a4 100644 --- a/test/rstar/runtests.jl +++ b/test/rstar/runtests.jl @@ -89,14 +89,13 @@ const xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mo @testset "table with chain_ids produces same result as 3d array" begin nparams = 2 nchains = 3 - samples = randn(nparams, N, nchains) + samples = randn(N, nchains, nparams) # manually construct samples_mat and chain_inds for comparison - samples_mat = Matrix{Float64}(undef, N * nchains, nparams) + samples_mat = reshape(samples, N * nchains, nparams) chain_inds = Vector{Int}(undef, N * nchains) i = 1 for chain in 1:nchains, draw in 1:N - samples_mat[i, :] = samples[:, draw, chain] chain_inds[i] = chain i += 1 end From ea01e3fc0bb93700b070ee24a5eabd27dabcffb6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 19:44:50 +0100 Subject: [PATCH 31/45] Apply suggestions from code review Co-authored-by: David Widmann --- src/gelmandiag.jl | 1 + src/rstar.jl | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/gelmandiag.jl b/src/gelmandiag.jl index ce11083d..001ab282 100644 --- a/src/gelmandiag.jl +++ b/src/gelmandiag.jl @@ -5,6 +5,7 @@ function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05) rfixed = (niters - 1) / niters rrandomscale = (nchains + 1) / (nchains * niters) + # `eachslice(psi; dims=2)` breaks type inference S2 = map(x -> Statistics.cov(x; dims=1), (view(psi, :, i, :) for i in axes(psi, 2))) W = Statistics.mean(S2) diff --git a/src/rstar.jl b/src/rstar.jl index b10a3e27..8894c641 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -1,6 +1,6 @@ """ rstar( - rng::Random.AbstractRNG=Random.GLOBAL_RNG, + rng::Random.AbstractRNG=Random.default_rng(), classifier::MLJModelInterface.Supervised, samples::AbstractArray{<:Real,3}; subset::Real=0.8, @@ -64,7 +64,7 @@ Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic rstar( - rng::Random.AbstractRNG=Random.GLOBAL_RNG, + rng::Random.AbstractRNG=Random.default_rng(), classifier::MLJModelInterface.Supervised, samples, chain_indices::AbstractVector{Int}; @@ -149,11 +149,11 @@ function rstar( end function rstar(classif::MLJModelInterface.Supervised, x, y::AbstractVector{Int}; kwargs...) - return rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...) + return rstar(Random.default_rng(), classif, x, y; kwargs...) end function rstar(classif::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; kwargs...) - return rstar(Random.GLOBAL_RNG, classif, x; kwargs...) + return rstar(Random.default_rng(), classif, x; kwargs...) end # R⋆ for deterministic predictions (algorithm 1) From d7437a03a8d6685c4394b0438f2a0440596ee378 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 20:29:20 +0100 Subject: [PATCH 32/45] Split rstar docstring --- src/rstar.jl | 130 +++++++++++++++++++++++++-------------------------- 1 file changed, 64 insertions(+), 66 deletions(-) diff --git a/src/rstar.jl b/src/rstar.jl index 8894c641..c08b7298 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -1,68 +1,4 @@ """ - rstar( - rng::Random.AbstractRNG=Random.default_rng(), - classifier::MLJModelInterface.Supervised, - samples::AbstractArray{<:Real,3}; - subset::Real=0.8, - verbosity::Int=0, - ) - -Compute the ``R^*`` convergence statistic of the `samples` with the `classifier`. - -`samples` is an array of draws with the shape `(draws, chains, parameters)`.` - -This implementation is an adaption of algorithms 1 and 2 described by Lambert and Vehtari. - -The `classifier` has to be a supervised classifier of the MLJ framework (see the -[MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/list_of_supported_models/#model_list) -for a list of supported models). It is trained with a `subset` of the samples. The training -of the classifier can be inspected by adjusting the `verbosity` level. - -If the classifier is deterministic, i.e., if it predicts a class, the value of the ``R^*`` -statistic is returned (algorithm 1). If the classifier is probabilistic, i.e., if it outputs -probabilities of classes, the scaled Poisson-binomial distribution of the ``R^*`` statistic -is returned (algorithm 2). - -!!! note - The correctness of the statistic depends on the convergence of the `classifier` used - internally in the statistic. - -# Examples - -```jldoctest rstar; setup = :(using Random; Random.seed!(101)) -julia> using MLJBase, MLJXGBoostInterface, Statistics - -julia> samples = fill(4.0, 100, 3, 2); -``` - -One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the -probabilistic classifier. - -```jldoctest rstar -julia> distribution = rstar(XGBoostClassifier(), samples); - -julia> isapprox(mean(distribution), 1; atol=0.1) -true -``` - -For deterministic classifiers, a single ``R^*`` statistic (algorithm 1) is returned. -Deterministic classifiers can also be derived from probabilistic classifiers by e.g. -predicting the mode. In MLJ this corresponds to a pipeline of models. - -```jldoctest rstar -julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode); - -julia> value = rstar(xgboost_deterministic, samples); - -julia> isapprox(value, 1; atol=0.2) -true -``` - -# References - -Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic with uncertainty using decision tree classifiers. - - rstar( rng::Random.AbstractRNG=Random.default_rng(), classifier::MLJModelInterface.Supervised, @@ -81,8 +17,6 @@ Compute the ``R^*`` convergence statistic of the table `samples` with the `class This method supports ragged chains, i.e. chains of nonequal lengths. """ -function rstar end - function rstar( rng::Random.AbstractRNG, classifier::MLJModelInterface.Supervised, @@ -137,6 +71,70 @@ function _predict(model::MLJModelInterface.Model, fitresult, x) end end +""" + rstar( + rng::Random.AbstractRNG=Random.default_rng(), + classifier::MLJModelInterface.Supervised, + samples::AbstractArray{<:Real,3}; + subset::Real=0.8, + verbosity::Int=0, + ) + +Compute the ``R^*`` convergence statistic of the `samples` with the `classifier`. + +`samples` is an array of draws with the shape `(draws, chains, parameters)`.` + +This implementation is an adaption of algorithms 1 and 2 described by Lambert and Vehtari. + +The `classifier` has to be a supervised classifier of the MLJ framework (see the +[MLJ documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/list_of_supported_models/#model_list) +for a list of supported models). It is trained with a `subset` of the samples. The training +of the classifier can be inspected by adjusting the `verbosity` level. + +If the classifier is deterministic, i.e., if it predicts a class, the value of the ``R^*`` +statistic is returned (algorithm 1). If the classifier is probabilistic, i.e., if it outputs +probabilities of classes, the scaled Poisson-binomial distribution of the ``R^*`` statistic +is returned (algorithm 2). + +!!! note + The correctness of the statistic depends on the convergence of the `classifier` used + internally in the statistic. + +# Examples + +```jldoctest rstar; setup = :(using Random; Random.seed!(101)) +julia> using MLJBase, MLJXGBoostInterface, Statistics + +julia> samples = fill(4.0, 100, 3, 2); +``` + +One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the +probabilistic classifier. + +```jldoctest rstar +julia> distribution = rstar(XGBoostClassifier(), samples); + +julia> isapprox(mean(distribution), 1; atol=0.1) +true +``` + +For deterministic classifiers, a single ``R^*`` statistic (algorithm 1) is returned. +Deterministic classifiers can also be derived from probabilistic classifiers by e.g. +predicting the mode. In MLJ this corresponds to a pipeline of models. + +```jldoctest rstar +julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode); + +julia> value = rstar(xgboost_deterministic, samples); + +julia> isapprox(value, 1; atol=0.2) +true +``` + +# References + +Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic with uncertainty using decision tree classifiers. +""" function rstar( rng::Random.AbstractRNG, classifier::MLJModelInterface.Supervised, From ac5f2cb93394b35c46e3d8af4233b8cc459c179e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 20:34:57 +0100 Subject: [PATCH 33/45] Convert to table once --- src/rstar.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/rstar.jl b/src/rstar.jl index c08b7298..ed88de1f 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -38,16 +38,18 @@ function rstar( train_ids = view(ids, 1:Ntrain) test_ids = view(ids, (Ntrain + 1):N) + xtable = _astable(x) + # train classifier on training data ycategorical = MLJModelInterface.categorical(y) - xtrain = MLJModelInterface.selectrows(x, train_ids) + xtrain = MLJModelInterface.selectrows(xtable, train_ids) fitresult, _ = MLJModelInterface.fit( - classifier, verbosity, _astable(xtrain), ycategorical[train_ids] + classifier, verbosity, xtrain, ycategorical[train_ids] ) # compute predictions on test data xtest = MLJModelInterface.selectrows(x, test_ids) - predictions = _predict(classifier, fitresult, _astable(xtest)) + predictions = _predict(classifier, fitresult, xtest) # compute statistic ytest = ycategorical[test_ids] From be462589cde80748a59474aa3fe894a782b0277e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 20:35:03 +0100 Subject: [PATCH 34/45] Clean up language --- src/rstar.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/rstar.jl b/src/rstar.jl index ed88de1f..23ff41d0 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -11,7 +11,8 @@ Compute the ``R^*`` convergence statistic of the table `samples` with the `classifier`. `samples` must be either an `AbstractMatrix`, an `AbstractVector`, or a table -(i.e. implements the Tables.jl interface) with shape `(draws, parameters)`. +(i.e. implements the Tables.jl interface) whose rows are draws and whose columns are +parameters. `chain_indices` indicates the chain ids of each row of `samples`. From 586b8f1a48914302c66943d4981df3594bffd7b2 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 20:56:27 +0100 Subject: [PATCH 35/45] Use correct variable name --- src/rstar.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rstar.jl b/src/rstar.jl index 23ff41d0..7eb61c9b 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -49,7 +49,7 @@ function rstar( ) # compute predictions on test data - xtest = MLJModelInterface.selectrows(x, test_ids) + xtest = MLJModelInterface.selectrows(xtable, test_ids) predictions = _predict(classifier, fitresult, xtest) # compute statistic From 365267f9d09bb134d415791baff06e1f76d71d0a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 21:40:12 +0100 Subject: [PATCH 36/45] Allow Tables v1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 61f7764b..faf11245 100644 --- a/Project.toml +++ b/Project.toml @@ -24,7 +24,7 @@ Distributions = "0.25" MLJModelInterface = "1.6" SpecialFunctions = "0.8, 0.9, 0.10, 1, 2" StatsBase = "0.33" -Tables = "1.9" +Tables = "1" julia = "1" [extras] From 9438e34389cd9f0e1b9b0ab67f8fb5afcd178702 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 21:41:15 +0100 Subject: [PATCH 37/45] Bump Julia compat to v1.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index faf11245..ddb46428 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ MLJModelInterface = "1.6" SpecialFunctions = "0.8, 0.9, 0.10, 1, 2" StatsBase = "0.33" Tables = "1" -julia = "1" +julia = "1.3" [extras] Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" From a5e784ccec64f33e6486263fde70596825f1c5d2 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 21:42:02 +0100 Subject: [PATCH 38/45] Remove compat dependency --- Project.toml | 1 - src/MCMCDiagnosticTools.jl | 1 - 2 files changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index ddb46428..b5959ef1 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,6 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] AbstractFFTs = "0.5, 1" -Compat = "3, 4" DataAPI = "1.6" Distributions = "0.25" MLJModelInterface = "1.6" diff --git a/src/MCMCDiagnosticTools.jl b/src/MCMCDiagnosticTools.jl index 7e37082c..6aca4543 100644 --- a/src/MCMCDiagnosticTools.jl +++ b/src/MCMCDiagnosticTools.jl @@ -1,7 +1,6 @@ module MCMCDiagnosticTools using AbstractFFTs: AbstractFFTs -using Compat: eachslice using DataAPI: DataAPI using Distributions: Distributions using MLJModelInterface: MLJModelInterface From 76317b9ebc927ed6c9491870a472aabe4e7f3fd6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 21:43:41 +0100 Subject: [PATCH 39/45] Remove all special-casing for less than v1.2 --- test/runtests.jl | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 54c058a6..2aa9a047 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,12 +1,5 @@ using Pkg -# Activate test environment on older Julia versions -@static if VERSION < v"1.2" - Pkg.activate(@__DIR__) - Pkg.develop(PackageSpec(; path=dirname(@__DIR__))) - Pkg.instantiate() -end - using MCMCDiagnosticTools using FFTW @@ -43,16 +36,15 @@ Random.seed!(1) include("rafterydiag.jl") end @testset "R⋆ diagnostic" begin - # MLJXGBoostInterface requires Julia >= 1.3 # XGBoost errors on 32bit systems: https://github.com/dmlc/XGBoost.jl/issues/92 - if VERSION >= v"1.3" && Sys.WORD_SIZE == 64 + if Sys.WORD_SIZE == 64 # run tests related to rstar statistic Pkg.activate("rstar") Pkg.develop(; path=dirname(dirname(pathof(MCMCDiagnosticTools)))) Pkg.instantiate() include(joinpath("rstar", "runtests.jl")) else - @info "R⋆ not tested: requires Julia >= 1.3 and a 64bit architecture" + @info "R⋆ not tested: requires 64bit architecture" end end end From 859f65871aa26bb141af5fb9849c8fe49f31a2dc Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 21:44:24 +0100 Subject: [PATCH 40/45] Test on v1.3 --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 4867c046..b96ee980 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -14,7 +14,7 @@ jobs: fail-fast: false matrix: version: - - '1.0' + - '1.3' - '1' - 'nightly' os: From c48380c6aca45142239def311a207fb8ebe990d6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 22:05:11 +0100 Subject: [PATCH 41/45] Use PackageSpec for v1.3 --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 2aa9a047..8901fa6e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,7 +40,7 @@ Random.seed!(1) if Sys.WORD_SIZE == 64 # run tests related to rstar statistic Pkg.activate("rstar") - Pkg.develop(; path=dirname(dirname(pathof(MCMCDiagnosticTools)))) + Pkg.develop(PackageSpec(path=dirname(dirname(pathof(MCMCDiagnosticTools))))) Pkg.instantiate() include(joinpath("rstar", "runtests.jl")) else From 7666803f203a861fa12a7aee34525e1ded149ec8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 22:08:49 +0100 Subject: [PATCH 42/45] Merge test Project.tomls --- test/Project.toml | 16 +++++++++++++++- test/rstar/Project.toml | 18 ------------------ 2 files changed, 15 insertions(+), 19 deletions(-) delete mode 100644 test/rstar/Project.toml diff --git a/test/Project.toml b/test/Project.toml index c3a16aaf..0c6ce1be 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,24 @@ [deps] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" +MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FFTW = "1.1" -julia = "1" +MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" +MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +julia = "1.3" diff --git a/test/rstar/Project.toml b/test/rstar/Project.toml deleted file mode 100644 index e0b6a797..00000000 --- a/test/rstar/Project.toml +++ /dev/null @@ -1,18 +0,0 @@ -[deps] -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" -MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" -MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[compat] -Distributions = "0.25" -MCMCDiagnosticTools = "0.2" -MLJBase = "0.19, 0.20, 0.21" -MLJLIBSVMInterface = "0.1, 0.2" -MLJXGBoostInterface = "0.1, 0.2" -Tables = "1" -julia = "1.3" From d3942c3eb518d0bb7ece13a487ac7ac06b472a16 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 22:10:03 +0100 Subject: [PATCH 43/45] Move rstar test file out of own directory --- test/{rstar/runtests.jl => rstar.jl} | 0 test/runtests.jl | 6 +----- 2 files changed, 1 insertion(+), 5 deletions(-) rename test/{rstar/runtests.jl => rstar.jl} (100%) diff --git a/test/rstar/runtests.jl b/test/rstar.jl similarity index 100% rename from test/rstar/runtests.jl rename to test/rstar.jl diff --git a/test/runtests.jl b/test/runtests.jl index 8901fa6e..63fdade0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,11 +38,7 @@ Random.seed!(1) @testset "R⋆ diagnostic" begin # XGBoost errors on 32bit systems: https://github.com/dmlc/XGBoost.jl/issues/92 if Sys.WORD_SIZE == 64 - # run tests related to rstar statistic - Pkg.activate("rstar") - Pkg.develop(PackageSpec(path=dirname(dirname(pathof(MCMCDiagnosticTools))))) - Pkg.instantiate() - include(joinpath("rstar", "runtests.jl")) + include("rstar.jl") else @info "R⋆ not tested: requires 64bit architecture" end From 5e0bd5bd8d0b6df8ee8d10cac358101d057693cd Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 22:11:32 +0100 Subject: [PATCH 44/45] Fix version numbers --- test/Project.toml | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 0c6ce1be..a0e407c1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,13 +12,11 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Distributions = "0.25" FFTW = "1.1" -MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0" -MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" -MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +MCMCDiagnosticTools = "0.2" +MLJBase = "0.19, 0.20, 0.21" +MLJLIBSVMInterface = "0.1, 0.2" +MLJXGBoostInterface = "0.1, 0.2" +Tables = "1" julia = "1.3" From ff7e5e045b2ac554ca26ec7def9584bfa4381e1f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 12 Dec 2022 23:18:34 +0100 Subject: [PATCH 45/45] Apply suggestions from code review Co-authored-by: David Widmann --- Project.toml | 1 - test/Project.toml | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index b5959ef1..05009511 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.2.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" -Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/Project.toml b/test/Project.toml index a0e407c1..09fca5b2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,6 +17,6 @@ FFTW = "1.1" MCMCDiagnosticTools = "0.2" MLJBase = "0.19, 0.20, 0.21" MLJLIBSVMInterface = "0.1, 0.2" -MLJXGBoostInterface = "0.1, 0.2" +MLJXGBoostInterface = "0.1, 0.2, 0.3" Tables = "1" julia = "1.3"