From 70ac431eeb4d0cd52dbd7ec825a230be9c553781 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 22 Nov 2021 22:59:40 +0100 Subject: [PATCH] Return results in an object (#10) * Return tail distribution * Rename to check_pareto_shape * Implement PSISResult * Add pareto_shape methods * Simplify psis! implementation * Implement psis! for multi-parameter case * Add scale * Run formatter * Update tests to use objects * Rename r_eff to reff * Remove now-unused tests * Add missing end * Refer to shape as shape instead of k * Just print pareto_shape * Remove normalize kwarg * Make return value of tail_dist type-inferrable * Return actual tail_lengths * Document remaining properties * Test PSISResult * Test more cases * Move diagnostic docs to PSISResult * Format docstrings * Increment version number * Test propertynames * Simplify test --- Project.toml | 2 +- src/PSIS.jl | 220 ++++++++++++++++++++++++-------------- src/generalized_pareto.jl | 4 + test/psis.jl | 166 +++++++++++++++++++--------- 4 files changed, 259 insertions(+), 133 deletions(-) diff --git a/Project.toml b/Project.toml index 47ed31ea..e74fa9e9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PSIS" uuid = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04" authors = ["Seth Axen and contributors"] -version = "0.1.7" +version = "0.2.0" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/PSIS.jl b/src/PSIS.jl index dc505748..0628a517 100644 --- a/src/PSIS.jl +++ b/src/PSIS.jl @@ -2,18 +2,101 @@ module PSIS using Distributions: Distributions using LinearAlgebra: dot -using LogExpFunctions: logsumexp, softmax! +using LogExpFunctions: logsumexp, softmax, softmax! using Printf: @sprintf using Statistics: mean, median, quantile using StatsBase: StatsBase +export PSISResult export psis, psis! include("utils.jl") include("generalized_pareto.jl") """ - psis(log_ratios, reff = 1.0; kwargs...) -> (log_weights, k) + PSISResult + +Result of Pareto-smoothed importance sampling (PSIS). + +# Properties + + - `log_weights`: unnormalized Pareto-smoothed log weights + - `weights`: normalized Pareto-smoothed weights (allocates a copy) + - `pareto_shape`: Pareto ``k=ξ`` shape parameter + - `nparams`: number of parameters in `log_weights` + - `ndraws`: number of draws in `log_weights` + - `nchains`: number of chains in `log_weights` + - `reff`: the ratio of the effective sample size of the unsmoothed importance ratios and + the actual sample size. + - `tail_length`: length of the upper tail of `log_weights` that was smoothed + - `tail_dist`: the generalized Pareto distribution that was fit to the tail of + `log_weights` + +# Diagnostic + +The `pareto_shape` parameter ``k=ξ`` of the generalized Pareto distribution `tail_dist` can +be used to diagnose reliability and convergence of estimates using the importance weights +[^VehtariSimpson2021]. + + - if ``k < \\frac{1}{3}``, importance sampling is stable, and importance sampling (IS) and + PSIS both are reliable. + - if ``k < \\frac{1}{2}``, then the importance ratio distributon has finite variance, and + the central limit theorem holds. As ``k`` approaches the upper bound, IS becomes less + reliable, while PSIS still works well but with a higher RMSE. + - if ``\\frac{1}{2} ≤ k < 0.7``, then the variance is infinite, and IS can behave quite + poorly. However, PSIS works well in this regime. + - if ``0.7 ≤ k < 1``, then it quickly becomes impractical to collect enough importance + weights to reliably compute estimates, and importance sampling is not recommended. + - if ``k ≥ 1``, then neither the variance nor the mean of the raw importance ratios + exists. The convergence rate is close to zero, and bias can be large with practical + sample sizes. + +[^VehtariSimpson2021]: Vehtari A, Simpson D, Gelman A, Yao Y, Gabry J. (2021). + Pareto smoothed importance sampling. + [arXiv:1507.02646v7](https://arxiv.org/abs/1507.02646v7) [stat.CO] +""" +struct PSISResult{T,W<:AbstractArray{T},R,L,D} + log_weights::W + reff::R + tail_length::L + tail_dist::D +end + +function Base.propertynames(r::PSISResult) + return [fieldnames(typeof(r))..., :weights, :nparams, :ndraws, :nchains, :pareto_shape] +end + +function Base.getproperty(r::PSISResult, k::Symbol) + if k === :weights + log_weights = getfield(r, :log_weights) + d = ndims(log_weights) + dims = d == 1 ? Colon() : ntuple(Base.Fix1(+, 1), d - 1) + return softmax(log_weights; dims=dims) + end + if k === :nparams + log_weights = getfield(r, :log_weights) + return ndims(log_weights) == 1 ? 1 : size(log_weights, 1) + end + if k === :ndraws + log_weights = getfield(r, :log_weights) + return ndims(log_weights) == 1 ? length(log_weights) : size(log_weights, 2) + end + if k === :nchains + log_weights = getfield(r, :log_weights) + return size(log_weights, 3) + end + k === :pareto_shape && return pareto_shape(r) + return getfield(r, k) +end + +function Base.show(io::IO, ::MIME"text/plain", r::PSISResult) + println(io, typeof(r), ":") + print(io, " pareto_shape: ", r.pareto_shape) + return nothing +end + +""" + psis(log_ratios, reff = 1.0; kwargs...) -> PSISResult Compute Pareto smoothed importance sampling (PSIS) log weights [^VehtariSimpson2021]. @@ -37,38 +120,16 @@ See [`psis!`](@ref) for a version that smoothes the ratios in-place. - `sorted=issorted(vec(log_ratios))`: whether `log_ratios` are already sorted. Only accepted if `nparams==1`. - - `normalize=false`: whether to normalize the log weights so that the resulting weights - for a given parameter sum to one. - `improved=false`: If `true`, use the adaptive empirical prior of [^Zhang2010]. If `false`, use the simpler prior of [^ZhangStephens2009], which is also used in [^VehtariSimpson2021]. # Returns - - `log_weights`: an array of smoothed log weights of the same size as `log_ratios` - - `k`: for each parameter, the estimated shape parameter ``k`` of the generalized Pareto - distribution, which is useful for diagnosing the distribution of importance ratios. - See details below. - -# Diagnostic - -The shape parameter ``k`` of the generalized Pareto distribution can be used to diagnose -reliability and convergence of estimates using the importance weights [^VehtariSimpson2021]: - - - if ``k < \\frac{1}{3}``, importance sampling is stable, and importance sampling (IS) and - PSIS both are reliable. - - if ``k < \\frac{1}{2}``, then the importance ratio distributon has finite variance, and - the central limit theorem holds. As ``k`` approaches the upper bound, IS becomes less - reliable, while PSIS still works well but with a higher RMSE. - - if ``\\frac{1}{2} ≤ k < 0.7``, then the variance is infinite, and IS can behave quite - poorly. However, PSIS works well in this regime. - - if ``0.7 ≤ k < 1``, then it quickly becomes impractical to collect enough importance - weights to reliably compute estimates, and importance sampling is not recommended. - - if ``k ≥ 1``, then neither the variance nor the mean of the raw importance ratios - exists. The convergence rate is close to zero, and bias can be large with practical - sample sizes. + - `result`: a [`PSISResult`](@ref) object containing the results of the Pareto-smoothing. -A warning is raised if ``k ≥ 0.7``. +A warning is raised if the Pareto shape parameter ``k ≥ 0.7``. See [`PSISResult`](@ref) for +details. [^VehtariSimpson2021]: Vehtari A, Simpson D, Gelman A, Yao Y, Gabry J. (2021). Pareto smoothed importance sampling. @@ -95,66 +156,56 @@ In-place compute Pareto smoothed importance sampling (PSIS) log weights. See [`psis`](@ref) for an out-of-place version and for description of arguments and return values. """ -function psis!( - logw::AbstractVector, reff=1; sorted=issorted(logw), normalize=false, improved=false -) - T = eltype(logw) +function psis!(logw::AbstractVector, reff=1; sorted=issorted(logw), improved=false) S = length(logw) - k_hat = T(Inf) - - @assert isone(length(reff)) # support numbers or single-element array reff_val = first(reff) M = tail_length(reff_val, S) if M < 5 @warn "Insufficient tail draws to fit the generalized Pareto distribution." - else - perm = sorted ? eachindex(logw) : sortperm(logw) - @inbounds logw_max = logw[last(perm)] - icut = S - M - tail_range = (icut + 1):S - - @inbounds logw_tail = @views logw[perm[tail_range]] - if logw_max - first(logw_tail) < eps(eltype(logw_tail)) / 100 - @warn "Cannot fit the generalized Pareto distribution because all tail " * - "values are the same" - else - logw_tail .-= logw_max - @inbounds logu = logw[perm[icut]] - logw_max - - _, k_hat = psis_tail!(logw_tail, logu, M, improved) - logw_tail .+= logw_max - - check_pareto_k(k_hat) - end + return PSISResult(logw, reff, M, missing) end - - if normalize - logw .-= logsumexp(logw) - end - - return logw, k_hat + perm = sorted ? collect(eachindex(logw)) : sortperm(logw) + icut = S - M + tail_range = (icut + 1):S + @inbounds logw_tail = @views logw[perm[tail_range]] + @inbounds logu = logw[perm[icut]] + _, tail_dist = psis_tail!(logw_tail, logu, M, improved) + check_pareto_shape(tail_dist) + return PSISResult(logw, reff_val, M, tail_dist) end function psis!(logw::AbstractArray, reff=1; kwargs...) + Tdist = Union{Distributions.GeneralizedPareto{eltype(logw)},Missing} logw_firstcol = view(logw, :, ntuple(_ -> 1, ndims(logw) - 1)...) reff_vec = reff isa Number ? fill!(similar(logw_firstcol), reff) : reff # support both 2D and 3D arrays, flattening the final dimension - _, k_hat = psis!(vec(selectdim(logw, 1, 1)), reff_vec[1]; kwargs...) - # for arrays with named dimensions, this pattern ensures k_hat has the same names - k_hats = similar(logw_firstcol, eltype(k_hat)) - k_hats[1] = k_hat - Threads.@threads for i in eachindex(k_hats, reff_vec) - _, k_hats[i] = psis!(vec(selectdim(logw, 1, i)), reff_vec[i]; kwargs...) + r1 = psis!(vec(selectdim(logw, 1, 1)), reff_vec[1]; kwargs...) + # for arrays with named dimensions, this pattern ensures tail_lengths and tail_dists + # have the same names + tail_lengths = similar(logw_firstcol, Int) + tail_lengths[1] = r1.tail_length + tail_dists = similar(logw_firstcol, Tdist) + tail_dists[1] = r1.tail_dist + Threads.@threads for i in eachindex(tail_dists, reff_vec, tail_lengths, tail_dists) + ri = psis!(vec(selectdim(logw, 1, i)), reff_vec[i]; kwargs...) + tail_lengths[i] = ri.tail_length + tail_dists[i] = ri.tail_dist end - return logw, k_hats + return PSISResult(logw, reff_vec, tail_lengths, map(identity, tail_dists)) end -function check_pareto_k(k) - if k ≥ 1 - @warn "Pareto k=$(@sprintf("%.2g", k)) ≥ 1. Resulting importance sampling " * +pareto_shape(::Missing) = missing +pareto_shape(dist::Distributions.GeneralizedPareto) = Distributions.shape(dist) +pareto_shape(r::PSISResult) = pareto_shape(getfield(r, :tail_dist)) +pareto_shape(dists) = map(pareto_shape, dists) + +function check_pareto_shape(dist::Distributions.GeneralizedPareto) + ξ = pareto_shape(dist) + if ξ ≥ 1 + @warn "Pareto shape=$(@sprintf("%.2g", ξ)) ≥ 1. Resulting importance sampling " * "estimates are likely to be unstable and are unlikely to converge with " * "additional samples." - elseif k ≥ 0.7 - @warn "Pareto k=$(@sprintf("%.2g", k)) ≥ 0.7. Resulting importance sampling " * + elseif ξ ≥ 0.7 + @warn "Pareto shape=$(@sprintf("%.2g", ξ)) ≥ 0.7. Resulting importance sampling " * "estimates are likely to be unstable." end return nothing @@ -162,20 +213,29 @@ end tail_length(reff, S) = min(cld(S, 5), ceil(Int, 3 * sqrt(S / reff))) -function psis_tail!(logw, logu, M=length(logw), improved=false) +function psis_tail!(logw, logμ, M=length(logw), improved=false) T = eltype(logw) - u = exp(logu) - w = (logw .= exp.(logw)) - d_hat = StatsBase.fit(GeneralizedParetoKnownMu(u), w; sorted=true, improved=improved) - d_hat = prior_adjust_shape(d_hat, M) - k_hat = Distributions.shape(d_hat) - if isfinite(k_hat) + logw_max = logw[M] + # to improve numerical stability, we first scale the log-weights to have a maximum of 1, + # equivalent to shifting the log-weights to have a maximum of 0. + μ_scaled = exp(logμ - logw_max) + w = (logw .= exp.(logw .- logw_max)) + tail_dist_scaled = StatsBase.fit( + GeneralizedParetoKnownMu(μ_scaled), w; sorted=true, improved=improved + ) + tail_dist_adjusted = prior_adjust_shape(tail_dist_scaled, M) + # undo the scaling + ξ = Distributions.shape(tail_dist_adjusted) + if isfinite(ξ) p = uniform_probabilities(T, M) @inbounds for i in eachindex(logw, p) - logw[i] = min(log(_quantile(d_hat, p[i])), 0) + # undo scaling in the log-weights + logw[i] = min(log(_quantile(tail_dist_adjusted, p[i])), 0) + logw_max end end - return logw, k_hat + # undo scaling for the tail distribution + tail_dist = scale(tail_dist_adjusted, exp(logw_max)) + return logw, tail_dist end end diff --git a/src/generalized_pareto.jl b/src/generalized_pareto.jl index 888d8659..2ccef555 100644 --- a/src/generalized_pareto.jl +++ b/src/generalized_pareto.jl @@ -184,3 +184,7 @@ function prior_adjust_shape(d::Distributions.GeneralizedPareto, n, ξ_prior=1//2 ξ = (n * d.ξ + nobs * ξ_prior) / (n + nobs) return Distributions.GeneralizedPareto(d.μ, d.σ, ξ) end + +function scale(d::Distributions.GeneralizedPareto, s) + return Distributions.GeneralizedPareto(d.μ * s, d.σ * s, d.ξ) +end diff --git a/test/psis.jl b/test/psis.jl index 9c58f42e..dc34d564 100644 --- a/test/psis.jl +++ b/test/psis.jl @@ -2,11 +2,74 @@ using PSIS using Test using Random using ReferenceTests -using Distributions: Normal, Cauchy, Exponential, logpdf, mean +using Distributions: GeneralizedPareto, Normal, Cauchy, Exponential, logpdf, mean, shape using LogExpFunctions: softmax using Logging: SimpleLogger, with_logger using AxisArrays: AxisArrays +@testset "PSISResult" begin + @testset "vector log-weights" begin + log_weights = randn(500) + tail_length = 100 + reff = 2.0 + tail_dist = GeneralizedPareto(1.0, 1.0, 0.5) + result = PSISResult(log_weights, reff, tail_length, tail_dist) + @test result isa PSISResult{Float64} + @test sort(propertynames(result)) == [ + :log_weights, + :nchains, + :ndraws, + :nparams, + :pareto_shape, + :reff, + :tail_dist, + :tail_length, + :weights, + ] + @test result.log_weights == log_weights + @test result.weights == softmax(log_weights) + @test result.reff == reff + @test result.nparams == 1 + @test result.ndraws == 500 + @test result.nchains == 1 + @test result.tail_length == tail_length + @test result.tail_dist == tail_dist + @test result.pareto_shape == 0.5 + + @testset "show" begin + @test sprint(show, "text/plain", result) == + "$(typeof(result)):\n pareto_shape: 0.5" + end + end + + @testset "array log-weights" begin + log_weights = randn(3, 500, 4) + tail_length = [1600, 1601, 1602] + reff = [0.8, 0.9, 1.1] + tail_dist = [ + GeneralizedPareto(1.0, 1.0, 0.5), + GeneralizedPareto(1.0, 1.0, 0.6), + GeneralizedPareto(1.0, 1.0, 0.7), + ] + result = PSISResult(log_weights, reff, tail_length, tail_dist) + @test result isa PSISResult{Float64} + @test result.log_weights == log_weights + @test result.weights == softmax(log_weights; dims=(2, 3)) + @test result.reff == reff + @test result.nparams == 3 + @test result.ndraws == 500 + @test result.nchains == 4 + @test result.tail_length == tail_length + @test result.tail_dist == tail_dist + @test result.pareto_shape == [0.5, 0.6, 0.7] + + @testset "show" begin + @test sprint(show, "text/plain", result) == + "$(typeof(result)):\n pareto_shape: [0.5, 0.6, 0.7]" + end + end +end + @testset "psis/psis!" begin @testset "importance sampling tests" begin target = Exponential(1) @@ -28,9 +91,30 @@ using AxisArrays: AxisArrays x = rand(rng, proposal, sz) logr = logpdf.(target, x) .- logpdf.(proposal, x) - logw, k = psis(logr) - w = softmax(logr; dims=dims) - @test all(x -> isapprox(x, ξ_exp; atol=0.15), k) + r = psis(logr) + @test r isa PSISResult + logw = r.log_weights + @test logw isa typeof(logr) + + if length(sz) == 3 + @test all(r.tail_length .== PSIS.tail_length(1, 400_000)) + else + @test all(r.tail_length .== PSIS.tail_length(1, 100_000)) + end + + ξ = r.pareto_shape + @test ξ isa (length(sz) == 1 ? Number : AbstractVector) + tail_dist = r.tail_dist + if length(sz) == 1 + @test tail_dist isa GeneralizedPareto + @test shape(tail_dist) == ξ + else + @test tail_dist isa Vector{<:GeneralizedPareto} + @test map(shape, tail_dist) == ξ + end + + w = r.weights + @test all(x -> isapprox(x, ξ_exp; atol=0.15), ξ) @test all(x -> isapprox(x, x_target; atol=atol), sum(x .* w; dims=dims)) @test all( x -> isapprox(x, x²_target; atol=atol), sum(x .^ 2 .* w; dims=dims) @@ -43,62 +127,35 @@ using AxisArrays: AxisArrays @testset "sorted=true" begin x = randn(100) perm = sortperm(x) - @test psis(x)[1] == invpermute!(psis(x[perm]; sorted=true)[1], perm) - @test psis(x)[2] == psis(x[perm]; sorted=true)[2] - end - - @testset "normalize=true" begin - @testset for sz in (100, (5, 100), (5, 100, 4)) - dims = length(sz) == 1 ? Colon() : 2:length(sz) - x = randn(sz) - lw1, k1 = psis(x) - lw2, k2 = psis(x; normalize=true) - @test k1 ≈ k2 - @test !(lw1 ≈ lw2) - - if VERSION ≥ v"1.1" - @test all(abs.(diff(lw1 - lw2; dims=length(sz))) .< sqrt(eps())) - end - @test all(x -> isapprox(x, 1), sum(exp.(lw2); dims=dims)) - end + @test psis(x).log_weights == + invpermute!(psis(x[perm]; sorted=true).log_weights, perm) + @test psis(x).pareto_shape == psis(x[perm]; sorted=true).pareto_shape end end @testset "warnings" begin io = IOBuffer() logr = randn(5) - logw, k = with_logger(SimpleLogger(io)) do + result = with_logger(SimpleLogger(io)) do psis(logr) end - @test logw == logr - @test isinf(k) + @test result.log_weights == logr + @test ismissing(result.tail_dist) + @test ismissing(result.pareto_shape) msg = String(take!(io)) @test occursin( "Warning: Insufficient tail draws to fit the generalized Pareto distribution", msg, ) - io = IOBuffer() - logr = ones(100) - logw, k = with_logger(SimpleLogger(io)) do - psis(logr) - end - @test logw == logr - @test isinf(k) - msg = String(take!(io)) - @test occursin( - "Warning: Cannot fit the generalized Pareto distribution because all tail values are the same", - msg, - ) - io = IOBuffer() x = rand(Exponential(100), 1_000) logr = logpdf.(Exponential(1), x) .- logpdf.(Exponential(1000), x) - logw, k = with_logger(SimpleLogger(io)) do + result = with_logger(SimpleLogger(io)) do psis(logr) end - @test logw != logr - @test k > 0.7 + @test result.log_weights != logr + @test result.pareto_shape > 0.7 msg = String(take!(io)) @test occursin( "Resulting importance sampling estimates are likely to be unstable", msg @@ -106,27 +163,27 @@ using AxisArrays: AxisArrays io = IOBuffer() with_logger(SimpleLogger(io)) do - PSIS.check_pareto_k(1.1) + PSIS.check_pareto_shape(GeneralizedPareto(0.0, 1.0, 1.1)) end msg = String(take!(io)) @test occursin( - "Warning: Pareto k=1.1 ≥ 1. Resulting importance sampling estimates are likely to be unstable and are unlikely to converge with additional samples.", + "Warning: Pareto shape=1.1 ≥ 1. Resulting importance sampling estimates are likely to be unstable and are unlikely to converge with additional samples.", msg, ) io = IOBuffer() with_logger(SimpleLogger(io)) do - PSIS.check_pareto_k(0.8) + PSIS.check_pareto_shape(GeneralizedPareto(0.0, 1.0, 0.8)) end msg = String(take!(io)) @test occursin( - "Warning: Pareto k=0.8 ≥ 0.7. Resulting importance sampling estimates are likely to be unstable.", + "Warning: Pareto shape=0.8 ≥ 0.7. Resulting importance sampling estimates are likely to be unstable.", msg, ) io = IOBuffer() with_logger(SimpleLogger(io)) do - PSIS.check_pareto_k(0.69) + PSIS.check_pareto_shape(GeneralizedPareto(0.0, 1.0, 0.69)) end msg = String(take!(io)) @test isempty(msg) @@ -147,7 +204,9 @@ using AxisArrays: AxisArrays ) @testset for r_eff in (0.7, 1.2), improved in (true, false) r_effs = fill(r_eff, sz[1]) - logw, k = psis(logr, r_effs; improved=improved) + result = psis(logr, r_effs; improved=improved) + logw = result.log_weights + k = result.pareto_shape @test !isapprox(logw, logr) basename = "normal_to_cauchy_reff_$(r_eff)" if improved @@ -176,11 +235,14 @@ using AxisArrays: AxisArrays AxisArrays.Axis{:iter}(iter_names), AxisArrays.Axis{:chain}(chain_names), ) - logw, k = psis(logr) - @test logw isa AxisArrays.AxisArray - @test AxisArrays.axes(logw) == AxisArrays.axes(logr) - @test k isa AxisArrays.AxisArray - @test AxisArrays.axes(k) == (AxisArrays.axes(logr, 1),) + result = psis(logr) + @test result.log_weights isa AxisArrays.AxisArray + @test AxisArrays.axes(result.log_weights) == AxisArrays.axes(logr) + for k in (:pareto_shape, :tail_length, :tail_dist, :reff) + prop = getproperty(result, k) + @test prop isa AxisArrays.AxisArray + @test AxisArrays.axes(prop) == (AxisArrays.axes(logr, 1),) + end end end end