Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change dimension ordering #50

Merged
merged 47 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3cdde3e
Update discretediag
sethaxen Nov 19, 2022
7764bb2
Update ess_rhat
sethaxen Nov 19, 2022
f2b3f5e
Update gelmandiag
sethaxen Nov 19, 2022
10e3814
Update rstar
sethaxen Nov 19, 2022
86297e9
Add 3d array method for rstar
sethaxen Nov 19, 2022
d5609e2
Add 3d array method for mcse
sethaxen Nov 19, 2022
79adf82
Update docstrings
sethaxen Nov 19, 2022
1370c7a
Change dimension order in tests
sethaxen Nov 19, 2022
29cba7f
Update rstar tests with permuted dims
sethaxen Nov 19, 2022
201bdec
Test rstar with 3d array
sethaxen Nov 19, 2022
163ea88
Test mcse with 3d array
sethaxen Nov 19, 2022
66d2c70
Increment version number
sethaxen Nov 19, 2022
9215e2f
Bump compat
sethaxen Nov 19, 2022
b733691
Run formatter
sethaxen Nov 19, 2022
06b4d9f
Update seed
sethaxen Nov 19, 2022
0917971
Add Compat as dependency
sethaxen Nov 21, 2022
3230e08
Use eachslice
sethaxen Nov 21, 2022
61653f9
Replace mapslices with an explicit loop
sethaxen Nov 21, 2022
1d0a086
Run formatter
sethaxen Nov 21, 2022
f1e05db
Avoid explicit axis check
sethaxen Nov 21, 2022
e7bd124
Remove type-instability
sethaxen Nov 21, 2022
1941044
Accept any table to rstar
sethaxen Nov 21, 2022
1dbc4f2
Update rstar documentation
sethaxen Nov 21, 2022
87d5f88
Release type constraint
sethaxen Nov 21, 2022
78d256b
Support rstar taking matrices or vectors
sethaxen Nov 21, 2022
aa01a1e
Update rstar tests
sethaxen Nov 21, 2022
e04fd46
Add type consistency check
sethaxen Nov 21, 2022
1a561f0
Don't permutedims in discretediag
sethaxen Nov 23, 2022
e6ce9b2
Revert all changes to mcse
sethaxen Nov 25, 2022
d1924cf
Reorder dimensions to (draw, chain, params)
sethaxen Dec 1, 2022
ea01e3f
Apply suggestions from code review
sethaxen Dec 12, 2022
d7437a0
Split rstar docstring
sethaxen Dec 12, 2022
ac5f2cb
Convert to table once
sethaxen Dec 12, 2022
be46258
Clean up language
sethaxen Dec 12, 2022
586b8f1
Use correct variable name
sethaxen Dec 12, 2022
365267f
Allow Tables v1
sethaxen Dec 12, 2022
9438e34
Bump Julia compat to v1.3
sethaxen Dec 12, 2022
a5e784c
Remove compat dependency
sethaxen Dec 12, 2022
76317b9
Remove all special-casing for less than v1.2
sethaxen Dec 12, 2022
859f658
Test on v1.3
sethaxen Dec 12, 2022
c48380c
Use PackageSpec for v1.3
sethaxen Dec 12, 2022
7666803
Merge test Project.tomls
sethaxen Dec 12, 2022
d3942c3
Move rstar test file out of own directory
sethaxen Dec 12, 2022
5e0bd5b
Fix version numbers
sethaxen Dec 12, 2022
11b7eb4
Merge branch 'main' into unifydimorder
sethaxen Dec 12, 2022
ff7e5e0
Apply suggestions from code review
sethaxen Dec 12, 2022
7ae2231
Merge branch 'main' into unifydimorder
sethaxen Dec 12, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.0'
- '1.3'
- '1'
- 'nightly'
os:
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "MCMCDiagnosticTools"
uuid = "be115224-59cd-429b-ad48-344e309966f0"
authors = ["David Widmann"]
version = "0.1.5"
version = "0.2.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -23,7 +24,7 @@ MLJModelInterface = "1.6"
SpecialFunctions = "0.8, 0.9, 0.10, 1, 2"
StatsBase = "0.33"
Tables = "1"
julia = "1"
julia = "1.3"

[extras]
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Documenter = "0.27"
MCMCDiagnosticTools = "0.1"
MCMCDiagnosticTools = "0.2"
MLJBase = "0.19, 0.20, 0.21"
MLJXGBoostInterface = "0.1, 0.2, 0.3"
julia = "1.3"
14 changes: 8 additions & 6 deletions src/discretediag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ function discretediag_sub(
start_iter::Int,
step_size::Int,
)
num_iters, num_vars, num_chains = size(c)
num_iters, num_chains, num_vars = size(c)

