From 88f5f71caf95bca8aecc76e6e9461be7045d4daa Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 26 Dec 2023 12:36:53 -0800 Subject: [PATCH 01/30] Implement all Pareto diagnostics --- src/PSIS.jl | 1 + src/diagnostics.jl | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 src/diagnostics.jl diff --git a/src/PSIS.jl b/src/PSIS.jl index f249a936..a48a9209 100644 --- a/src/PSIS.jl +++ b/src/PSIS.jl @@ -10,6 +10,7 @@ export psis, psis!, ess_is include("utils.jl") include("generalized_pareto.jl") +include("diagnostics.jl") include("core.jl") include("ess.jl") include("recipes/plots.jl") diff --git a/src/diagnostics.jl b/src/diagnostics.jl new file mode 100644 index 00000000..0542f39f --- /dev/null +++ b/src/diagnostics.jl @@ -0,0 +1,43 @@ +""" + pareto_shape_threshold(sample_size::Real) -> Real + +Given the `sample_size`, compute the Pareto shape ``k`` threshold needed for a reliable +Pareto-smoothed estimate (i.e. to have small probability of large error). +""" +pareto_shape_threshold(sample_size::Real) = 1 - inv(log10(sample_size)) + +""" + min_sample_size(pareto_shape::Real) -> Real + min_sample_size(pareto_shape::AbstractArray) -> AbstractArray + +Given the Pareto shape values ``k``, compute the minimum sample size needed for a reliable +Pareto-smoothed estimate (i.e. to have small probability of large error). +""" +function min_sample_size end +min_sample_size(pareto_shape::Real) = exp10(inv(1 - max(0, pareto_shape))) +min_sample_size(pareto_shape::AbstractArray) = map(min_sample_size, pareto_shape) + +""" + convergence_rate(pareto_shape::Real, sample_size::Real) -> Real + convergence_rate(pareto_shape::AbstractArray, sample_size::Real) -> AbstractArray + +Given `sample_size` and Pareto shape values ``k``, compute the relative convergence rate of +the RMSE of the Pareto-smoothed estimate. +""" +function convergence_rate end +function convergence_rate(k::AbstractArray{<:Real}, S::Real) + return convergence_rate.(k, S) +end +function convergence_rate(k::Real, S::Real) + T = typeof((one(S) * 1^zero(k) * oneunit(k)) / (one(S) * 1^zero(k))) + k < 0 && return oneunit(T) + k > 1 && return zero(T) + k == 1//2 && return T(1 - inv(log(S))) + return T( + max( + 0, + (2 * (k - 1) * S^(2k) - (2k - 1) * S^(2k - 1) + S) / + ((S - 1) * (1 - S^(2k - 1))), + ), + ) +end From 3ff48a4e0b3c2bd8bc4bf4766dafa92fe952c230 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 26 Dec 2023 12:37:11 -0800 Subject: [PATCH 02/30] Add standalone `pareto_diagnose` method --- src/PSIS.jl | 2 + src/pareto_diagnose.jl | 138 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+) create mode 100644 src/pareto_diagnose.jl diff --git a/src/PSIS.jl b/src/PSIS.jl index a48a9209..7375c1dc 100644 --- a/src/PSIS.jl +++ b/src/PSIS.jl @@ -7,10 +7,12 @@ using Statistics: Statistics export PSISPlots export PSISResult export psis, psis!, ess_is +export pareto_diagnose include("utils.jl") include("generalized_pareto.jl") include("diagnostics.jl") +include("pareto_diagnose.jl") include("core.jl") include("ess.jl") include("recipes/plots.jl") diff --git a/src/pareto_diagnose.jl b/src/pareto_diagnose.jl new file mode 100644 index 00000000..15eb8ef7 --- /dev/null +++ b/src/pareto_diagnose.jl @@ -0,0 +1,138 @@ +@enum Tails LeftTail RightTail BothTails +const TAIL_OPTIONS = (left=LeftTail, right=RightTail, both=BothTails) + +_validate_tails(tails::Tails) = tails +function _validate_tails(tails::Symbol) + if !haskey(TAIL_OPTIONS, tails) + throw(ArgumentError("invalid tails: $tails. Valid values are :left, :right, :both")) + end + return TAIL_OPTIONS[tails] +end + +_default_tails(log::Bool) = log ? RightTail : BothTails + +_as_scale(log::Bool) = log ? Base.log : identity + +""" + pareto_diagnose(x::AbstractArray; warn=false, reff=1, log=false[, tails::Symbol]) + +Compute diagnostics for Pareto-smoothed estimate of the expectand ``x``. + +# Arguments + + - `x`: An array of values of shape `(draws[, chains[, params...]])`. + +# Keywords + + - `warn=false`: Whether to raise an informative warning if the diagnostics indicate that the + Pareto-smoothed estimate may be unreliable. + - `reff=1`: The relative efficiency of the importance weights. Must be either a scalar or an + array of shape `(params...,)`. + - `log=false`: Whether `x` represents the log of the expectand. If `true`, the diagnostics + are computed on the original scale, taking care to avoid numerical overflow. + - `tails`: Which tail(s) to use for the diagnostics. Valid values are `:left`, `:right` and + `:both`. If `log=true`, only `:right` is valid. Defaults to `:both` if `log=false`. + +# Returns + + - `diagnostics::NamedTuple`: A named tuple containing the following fields: + + + `pareto_shape`: The Pareto shape parameter ``k``. + + `min_sample_size`: The minimum sample size needed for a reliable Pareto-smoothed + estimate (i.e. to have small probability of large error). + + `pareto_shape_threshold`: The Pareto shape ``k`` threshold needed for a reliable + Pareto-smoothed estimate (i.e. to have small probability of large error). + + `convergence_rate`: The relative convergence rate of the RMSE of the + Pareto-smoothed estimate. +""" +function pareto_diagnose( + x::AbstractArray{<:Real}; + warn::Bool=false, + reff=1, + log::Bool=false, + tails::Union{Tails,Symbol}=_default_tails(log), +) + _tails = _validate_tails(tails) + if log && _tails !== RightTail + throw(ArgumentError("log can only be true when tails=:right")) + end + sample_size = prod(map(Base.Fix1(size, x), _sample_dims(x))) + if _tails === BothTails + end + diagnostics = _pareto_diagnose(x, reff, _tails, _as_scale(log)) + if warn + # TODO: check diagnostics and raise warning + end + return diagnostics +end + +function _pareto_diagnose(x::AbstractArray, reff, tails::Tails, scale) + tail_dist = _fit_tail_dist(x, reff, tails, scale) + sample_size = prod(map(Base.Fix1(size, x), _sample_dims(x))) + return _compute_diagnostics(pareto_shape(tail_dist), sample_size) +end + +function _compute_diagnostics(pareto_shape, sample_size) + return ( + pareto_shape, + min_sample_size=min_sample_size(pareto_shape), + pareto_shape_threshold=pareto_shape_threshold(sample_size), + convergence_rate=convergence_rate(pareto_shape, sample_size), + ) +end + +@inline function _fit_tail_dist( + x::AbstractArray, + reff::Union{Real,AbstractArray{<:Real}}, + tails::Tails, + scale::Union{typeof(log),typeof(identity)}, +) + return map(_eachparamindex(x)) do i + reff_i = reff isa Real ? reff : _selectparam(reff, i) + return _fit_tail_dist(_selectparam(x, i), reff_i, tails, scale) + end +end +function _fit_tail_dist( + x::AbstractVecOrMat, + reff::Real, + tails::Tails, + scale::Union{typeof(log),typeof(identity)}, +) + S = length(x) + M = tail_length(reff, S) + x_tail = similar(vec(x), M) + _tails = tails === BothTails ? (LeftTail, RightTail) : (tails,) + tail_dists = map(_tails) do tail + _, cutoff = _get_tail!(x_tail, vec(x), tail) + _shift_tail!(x_tail, cutoff, tail, scale) + return _fit_tail_dist(x_tail) + end + tail_dist = argmax(pareto_shape, tail_dists) + return tail_dist +end +_fit_tail_dist(x_tail::AbstractVector) = fit_gpd(x_tail; prior_adjusted=true, sorted=true) + +function _get_tail!(x_tail::AbstractVector, x::AbstractVector, tail::Tails) + S = length(x) + M = length(x_tail) + ind_offset = firstindex(x) - 1 + perm = partialsortperm(x, ind_offset .+ ((S - M):S); rev=tail === LeftTail) + cutoff = x[first(perm)] + tail_inds = @view perm[(firstindex(perm) + 1):end] + copyto!(x_tail, @views x[tail_inds]) + return x_tail, cutoff +end + +function _shift_tail!( + x_tail, cutoff, tails::Tails, scale::Union{typeof(log),typeof(identity)} +) + if scale === log + x_tail_max = x_tail[end] + @. x_tail = exp(x_tail - x_tail_max) - exp(cutoff - x_tail_max) + elseif tails === LeftTail + @. x_tail = cutoff - x_tail + else + @. x_tail = x_tail - cutoff + end + return x_tail +end From 50c0d2dd8f526964b4592add6056ae3497a09ac1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 28 Dec 2023 21:39:01 +0100 Subject: [PATCH 03/30] Remove dead code branches --- src/pareto_diagnose.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/pareto_diagnose.jl b/src/pareto_diagnose.jl index 15eb8ef7..3b2b4654 100644 --- a/src/pareto_diagnose.jl +++ b/src/pareto_diagnose.jl @@ -56,10 +56,9 @@ function pareto_diagnose( if log && _tails !== RightTail throw(ArgumentError("log can only be true when tails=:right")) end + pareto_shape = _pareto_diagnose(x, reff, _tails, _as_scale(log)) sample_size = prod(map(Base.Fix1(size, x), _sample_dims(x))) - if _tails === BothTails - end - diagnostics = _pareto_diagnose(x, reff, _tails, _as_scale(log)) + diagnostics = _compute_diagnostics(pareto_shape, sample_size) if warn # TODO: check diagnostics and raise warning end From 00514efdaead3c8fe1616cf3e8732b0ff710710c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 28 Dec 2023 21:40:24 +0100 Subject: [PATCH 04/30] Add h-specific pareto diagnose --- src/pareto_diagnose.jl | 74 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/src/pareto_diagnose.jl b/src/pareto_diagnose.jl index 3b2b4654..155b9480 100644 --- a/src/pareto_diagnose.jl +++ b/src/pareto_diagnose.jl @@ -65,6 +65,80 @@ function pareto_diagnose( return diagnostics end +""" + pareto_diagnose(x::AbstractArray, ratios::AbstractArray; kwargs...) + +Compute diagnostics for Pareto-smoothed importance-weighted estimate of the expectand ``x``. + +# Arguments + + - `x`: An array of values of shape `(draws[, chains])`. If `log=true`, the values are + assumed to be on the log scale. + - `ratios`: An array of unnormalized importance ratios of shape + `(draws[, chains[, params...]])`. If `log_ratios=true`, the ratios are assumed to be on + the log scale. + +# Keywords + + - `warn=false`: Whether to raise an informative warning if the diagnostics indicate that + the Pareto-smoothed estimate may be unreliable. + - `reff=1`: The relative efficiency of the importance weights on the original scale. Must + be either a scalar or an array of shape `(params...,)`. + - `log=false`: Whether `x` represents the log of the expectand. + - `log_ratios=true`: Whether `ratios` represents the log of the importance ratios. + - `diagnose_ratios=true`: Whether to compute diagnostics for the importance ratios. + - `tails`: Which tail(s) of `x * ratios` to use for the diagnostics. Valid values are + `:left`, `:right`, and `:both`. If `log=true`, only `:right` is valid. Defaults to + `:both` if `log=false`. + +# Returns + + - `diagnostics::NamedTuple`: A named tuple containing the following fields: + + + `pareto_shape`: The Pareto shape parameter ``k``. + + `min_sample_size`: The minimum sample size needed for a reliable Pareto-smoothed + estimate (i.e. to have small probability of large error). + + `pareto_shape_threshold`: The Pareto shape ``k`` threshold needed for a reliable + Pareto-smoothed estimate (i.e. to have small probability of large error). + + `convergence_rate`: The relative convergence rate of the RMSE of the + Pareto-smoothed estimate. +""" +function pareto_diagnose( + x::AbstractArray{<:Real}, + ratios::AbstractArray{<:Real}; + warn::Bool=false, + log::Bool=false, + log_ratios::Bool=true, + diagnose_ratios::Bool=true, + tails::Union{Tails,Symbol}=_default_tails(log), + reff=1, +) + _tails = _validate_tails(tails) + expectand = _compute_expectand(x, ratios; log, log_ratios) + pareto_shape_numerator = _pareto_diagnose(expectand, reff, _tails, _as_scale(log)) + if diagnose_ratios + pareto_shape_denominator = _pareto_diagnose( + ratios, reff, RightTail, _as_scale(log_ratios) + ) + pareto_shape = max.(pareto_shape_numerator, pareto_shape_denominator) + else + pareto_shape = pareto_shape_numerator + end + sample_size = prod(map(Base.Fix1(size, x), _sample_dims(x))) + diagnostics = _compute_diagnostics(pareto_shape, sample_size) + # TODO: check diagnostics and raise warning + return diagnostics +end + +function _compute_expectand(x, ratios; log, log_ratios) + log && log_ratios && return x .+ ratios + !log && !log_ratios && return x .* ratios + log && return x .+ Base.log.(ratios) + dims = _param_dims(ratios) + # scale ratios to maximum of 1 to reduce numerical issues + return x .* exp.(ratios .- dropdims(maximum(ratios; dims); dims)) +end + function _pareto_diagnose(x::AbstractArray, reff, tails::Tails, scale) tail_dist = _fit_tail_dist(x, reff, tails, scale) sample_size = prod(map(Base.Fix1(size, x), _sample_dims(x))) From 4b648e21f155ca1f68132c0c627aa2f6cf4930f9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 28 Dec 2023 21:43:20 +0100 Subject: [PATCH 05/30] Run formatter --- src/pareto_diagnose.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/pareto_diagnose.jl b/src/pareto_diagnose.jl index 155b9480..c03d84e1 100644 --- a/src/pareto_diagnose.jl +++ b/src/pareto_diagnose.jl @@ -95,13 +95,13 @@ Compute diagnostics for Pareto-smoothed importance-weighted estimate of the expe - `diagnostics::NamedTuple`: A named tuple containing the following fields: - + `pareto_shape`: The Pareto shape parameter ``k``. - + `min_sample_size`: The minimum sample size needed for a reliable Pareto-smoothed - estimate (i.e. to have small probability of large error). - + `pareto_shape_threshold`: The Pareto shape ``k`` threshold needed for a reliable - Pareto-smoothed estimate (i.e. to have small probability of large error). - + `convergence_rate`: The relative convergence rate of the RMSE of the - Pareto-smoothed estimate. + + `pareto_shape`: The Pareto shape parameter ``k``. + + `min_sample_size`: The minimum sample size needed for a reliable Pareto-smoothed + estimate (i.e. to have small probability of large error). + + `pareto_shape_threshold`: The Pareto shape ``k`` threshold needed for a reliable + Pareto-smoothed estimate (i.e. to have small probability of large error). + + `convergence_rate`: The relative convergence rate of the RMSE of the + Pareto-smoothed estimate. """ function pareto_diagnose( x::AbstractArray{<:Real}, From a3cfa6ee3261f8ec803025a245748f31687385d5 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 20:08:38 +0100 Subject: [PATCH 06/30] Add object for storing Pareto diagnostics --- Project.toml | 6 ++ src/PSIS.jl | 3 + src/diagnostics.jl | 187 ++++++++++++++++++++++++++++++++++++++++----- 3 files changed, 176 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index 61540253..0a838489 100644 --- a/Project.toml +++ b/Project.toml @@ -4,8 +4,11 @@ authors = ["Seth Axen and contributors"] version = "0.9.4" [deps] +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -13,10 +16,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] DimensionalData = "0.24" Distributions = "0.25.81" +DocStringExtensions = "0.9" +IntervalSets = "0.7" JLD2 = "0.4" LinearAlgebra = "1.6" LogExpFunctions = "0.2.0, 0.3" Plots = "1" +PrettyTables = "2" Printf = "1.6" RecipesBase = "1" ReferenceTests = "0.9, 0.10" diff --git a/src/PSIS.jl b/src/PSIS.jl index 7375c1dc..3b2c130a 100644 --- a/src/PSIS.jl +++ b/src/PSIS.jl @@ -1,6 +1,9 @@ module PSIS +using DocStringExtensions: FIELDS +using IntervalSets: IntervalSets using LogExpFunctions: LogExpFunctions +using PrettyTables: PrettyTables using Printf: @sprintf using Statistics: Statistics diff --git a/src/diagnostics.jl b/src/diagnostics.jl index 0542f39f..38a113d7 100644 --- a/src/diagnostics.jl +++ b/src/diagnostics.jl @@ -1,30 +1,35 @@ """ - pareto_shape_threshold(sample_size::Real) -> Real + ParetoDiagnostics -Given the `sample_size`, compute the Pareto shape ``k`` threshold needed for a reliable -Pareto-smoothed estimate (i.e. to have small probability of large error). -""" -pareto_shape_threshold(sample_size::Real) = 1 - inv(log10(sample_size)) +Diagnostic information for Pareto-smoothed importance sampling.[^VehtariSimpson2021] -""" - min_sample_size(pareto_shape::Real) -> Real - min_sample_size(pareto_shape::AbstractArray) -> AbstractArray +$FIELDS -Given the Pareto shape values ``k``, compute the minimum sample size needed for a reliable -Pareto-smoothed estimate (i.e. to have small probability of large error). +[^VehtariSimpson2021]: Vehtari A, Simpson D, Gelman A, Yao Y, Gabry J. (2021). + Pareto smoothed importance sampling. + [arXiv:1507.02646v7](https://arxiv.org/abs/1507.02646v7) [stat.CO] """ -function min_sample_size end -min_sample_size(pareto_shape::Real) = exp10(inv(1 - max(0, pareto_shape))) -min_sample_size(pareto_shape::AbstractArray) = map(min_sample_size, pareto_shape) +struct ParetoDiagnostics{TK,TKM,TS,TR} + "The estimated Pareto shape ``\\hat{k}`` for each parameter." + pareto_shape::TK + "The sample-size-dependent Pareto shape threshold ``k_\\mathrm{threshold}`` needed for a + reliable Pareto-smoothed estimate (i.e. to have small probability of large error)." + pareto_shape_threshold::TKM + "The estimated minimum sample size needed for a reliable Pareto-smoothed estimate (i.e. + to have small probability of large error)." + min_sample_size::TS + "The estimated relative convergence rate of the RMSE of the Pareto-smoothed estimate." + convergence_rate::TR +end -""" - convergence_rate(pareto_shape::Real, sample_size::Real) -> Real - convergence_rate(pareto_shape::AbstractArray, sample_size::Real) -> AbstractArray +pareto_shape_threshold(sample_size::Real) = 1 - inv(log10(sample_size)) + +function min_sample_size(pareto_shape::Real) + min_ss = exp10(inv(1 - max(0, pareto_shape))) + return pareto_shape > 1 ? oftype(min_ss, Inf) : min_ss +end +min_sample_size(pareto_shape::AbstractArray) = map(min_sample_size, pareto_shape) -Given `sample_size` and Pareto shape values ``k``, compute the relative convergence rate of -the RMSE of the Pareto-smoothed estimate. -""" -function convergence_rate end function convergence_rate(k::AbstractArray{<:Real}, S::Real) return convergence_rate.(k, S) end @@ -41,3 +46,145 @@ function convergence_rate(k::Real, S::Real) ), ) end + +""" + check_pareto_diagnostics(diagnostics::ParetoDiagnostics) + +Check the diagnostics in [`ParetoDiagnostics`](@ref) and issue warnings if necessary. +""" +function check_pareto_diagnostics(diag::ParetoDiagnostics) + categories = _diagnostic_intervals(diag) + category_assignments = _diagnostic_category_assignments(diag) + nparams = length(diag.pareto_shape) + for (category, inds) in pairs(category_assignments) + count = length(inds) + count > 0 || continue + perc = round(Int, 100 * count / nparams) + msg = if category === :failed + "The generalized Pareto distribution could not be fit to the tail draws. " * + "Total number of draws should in general exceed 25, and the tail draws must " * + "be finite." + elseif category === :very_bad + "All estimates are unreliable. If the distribution of draws is bounded, " * + "further draws may improve the estimates, but it is not possible to predict " * + "whether any feasible sample size is sufficient." + elseif category === :bad + ss_max = ceil(maximum(i -> diag.min_sample_size[i], inds)) + "Sample size is too small and must be larger than " * + "$(@sprintf("%.10g", ss_max)) for all estimates to be reliable." + elseif category === :high_bias + "Bias dominates RMSE, and variance-based MCSE estimates are underestimated." + else + continue + end + suffix = + category === :failed ? "" : " (k ∈ $(_interval_string(categories[category])))" + prefix = if nparams > 1 + msg = lowercasefirst(msg) + prefix = "For $count parameters ($perc%), " + else + "" + end + @warn "$prefix$msg$suffix" + end +end + +function _compute_diagnostics(pareto_shape, sample_size) + return ParetoDiagnostics( + pareto_shape, + pareto_shape_threshold(sample_size), + min_sample_size(pareto_shape), + convergence_rate(pareto_shape, sample_size), + ) +end + +function _interval_string(i::IntervalSets.Interval) + l = IntervalSets.isleftopen(i) || !isfinite(minimum(i)) ? "(" : "[" + r = IntervalSets.isrightopen(i) || !isfinite(maximum(i)) ? ")" : "]" + imin, imax = IntervalSets.endpoints(i) + return "$l$(@sprintf("%.1g", imin)), $(@sprintf("%.1g", imax))$r" +end + +function _diagnostic_intervals(diag::ParetoDiagnostics) + khat_thresh = diag.pareto_shape_threshold + return ( + good=IntervalSets.ClosedInterval(-Inf, khat_thresh), + bad=IntervalSets.Interval{:open,:closed}(khat_thresh, 1), + very_bad=IntervalSets.Interval{:open,:closed}(1, Inf), + high_bias=IntervalSets.Interval{:open,:closed}(0.7, 1), + ) +end + +function _diagnostic_category_assignments(diagnostics) + intervals = _diagnostic_intervals(diagnostics) + result_counts = map(intervals) do interval + return findall(∈(interval), diagnostics.pareto_shape) + end + failed = findall(isnan, diagnostics.pareto_shape) + return merge(result_counts, (; failed)) +end + +function Base.show(io::IO, ::MIME"text/plain", diag::ParetoDiagnostics) + nparams = length(diag.pareto_shape) + println(io, "ParetoDiagnostics with $nparams parameters") + return _print_pareto_diagnostics_summary(io, diag; newline_at_end=false) +end + +function _print_pareto_diagnostics_summary(io::IO, diag::ParetoDiagnostics; kwargs...) + k = as_array(diag.pareto_shape) + category_assignments = NamedTuple{(:good, :bad, :very_bad, :failed)}( + _diagnostic_category_assignments(diag) + ) + category_intervals = _diagnostic_intervals(diag) + npoints = length(k) + rows = map(collect(pairs(category_assignments))) do (desc, inds) + interval = desc === :failed ? "--" : _interval_string(category_intervals[desc]) + return (; interval, desc, count=length(inds)) + end + return _print_pareto_diagnostics_summary(io::IO, rows, npoints; kwargs...) +end + +function _print_pareto_diagnostics_summary(io::IO, _rows, npoints; kwargs...) + rows = filter(r -> r.count > 0, _rows) + header = ["", "", "Count"] + alignment = [:r, :l, :l] + if length(first(rows)) > 3 + push!(header, "Min. ESS") + push!(alignment, :r) + end + formatters = ( + (v, i, j) -> j == 2 ? replace(string(v), '_' => " ") : v, + (v, i, j) -> j == 3 ? "$v ($(round(v * (100 // npoints); digits=1))%)" : v, + (v, i, j) -> j == 4 ? (rows[i].desc === :good ? "$(floor(Int, v))" : "——") : v, + ) + highlighters = ( + PrettyTables.Highlighter( + (data, i, j) -> (j == 3 && data[i][2] === :bad); + bold=true, + foreground=:light_red, + ), + PrettyTables.Highlighter( + (data, i, j) -> (j == 3 && data[i][2] === :very_bad); bold=true, foreground=:red + ), + PrettyTables.Highlighter( + (data, i, j) -> (j == 3 && data[i][2] === :failed); foreground=:red + ), + ) + + PrettyTables.pretty_table( + io, + rows; + header, + alignment, + alignment_anchor_regex=Dict(3 => [r"\s"]), + hlines=:none, + vlines=:none, + formatters, + highlighters, + kwargs..., + ) + return nothing +end + +_pad_left(s, nchars) = " "^max(nchars - length("$s"), 0) * "$s" +_pad_right(s, nchars) = "$s" * " "^max(0, nchars - length("$s")) From 4be297ad4a8a629c7418010db512d0b833fa76e8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:00:03 +0100 Subject: [PATCH 07/30] Move methods --- src/core.jl | 4 ---- src/generalized_pareto.jl | 3 +++ 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/core.jl b/src/core.jl index 07fe7ccc..3c82acb1 100644 --- a/src/core.jl +++ b/src/core.jl @@ -170,8 +170,6 @@ function _print_pareto_shape_summary(io::IO, r::PSISResult; kwargs...) return nothing end -_pad_left(s, nchars) = " "^(nchars - length("$s")) * "$s" -_pad_right(s, nchars) = "$s" * " "^(nchars - length("$s")) """ psis(log_ratios, reff = 1.0; kwargs...) -> PSISResult @@ -293,9 +291,7 @@ function psis!(logw::AbstractArray, reff=1; normalize::Bool=true, warn::Bool=tru return result end -pareto_shape(dist::GeneralizedPareto) = dist.k pareto_shape(r::PSISResult) = pareto_shape(getfield(r, :tail_dist)) -pareto_shape(dists) = map(pareto_shape, dists) function check_reff(reff) isvalid = all(reff) do r diff --git a/src/generalized_pareto.jl b/src/generalized_pareto.jl index 5d19ca59..f314a20e 100644 --- a/src/generalized_pareto.jl +++ b/src/generalized_pareto.jl @@ -30,6 +30,9 @@ struct GeneralizedPareto{T} end GeneralizedPareto(μ, σ, k) = GeneralizedPareto(Base.promote(μ, σ, k)...) +pareto_shape(dist::GeneralizedPareto) = dist.k +pareto_shape(dists) = map(pareto_shape, dists) + function quantile(d::GeneralizedPareto{T}, p::Real) where {T<:Real} nlog1pp = -log1p(-p * one(T)) k = d.k From a942492ee131e499897d95f6c47fabe719a0e9e7 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:03:41 +0100 Subject: [PATCH 08/30] Move tails functionality to own file --- src/PSIS.jl | 1 + src/core.jl | 7 ------- src/tails.jl | 45 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 7 deletions(-) create mode 100644 src/tails.jl diff --git a/src/PSIS.jl b/src/PSIS.jl index 3b2c130a..cdbe43be 100644 --- a/src/PSIS.jl +++ b/src/PSIS.jl @@ -14,6 +14,7 @@ export pareto_diagnose include("utils.jl") include("generalized_pareto.jl") +include("tails.jl") include("diagnostics.jl") include("pareto_diagnose.jl") include("core.jl") diff --git a/src/core.jl b/src/core.jl index 3c82acb1..77427cf2 100644 --- a/src/core.jl +++ b/src/core.jl @@ -327,13 +327,6 @@ function check_pareto_shape(dists::AbstractArray{<:GeneralizedPareto}) return nothing end -function tail_length(reff, S) - max_length = cld(S, 5) - (isfinite(reff) && reff > 0) || return max_length - min_length = ceil(Int, 3 * sqrt(S / reff)) - return min(max_length, min_length) -end - function psis_tail!(logw, logμ) T = eltype(logw) logw_max = logw[end] diff --git a/src/tails.jl b/src/tails.jl new file mode 100644 index 00000000..8d1187d6 --- /dev/null +++ b/src/tails.jl @@ -0,0 +1,45 @@ +# utilities for specifying or retrieving tails + +@enum Tails LeftTail RightTail BothTails +const TAIL_OPTIONS = (left=LeftTail, right=RightTail, both=BothTails) + +_standardize_tails(tails::Tails) = tails +function _standardize_tails(tails::Symbol) + if !haskey(TAIL_OPTIONS, tails) + throw( + ArgumentError("invalid tails: $tails. Valid values are $(keys(TAIL_OPTIONS)))") + ) + end + return TAIL_OPTIONS[tails] +end + +function tail_length(reff, S) + (isfinite(reff) && reff > 0 && S > 225) || return cld(S, 5) + return ceil(Int, 3 * sqrt(S / reff)) +end + +function _tail_length(reff, S, tails::Tails) + M = tail_length(reff, S) + if tails === BothTails && M > fld(S, 2) + M = Int(fld(S, 2)) + end + return M +end + +function _tail_and_cutoff(x::AbstractVector, M::Integer, tail::Tails) + S = length(x) + ind_offset = firstindex(x) - 1 + perm = partialsortperm(x, ind_offset .+ ((S - M):S); rev=tail === LeftTail) + cutoff = x[first(perm)] + tail_inds = @view perm[(firstindex(perm) + 1):end] + return @views x[tail_inds], cutoff +end + +function _shift_tail!(x_tail_shifted, x_tail, cutoff, tails::Tails) + if tails === LeftTail + @. x_tail_shifted = cutoff - x_tail + else + @. x_tail_shifted = x_tail - cutoff + end + return x_tail_shifted +end From bf8668cc58af6c64473576019eca6faa2e7da676 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:04:58 +0100 Subject: [PATCH 09/30] Add utilities for computing expectands --- Project.toml | 10 +++++++++ ext/PSISStatsBaseExt.jl | 12 +++++++++++ src/PSIS.jl | 10 +++++++++ src/expectand.jl | 48 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+) create mode 100644 ext/PSISStatsBaseExt.jl create mode 100644 src/expectand.jl diff --git a/Project.toml b/Project.toml index 0a838489..1a66b3ff 100644 --- a/Project.toml +++ b/Project.toml @@ -11,8 +11,15 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +[weakdeps] +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" + +[extensions] +PSISStatsBaseExt = ["StatsBase"] + [compat] DimensionalData = "0.24" Distributions = "0.25.81" @@ -26,7 +33,9 @@ PrettyTables = "2" Printf = "1.6" RecipesBase = "1" ReferenceTests = "0.9, 0.10" +Requires = "1" Statistics = "1.6" +StatsBase = "0.32, 0.33, 0.34" julia = "1.6" [extras] @@ -37,6 +46,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] diff --git a/ext/PSISStatsBaseExt.jl b/ext/PSISStatsBaseExt.jl new file mode 100644 index 00000000..f3a9e215 --- /dev/null +++ b/ext/PSISStatsBaseExt.jl @@ -0,0 +1,12 @@ +module PSISStatsBaseExt + +using PSIS, StatsBase + +PSIS._max_moment_required(::typeof(StatsBase.skewness)) = 3 +PSIS._max_moment_required(::typeof(StatsBase.kurtosis)) = 4 +PSIS._max_moment_required(f::Base.Fix2{typeof(StatsBase.moment),<:Integer}) = f.x +# the pth cumulant is a polynomial of degree p in the moments +PSIS._max_moment_required(f::Base.Fix2{typeof(StatsBase.cumulant),<:Integer}) = f.x +PSIS._max_moment_required(::Base.Fix2{typeof(StatsBase.percentile),<:Real}) = 0 + +end # module diff --git a/src/PSIS.jl b/src/PSIS.jl index cdbe43be..c1449829 100644 --- a/src/PSIS.jl +++ b/src/PSIS.jl @@ -15,10 +15,20 @@ export pareto_diagnose include("utils.jl") include("generalized_pareto.jl") include("tails.jl") +include("expectand.jl") include("diagnostics.jl") include("pareto_diagnose.jl") include("core.jl") include("ess.jl") include("recipes/plots.jl") +@static if !isdefined(Base, :get_extension) + using Requires: @require + function __init__() + @require StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" begin + include("../ext/PSISStatsBaseExt.jl") + end + end +end + end diff --git a/src/expectand.jl b/src/expectand.jl new file mode 100644 index 00000000..fb1c420d --- /dev/null +++ b/src/expectand.jl @@ -0,0 +1,48 @@ +# utilities for computing properties or proxies of an expectand + +_elementwise_transform(f::Base.Fix1{typeof(Statistics.mean)}) = f.x +_elementwise_transform(::Any) = identity + +_max_moment_required(::Base.Fix1{typeof(Statistics.mean)}) = 1 +_max_moment_required(::typeof(Statistics.mean)) = 1 +_max_moment_required(::typeof(Statistics.var)) = 2 +_max_moment_required(::typeof(Statistics.std)) = 2 +_max_moment_required(::Base.Fix2{typeof(Statistics.quantile),<:Real}) = 0 +_max_moment_required(::typeof(Statistics.median)) = 0 + +_requires_moments(f) = _max_moment_required(f) > 0 + +function _check_requires_moments(kind) + _requires_moments(kind) && return nothing + throw( + ArgumentError("kind=$kind requires no moments. Pareto diagnostics are not useful.") + ) +end + +# Compute an expectand `z` such that E[zr] requires the same number of moments as E[xr] +@inline function _expectand_proxy(f, x, r, is_x_log, is_r_log) + fi = _elementwise_transform(f) + p = _max_moment_required(f) + if !is_x_log + if !is_r_log + return fi.(x) .^ p .* r + else + # scale ratios to maximum of 1 to avoid under/overflow + return (fi.(x) .* exp.((r .- maximum(r; dims=_sample_dims(r))) ./ p)) .^ p + end + elseif fi === identity + log_z = if is_r_log + p .* x .+ r + else + p .* x .+ log.(r) + end + # scale to maximum of 1 to avoid overflow + return exp.(log_z .- maximum(log_z; dims=_sample_dims(log_z))) + else + throw( + ArgumentError( + "cannot compute expectand proxy from log with non-identity transform" + ), + ) + end +end From b1fc72505c758e0014ec850558f06e13cb4f0ef9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:05:33 +0100 Subject: [PATCH 10/30] Update pareto_diagnose --- src/pareto_diagnose.jl | 250 ++++++++++++++++++++--------------------- 1 file changed, 120 insertions(+), 130 deletions(-) diff --git a/src/pareto_diagnose.jl b/src/pareto_diagnose.jl index c03d84e1..05cbd40b 100644 --- a/src/pareto_diagnose.jl +++ b/src/pareto_diagnose.jl @@ -1,22 +1,7 @@ -@enum Tails LeftTail RightTail BothTails -const TAIL_OPTIONS = (left=LeftTail, right=RightTail, both=BothTails) - -_validate_tails(tails::Tails) = tails -function _validate_tails(tails::Symbol) - if !haskey(TAIL_OPTIONS, tails) - throw(ArgumentError("invalid tails: $tails. Valid values are :left, :right, :both")) - end - return TAIL_OPTIONS[tails] -end - -_default_tails(log::Bool) = log ? RightTail : BothTails - -_as_scale(log::Bool) = log ? Base.log : identity - """ - pareto_diagnose(x::AbstractArray; warn=false, reff=1, log=false[, tails::Symbol]) + pareto_diagnose(x::AbstractArray; kwargs...) -Compute diagnostics for Pareto-smoothed estimate of the expectand ``x``. +Compute diagnostics for Pareto-smoothed estimate of the expectand `x`. # Arguments @@ -24,18 +9,17 @@ Compute diagnostics for Pareto-smoothed estimate of the expectand ``x``. # Keywords - - `warn=false`: Whether to raise an informative warning if the diagnostics indicate that the - Pareto-smoothed estimate may be unreliable. - - `reff=1`: The relative efficiency of the importance weights. Must be either a scalar or an - array of shape `(params...,)`. - - `log=false`: Whether `x` represents the log of the expectand. If `true`, the diagnostics - are computed on the original scale, taking care to avoid numerical overflow. - - `tails`: Which tail(s) to use for the diagnostics. Valid values are `:left`, `:right` and - `:both`. If `log=true`, only `:right` is valid. Defaults to `:both` if `log=false`. + - `reff=1`: The relative tail efficiency of `x`. Must be either a scalar or an array of + shape `(params...,)`. + - `is_log=false`: Whether `x` represents the log of the expectand. If `true`, the + diagnostics are computed on the original scale, taking care to avoid numerical overflow. + - `tails=:both`: Which tail(s) to diagnose. Valid values are `:left`, `:right`, and + `:both`. If `tails=:both`, diagnostic values correspond to the tail with the worst + properties. # Returns - - `diagnostics::NamedTuple`: A named tuple containing the following fields: + - `diagnostics::ParetoDiagnostics`: A named tuple containing the following fields: + `pareto_shape`: The Pareto shape parameter ``k``. + `min_sample_size`: The minimum sample size needed for a reliable Pareto-smoothed @@ -46,54 +30,52 @@ Compute diagnostics for Pareto-smoothed estimate of the expectand ``x``. Pareto-smoothed estimate. """ function pareto_diagnose( - x::AbstractArray{<:Real}; - warn::Bool=false, + x::AbstractArray; reff=1, - log::Bool=false, - tails::Union{Tails,Symbol}=_default_tails(log), + is_log::Bool=false, + tails::Union{Tails,Symbol}=BothTails, + kind=Statistics.mean, ) - _tails = _validate_tails(tails) - if log && _tails !== RightTail - throw(ArgumentError("log can only be true when tails=:right")) - end - pareto_shape = _pareto_diagnose(x, reff, _tails, _as_scale(log)) - sample_size = prod(map(Base.Fix1(size, x), _sample_dims(x))) + # validate/format inputs + _tails = _standardize_tails(tails) + _check_requires_moments(kind) + + # diagnose the unnormalized expectation + pareto_shape = _compute_pareto_shape(x, reff, _tails, kind, is_log) + + # compute remaining diagnostics + sample_size = _sample_size(x) diagnostics = _compute_diagnostics(pareto_shape, sample_size) - if warn - # TODO: check diagnostics and raise warning - end + return diagnostics end """ pareto_diagnose(x::AbstractArray, ratios::AbstractArray; kwargs...) -Compute diagnostics for Pareto-smoothed importance-weighted estimate of the expectand ``x``. +Compute diagnostics for Pareto-smoothed importance-weighted estimate of the expectand `x`. # Arguments - - `x`: An array of values of shape `(draws[, chains])`. If `log=true`, the values are - assumed to be on the log scale. + - `x`: An array of values of shape `(draws[, chains[, params...]])`. - `ratios`: An array of unnormalized importance ratios of shape - `(draws[, chains[, params...]])`. If `log_ratios=true`, the ratios are assumed to be on - the log scale. + `(draws[, chains[, params...]])`. # Keywords - - `warn=false`: Whether to raise an informative warning if the diagnostics indicate that - the Pareto-smoothed estimate may be unreliable. - `reff=1`: The relative efficiency of the importance weights on the original scale. Must be either a scalar or an array of shape `(params...,)`. - - `log=false`: Whether `x` represents the log of the expectand. - - `log_ratios=true`: Whether `ratios` represents the log of the importance ratios. + - `is_log=false`: Whether `x` represents the log of the expectand. + - `is_ratios_log=true`: Whether `ratios` represents the log of the importance ratios. - `diagnose_ratios=true`: Whether to compute diagnostics for the importance ratios. - - `tails`: Which tail(s) of `x * ratios` to use for the diagnostics. Valid values are - `:left`, `:right`, and `:both`. If `log=true`, only `:right` is valid. Defaults to - `:both` if `log=false`. + This should only be set to `false` if the ratios are by construction normalized, as is + the case if they are are computed from already-normalized densities. + - `tails`: Which tail(s) of `x * ratios` to diagnose. Valid values are `:left`, `:right`, + and `:both`. # Returns - - `diagnostics::NamedTuple`: A named tuple containing the following fields: + - `diagnostics::ParetoDiagnostics`: A named tuple containing the following fields: + `pareto_shape`: The Pareto shape parameter ``k``. + `min_sample_size`: The minimum sample size needed for a reliable Pareto-smoothed @@ -104,108 +86,116 @@ Compute diagnostics for Pareto-smoothed importance-weighted estimate of the expe Pareto-smoothed estimate. """ function pareto_diagnose( - x::AbstractArray{<:Real}, + x::AbstractArray, ratios::AbstractArray{<:Real}; - warn::Bool=false, - log::Bool=false, - log_ratios::Bool=true, + is_log::Bool=false, + is_ratios_log::Bool=true, diagnose_ratios::Bool=true, - tails::Union{Tails,Symbol}=_default_tails(log), + tails::Union{Tails,Symbol}=BothTails, + kind=Statistics.mean, reff=1, ) - _tails = _validate_tails(tails) - expectand = _compute_expectand(x, ratios; log, log_ratios) - pareto_shape_numerator = _pareto_diagnose(expectand, reff, _tails, _as_scale(log)) - if diagnose_ratios - pareto_shape_denominator = _pareto_diagnose( - ratios, reff, RightTail, _as_scale(log_ratios) + + # validate/format inputs + _tails = _standardize_tails(tails) + + # diagnose the unnormalized expectation + pareto_shape_numerator = if _requires_moments(kind) + _compute_pareto_shape(x, ratios, _tails, kind, is_log, is_ratios_log) + elseif diagnose_ratios + nothing + else + throw( + ArgumentError( + "kind=$kind requires no moments. `diagnose_ratios` must be `true`." + ), ) - pareto_shape = max.(pareto_shape_numerator, pareto_shape_denominator) + end + + # diagnose the normalization term + pareto_shape_denominator = if diagnose_ratios + _compute_pareto_shape(ratios, reff, RightTail, Statistics.mean, is_ratios_log) else - pareto_shape = pareto_shape_numerator + nothing end - sample_size = prod(map(Base.Fix1(size, x), _sample_dims(x))) - diagnostics = _compute_diagnostics(pareto_shape, sample_size) - # TODO: check diagnostics and raise warning - return diagnostics -end -function _compute_expectand(x, ratios; log, log_ratios) - log && log_ratios && return x .+ ratios - !log && !log_ratios && return x .* ratios - log && return x .+ Base.log.(ratios) - dims = _param_dims(ratios) - # scale ratios to maximum of 1 to reduce numerical issues - return x .* exp.(ratios .- dropdims(maximum(ratios; dims); dims)) -end + # compute the maximum of the Pareto shapes + pareto_shape = if pareto_shape_numerator === nothing + pareto_shape_denominator + elseif !diagnose_ratios + pareto_shape_numerator + else + max(pareto_shape_numerator, pareto_shape_denominator) + end -function _pareto_diagnose(x::AbstractArray, reff, tails::Tails, scale) - tail_dist = _fit_tail_dist(x, reff, tails, scale) - sample_size = prod(map(Base.Fix1(size, x), _sample_dims(x))) - return _compute_diagnostics(pareto_shape(tail_dist), sample_size) -end + # compute remaining diagnostics + sample_size = _sample_size(x) + diagnostics = _compute_diagnostics(pareto_shape, sample_size) -function _compute_diagnostics(pareto_shape, sample_size) - return ( - pareto_shape, - min_sample_size=min_sample_size(pareto_shape), - pareto_shape_threshold=pareto_shape_threshold(sample_size), - convergence_rate=convergence_rate(pareto_shape, sample_size), - ) + return diagnostics end -@inline function _fit_tail_dist( - x::AbstractArray, - reff::Union{Real,AbstractArray{<:Real}}, - tails::Tails, - scale::Union{typeof(log),typeof(identity)}, +# batch methods +function _compute_pareto_shape(x::AbstractArray, reff, tails::Tails, kind, is_log::Bool) + return _map_params(x, reff) do x_i, reff_i + return _compute_pareto_shape(x_i, reff_i, tails, kind, is_log) + end +end +function _compute_pareto_shape( + x::AbstractArray, r::AbstractArray, tails::Tails, kind, is_x_log::Bool, is_r_log::Bool ) - return map(_eachparamindex(x)) do i - reff_i = reff isa Real ? reff : _selectparam(reff, i) - return _fit_tail_dist(_selectparam(x, i), reff_i, tails, scale) + return _map_params(x, r) do x_i, r_i + return _compute_pareto_shape(x_i, r_i, tails, kind, is_x_log, is_r_log) end end -function _fit_tail_dist( +# single methods +function _compute_pareto_shape( + x::AbstractVecOrMat, reff::Real, tails::Tails, kind, is_log::Bool +) + expectand_proxy = _expectand_proxy(kind, x, !is_log, is_log, is_log) + return _compute_pareto_shape(expectand_proxy, reff, tails) +end +Base.@constprop :aggressive function _compute_pareto_shape( x::AbstractVecOrMat, - reff::Real, + r::AbstractVecOrMat, tails::Tails, - scale::Union{typeof(log),typeof(identity)}, + kind, + is_x_log::Bool, + is_r_log::Bool, ) + expectand_proxy = _expectand_proxy(kind, x, r, is_x_log, is_r_log) + return _compute_pareto_shape(expectand_proxy, true, tails) +end + +# base method +function _compute_pareto_shape(x::AbstractVecOrMat, reff::Real, tails::Tails) S = length(x) - M = tail_length(reff, S) + M = _tail_length(reff, S, tails) + T = float(eltype(x)) + if M < 5 + @warn "Tail must contain at least 5 draws. Generalized Pareto distribution cannot be reliably fit." + return convert(T, NaN) + end x_tail = similar(vec(x), M) + return _compute_pareto_shape!(x_tail, x, tails) +end + +function _compute_pareto_shape!(x_tail::AbstractVector, x::AbstractVecOrMat, tails::Tails) _tails = tails === BothTails ? (LeftTail, RightTail) : (tails,) - tail_dists = map(_tails) do tail - _, cutoff = _get_tail!(x_tail, vec(x), tail) - _shift_tail!(x_tail, cutoff, tail, scale) - return _fit_tail_dist(x_tail) + return maximum(_tails) do tail + tail_dist = _fit_tail_dist!(x_tail, x, tail) + return pareto_shape(tail_dist) end - tail_dist = argmax(pareto_shape, tail_dists) - return tail_dist end -_fit_tail_dist(x_tail::AbstractVector) = fit_gpd(x_tail; prior_adjusted=true, sorted=true) -function _get_tail!(x_tail::AbstractVector, x::AbstractVector, tail::Tails) - S = length(x) +function _fit_tail_dist!(x_tail, x, tail) M = length(x_tail) - ind_offset = firstindex(x) - 1 - perm = partialsortperm(x, ind_offset .+ ((S - M):S); rev=tail === LeftTail) - cutoff = x[first(perm)] - tail_inds = @view perm[(firstindex(perm) + 1):end] - copyto!(x_tail, @views x[tail_inds]) - return x_tail, cutoff -end - -function _shift_tail!( - x_tail, cutoff, tails::Tails, scale::Union{typeof(log),typeof(identity)} -) - if scale === log - x_tail_max = x_tail[end] - @. x_tail = exp(x_tail - x_tail_max) - exp(cutoff - x_tail_max) - elseif tails === LeftTail - @. x_tail = cutoff - x_tail - else - @. x_tail = x_tail - cutoff + x_tail_view, cutoff = _tail_and_cutoff(vec(x), M, tail) + if any(!isfinite, x_tail_view) + @warn "Tail contains non-finite values. Generalized Pareto distribution cannot be reliably fit." + T = float(eltype(x_tail)) + return GeneralizedPareto(zero(T), convert(T, NaN), convert(T, NaN)) end - return x_tail + _shift_tail!(x_tail, x_tail_view, cutoff, tail) + return fit_gpd(x_tail; prior_adjusted=true, sorted=true) end From e1df148cae91d9a1f9d5ca91191bcd3345f2ed5e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:06:05 +0100 Subject: [PATCH 11/30] Add pareto_smooth --- src/PSIS.jl | 1 + src/pareto_smooth.jl | 91 ++++++++++++++++++++++++++++++++++++++++++++ src/utils.jl | 18 +++++++++ 3 files changed, 110 insertions(+) create mode 100644 src/pareto_smooth.jl diff --git a/src/PSIS.jl b/src/PSIS.jl index c1449829..dd0ab8ae 100644 --- a/src/PSIS.jl +++ b/src/PSIS.jl @@ -18,6 +18,7 @@ include("tails.jl") include("expectand.jl") include("diagnostics.jl") include("pareto_diagnose.jl") +include("pareto_smooth.jl") include("core.jl") include("ess.jl") include("recipes/plots.jl") diff --git a/src/pareto_smooth.jl b/src/pareto_smooth.jl new file mode 100644 index 00000000..72085d2e --- /dev/null +++ b/src/pareto_smooth.jl @@ -0,0 +1,91 @@ +""" + pareto_smooth(x::AbstractArray; kwargs...) + +Pareto-smooth the values `x` for computation of the mean. + +# Arguments + + - `x`: An array of values of shape `(draws[, chains[, params...]])`. + +# Keywords + + - `reff=1`: The relative tail efficiency of `x`. Must be either a scalar or an array of + shape `(params...,)`. + - `is_log=false`: Whether `x` represents the log of the expectand. If `true`, the + diagnostics are computed on the original scale, taking care to avoid numerical overflow. + - `tails=:both`: Which tail(s) to smooth. Valid values are `:left`, `:right`, and + `:both`. If `tails=:both`, diagnostic values correspond to the tail with the worst + properties. + +# Returns + + - `x_smoothed`: An array of the same shape as `x` with the specified tails Pareto- + smoothed. + - `diagnostics::ParetoDiagnostics`: Pareto diagnostics for the specified tails. +""" +function pareto_smooth( + x::AbstractArray{<:Real}; + reff=1, + is_log::Bool=false, + tails::Union{Tails,Symbol}=BothTails, + warn::Bool=true, +) + # validate/format inputs + _tails = _standardize_tails(tails) + + # smooth the tails and compute Pareto shape + x_smooth, pareto_shape = _pareto_smooth(x, reff, _tails, is_log) + + # compute remaining diagnostics + diagnostics = _compute_diagnostics(pareto_shape, _sample_size(x)) + + # warn if necessary + warn && check_pareto_diagnostics(diagnostics) + + return x_smooth, diagnostics +end + +function _pareto_smooth(x, reff, tails, is_log) + x_smooth = similar(x, float(eltype(x))) + copyto!(x_smooth, x) + pareto_shape = _pareto_smooth!(x_smooth, reff, tails, is_log) + return x_smooth, pareto_shape +end + +function _pareto_smooth!(x::AbstractArray, reff, tails::Tails, is_log::Bool) + return _map_params(x, reff) do x_i, reff_i + return _pareto_smooth!(x_i, reff_i, tails, is_log) + end +end +function _pareto_smooth!(x::AbstractVecOrMat, reff::Real, tails::Tails, is_log::Bool) + S = length(x) + M = _tail_length(reff, S, tails) + _tails = tails === BothTails ? (LeftTail, RightTail) : (tails,) + return maximum(_tails) do tail + x_tail, cutoff = _tail_and_cutoff(vec(x), M, tail) + dist = _fit_tail_dist_and_smooth!(x_tail, cutoff, tail, is_log) + return pareto_shape(dist) + end +end + +function _fit_tail_dist_and_smooth!(x_tail, cutoff, tail, is_log) + if is_log + x_max = tail === LeftTail ? cutoff : last(x_tail) + x_tail .= exp.(x_tail .- x_max) + cutoff = exp(cutoff - x_max) + end + _shift_tail!(x_tail, x_tail, cutoff, tail) + dist = fit_gpd(x_tail; prior_adjusted=true, sorted=true) + _pareto_smooth_tail!(x_tail, dist) + _shift_tail!(x_tail, x_tail, tail === RightTail ? -cutoff : cutoff, tail) + if is_log + x_tail .= min.(log.(x_tail), 0) .+ x_max + end + return dist +end + +function _pareto_smooth_tail!(x_tail, tail_dist) + p = uniform_probabilities(eltype(x_tail), length(x_tail)) + x_tail .= quantile.(Ref(tail_dist), p) + return x_tail +end diff --git a/src/utils.jl b/src/utils.jl index 3b3cdad2..d4484c1a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,6 +6,12 @@ end as_array(x::AbstractArray) = x as_array(x) = [x] +_as_array2(x) = fill(x) +_as_array2(x::AbstractArray) = x + +_maybe_scalar(x) = x +_maybe_scalar(x::AbstractArray{<:Any,0}) = x[] + missing_to_nan(x::AbstractArray{>:Missing}) = replace(x, missing => NaN) missing_to_nan(::Missing) = NaN missing_to_nan(x) = x @@ -13,6 +19,8 @@ missing_to_nan(x) = x # dimensions corresponding to draws (and maybe chains) _sample_dims(x::AbstractArray) = ntuple(identity, min(2, ndims(x))) +_sample_size(x::AbstractArray) = prod(Base.Fix1(size, x), _sample_dims(x)) + # dimension corresponding to parameters _param_dims(x::AbstractArray) = ntuple(i -> i + 2, max(0, ndims(x) - 2)) @@ -27,6 +35,16 @@ function _selectparam(x::AbstractArray, i::CartesianIndex) sample_dims = ntuple(_ -> Colon(), ndims(x) - length(i)) return view(x, sample_dims..., i) end +_selectparam(x::Real, ::CartesianIndex) = x + +# map function over all parameters. All arguments assumed to have params (or be scalar) +function _map_params(f, x, others...) + return map(_eachparamindex(x)) do i + return f( + _selectparam(x, i), map(_maybe_scalar ∘ Base.Fix2(_selectparam, i), others)... + ) + end +end function _maybe_log_normalize!(x::AbstractArray, normalize::Bool) if normalize From a8330da1c90198bd42e607c7029e5535e032f504 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:06:20 +0100 Subject: [PATCH 12/30] Update psis to wrap pareto_smooth --- src/core.jl | 230 ++++++---------------------------------------------- 1 file changed, 23 insertions(+), 207 deletions(-) diff --git a/src/core.jl b/src/core.jl index 77427cf2..2d992064 100644 --- a/src/core.jl +++ b/src/core.jl @@ -1,14 +1,3 @@ -# range, description, condition -const SHAPE_DIAGNOSTIC_CATEGORIES = ( - ("(-Inf, 0.5]", "good", ≤(0.5)), - ("(0.5, 0.7]", "okay", x -> 0.5 < x ≤ 0.7), - ("(0.7, 1]", "bad", x -> 0.7 < x ≤ 1), - ("(1, Inf)", "very bad", >(1)), - ("——", "failed", isnan), -) -const BAD_SHAPE_SUMMARY = "Resulting importance sampling estimates are likely to be unstable." -const VERY_BAD_SHAPE_SUMMARY = "Corresponding importance sampling estimates are likely to be unstable and are unlikely to converge with additional samples." - """ PSISResult @@ -59,14 +48,15 @@ See [`PSISPlots.paretoshapeplot`](@ref) for a diagnostic plot. Pareto smoothed importance sampling. [arXiv:1507.02646v7](https://arxiv.org/abs/1507.02646v7) [stat.CO] """ -struct PSISResult{T,W<:AbstractArray{T},R,L,D} +struct PSISResult{T,W<:AbstractArray{T},R,D} log_weights::W reff::R - tail_length::L - tail_dist::D normalized::Bool + diagnostics::D end +check_pareto_diagnostics(r::PSISResult) = check_pareto_diagnostics(r.diagnostics) + function Base.propertynames(r::PSISResult) return [fieldnames(typeof(r))..., :weights, :nparams, :ndraws, :nchains, :pareto_shape] end @@ -91,7 +81,7 @@ function Base.getproperty(r::PSISResult, k::Symbol) log_weights = getfield(r, :log_weights) return size(log_weights, 2) end - k === :pareto_shape && return pareto_shape(r) + k === :pareto_shape && return pareto_shape(getfield(r, :diagnostics)) k === :ess && return ess_is(r) return getfield(r, k) end @@ -111,74 +101,31 @@ end function _print_pareto_shape_summary(io::IO, r::PSISResult; kwargs...) k = as_array(pareto_shape(r)) + sample_size = r.ndraws * r.nchains ess = as_array(ess_is(r)) - npoints = r.nparams - rows = map(SHAPE_DIAGNOSTIC_CATEGORIES) do (range, desc, cond) - inds = findall(cond, k) - count = length(inds) - perc = 100 * count / npoints - ess_min = if count == 0 || desc == "failed" - oftype(first(ess), NaN) - else - minimum(view(ess, inds)) - end - return (range=range, desc=desc, count_perc=(count, perc), ess_min=ess_min) - end - rows = filter(r -> r.count_perc[1] > 0, rows) - formats = Dict( - "good" => (), - "okay" => (; color=:yellow), - "bad" => (bold=true, color=:light_red), - "very bad" => (bold=true, color=:red), - "failed" => (; color=:red), - ) - - col_padding = " " - col_delim = "" - col_delim_tot = col_padding * col_delim * col_padding - col_widths = [ - maximum(r -> length(r.range), rows), - maximum(r -> length(r.desc), rows), - maximum(r -> ndigits(r.count_perc[1]), rows), - floor(Int, log10(maximum(r -> r.count_perc[2], rows))) + 6, - ] + diag = _compute_diagnostics(k, sample_size) - println(io, "Pareto shape (k) diagnostic values:") - printstyled( - io, - col_padding, - " "^col_widths[1], - col_delim_tot, - " "^col_widths[2], - col_delim_tot, - _pad_right("Count", col_widths[3] + col_widths[4] + 1), - col_delim_tot, - "Min. ESS"; - bold=true, + category_assignments = NamedTuple{(:good, :bad, :very_bad, :failed)}( + _diagnostic_category_assignments(diag) ) - for r in rows - count, perc = r.count_perc - perc_str = "($(round(perc; digits=1))%)" - println(io) - print(io, col_padding, _pad_left(r.range, col_widths[1]), col_delim_tot) - print(io, _pad_right(r.desc, col_widths[2]), col_delim_tot) - format = formats[r.desc] - printstyled(io, _pad_left(count, col_widths[3]); format...) - printstyled(io, " ", _pad_right(perc_str, col_widths[4]); format...) - print(io, col_delim_tot, isfinite(r.ess_min) ? floor(Int, r.ess_min) : "——") + category_intervals = _diagnostic_intervals(diag) + npoints = length(k) + rows = map(collect(pairs(category_assignments))) do (desc, inds) + interval = desc === :failed ? "--" : _interval_string(category_intervals[desc]) + min_ess = @views isempty(inds) ? NaN : minimum(ess[inds]) + return (; interval, desc, count=length(inds), min_ess) end + _print_pareto_diagnostics_summary(io::IO, rows, npoints; kwargs...) return nothing end +pareto_shape(r::PSISResult) = pareto_shape(r.diagnostics) """ psis(log_ratios, reff = 1.0; kwargs...) -> PSISResult - psis!(log_ratios, reff = 1.0; kwargs...) -> PSISResult Compute Pareto smoothed importance sampling (PSIS) log weights [^VehtariSimpson2021]. -While `psis` computes smoothed log weights out-of-place, `psis!` smooths them in-place. - # Arguments - `log_ratios`: an array of logarithms of importance ratios, with size @@ -206,143 +153,12 @@ details and [`PSISPlots.paretoshapeplot`](@ref) for a diagnostic plot. Pareto smoothed importance sampling. [arXiv:1507.02646v7](https://arxiv.org/abs/1507.02646v7) [stat.CO] """ -psis, psis! - -function psis(logr, reff=1; kwargs...) - T = float(eltype(logr)) - logw = similar(logr, T) - copyto!(logw, logr) - return psis!(logw, reff; kwargs...) -end - -function psis!(logw::AbstractVecOrMat, reff=1; normalize::Bool=true, warn::Bool=true) - T = typeof(float(one(eltype(logw)))) - if length(reff) != 1 - throw(DimensionMismatch("`reff` has length $(length(reff)) but must have length 1")) - end - warn && check_reff(reff) - S = length(logw) - reff_val = first(reff) - M = tail_length(reff_val, S) - if M < 5 - warn && - @warn "$M tail draws is insufficient to fit the generalized Pareto distribution. Total number of draws should in general exceed 25." - _maybe_log_normalize!(logw, normalize) - tail_dist_failed = GeneralizedPareto(0, T(NaN), T(NaN)) - return PSISResult(logw, reff_val, M, tail_dist_failed, normalize) - end - perm = partialsortperm(logw, (S - M):S) - cutoff_ind = perm[1] - tail_inds = @view perm[2:(M + 1)] - logu = logw[cutoff_ind] - logw_tail = @views logw[tail_inds] - if !all(isfinite, logw_tail) - warn && - @warn "Tail contains non-finite values. Generalized Pareto distribution cannot be reliably fit." - _maybe_log_normalize!(logw, normalize) - tail_dist_failed = GeneralizedPareto(0, T(NaN), T(NaN)) - return PSISResult(logw, reff_val, M, tail_dist_failed, normalize) - end - _, tail_dist = psis_tail!(logw_tail, logu) - warn && check_pareto_shape(tail_dist) - _maybe_log_normalize!(logw, normalize) - return PSISResult(logw, reff_val, M, tail_dist, normalize) -end -function psis!(logw::AbstractMatrix, reff=1; kwargs...) - result = psis!(vec(logw), reff; kwargs...) - # unflatten log_weights - return PSISResult( - logw, result.reff, result.tail_length, result.tail_dist, result.normalized - ) -end -function psis!(logw::AbstractArray, reff=1; normalize::Bool=true, warn::Bool=true) - T = typeof(float(one(eltype(logw)))) - # if an array defines custom indices (e.g. AbstractDimArray), we preserve them - param_axes = _param_axes(logw) - param_shape = map(length, param_axes) - if !(length(reff) == 1 || size(reff) == param_shape) - throw( - DimensionMismatch( - "`reff` has shape $(size(reff)) but must have same shape as the parameter axes $(param_shape)", - ), - ) - end - check_reff(reff) - - # allocate containers - reffs = similar(logw, eltype(reff), param_axes) - reffs .= reff - tail_lengths = similar(logw, Int, param_axes) - tail_dists = similar(logw, GeneralizedPareto{T}, param_axes) +psis - # call psis! in parallel for all parameters - Threads.@threads for i in _eachparamindex(logw) - logw_i = _selectparam(logw, i) - result_i = psis!(logw_i, reffs[i]; normalize=normalize, warn=false) - tail_lengths[i] = result_i.tail_length - tail_dists[i] = result_i.tail_dist - end - - # combine results - result = PSISResult(logw, reffs, tail_lengths, map(identity, tail_dists), normalize) - - # warn for bad shape - warn && check_pareto_shape(result) - return result -end - -pareto_shape(r::PSISResult) = pareto_shape(getfield(r, :tail_dist)) - -function check_reff(reff) - isvalid = all(reff) do r - return isfinite(r) && r > 0 - end - isvalid || @warn "All values of `reff` should be finite, but some are not." - return nothing -end - -check_pareto_shape(result::PSISResult) = check_pareto_shape(result.tail_dist) -function check_pareto_shape(dist::GeneralizedPareto) - k = pareto_shape(dist) - if k > 1 - @warn "Pareto shape k = $(@sprintf("%.2g", k)) > 1. $VERY_BAD_SHAPE_SUMMARY" - elseif k > 0.7 - @warn "Pareto shape k = $(@sprintf("%.2g", k)) > 0.7. $BAD_SHAPE_SUMMARY" - end - return nothing -end -function check_pareto_shape(dists::AbstractArray{<:GeneralizedPareto}) - nnan = count(isnan ∘ pareto_shape, dists) - ngt07 = count(>(0.7) ∘ pareto_shape, dists) - ngt1 = iszero(ngt07) ? ngt07 : count(>(1) ∘ pareto_shape, dists) - if ngt07 > ngt1 - @warn "$(ngt07 - ngt1) parameters had Pareto shape values 0.7 < k ≤ 1. $BAD_SHAPE_SUMMARY" - end - if ngt1 > 0 - @warn "$ngt1 parameters had Pareto shape values k > 1. $VERY_BAD_SHAPE_SUMMARY" - end - if nnan > 0 - @warn "For $nnan parameters, the generalized Pareto distribution could not be fit to the tail draws. Total number of draws should in general exceed 25, and the tail draws must be finite." - end - return nothing -end - -function psis_tail!(logw, logμ) - T = eltype(logw) - logw_max = logw[end] - # to improve numerical stability, we first shift the log-weights to have a maximum of 0, - # equivalent to scaling the weights to have a maximum of 1. - μ_scaled = exp(logμ - logw_max) - w_scaled = (logw .= exp.(logw .- logw_max) .- μ_scaled) - tail_dist = fit_gpd(w_scaled; prior_adjusted=true, sorted=true) - # undo the scaling - k = pareto_shape(tail_dist) - if isfinite(k) - p = uniform_probabilities(T, length(logw)) - @inbounds for i in eachindex(logw, p) - # undo scaling in the log-weights - logw[i] = min(log(quantile(tail_dist, p[i]) + μ_scaled), 0) + logw_max - end +function psis(logr::AbstractArray{<:Real}; normalize::Bool=true, reff=1, kwargs...) + logw, diagnostics = pareto_smooth(logr; is_log=true, tails=RightTail, reff, kwargs...) + if normalize + logw .-= LogExpFunctions.logsumexp(logw; dims=_sample_dims(logw)) end - return logw, tail_dist + return PSISResult(logw, reff, normalize, diagnostics) end From 8d5ddc24ba961945178f9d56f619fe7cd03cb055 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:06:28 +0100 Subject: [PATCH 13/30] Update exports --- src/PSIS.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/PSIS.jl b/src/PSIS.jl index dd0ab8ae..c5c7c6eb 100644 --- a/src/PSIS.jl +++ b/src/PSIS.jl @@ -8,9 +8,9 @@ using Printf: @sprintf using Statistics: Statistics export PSISPlots -export PSISResult -export psis, psis!, ess_is -export pareto_diagnose +export ParetoDiagnostics, PSISResult +export pareto_diagnose, pareto_smooth, psis, psis! +export check_pareto_diagnostics, ess_is include("utils.jl") include("generalized_pareto.jl") From a3bb7d24936c16505f85c32a9093fec8f501f921 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:06:56 +0100 Subject: [PATCH 14/30] Refactor ess_is to use S-specific khat threshold --- src/ess.jl | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/ess.jl b/src/ess.jl index 9cc45416..f261e67e 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -17,28 +17,31 @@ Estimate ESS for Pareto-smoothed importance sampling. !!! note - ESS estimates for Pareto shape values ``k > 0.7``, which are unreliable and misleadingly - high, are set to `NaN`. To avoid this, set `bad_shape_nan=false`. + ESS estimates for Pareto shape values ``k > k_\\mathrm{threshold}``, which are + unreliable and misleadingly high, are set to `NaN`. To avoid this, set + `bad_shape_nan=false`. """ ess_is function ess_is(r::PSISResult; bad_shape_nan::Bool=true) - neff = ess_is(r.weights; reff=r.reff) - return _apply_nan(neff, r.tail_dist; bad_shape_nan=bad_shape_nan) + ess = ess_is(r.weights; reff=r.reff) + diagnostics = r.diagnostics + khat = diagnostics.pareto_shape + khat_thresh = diagnostics.pareto_shape_threshold + return _apply_nan(ess, khat; khat_thresh, bad_shape_nan=bad_shape_nan) end function ess_is(weights; reff=1) dims = _sample_dims(weights) return reff ./ dropdims(sum(abs2, weights; dims=dims); dims=dims) end -function _apply_nan(neff, dist; bad_shape_nan) - bad_shape_nan || return neff - k = pareto_shape(dist) - (isnan(k) || k > 0.7) && return oftype(neff, NaN) - return neff +function _apply_nan(ess::Real, khat::Real; khat_thresh::Real, bad_shape_nan) + bad_shape_nan || return ess + (isnan(khat) || khat > khat_thresh) && return oftype(ess, NaN) + return ess end -function _apply_nan(ess::AbstractArray, tail_dist::AbstractArray; kwargs...) - return map(ess, tail_dist) do essᵢ, tail_distᵢ - return _apply_nan(essᵢ, tail_distᵢ; kwargs...) +function _apply_nan(ess::AbstractArray, khat::AbstractArray; kwargs...) + return map(ess, khat) do essᵢ, khatᵢ + return _apply_nan(essᵢ, khatᵢ; kwargs...) end end From 937d0ae8460d26c1e87937cc0c05d4380493a663 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:08:49 +0100 Subject: [PATCH 15/30] Add Compat as dependency --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 1a66b3ff..ba4c4493 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Seth Axen and contributors"] version = "0.9.4" [deps] +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -21,6 +22,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" PSISStatsBaseExt = ["StatsBase"] [compat] +Compat = "3, 4" DimensionalData = "0.24" Distributions = "0.25.81" DocStringExtensions = "0.9" From f5d06609812a6f7da4f2141f28407d30fcea827e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:51:02 +0100 Subject: [PATCH 16/30] Document shape intervals --- src/diagnostics.jl | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/diagnostics.jl b/src/diagnostics.jl index 38a113d7..bd94ba58 100644 --- a/src/diagnostics.jl +++ b/src/diagnostics.jl @@ -1,10 +1,30 @@ """ ParetoDiagnostics -Diagnostic information for Pareto-smoothed importance sampling.[^VehtariSimpson2021] +Diagnostic information for Pareto-smoothed importance sampling. $FIELDS +# Diagnostics + +The `pareto_shape` parameter ``k`` of the generalized Pareto distribution when positive +indicates the inverse of the number of finite moments of the distribution. Its estimate +``\\hat{k}`` from the tail(s) can be used to diagnose reliability and convergence of +estimates using [^VehtariSimpson2021]. + + - if ``\\hat{k} ≤ 0.5``, then PSIS behaves like the importance ratios have finite + variance, the resulting estimate will be accurate, and the converge rate is + ``S^{−1/2}``. + - if ``0.5 < \\hat{k} \\lessim 0.7, then the variance is infinite and plain IS can behave + poorly. PSIS works well in this regime, but the convergence rate is between ``S^{−1/2}`` + and ``S^{−3/10}``. + - if ``\\hat{k} \\gtsim k_\\mathrm{threshold}``, then the Pareto smoothed estimate is not + reliable. It may help to increase the sample size. + - if ``\\hat{k} \\gtsim 0.7``, it quickly becomes too expensive to get an accurate + estimate. Importance sampling is not recommended. + +See [`PSISPlots.paretoshapeplot`](@ref) for a diagnostic plot. + [^VehtariSimpson2021]: Vehtari A, Simpson D, Gelman A, Yao Y, Gabry J. (2021). Pareto smoothed importance sampling. [arXiv:1507.02646v7](https://arxiv.org/abs/1507.02646v7) [stat.CO] From a7f7722ed228d41acff0bf0a340d332a2360dc97 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:51:18 +0100 Subject: [PATCH 17/30] Remove old diagnostic docs --- src/core.jl | 41 +---------------------------------------- 1 file changed, 1 insertion(+), 40 deletions(-) diff --git a/src/core.jl b/src/core.jl index 2d992064..e69d7a2b 100644 --- a/src/core.jl +++ b/src/core.jl @@ -3,46 +3,7 @@ Result of Pareto-smoothed importance sampling (PSIS) using [`psis`](@ref). -# Properties - - - `log_weights`: un-normalized Pareto-smoothed log weights - - `weights`: normalized Pareto-smoothed weights (allocates a copy) - - `pareto_shape`: Pareto ``k=ξ`` shape parameter - - `nparams`: number of parameters in `log_weights` - - `ndraws`: number of draws in `log_weights` - - `nchains`: number of chains in `log_weights` - - `reff`: the ratio of the effective sample size of the unsmoothed importance ratios and - the actual sample size. - - `ess`: estimated effective sample size of estimate of mean using smoothed importance - samples (see [`ess_is`](@ref)) - - `tail_length`: length of the upper tail of `log_weights` that was smoothed - - `tail_dist`: the generalized Pareto distribution that was fit to the tail of - `log_weights`. Note that the tail weights are scaled to have a maximum of 1, so - `tail_dist * exp(maximum(log_ratios))` is the corresponding fit directly to the tail of - `log_ratios`. - - `normalized::Bool`:indicates whether `log_weights` are log-normalized along the sample - dimensions. - -# Diagnostic - -The `pareto_shape` parameter ``k=ξ`` of the generalized Pareto distribution `tail_dist` can -be used to diagnose reliability and convergence of estimates using the importance weights -[^VehtariSimpson2021]. - - - if ``k < \\frac{1}{3}``, importance sampling is stable, and importance sampling (IS) and - PSIS both are reliable. - - if ``k ≤ \\frac{1}{2}``, then the importance ratio distributon has finite variance, and - the central limit theorem holds. As ``k`` approaches the upper bound, IS becomes less - reliable, while PSIS still works well but with a higher RMSE. - - if ``\\frac{1}{2} < k ≤ 0.7``, then the variance is infinite, and IS can behave quite - poorly. However, PSIS works well in this regime. - - if ``0.7 < k ≤ 1``, then it quickly becomes impractical to collect enough importance - weights to reliably compute estimates, and importance sampling is not recommended. - - if ``k > 1``, then neither the variance nor the mean of the raw importance ratios - exists. The convergence rate is close to zero, and bias can be large with practical - sample sizes. - -See [`PSISPlots.paretoshapeplot`](@ref) for a diagnostic plot. +$FIELDS [^VehtariSimpson2021]: Vehtari A, Simpson D, Gelman A, Yao Y, Gabry J. (2021). Pareto smoothed importance sampling. From 9620750183875a878b322d51944f6e3b8c100cce Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:51:43 +0100 Subject: [PATCH 18/30] Remove PSISResult properties --- src/core.jl | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/core.jl b/src/core.jl index e69d7a2b..c84219df 100644 --- a/src/core.jl +++ b/src/core.jl @@ -9,10 +9,15 @@ $FIELDS Pareto smoothed importance sampling. [arXiv:1507.02646v7](https://arxiv.org/abs/1507.02646v7) [stat.CO] """ -struct PSISResult{T,W<:AbstractArray{T},R,D} +struct PSISResult{T,W<:AbstractArray{T},R,D<:ParetoDiagnostics} + "Pareto-smoothed log-weights. Log-normalized if `normalized=true`." log_weights::W + "the relative efficiency, i.e. the ratio of the effective sample size of the unsmoothed + importance ratios and the actual sample size." reff::R + "whether `log_weights` are log-normalized along the sample dimensions." normalized::Bool + "diagnostics for the Pareto-smoothing." diagnostics::D end @@ -48,21 +53,19 @@ function Base.getproperty(r::PSISResult, k::Symbol) end function Base.show(io::IO, ::MIME"text/plain", r::PSISResult) - npoints = r.nparams - nchains = r.nchains + log_weights = r.log_weights + ndraws = size(log_weights, 1) + nchains = size(log_weights, 2) + npoints = prod(_param_sizes(log_weights)) println( - io, "PSISResult with $(r.ndraws) draws, $nchains chains, and $npoints parameters" + io, "PSISResult with $ndraws draws, $nchains chains, and $npoints parameters" ) return _print_pareto_shape_summary(io, r; newline_at_end=false) end -function pareto_shape_summary(r::PSISResult; kwargs...) - return _print_pareto_shape_summary(stdout, r; kwargs...) -end - function _print_pareto_shape_summary(io::IO, r::PSISResult; kwargs...) k = as_array(pareto_shape(r)) - sample_size = r.ndraws * r.nchains + sample_size = _sample_size(r.log_weights) ess = as_array(ess_is(r)) diag = _compute_diagnostics(k, sample_size) From 239c35ca3b7635cee590cfd20e31018a0e1c6f65 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:51:52 +0100 Subject: [PATCH 19/30] Update ess_is --- src/ess.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/ess.jl b/src/ess.jl index f261e67e..973b35e8 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -24,7 +24,13 @@ Estimate ESS for Pareto-smoothed importance sampling. ess_is function ess_is(r::PSISResult; bad_shape_nan::Bool=true) - ess = ess_is(r.weights; reff=r.reff) + log_weights = r.log_weights + if r.normalized + weights = exp.(log_weights) + else + weights = LogExpFunctions.softmax(log_weights; dims=_sample_dims(log_weights)) + end + ess = ess_is(weights; reff=r.reff) diagnostics = r.diagnostics khat = diagnostics.pareto_shape khat_thresh = diagnostics.pareto_shape_threshold From 8cf295432d15785be8fa6db8ac6c381fdc3e7ffe Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:52:04 +0100 Subject: [PATCH 20/30] Add param_sizes utility --- src/utils.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index d4484c1a..2bfd84cf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -24,6 +24,8 @@ _sample_size(x::AbstractArray) = prod(Base.Fix1(size, x), _sample_dims(x)) # dimension corresponding to parameters _param_dims(x::AbstractArray) = ntuple(i -> i + 2, max(0, ndims(x) - 2)) +_param_sizes(x::AbstractArray) = mal(Base.Fix1(size, x), _param_dims(x)) + # axes corresponding to parameters _param_axes(x::AbstractArray) = map(Base.Fix1(axes, x), _param_dims(x)) From cd2dbd9af32c88a3836afe22e0d8f836a36fbe04 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:52:11 +0100 Subject: [PATCH 21/30] Update paretoshapeplot --- src/recipes/plots.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/recipes/plots.jl b/src/recipes/plots.jl index 04ca79dd..3b76241b 100644 --- a/src/recipes/plots.jl +++ b/src/recipes/plots.jl @@ -14,7 +14,10 @@ using RecipesBase: RecipesBase Plot shape parameters of fitted Pareto tail distributions for diagnosing convergence. -`values` may be either a vector of Pareto shape parameters or a [`PSIS.PSISResult`](@ref). +`values` may be: +- a vector of Pareto shape parameters +- a [`PSIS.PSISResult`](@ref) +- a [`PSIS.ParetoDiagnostics`](@ref) If `showlines==true`, horizontal lines indicating relevant Pareto shape thresholds are drawn. See [`PSIS.PSISResult`](@ref) for an explanation of the thresholds. @@ -64,17 +67,21 @@ RecipesBase.@recipe function f(plt::ParetoShapePlot; showlines=false) yguide --> "Pareto shape" seriestype --> :scatter arg = first(plt.args) - k = arg isa PSIS.PSISResult ? PSIS.pareto_shape(arg) : arg - return (PSIS.as_array(PSIS.missing_to_nan(k)),) + k = _pareto_shape(arg) + return (vec(PSIS.as_array(PSIS.missing_to_nan(k))),) end +_pareto_shape(r::PSIS.PSISResult) = PSIS.pareto_shape(r.diagnostics) +_pareto_shape(d::PSIS.ParetoDiagnostics) = PSIS.pareto_shape(d) +_pareto_shape(k) = k + # plot PSISResult using paretoshapeplot if seriestype not specified -RecipesBase.@recipe function f(result::PSISResult) +RecipesBase.@recipe function f(r::Union{PSIS.PSISResult,PSIS.ParetoDiagnostics}) if haskey(plotattributes, :seriestype) - k = PSIS.as_array(PSIS.missing_to_nan(PSIS.pareto_shape(result))) + k = PSIS.as_array(PSIS.missing_to_nan(_pareto_shape(r))) return (k,) else - return ParetoShapePlot((result,)) + return ParetoShapePlot((r,)) end end From 77a59833ba2d2051ce90df8ad50d2dff29c5d852 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 21:53:06 +0100 Subject: [PATCH 22/30] Remove properties accessors --- src/core.jl | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/src/core.jl b/src/core.jl index c84219df..f3809a77 100644 --- a/src/core.jl +++ b/src/core.jl @@ -23,35 +23,6 @@ end check_pareto_diagnostics(r::PSISResult) = check_pareto_diagnostics(r.diagnostics) -function Base.propertynames(r::PSISResult) - return [fieldnames(typeof(r))..., :weights, :nparams, :ndraws, :nchains, :pareto_shape] -end - -function Base.getproperty(r::PSISResult, k::Symbol) - if k === :weights - log_weights = getfield(r, :log_weights) - getfield(r, :normalized) && return exp.(log_weights) - return LogExpFunctions.softmax(log_weights; dims=_sample_dims(log_weights)) - elseif k === :nparams - log_weights = getfield(r, :log_weights) - return if ndims(log_weights) == 1 - 1 - else - param_dims = _param_dims(log_weights) - prod(Base.Fix1(size, log_weights), param_dims; init=1) - end - elseif k === :ndraws - log_weights = getfield(r, :log_weights) - return size(log_weights, 1) - elseif k === :nchains - log_weights = getfield(r, :log_weights) - return size(log_weights, 2) - end - k === :pareto_shape && return pareto_shape(getfield(r, :diagnostics)) - k === :ess && return ess_is(r) - return getfield(r, k) -end - function Base.show(io::IO, ::MIME"text/plain", r::PSISResult) log_weights = r.log_weights ndraws = size(log_weights, 1) From 7df6efec570f85e8ef0d1126f5f9efaef9497374 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 22:28:12 +0100 Subject: [PATCH 23/30] Fix bug --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 2bfd84cf..c9d9b8c4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -24,7 +24,7 @@ _sample_size(x::AbstractArray) = prod(Base.Fix1(size, x), _sample_dims(x)) # dimension corresponding to parameters _param_dims(x::AbstractArray) = ntuple(i -> i + 2, max(0, ndims(x) - 2)) -_param_sizes(x::AbstractArray) = mal(Base.Fix1(size, x), _param_dims(x)) +_param_sizes(x::AbstractArray) = map(Base.Fix1(size, x), _param_dims(x)) # axes corresponding to parameters _param_axes(x::AbstractArray) = map(Base.Fix1(axes, x), _param_dims(x)) From 93abf46d1c1fbf5af94abdc431b5ff7f4fc4d79b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 22:28:51 +0100 Subject: [PATCH 24/30] Run formatter --- src/core.jl | 9 +++------ src/recipes/plots.jl | 7 ++++--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/core.jl b/src/core.jl index f3809a77..472d17a4 100644 --- a/src/core.jl +++ b/src/core.jl @@ -28,9 +28,7 @@ function Base.show(io::IO, ::MIME"text/plain", r::PSISResult) ndraws = size(log_weights, 1) nchains = size(log_weights, 2) npoints = prod(_param_sizes(log_weights)) - println( - io, "PSISResult with $ndraws draws, $nchains chains, and $npoints parameters" - ) + println(io, "PSISResult with $ndraws draws, $nchains chains, and $npoints parameters") return _print_pareto_shape_summary(io, r; newline_at_end=false) end @@ -61,6 +59,8 @@ pareto_shape(r::PSISResult) = pareto_shape(r.diagnostics) Compute Pareto smoothed importance sampling (PSIS) log weights [^VehtariSimpson2021]. +Internally the function calls [`pareto_smooth`](@ref). + # Arguments - `log_ratios`: an array of logarithms of importance ratios, with size @@ -81,9 +81,6 @@ Compute Pareto smoothed importance sampling (PSIS) log weights [^VehtariSimpson2 - `result`: a [`PSISResult`](@ref) object containing the results of the Pareto-smoothing. -A warning is raised if the Pareto shape parameter ``k ≥ 0.7``. See [`PSISResult`](@ref) for -details and [`PSISPlots.paretoshapeplot`](@ref) for a diagnostic plot. - [^VehtariSimpson2021]: Vehtari A, Simpson D, Gelman A, Yao Y, Gabry J. (2021). Pareto smoothed importance sampling. [arXiv:1507.02646v7](https://arxiv.org/abs/1507.02646v7) [stat.CO] diff --git a/src/recipes/plots.jl b/src/recipes/plots.jl index 3b76241b..2f6cd727 100644 --- a/src/recipes/plots.jl +++ b/src/recipes/plots.jl @@ -15,9 +15,10 @@ using RecipesBase: RecipesBase Plot shape parameters of fitted Pareto tail distributions for diagnosing convergence. `values` may be: -- a vector of Pareto shape parameters -- a [`PSIS.PSISResult`](@ref) -- a [`PSIS.ParetoDiagnostics`](@ref) + + - a vector of Pareto shape parameters + - a [`PSIS.PSISResult`](@ref) + - a [`PSIS.ParetoDiagnostics`](@ref) If `showlines==true`, horizontal lines indicating relevant Pareto shape thresholds are drawn. See [`PSIS.PSISResult`](@ref) for an explanation of the thresholds. From e603f57843c403067883a3f255d172d00ca96d03 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 22:29:13 +0100 Subject: [PATCH 25/30] Preserve older table formatting --- src/diagnostics.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diagnostics.jl b/src/diagnostics.jl index bd94ba58..8ecde9d5 100644 --- a/src/diagnostics.jl +++ b/src/diagnostics.jl @@ -168,9 +168,11 @@ function _print_pareto_diagnostics_summary(io::IO, _rows, npoints; kwargs...) rows = filter(r -> r.count > 0, _rows) header = ["", "", "Count"] alignment = [:r, :l, :l] + alignment_anchor_regex = Dict(3 => [r"\s"]) if length(first(rows)) > 3 push!(header, "Min. ESS") - push!(alignment, :r) + push!(alignment, :l) + alignment_anchor_regex[4] = [r"\d"] end formatters = ( (v, i, j) -> j == 2 ? replace(string(v), '_' => " ") : v, @@ -196,11 +198,12 @@ function _print_pareto_diagnostics_summary(io::IO, _rows, npoints; kwargs...) rows; header, alignment, - alignment_anchor_regex=Dict(3 => [r"\s"]), + alignment_anchor_regex, hlines=:none, vlines=:none, formatters, highlighters, + title="Pareto shape (k) diagnostic values:", kwargs..., ) return nothing From 9a5fef7b4653befaaa08a3a9a8e0f277ca6e3bce Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 22:29:21 +0100 Subject: [PATCH 26/30] Update docstring --- src/core.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/core.jl b/src/core.jl index 472d17a4..6f205638 100644 --- a/src/core.jl +++ b/src/core.jl @@ -5,9 +5,7 @@ Result of Pareto-smoothed importance sampling (PSIS) using [`psis`](@ref). $FIELDS -[^VehtariSimpson2021]: Vehtari A, Simpson D, Gelman A, Yao Y, Gabry J. (2021). - Pareto smoothed importance sampling. - [arXiv:1507.02646v7](https://arxiv.org/abs/1507.02646v7) [stat.CO] +See [`ParetoDiagnose`](@ref) for a description of the diagnostics. """ struct PSISResult{T,W<:AbstractArray{T},R,D<:ParetoDiagnostics} "Pareto-smoothed log-weights. Log-normalized if `normalized=true`." From 9a54cde2a6dd9dd6e08538dffb69cd60ba94281d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 22:29:30 +0100 Subject: [PATCH 27/30] Overload pareto shape accessor --- src/diagnostics.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diagnostics.jl b/src/diagnostics.jl index 8ecde9d5..dbd41ab8 100644 --- a/src/diagnostics.jl +++ b/src/diagnostics.jl @@ -42,6 +42,8 @@ struct ParetoDiagnostics{TK,TKM,TS,TR} convergence_rate::TR end +pareto_shape(diagnostics::ParetoDiagnostics) = diagnostics.pareto_shape + pareto_shape_threshold(sample_size::Real) = 1 - inv(log10(sample_size)) function min_sample_size(pareto_shape::Real) From 54efca0955993058eaf0f44bae45ea2ecb5d340f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 22:41:29 +0100 Subject: [PATCH 28/30] Make ESS column elements flush on the right --- src/diagnostics.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diagnostics.jl b/src/diagnostics.jl index dbd41ab8..ddf31cdd 100644 --- a/src/diagnostics.jl +++ b/src/diagnostics.jl @@ -174,7 +174,7 @@ function _print_pareto_diagnostics_summary(io::IO, _rows, npoints; kwargs...) if length(first(rows)) > 3 push!(header, "Min. ESS") push!(alignment, :l) - alignment_anchor_regex[4] = [r"\d"] + alignment_anchor_regex[4] = [r"[\d—]$"] end formatters = ( (v, i, j) -> j == 2 ? replace(string(v), '_' => " ") : v, From 4c739df7d8bf33d2d7d51fc83c08113eb3a9470a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 9 Jan 2024 22:41:39 +0100 Subject: [PATCH 29/30] Update PSISResult tests --- test/core.jl | 518 ++++++++++++++++++++++++--------------------------- 1 file changed, 245 insertions(+), 273 deletions(-) diff --git a/test/core.jl b/test/core.jl index dd310f1d..6f2714d5 100644 --- a/test/core.jl +++ b/test/core.jl @@ -14,40 +14,22 @@ using DimensionalData: Dimensions, DimArray tail_length = 100 reff = 2.0 tail_dist = PSIS.GeneralizedPareto(1.0, 1.0, 0.5) - result = PSISResult(log_weights, reff, tail_length, tail_dist, false) + diag = PSIS.ParetoDiagnostics(0.5, 0.7, 100, 1) + result = PSISResult(log_weights, reff, false, diag) @test result isa PSISResult{Float64} - @test issetequal( - propertynames(result), - [ - :log_weights, - :nchains, - :ndraws, - :normalized, - :nparams, - :pareto_shape, - :reff, - :tail_dist, - :tail_length, - :weights, - ], - ) @test result.log_weights == log_weights - @test result.weights ≈ softmax(log_weights) @test result.reff == reff - @test result.nparams == 1 - @test result.ndraws == 500 - @test result.nchains == 1 - @test result.tail_length == tail_length - @test result.tail_dist == tail_dist - @test result.pareto_shape == 0.5 - @test result.ess ≈ ess_is(result) + @test !result.normalized + @test result.diagnostics == diag + + ess = ess_is(result) @testset "show" begin @test sprint(show, "text/plain", result) == """ PSISResult with 500 draws, 1 chains, and 1 parameters Pareto shape (k) diagnostic values: Count Min. ESS - (-Inf, 0.5] good 1 (100.0%) $(floor(Int, result.ess))""" + (-Inf, 0.6] good 1 (100.0%) $(floor(Int, ess))""" end end @@ -55,24 +37,15 @@ using DimensionalData: Dimensions, DimArray log_weights = randn(500, 4, 3) log_weights_norm = logsumexp(log_weights; dims=(1, 2)) log_weights .-= log_weights_norm - tail_length = [1600, 1601, 1602] + # tail_length = [1600, 1601, 1602] reff = [0.8, 0.9, 1.1] - tail_dist = [ - PSIS.GeneralizedPareto(1.0, 1.0, 0.5), - PSIS.GeneralizedPareto(1.0, 1.0, 0.6), - PSIS.GeneralizedPareto(1.0, 1.0, 0.7), - ] - result = PSISResult(log_weights, reff, tail_length, tail_dist, true) + pareto_shape = [0.5, 0.6, 0.7] + diag = ParetoDiagnostics(pareto_shape, 0.7, nothing, nothing) + result = PSISResult(log_weights, reff, true, diag) @test result isa PSISResult{Float64} @test result.log_weights == log_weights - @test result.weights ≈ softmax(log_weights; dims=(1, 2)) @test result.reff == reff - @test result.nparams == 3 - @test result.ndraws == 500 - @test result.nchains == 4 - @test result.tail_length == tail_length - @test result.tail_dist == tail_dist - @test result.pareto_shape == [0.5, 0.6, 0.7] + @test result.diagnostics == diag @testset "show" begin proposal = Normal() @@ -80,268 +53,267 @@ using DimensionalData: Dimensions, DimArray rng = MersenneTwister(42) x = rand(rng, proposal, 100, 1, 30) log_ratios = logpdf.(target, x) .- logpdf.(proposal, x) - reff = [100; ones(29)] - result = psis(log_ratios, reff) + log_ratios[1, 1, 1] = NaN + result = psis(log_ratios) @test sprint(show, "text/plain", result) == """ PSISResult with 100 draws, 1 chains, and 30 parameters Pareto shape (k) diagnostic values: Count Min. ESS (-Inf, 0.5] good 2 (6.7%) 98 - (0.5, 0.7] okay 6 (20.0%) 92 - (0.7, 1] bad 4 (13.3%) —— + (0.5, 1] bad 10 (33.3%) —— (1, Inf) very bad 17 (56.7%) —— - —— failed 1 (3.3%) ——""" + -- failed 1 (3.3%) ——""" end end end -@testset "psis/psis!" begin - @testset "importance sampling tests" begin - target = Exponential(1) - x_target = 1 # 𝔼[x] with x ~ Exponential(1) - x²_target = 2 # 𝔼[x²] with x ~ Exponential(1) - # For θ < 1, the closed-form distribution of importance ratios with k = 1 - θ is - # GeneralizedPareto(θ, θ * k, k), and the closed-form distribution of tail ratios is - # GeneralizedPareto(5^k * θ, θ * k, k). - # For θ < 0.5, the tail distribution has no variance, and estimates with importance - # weights become unstable - @testset "Exponential($θ) → Exponential(1)" for (θ, atol) in [ - (0.8, 0.05), (0.55, 0.2), (0.3, 0.7) - ] - proposal = Exponential(θ) - k_exp = 1 - θ - for sz in ((100_000,), (100_000, 4), (100_000, 4, 5)) - dims = length(sz) < 3 ? Colon() : 1:(length(sz) - 1) - rng = MersenneTwister(42) - x = rand(rng, proposal, sz) - logr = logpdf.(target, x) .- logpdf.(proposal, x) +# @testset "psis/psis!" begin +# @testset "importance sampling tests" begin +# target = Exponential(1) +# x_target = 1 # 𝔼[x] with x ~ Exponential(1) +# x²_target = 2 # 𝔼[x²] with x ~ Exponential(1) +# # For θ < 1, the closed-form distribution of importance ratios with k = 1 - θ is +# # GeneralizedPareto(θ, θ * k, k), and the closed-form distribution of tail ratios is +# # GeneralizedPareto(5^k * θ, θ * k, k). +# # For θ < 0.5, the tail distribution has no variance, and estimates with importance +# # weights become unstable +# @testset "Exponential($θ) → Exponential(1)" for (θ, atol) in [ +# (0.8, 0.05), (0.55, 0.2), (0.3, 0.7) +# ] +# proposal = Exponential(θ) +# k_exp = 1 - θ +# for sz in ((100_000,), (100_000, 4), (100_000, 4, 5)) +# dims = length(sz) < 3 ? Colon() : 1:(length(sz) - 1) +# rng = MersenneTwister(42) +# x = rand(rng, proposal, sz) +# logr = logpdf.(target, x) .- logpdf.(proposal, x) - r = @inferred psis(logr) - @test r isa PSISResult - logw = r.log_weights - @test logw isa typeof(logr) - @test exp.(logw) == r.weights +# r = @inferred psis(logr) +# @test r isa PSISResult +# logw = r.log_weights +# @test logw isa typeof(logr) +# @test exp.(logw) == r.weights - r2 = psis(logr; normalize=false) - @test !(r2.log_weights ≈ r.log_weights) - @test r2.weights ≈ r.weights +# r2 = psis(logr; normalize=false) +# @test !(r2.log_weights ≈ r.log_weights) +# @test r2.weights ≈ r.weights - if length(sz) > 1 - @test all(r.tail_length .== PSIS.tail_length(1, 400_000)) - else - @test all(r.tail_length .== PSIS.tail_length(1, 100_000)) - end +# if length(sz) > 1 +# @test all(r.tail_length .== PSIS.tail_length(1, 400_000)) +# else +# @test all(r.tail_length .== PSIS.tail_length(1, 100_000)) +# end - k = r.pareto_shape - @test k isa (length(sz) < 3 ? Number : AbstractVector) - tail_dist = r.tail_dist - if length(sz) < 3 - @test tail_dist isa PSIS.GeneralizedPareto - @test tail_dist.k == k - else - @test tail_dist isa Vector{<:PSIS.GeneralizedPareto} - @test map(d -> d.k, tail_dist) == k - end +# k = r.pareto_shape +# @test k isa (length(sz) < 3 ? Number : AbstractVector) +# tail_dist = r.tail_dist +# if length(sz) < 3 +# @test tail_dist isa PSIS.GeneralizedPareto +# @test tail_dist.k == k +# else +# @test tail_dist isa Vector{<:PSIS.GeneralizedPareto} +# @test map(d -> d.k, tail_dist) == k +# end - w = r.weights - @test all(x -> isapprox(x, k_exp; atol=0.15), k) - @test all(x -> isapprox(x, x_target; atol=atol), sum(x .* w; dims=dims)) - @test all( - x -> isapprox(x, x²_target; atol=atol), sum(x .^ 2 .* w; dims=dims) - ) - end - end - end +# w = r.weights +# @test all(x -> isapprox(x, k_exp; atol=0.15), k) +# @test all(x -> isapprox(x, x_target; atol=atol), sum(x .* w; dims=dims)) +# @test all( +# x -> isapprox(x, x²_target; atol=atol), sum(x .^ 2 .* w; dims=dims) +# ) +# end +# end +# end - @testset "reff combinations" begin - reffs_uniform = [rand(), fill(rand()), [rand()]] - x = randn(1000) - for r in reffs_uniform - psis(x, r) - end - @test_throws DimensionMismatch psis(x, rand(2)) +# @testset "reff combinations" begin +# reffs_uniform = [rand(), fill(rand()), [rand()]] +# x = randn(1000) +# for r in reffs_uniform +# psis(x, r) +# end +# @test_throws DimensionMismatch psis(x, rand(2)) - x = randn(1000, 4) - for r in reffs_uniform - psis(x, r) - end - @test_throws DimensionMismatch psis(x, rand(2)) +# x = randn(1000, 4) +# for r in reffs_uniform +# psis(x, r) +# end +# @test_throws DimensionMismatch psis(x, rand(2)) - x = randn(1000, 4, 2) - for r in reffs_uniform - psis(x, r) - end - psis(x, rand(2)) - @test_throws DimensionMismatch psis(x, rand(3)) +# x = randn(1000, 4, 2) +# for r in reffs_uniform +# psis(x, r) +# end +# psis(x, rand(2)) +# @test_throws DimensionMismatch psis(x, rand(3)) - x = randn(1000, 4, 2, 3) - for r in reffs_uniform - psis(x, r) - end - psis(x, rand(2, 3)) - @test_throws DimensionMismatch psis(x, rand(3)) - end +# x = randn(1000, 4, 2, 3) +# for r in reffs_uniform +# psis(x, r) +# end +# psis(x, rand(2, 3)) +# @test_throws DimensionMismatch psis(x, rand(3)) +# end - @testset "warnings" begin - io = IOBuffer() - @testset for sz in (100, (100, 4, 3)), rbad in (-1, 0, NaN) - logr = randn(sz) - result = with_logger(SimpleLogger(io)) do - psis(logr, rbad) - end - msg = String(take!(io)) - @test occursin("All values of `reff` should be finite, but some are not.", msg) - end +# @testset "warnings" begin +# io = IOBuffer() +# @testset for sz in (100, (100, 4, 3)), rbad in (-1, 0, NaN) +# logr = randn(sz) +# result = with_logger(SimpleLogger(io)) do +# psis(logr, rbad) +# end +# msg = String(take!(io)) +# @test occursin("All values of `reff` should be finite, but some are not.", msg) +# end - io = IOBuffer() - logr = randn(5) - result = with_logger(SimpleLogger(io)) do - psis(logr; normalize=false) - end - @test result.log_weights == logr - @test isnan(result.tail_dist.σ) - @test isnan(result.pareto_shape) - msg = String(take!(io)) - @test occursin( - "Warning: 1 tail draws is insufficient to fit the generalized Pareto distribution.", - msg, - ) +# io = IOBuffer() +# logr = randn(5) +# result = with_logger(SimpleLogger(io)) do +# psis(logr; normalize=false) +# end +# @test result.log_weights == logr +# @test isnan(result.tail_dist.σ) +# @test isnan(result.pareto_shape) +# msg = String(take!(io)) +# @test occursin( +# "Warning: 1 tail draws is insufficient to fit the generalized Pareto distribution.", +# msg, +# ) - skipnan(x) = filter(!isnan, x) - io = IOBuffer() - for logr in [ - [NaN; randn(100)], - [Inf; randn(100)], - fill(-Inf, 100), - vcat(ones(50), fill(-Inf, 435)), - ] - result = with_logger(SimpleLogger(io)) do - psis(logr; normalize=false) - end - @test skipnan(result.log_weights) == skipnan(logr) - @test isnan(result.tail_dist.σ) - @test isnan(result.pareto_shape) - msg = String(take!(io)) - @test occursin("Warning: Tail contains non-finite values.", msg) - end +# skipnan(x) = filter(!isnan, x) +# io = IOBuffer() +# for logr in [ +# [NaN; randn(100)], +# [Inf; randn(100)], +# fill(-Inf, 100), +# vcat(ones(50), fill(-Inf, 435)), +# ] +# result = with_logger(SimpleLogger(io)) do +# psis(logr; normalize=false) +# end +# @test skipnan(result.log_weights) == skipnan(logr) +# @test isnan(result.tail_dist.σ) +# @test isnan(result.pareto_shape) +# msg = String(take!(io)) +# @test occursin("Warning: Tail contains non-finite values.", msg) +# end - io = IOBuffer() - rng = MersenneTwister(42) - x = rand(rng, Exponential(50), 1_000) - logr = logpdf.(Exponential(1), x) .- logpdf.(Exponential(50), x) - result = with_logger(SimpleLogger(io)) do - psis(logr; normalize=false) - end - @test result.log_weights != logr - @test result.pareto_shape > 0.7 - msg = String(take!(io)) - @test occursin( - "Warning: Pareto shape k = 0.73 > 0.7. $(PSIS.BAD_SHAPE_SUMMARY)", msg - ) +# io = IOBuffer() +# rng = MersenneTwister(42) +# x = rand(rng, Exponential(50), 1_000) +# logr = logpdf.(Exponential(1), x) .- logpdf.(Exponential(50), x) +# result = with_logger(SimpleLogger(io)) do +# psis(logr; normalize=false) +# end +# @test result.log_weights != logr +# @test result.pareto_shape > 0.7 +# msg = String(take!(io)) +# @test occursin( +# "Warning: Pareto shape k = 0.73 > 0.7. $(PSIS.BAD_SHAPE_SUMMARY)", msg +# ) - io = IOBuffer() - with_logger(SimpleLogger(io)) do - PSIS.check_pareto_shape(PSIS.GeneralizedPareto(0.0, 1.0, 1.1)) - end - msg = String(take!(io)) - @test occursin( - "Warning: Pareto shape k = 1.1 > 1. $(PSIS.VERY_BAD_SHAPE_SUMMARY)", msg - ) +# io = IOBuffer() +# with_logger(SimpleLogger(io)) do +# PSIS.check_pareto_shape(PSIS.GeneralizedPareto(0.0, 1.0, 1.1)) +# end +# msg = String(take!(io)) +# @test occursin( +# "Warning: Pareto shape k = 1.1 > 1. $(PSIS.VERY_BAD_SHAPE_SUMMARY)", msg +# ) - io = IOBuffer() - with_logger(SimpleLogger(io)) do - PSIS.check_pareto_shape(PSIS.GeneralizedPareto(0.0, 1.0, 0.8)) - end - msg = String(take!(io)) - @test occursin( - "Warning: Pareto shape k = 0.8 > 0.7. $(PSIS.BAD_SHAPE_SUMMARY)", msg - ) +# io = IOBuffer() +# with_logger(SimpleLogger(io)) do +# PSIS.check_pareto_shape(PSIS.GeneralizedPareto(0.0, 1.0, 0.8)) +# end +# msg = String(take!(io)) +# @test occursin( +# "Warning: Pareto shape k = 0.8 > 0.7. $(PSIS.BAD_SHAPE_SUMMARY)", msg +# ) - io = IOBuffer() - with_logger(SimpleLogger(io)) do - PSIS.check_pareto_shape(PSIS.GeneralizedPareto(0.0, 1.0, 0.69)) - end - msg = String(take!(io)) - @test isempty(msg) +# io = IOBuffer() +# with_logger(SimpleLogger(io)) do +# PSIS.check_pareto_shape(PSIS.GeneralizedPareto(0.0, 1.0, 0.69)) +# end +# msg = String(take!(io)) +# @test isempty(msg) - tail_dist = [ - PSIS.GeneralizedPareto(0, NaN, NaN), - PSIS.GeneralizedPareto(0, 1, 0.69), - PSIS.GeneralizedPareto(0, 1, 0.71), - PSIS.GeneralizedPareto(0, 1, 1.1), - ] - io = IOBuffer() - with_logger(SimpleLogger(io)) do - PSIS.check_pareto_shape(tail_dist) - end - msg = String(take!(io)) - @test occursin( - "Warning: 1 parameters had Pareto shape values 0.7 < k ≤ 1. $(PSIS.BAD_SHAPE_SUMMARY)", - msg, - ) - @test occursin( - "Warning: 1 parameters had Pareto shape values k > 1. $(PSIS.VERY_BAD_SHAPE_SUMMARY)", - msg, - ) - @test occursin( - "Warning: For 1 parameters, the generalized Pareto distribution could not be fit to the tail draws.", - msg, - ) - end +# tail_dist = [ +# PSIS.GeneralizedPareto(0, NaN, NaN), +# PSIS.GeneralizedPareto(0, 1, 0.69), +# PSIS.GeneralizedPareto(0, 1, 0.71), +# PSIS.GeneralizedPareto(0, 1, 1.1), +# ] +# io = IOBuffer() +# with_logger(SimpleLogger(io)) do +# PSIS.check_pareto_shape(tail_dist) +# end +# msg = String(take!(io)) +# @test occursin( +# "Warning: 1 parameters had Pareto shape values 0.7 < k ≤ 1. $(PSIS.BAD_SHAPE_SUMMARY)", +# msg, +# ) +# @test occursin( +# "Warning: 1 parameters had Pareto shape values k > 1. $(PSIS.VERY_BAD_SHAPE_SUMMARY)", +# msg, +# ) +# @test occursin( +# "Warning: For 1 parameters, the generalized Pareto distribution could not be fit to the tail draws.", +# msg, +# ) +# end - @testset "test against reference values" begin - rng = MersenneTwister(42) - proposal = Normal() - target = Cauchy() - sz = (5, 1_000, 4) - x = rand(rng, proposal, sz) - logr = logpdf.(target, x) .- logpdf.(proposal, x) - logr = permutedims(logr, (2, 3, 1)) - @testset for r_eff in (0.7, 1.2) - r_effs = fill(r_eff, sz[1]) - result = @inferred psis(logr, r_effs; normalize=false) - logw = result.log_weights - @test !isapprox(logw, logr) - basename = "normal_to_cauchy_reff_$(r_eff)" - @test_reference( - "references/$basename.jld2", - Dict("log_weights" => logw, "pareto_shape" => result.pareto_shape), - by = - (ref, x) -> - isapprox(ref["log_weights"], x["log_weights"]; rtol=1e-6) && - isapprox(ref["pareto_shape"], x["pareto_shape"]; rtol=1e-6), - ) - end - end +# @testset "test against reference values" begin +# rng = MersenneTwister(42) +# proposal = Normal() +# target = Cauchy() +# sz = (5, 1_000, 4) +# x = rand(rng, proposal, sz) +# logr = logpdf.(target, x) .- logpdf.(proposal, x) +# logr = permutedims(logr, (2, 3, 1)) +# @testset for r_eff in (0.7, 1.2) +# r_effs = fill(r_eff, sz[1]) +# result = @inferred psis(logr, r_effs; normalize=false) +# logw = result.log_weights +# @test !isapprox(logw, logr) +# basename = "normal_to_cauchy_reff_$(r_eff)" +# @test_reference( +# "references/$basename.jld2", +# Dict("log_weights" => logw, "pareto_shape" => result.pareto_shape), +# by = +# (ref, x) -> +# isapprox(ref["log_weights"], x["log_weights"]; rtol=1e-6) && +# isapprox(ref["pareto_shape"], x["pareto_shape"]; rtol=1e-6), +# ) +# end +# end - # https://github.com/arviz-devs/PSIS.jl/issues/27 - @testset "no failure for very low log-weights" begin - psis(rand(1000) .- 1500) - end +# # https://github.com/arviz-devs/PSIS.jl/issues/27 +# @testset "no failure for very low log-weights" begin +# psis(rand(1000) .- 1500) +# end - @testset "compatibility with arrays with named axes/dims" begin - param_names = [Symbol("x[$i]") for i in 1:10] - iter_names = 101:200 - chain_names = 1:4 - x = randn(length(iter_names), length(chain_names), length(param_names)) +# @testset "compatibility with arrays with named axes/dims" begin +# param_names = [Symbol("x[$i]") for i in 1:10] +# iter_names = 101:200 +# chain_names = 1:4 +# x = randn(length(iter_names), length(chain_names), length(param_names)) - @testset "DimensionalData" begin - logr = DimArray( - x, - ( - Dimensions.Dim{:iter}(iter_names), - Dimensions.Dim{:chain}(chain_names), - Dimensions.Dim{:param}(param_names), - ), - ) - result = @inferred psis(logr) - @test result.log_weights isa DimArray - @test Dimensions.dims(result.log_weights) == Dimensions.dims(logr) - for k in (:pareto_shape, :tail_length, :tail_dist, :reff) - prop = getproperty(result, k) - @test prop isa DimArray - @test Dimensions.dims(prop) == Dimensions.dims(logr, (:param,)) - end - end - end -end +# @testset "DimensionalData" begin +# logr = DimArray( +# x, +# ( +# Dimensions.Dim{:iter}(iter_names), +# Dimensions.Dim{:chain}(chain_names), +# Dimensions.Dim{:param}(param_names), +# ), +# ) +# result = @inferred psis(logr) +# @test result.log_weights isa DimArray +# @test Dimensions.dims(result.log_weights) == Dimensions.dims(logr) +# for k in (:pareto_shape, :tail_length, :tail_dist, :reff) +# prop = getproperty(result, k) +# @test prop isa DimArray +# @test Dimensions.dims(prop) == Dimensions.dims(logr, (:param,)) +# end +# end +# end +# end From 8df64f64aacd1b6e6f2e145f4729bcfddef8ae5b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 22 Aug 2024 21:25:37 +0200 Subject: [PATCH 30/30] Fix errors for older Julia versions --- src/PSIS.jl | 10 ++++++++-- src/pareto_diagnose.jl | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/PSIS.jl b/src/PSIS.jl index c5c7c6eb..886f921e 100644 --- a/src/PSIS.jl +++ b/src/PSIS.jl @@ -1,5 +1,6 @@ module PSIS +using Compat: @constprop using DocStringExtensions: FIELDS using IntervalSets: IntervalSets using LogExpFunctions: LogExpFunctions @@ -7,6 +8,8 @@ using PrettyTables: PrettyTables using Printf: @sprintf using Statistics: Statistics +const EXTENSIONS_SUPPORTED = isdefined(Base, :get_extension) + export PSISPlots export ParetoDiagnostics, PSISResult export pareto_diagnose, pareto_smooth, psis, psis! @@ -23,9 +26,12 @@ include("core.jl") include("ess.jl") include("recipes/plots.jl") -@static if !isdefined(Base, :get_extension) +if !EXTENSIONS_SUPPORTED using Requires: @require - function __init__() +end + +function __init__() + @static if EXTENSIONS_SUPPORTED @require StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" begin include("../ext/PSISStatsBaseExt.jl") end diff --git a/src/pareto_diagnose.jl b/src/pareto_diagnose.jl index 05cbd40b..8a1f0649 100644 --- a/src/pareto_diagnose.jl +++ b/src/pareto_diagnose.jl @@ -155,7 +155,7 @@ function _compute_pareto_shape( expectand_proxy = _expectand_proxy(kind, x, !is_log, is_log, is_log) return _compute_pareto_shape(expectand_proxy, reff, tails) end -Base.@constprop :aggressive function _compute_pareto_shape( +@constprop :aggressive function _compute_pareto_shape( x::AbstractVecOrMat, r::AbstractVecOrMat, tails::Tails,