Skip to content

Commit

Permalink
Return results in an object (#10)
Browse files Browse the repository at this point in the history
* Return tail distribution

* Rename to check_pareto_shape

* Implement PSISResult

* Add pareto_shape methods

* Simplify psis! implementation

* Implement psis! for multi-parameter case

* Add scale

* Run formatter

* Update tests to use objects

* Rename r_eff to reff

* Remove now-unused tests

* Add missing end

* Refer to shape as shape instead of k

* Just print pareto_shape

* Remove normalize kwarg

* Make return value of tail_dist type-inferrable

* Return actual tail_lengths

* Document remaining properties

* Test PSISResult

* Test more cases

* Move diagnostic docs to PSISResult

* Format docstrings

* Increment version number

* Test propertynames

* Simplify test
  • Loading branch information
sethaxen authored Nov 22, 2021
1 parent c596dfe commit 70ac431
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 133 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PSIS"
uuid = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.1.7"
version = "0.2.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
220 changes: 140 additions & 80 deletions src/PSIS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,101 @@ module PSIS

using Distributions: Distributions
using LinearAlgebra: dot
using LogExpFunctions: logsumexp, softmax!
using LogExpFunctions: logsumexp, softmax, softmax!
using Printf: @sprintf
using Statistics: mean, median, quantile
using StatsBase: StatsBase

export PSISResult
export psis, psis!

include("utils.jl")
include("generalized_pareto.jl")

"""
psis(log_ratios, reff = 1.0; kwargs...) -> (log_weights, k)
PSISResult
Result of Pareto-smoothed importance sampling (PSIS).
# Properties
- `log_weights`: unnormalized Pareto-smoothed log weights
- `weights`: normalized Pareto-smoothed weights (allocates a copy)
- `pareto_shape`: Pareto ``k=ξ`` shape parameter
- `nparams`: number of parameters in `log_weights`
- `ndraws`: number of draws in `log_weights`
- `nchains`: number of chains in `log_weights`
- `reff`: the ratio of the effective sample size of the unsmoothed importance ratios and
the actual sample size.
- `tail_length`: length of the upper tail of `log_weights` that was smoothed
- `tail_dist`: the generalized Pareto distribution that was fit to the tail of
`log_weights`
# Diagnostic
The `pareto_shape` parameter ``k=ξ`` of the generalized Pareto distribution `tail_dist` can
be used to diagnose reliability and convergence of estimates using the importance weights
[^VehtariSimpson2021].
- if ``k < \\frac{1}{3}``, importance sampling is stable, and importance sampling (IS) and
PSIS both are reliable.
- if ``k < \\frac{1}{2}``, then the importance ratio distributon has finite variance, and
the central limit theorem holds. As ``k`` approaches the upper bound, IS becomes less
reliable, while PSIS still works well but with a higher RMSE.
- if ``\\frac{1}{2} ≤ k < 0.7``, then the variance is infinite, and IS can behave quite
poorly. However, PSIS works well in this regime.
- if ``0.7 ≤ k < 1``, then it quickly becomes impractical to collect enough importance
weights to reliably compute estimates, and importance sampling is not recommended.
- if ``k ≥ 1``, then neither the variance nor the mean of the raw importance ratios
exists. The convergence rate is close to zero, and bias can be large with practical
sample sizes.
[^VehtariSimpson2021]: Vehtari A, Simpson D, Gelman A, Yao Y, Gabry J. (2021).
Pareto smoothed importance sampling.
[arXiv:1507.02646v7](https://arxiv.org/abs/1507.02646v7) [stat.CO]
"""
struct PSISResult{T,W<:AbstractArray{T},R,L,D}
log_weights::W
reff::R
tail_length::L
tail_dist::D
end

function Base.propertynames(r::PSISResult)
return [fieldnames(typeof(r))..., :weights, :nparams, :ndraws, :nchains, :pareto_shape]
end

function Base.getproperty(r::PSISResult, k::Symbol)
if k === :weights
log_weights = getfield(r, :log_weights)
d = ndims(log_weights)
dims = d == 1 ? Colon() : ntuple(Base.Fix1(+, 1), d - 1)
return softmax(log_weights; dims=dims)
end
if k === :nparams
log_weights = getfield(r, :log_weights)
return ndims(log_weights) == 1 ? 1 : size(log_weights, 1)
end
if k === :ndraws
log_weights = getfield(r, :log_weights)
return ndims(log_weights) == 1 ? length(log_weights) : size(log_weights, 2)
end
if k === :nchains
log_weights = getfield(r, :log_weights)
return size(log_weights, 3)
end
k === :pareto_shape && return pareto_shape(r)
return getfield(r, k)
end

function Base.show(io::IO, ::MIME"text/plain", r::PSISResult)
println(io, typeof(r), ":")
print(io, " pareto_shape: ", r.pareto_shape)
return nothing
end

"""
psis(log_ratios, reff = 1.0; kwargs...) -> PSISResult
Compute Pareto smoothed importance sampling (PSIS) log weights [^VehtariSimpson2021].
Expand All @@ -37,38 +120,16 @@ See [`psis!`](@ref) for a version that smoothes the ratios in-place.
- `sorted=issorted(vec(log_ratios))`: whether `log_ratios` are already sorted. Only
accepted if `nparams==1`.
- `normalize=false`: whether to normalize the log weights so that the resulting weights
for a given parameter sum to one.
- `improved=false`: If `true`, use the adaptive empirical prior of [^Zhang2010].
If `false`, use the simpler prior of [^ZhangStephens2009], which is also used in
[^VehtariSimpson2021].
# Returns
- `log_weights`: an array of smoothed log weights of the same size as `log_ratios`
- `k`: for each parameter, the estimated shape parameter ``k`` of the generalized Pareto
distribution, which is useful for diagnosing the distribution of importance ratios.
See details below.
# Diagnostic
The shape parameter ``k`` of the generalized Pareto distribution can be used to diagnose
reliability and convergence of estimates using the importance weights [^VehtariSimpson2021]:
- if ``k < \\frac{1}{3}``, importance sampling is stable, and importance sampling (IS) and
PSIS both are reliable.
- if ``k < \\frac{1}{2}``, then the importance ratio distributon has finite variance, and
the central limit theorem holds. As ``k`` approaches the upper bound, IS becomes less
reliable, while PSIS still works well but with a higher RMSE.
- if ``\\frac{1}{2} ≤ k < 0.7``, then the variance is infinite, and IS can behave quite
poorly. However, PSIS works well in this regime.
- if ``0.7 ≤ k < 1``, then it quickly becomes impractical to collect enough importance
weights to reliably compute estimates, and importance sampling is not recommended.
- if ``k ≥ 1``, then neither the variance nor the mean of the raw importance ratios
exists. The convergence rate is close to zero, and bias can be large with practical
sample sizes.
- `result`: a [`PSISResult`](@ref) object containing the results of the Pareto-smoothing.
A warning is raised if ``k ≥ 0.7``.
A warning is raised if the Pareto shape parameter ``k ≥ 0.7``. See [`PSISResult`](@ref) for
details.
[^VehtariSimpson2021]: Vehtari A, Simpson D, Gelman A, Yao Y, Gabry J. (2021).
Pareto smoothed importance sampling.
Expand All @@ -95,87 +156,86 @@ In-place compute Pareto smoothed importance sampling (PSIS) log weights.
See [`psis`](@ref) for an out-of-place version and for description of arguments and return
values.
"""
function psis!(
logw::AbstractVector, reff=1; sorted=issorted(logw), normalize=false, improved=false
)
T = eltype(logw)
function psis!(logw::AbstractVector, reff=1; sorted=issorted(logw), improved=false)
S = length(logw)
k_hat = T(Inf)

@assert isone(length(reff)) # support numbers or single-element array
reff_val = first(reff)
M = tail_length(reff_val, S)
if M < 5
@warn "Insufficient tail draws to fit the generalized Pareto distribution."
else
perm = sorted ? eachindex(logw) : sortperm(logw)
@inbounds logw_max = logw[last(perm)]
icut = S - M
tail_range = (icut + 1):S

@inbounds logw_tail = @views logw[perm[tail_range]]
if logw_max - first(logw_tail) < eps(eltype(logw_tail)) / 100
@warn "Cannot fit the generalized Pareto distribution because all tail " *
"values are the same"
else
logw_tail .-= logw_max
@inbounds logu = logw[perm[icut]] - logw_max

_, k_hat = psis_tail!(logw_tail, logu, M, improved)
logw_tail .+= logw_max

check_pareto_k(k_hat)
end
return PSISResult(logw, reff, M, missing)
end

if normalize
logw .-= logsumexp(logw)
end

return logw, k_hat
perm = sorted ? collect(eachindex(logw)) : sortperm(logw)
icut = S - M
tail_range = (icut + 1):S
@inbounds logw_tail = @views logw[perm[tail_range]]
@inbounds logu = logw[perm[icut]]
_, tail_dist = psis_tail!(logw_tail, logu, M, improved)
check_pareto_shape(tail_dist)
return PSISResult(logw, reff_val, M, tail_dist)
end
function psis!(logw::AbstractArray, reff=1; kwargs...)
Tdist = Union{Distributions.GeneralizedPareto{eltype(logw)},Missing}
logw_firstcol = view(logw, :, ntuple(_ -> 1, ndims(logw) - 1)...)
reff_vec = reff isa Number ? fill!(similar(logw_firstcol), reff) : reff
# support both 2D and 3D arrays, flattening the final dimension
_, k_hat = psis!(vec(selectdim(logw, 1, 1)), reff_vec[1]; kwargs...)
# for arrays with named dimensions, this pattern ensures k_hat has the same names
k_hats = similar(logw_firstcol, eltype(k_hat))
k_hats[1] = k_hat
Threads.@threads for i in eachindex(k_hats, reff_vec)
_, k_hats[i] = psis!(vec(selectdim(logw, 1, i)), reff_vec[i]; kwargs...)
r1 = psis!(vec(selectdim(logw, 1, 1)), reff_vec[1]; kwargs...)
# for arrays with named dimensions, this pattern ensures tail_lengths and tail_dists
# have the same names
tail_lengths = similar(logw_firstcol, Int)
tail_lengths[1] = r1.tail_length
tail_dists = similar(logw_firstcol, Tdist)
tail_dists[1] = r1.tail_dist
Threads.@threads for i in eachindex(tail_dists, reff_vec, tail_lengths, tail_dists)
ri = psis!(vec(selectdim(logw, 1, i)), reff_vec[i]; kwargs...)
tail_lengths[i] = ri.tail_length
tail_dists[i] = ri.tail_dist
end
return logw, k_hats
return PSISResult(logw, reff_vec, tail_lengths, map(identity, tail_dists))
end

function check_pareto_k(k)
if k 1
@warn "Pareto k=$(@sprintf("%.2g", k)) ≥ 1. Resulting importance sampling " *
pareto_shape(::Missing) = missing
pareto_shape(dist::Distributions.GeneralizedPareto) = Distributions.shape(dist)
pareto_shape(r::PSISResult) = pareto_shape(getfield(r, :tail_dist))
pareto_shape(dists) = map(pareto_shape, dists)

function check_pareto_shape(dist::Distributions.GeneralizedPareto)
ξ = pareto_shape(dist)
if ξ 1
@warn "Pareto shape=$(@sprintf("%.2g", ξ)) ≥ 1. Resulting importance sampling " *
"estimates are likely to be unstable and are unlikely to converge with " *
"additional samples."
elseif k 0.7
@warn "Pareto k=$(@sprintf("%.2g", k)) ≥ 0.7. Resulting importance sampling " *
elseif ξ 0.7
@warn "Pareto shape=$(@sprintf("%.2g", ξ)) ≥ 0.7. Resulting importance sampling " *
"estimates are likely to be unstable."
end
return nothing
end

tail_length(reff, S) = min(cld(S, 5), ceil(Int, 3 * sqrt(S / reff)))

function psis_tail!(logw, logu, M=length(logw), improved=false)
function psis_tail!(logw, logμ, M=length(logw), improved=false)
T = eltype(logw)
u = exp(logu)
w = (logw .= exp.(logw))
d_hat = StatsBase.fit(GeneralizedParetoKnownMu(u), w; sorted=true, improved=improved)
d_hat = prior_adjust_shape(d_hat, M)
k_hat = Distributions.shape(d_hat)
if isfinite(k_hat)
logw_max = logw[M]
# to improve numerical stability, we first scale the log-weights to have a maximum of 1,
# equivalent to shifting the log-weights to have a maximum of 0.
μ_scaled = exp(logμ - logw_max)
w = (logw .= exp.(logw .- logw_max))
tail_dist_scaled = StatsBase.fit(
GeneralizedParetoKnownMu(μ_scaled), w; sorted=true, improved=improved
)
tail_dist_adjusted = prior_adjust_shape(tail_dist_scaled, M)
# undo the scaling
ξ = Distributions.shape(tail_dist_adjusted)
if isfinite(ξ)
p = uniform_probabilities(T, M)
@inbounds for i in eachindex(logw, p)
logw[i] = min(log(_quantile(d_hat, p[i])), 0)
# undo scaling in the log-weights
logw[i] = min(log(_quantile(tail_dist_adjusted, p[i])), 0) + logw_max
end
end
return logw, k_hat
# undo scaling for the tail distribution
tail_dist = scale(tail_dist_adjusted, exp(logw_max))
return logw, tail_dist
end

end
4 changes: 4 additions & 0 deletions src/generalized_pareto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,7 @@ function prior_adjust_shape(d::Distributions.GeneralizedPareto, n, ξ_prior=1//2
ξ = (n * d.ξ + nobs * ξ_prior) / (n + nobs)
return Distributions.GeneralizedPareto(d.μ, d.σ, ξ)
end

function scale(d::Distributions.GeneralizedPareto, s)
return Distributions.GeneralizedPareto(d.μ * s, d.σ * s, d.ξ)
end
Loading

2 comments on commit 70ac431

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/49207

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 70ac431eeb4d0cd52dbd7ec825a230be9c553781
git push origin v0.2.0

Please sign in to comment.