## Between-chain diagnostic
length_results = length(start_iter:step_size:num_iters)
Expand All @@ -384,7 +384,7 @@ function discretediag_sub(
pvalue=Vector{Float64}(undef, num_vars),
)
for j in 1:num_vars
X = convert(AbstractMatrix{Int}, c[:, j, :])
X = convert(AbstractMatrix{Int}, c[:, :, j])
result = diag_all(X, method, nsim, start_iter, step_size)

plot_vals_stat[:, j] .= result.stat ./ result.df
Expand All @@ -403,7 +403,7 @@ function discretediag_sub(
)
for k in 1:num_chains
for j in 1:num_vars
x = convert(AbstractVector{Int}, c[:, j, k])
x = convert(AbstractVector{Int}, c[:, k, j])

idx1 = 1:round(Int, frac * num_iters)
idx2 = round(Int, num_iters - frac * num_iters + 1):num_iters
Expand All @@ -423,14 +423,16 @@ function discretediag_sub(
end

"""
discretediag(chains::AbstractArray{<:Real,3}; frac=0.3, method=:weiss, nsim=1_000)
discretediag(samples::AbstractArray{<:Real,3}; frac=0.3, method=:weiss, nsim=1_000)

Compute discrete diagnostic where `method` can be one of `:weiss`, `:hangartner`,
Compute discrete diagnostic on `samples` with shape `(draws, chains, parameters)`.

`method` can be one of `:weiss`, `:hangartner`,
`:DARBOOT`, `:MCBOOT`, `:billinsgley`, and `:billingsleyBOOT`.

# References

Benjamin E. Deonovic, & Brian J. Smith. (2017). Convergence diagnostics for MCMC draws of a categorical variable.
Benjamin E. Deonovic, & Brian J. Smith. (2017). Convergence diagnostics for MCMC draws of a categorical variable.
"""
function discretediag(
chains::AbstractArray{<:Real,3}; frac::Real=0.3, method::Symbol=:weiss, nsim::Int=1000
Expand Down
8 changes: 4 additions & 4 deletions src/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ end
)

Estimate the effective sample size and the potential scale reduction of the `samples` of
shape (draws, parameters, chains) with the `method` and a maximum lag of `maxlag`.
shape `(draws, chains, parameters)` with the `method` and a maximum lag of `maxlag`.

See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref)
"""
Expand All @@ -212,8 +212,8 @@ function ess_rhat(
)
# compute size of matrices (each chain is split!)
niter = size(chains, 1) ÷ 2
nparams = size(chains, 2)
nchains = 2 * size(chains, 3)
nparams = size(chains, 3)
nchains = 2 * size(chains, 2)
ntotal = niter * nchains

# do not compute estimates if there is only one sample or lag
Expand All @@ -238,7 +238,7 @@ function ess_rhat(
rhat = Vector{T}(undef, nparams)

# for each parameter
for (i, chains_slice) in enumerate((view(chains, :, i, :) for i in axes(chains, 2)))
for (i, chains_slice) in enumerate(eachslice(chains; dims=3))
# check that no values are missing
if any(x -> x === missing, chains_slice)
rhat[i] = missing
Expand Down
19 changes: 11 additions & 8 deletions src/gelmandiag.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05)
niters, nparams, nchains = size(psi)
niters, nchains, nparams = size(psi)
nchains > 1 || error("Gelman diagnostic requires at least 2 chains")

rfixed = (niters - 1) / niters
rrandomscale = (nchains + 1) / (nchains * niters)

S2 = map(Statistics.cov, (view(psi, :, :, i) for i in axes(psi, 3)))
# `eachslice(psi; dims=2)` breaks type inference
S2 = map(x -> Statistics.cov(x; dims=1), (view(psi, :, i, :) for i in axes(psi, 2)))
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
W = Statistics.mean(S2)

psibar = dropdims(Statistics.mean(psi; dims=1); dims=1)'
psibar = dropdims(Statistics.mean(psi; dims=1); dims=1)
B = niters .* Statistics.cov(psibar)

w = LinearAlgebra.diag(W)
Expand Down Expand Up @@ -52,9 +53,10 @@ function _gelmandiag(psi::AbstractArray{<:Real,3}; alpha::Real=0.05)
end

"""
gelmandiag(chains::AbstractArray{<:Real,3}; alpha::Real=0.95)
gelmandiag(samples::AbstractArray{<:Real,3}; alpha::Real=0.95)

Compute the Gelman, Rubin and Brooks diagnostics [^Gelman1992] [^Brooks1998]. Values of the
Compute the Gelman, Rubin and Brooks diagnostics [^Gelman1992] [^Brooks1998] on `samples`
with shape `(draws, chains, parameters)`. Values of the
diagnostic’s potential scale reduction factor (PSRF) that are close to one suggest
convergence. As a rule-of-thumb, convergence is rejected if the 97.5 percentile of a PSRF
is greater than 1.2.
Expand All @@ -70,12 +72,13 @@ function gelmandiag(chains::AbstractArray{<:Real,3}; kwargs...)
end

"""
gelmandiag_multivariate(chains::AbstractArray{<:Real,3}; alpha::Real=0.05)
gelmandiag_multivariate(samples::AbstractArray{<:Real,3}; alpha::Real=0.05)

Compute the multivariate Gelman, Rubin and Brooks diagnostics.
Compute the multivariate Gelman, Rubin and Brooks diagnostics on `samples` with shape
`(draws, chains, parameters)`.
"""
function gelmandiag_multivariate(chains::AbstractArray{<:Real,3}; kwargs...)
niters, nparams, nchains = size(chains)
niters, nchains, nparams = size(chains)
if nparams < 2
error(
"computation of the multivariate potential scale reduction factor requires ",
Expand Down
2 changes: 1 addition & 1 deletion src/rafterydiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function rafterydiag(
dichot = Int[(x .<= StatsBase.quantile(x, q))...]
kthin = 0
bic = 1.0
local test , ntest
local test, ntest
while bic >= 0.0
kthin += 1
test = dichot[1:kthin:nx]
Expand Down
154 changes: 94 additions & 60 deletions src/rstar.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,91 @@
"""
rstar(
rng=Random.GLOBAL_RNG,
classifier,
samples::AbstractMatrix,
rng::Random.AbstractRNG=Random.default_rng(),
classifier::MLJModelInterface.Supervised,
samples,
chain_indices::AbstractVector{Int};
subset::Real=0.8,
verbosity::Int=0,
)

Compute the ``R^*`` convergence statistic of the `samples` with shape (draws, parameters)
and corresponding chains `chain_indices` with the `classifier`.
Compute the ``R^*`` convergence statistic of the table `samples` with the `classifier`.

`samples` must be either an `AbstractMatrix`, an `AbstractVector`, or a table
(i.e. implements the Tables.jl interface) whose rows are draws and whose columns are
parameters.

`chain_indices` indicates the chain ids of each row of `samples`.

This method supports ragged chains, i.e. chains of nonequal lengths.
"""
function rstar(
rng::Random.AbstractRNG,
classifier::MLJModelInterface.Supervised,
x,
y::AbstractVector{Int};
subset::Real=0.8,
verbosity::Int=0,
)
# checks
MLJModelInterface.nrows(x) != length(y) && throw(DimensionMismatch())
0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)"))

# randomly sub-select training and testing set
N = length(y)
Ntrain = round(Int, N * subset)
0 < Ntrain < N ||
throw(ArgumentError("training and test data subsets must not be empty"))
ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain + 1):N)

xtable = _astable(x)

# train classifier on training data
ycategorical = MLJModelInterface.categorical(y)
xtrain = MLJModelInterface.selectrows(xtable, train_ids)
fitresult, _ = MLJModelInterface.fit(
classifier, verbosity, xtrain, ycategorical[train_ids]
)

# compute predictions on test data
xtest = MLJModelInterface.selectrows(xtable, test_ids)
predictions = _predict(classifier, fitresult, xtest)

# compute statistic
ytest = ycategorical[test_ids]
result = _rstar(predictions, ytest)

return result
end

_astable(x::AbstractVecOrMat) = Tables.table(x)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
_astable(x) = Tables.istable(x) ? x : throw(ArgumentError("Argument is not a valid table"))

# Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863
# `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information
# TODO: Remove once the upstream issue is fixed
function _predict(model::MLJModelInterface.Model, fitresult, x)
y = MLJModelInterface.predict(model, fitresult, x)
return if :predict in MLJModelInterface.reporting_operations(model)
first(y)
else
y
end
end

"""
rstar(
rng::Random.AbstractRNG=Random.default_rng(),
classifier::MLJModelInterface.Supervised,
samples::AbstractArray{<:Real,3};
subset::Real=0.8,
verbosity::Int=0,
)

Compute the ``R^*`` convergence statistic of the `samples` with the `classifier`.

`samples` is an array of draws with the shape `(draws, chains, parameters)`.`

This implementation is an adaption of algorithms 1 and 2 described by Lambert and Vehtari.

Expand All @@ -29,19 +105,17 @@ is returned (algorithm 2).

# Examples

```jldoctest rstar; setup = :(using Random; Random.seed!(100))
```jldoctest rstar; setup = :(using Random; Random.seed!(101))
julia> using MLJBase, MLJXGBoostInterface, Statistics

julia> samples = fill(4.0, 300, 2);

julia> chain_indices = repeat(1:3; outer=100);
julia> samples = fill(4.0, 100, 3, 2);
```

One can compute the distribution of the ``R^*`` statistic (algorithm 2) with the
probabilistic classifier.

```jldoctest rstar
julia> distribution = rstar(XGBoostClassifier(), samples, chain_indices);
julia> distribution = rstar(XGBoostClassifier(), samples);

julia> isapprox(mean(distribution), 1; atol=0.1)
true
Expand All @@ -54,7 +128,7 @@ predicting the mode. In MLJ this corresponds to a pipeline of models.
```jldoctest rstar
julia> xgboost_deterministic = Pipeline(XGBoostClassifier(); operation=predict_mode);

julia> value = rstar(xgboost_deterministic, samples, chain_indices);
julia> value = rstar(xgboost_deterministic, samples);

julia> isapprox(value, 1; atol=0.2)
true
Expand All @@ -67,60 +141,20 @@ Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic
function rstar(
rng::Random.AbstractRNG,
classifier::MLJModelInterface.Supervised,
x::AbstractMatrix,
y::AbstractVector{Int};
subset::Real=0.8,
verbosity::Int=0,
x::AbstractArray{<:Any,3};
kwargs...,
)
# checks
size(x, 1) != length(y) && throw(DimensionMismatch())
0 < subset < 1 || throw(ArgumentError("`subset` must be a number in (0, 1)"))

# randomly sub-select training and testing set
N = length(y)
Ntrain = round(Int, N * subset)
0 < Ntrain < N ||
throw(ArgumentError("training and test data subsets must not be empty"))
ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain + 1):N)

# train classifier on training data
ycategorical = MLJModelInterface.categorical(y)
fitresult, _ = MLJModelInterface.fit(
classifier, verbosity, Tables.table(x[train_ids, :]), ycategorical[train_ids]
)

# compute predictions on test data
xtest = Tables.table(x[test_ids, :])
predictions = _predict(classifier, fitresult, xtest)

# compute statistic
ytest = ycategorical[test_ids]
result = _rstar(predictions, ytest)

return result
samples = reshape(x, :, size(x, 3))
chain_inds = repeat(axes(x, 2); inner=size(x, 1))
return rstar(rng, classifier, samples, chain_inds; kwargs...)
end

# Workaround for https://github.com/JuliaAI/MLJBase.jl/issues/863
# `MLJModelInterface.predict` sometimes returns predictions and sometimes predictions + additional information
# TODO: Remove once the upstream issue is fixed
function _predict(model::MLJModelInterface.Model, fitresult, x)
y = MLJModelInterface.predict(model, fitresult, x)
return if :predict in MLJModelInterface.reporting_operations(model)
first(y)
else
y
end
function rstar(classif::MLJModelInterface.Supervised, x, y::AbstractVector{Int}; kwargs...)
return rstar(Random.default_rng(), classif, x, y; kwargs...)
end

function rstar(
classif::MLJModelInterface.Supervised,
x::AbstractMatrix,
y::AbstractVector{Int};
kwargs...,
)
return rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...)
function rstar(classif::MLJModelInterface.Supervised, x::AbstractArray{<:Any,3}; kwargs...)
return rstar(Random.default_rng(), classif, x; kwargs...)
end

# R⋆ for deterministic predictions (algorithm 1)
Expand Down
14 changes: 13 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Distributions = "0.25"
FFTW = "1.1"
julia = "1"
MCMCDiagnosticTools = "0.2"
MLJBase = "0.19, 0.20, 0.21"
MLJLIBSVMInterface = "0.1, 0.2"
MLJXGBoostInterface = "0.1, 0.2"
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
Tables = "1"
julia = "1.3"
Loading