Skip to content


Return results in an object (#10)
Browse files Browse the repository at this point in the history
* Return tail distribution

* Rename to check_pareto_shape

* Implement PSISResult

* Add pareto_shape methods

* Simplify psis! implementation

* Implement psis! for multi-parameter case

* Add scale

* Run formatter

* Update tests to use objects

* Rename r_eff to reff

* Remove now-unused tests

* Add missing end

* Refer to shape as shape instead of k

* Just print pareto_shape

* Remove normalize kwarg

* Make return value of tail_dist type-inferrable

* Return actual tail_lengths

* Document remaining properties

* Test PSISResult

* Test more cases

* Move diagnostic docs to PSISResult

* Format docstrings

* Increment version number

* Test propertynames

* Simplify test
  • Loading branch information
sethaxen authored Nov 22, 2021
1 parent c596dfe commit 70ac431
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 133 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PSIS"
uuid = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.1.7"
version = "0.2.0"

Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
220 changes: 140 additions & 80 deletions src/PSIS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,101 @@ module PSIS

using Distributions: Distributions
using LinearAlgebra: dot
using LogExpFunctions: logsumexp, softmax!
using LogExpFunctions: logsumexp, softmax, softmax!
using Printf: @sprintf
using Statistics: mean, median, quantile
using StatsBase: StatsBase

export PSISResult
export psis, psis!


psis(log_ratios, reff = 1.0; kwargs...) -> (log_weights, k)
Result of Pareto-smoothed importance sampling (PSIS).
# Properties
- `log_weights`: unnormalized Pareto-smoothed log weights
- `weights`: normalized Pareto-smoothed weights (allocates a copy)
- `pareto_shape`: Pareto ``k=ξ`` shape parameter
- `nparams`: number of parameters in `log_weights`
- `ndraws`: number of draws in `log_weights`
- `nchains`: number of chains in `log_weights`
- `reff`: the ratio of the effective sample size of the unsmoothed importance ratios and
the actual sample size.
- `tail_length`: length of the upper tail of `log_weights` that was smoothed
- `tail_dist`: the generalized Pareto distribution that was fit to the tail of
# Diagnostic
The `pareto_shape` parameter ``k=ξ`` of the generalized Pareto distribution `tail_dist` can
be used to diagnose reliability and convergence of estimates using the importance weights
- if ``k < \\frac{1}{3}``, importance sampling is stable, and importance sampling (IS) and
PSIS both are reliable.
- if ``k < \\frac{1}{2}``, then the importance ratio distributon has finite variance, and
the central limit theorem holds. As ``k`` approaches the upper bound, IS becomes less
reliable, while PSIS still works well but with a higher RMSE.
- if ``\\frac{1}{2} ≤ k < 0.7``, then the variance is infinite, and IS can behave quite
poorly. However, PSIS works well in this regime.
- if ``0.7 ≤ k < 1``, then it quickly becomes impractical to collect enough importance
weights to reliably compute estimates, and importance sampling is not recommended.
- if ``k ≥ 1``, then neither the variance nor the mean of the raw importance ratios
exists. The convergence rate is close to zero, and bias can be large with practical
sample sizes.
[^VehtariSimpson2021]: Vehtari A, Simpson D, Gelman A, Yao Y, Gabry J. (2021).
Pareto smoothed importance sampling.
[arXiv:1507.02646v7]( [stat.CO]
struct PSISResult{T,W<:AbstractArray{T},R,L,D}

function Base.propertynames(r::PSISResult)
return [fieldnames(typeof(r))..., :weights, :nparams, :ndraws, :nchains, :pareto_shape]

function Base.getproperty(r::PSISResult, k::Symbol)
if k === :weights
log_weights = getfield(r, :log_weights)
d = ndims(log_weights)
dims = d == 1 ? Colon() : ntuple(Base.Fix1(+, 1), d - 1)
return softmax(log_weights; dims=dims)
if k === :nparams
log_weights = getfield(r, :log_weights)
return ndims(log_weights) == 1 ? 1 : size(log_weights, 1)
if k === :ndraws
log_weights = getfield(r, :log_weights)
return ndims(log_weights) == 1 ? length(log_weights) : size(log_weights, 2)
if k === :nchains
log_weights = getfield(r, :log_weights)
return size(log_weights, 3)
k === :pareto_shape && return pareto_shape(r)
return getfield(r, k)

function, ::MIME"text/plain", r::PSISResult)
println(io, typeof(r), ":")
print(io, " pareto_shape: ", r.pareto_shape)
return nothing

psis(log_ratios, reff = 1.0; kwargs...) -> PSISResult
Compute Pareto smoothed importance sampling (PSIS) log weights [^VehtariSimpson2021].
Expand All @@ -37,38 +120,16 @@ See [`psis!`](@ref) for a version that smoothes the ratios in-place.
- `sorted=issorted(vec(log_ratios))`: whether `log_ratios` are already sorted. Only
accepted if `nparams==1`.
- `normalize=false`: whether to normalize the log weights so that the resulting weights
for a given parameter sum to one.
- `improved=false`: If `true`, use the adaptive empirical prior of [^Zhang2010].
If `false`, use the simpler prior of [^ZhangStephens2009], which is also used in
# Returns
- `log_weights`: an array of smoothed log weights of the same size as `log_ratios`
- `k`: for each parameter, the estimated shape parameter ``k`` of the generalized Pareto
distribution, which is useful for diagnosing the distribution of importance ratios.
See details below.
# Diagnostic
The shape parameter ``k`` of the generalized Pareto distribution can be used to diagnose
reliability and convergence of estimates using the importance weights [^VehtariSimpson2021]:
- if ``k < \\frac{1}{3}``, importance sampling is stable, and importance sampling (IS) and
PSIS both are reliable.
- if ``k < \\frac{1}{2}``, then the importance ratio distributon has finite variance, and
the central limit theorem holds. As ``k`` approaches the upper bound, IS becomes less
reliable, while PSIS still works well but with a higher RMSE.
- if ``\\frac{1}{2} ≤ k < 0.7``, then the variance is infinite, and IS can behave quite
poorly. However, PSIS works well in this regime.
- if ``0.7 ≤ k < 1``, then it quickly becomes impractical to collect enough importance
weights to reliably compute estimates, and importance sampling is not recommended.
- if ``k ≥ 1``, then neither the variance nor the mean of the raw importance ratios
exists. The convergence rate is close to zero, and bias can be large with practical
sample sizes.
- `result`: a [`PSISResult`](@ref) object containing the results of the Pareto-smoothing.
A warning is raised if ``k ≥ 0.7``.
A warning is raised if the Pareto shape parameter ``k ≥ 0.7``. See [`PSISResult`](@ref) for
[^VehtariSimpson2021]: Vehtari A, Simpson D, Gelman A, Yao Y, Gabry J. (2021).
Pareto smoothed importance sampling.
Expand All @@ -95,87 +156,86 @@ In-place compute Pareto smoothed importance sampling (PSIS) log weights.
See [`psis`](@ref) for an out-of-place version and for description of arguments and return
function psis!(
logw::AbstractVector, reff=1; sorted=issorted(logw), normalize=false, improved=false
T = eltype(logw)
function psis!(logw::AbstractVector, reff=1; sorted=issorted(logw), improved=false)
S = length(logw)
k_hat = T(Inf)

@assert isone(length(reff)) # support numbers or single-element array
reff_val = first(reff)
M = tail_length(reff_val, S)
if M < 5
@warn "Insufficient tail draws to fit the generalized Pareto distribution."
perm = sorted ? eachindex(logw) : sortperm(logw)
@inbounds logw_max = logw[last(perm)]
icut = S - M
tail_range = (icut + 1):S

@inbounds logw_tail = @views logw[perm[tail_range]]
if logw_max - first(logw_tail) < eps(eltype(logw_tail)) / 100
@warn "Cannot fit the generalized Pareto distribution because all tail " *
"values are the same"
logw_tail .-= logw_max
@inbounds logu = logw[perm[icut]] - logw_max

_, k_hat = psis_tail!(logw_tail, logu, M, improved)
logw_tail .+= logw_max

return PSISResult(logw, reff, M, missing)

if normalize
logw .-= logsumexp(logw)

return logw, k_hat
perm = sorted ? collect(eachindex(logw)) : sortperm(logw)
icut = S - M
tail_range = (icut + 1):S
@inbounds logw_tail = @views logw[perm[tail_range]]
@inbounds logu = logw[perm[icut]]
_, tail_dist = psis_tail!(logw_tail, logu, M, improved)
return PSISResult(logw, reff_val, M, tail_dist)
function psis!(logw::AbstractArray, reff=1; kwargs...)
Tdist = Union{Distributions.GeneralizedPareto{eltype(logw)},Missing}
logw_firstcol = view(logw, :, ntuple(_ -> 1, ndims(logw) - 1)...)
reff_vec = reff isa Number ? fill!(similar(logw_firstcol), reff) : reff
# support both 2D and 3D arrays, flattening the final dimension
_, k_hat = psis!(vec(selectdim(logw, 1, 1)), reff_vec[1]; kwargs...)
# for arrays with named dimensions, this pattern ensures k_hat has the same names
k_hats = similar(logw_firstcol, eltype(k_hat))
k_hats[1] = k_hat
Threads.@threads for i in eachindex(k_hats, reff_vec)
_, k_hats[i] = psis!(vec(selectdim(logw, 1, i)), reff_vec[i]; kwargs...)
r1 = psis!(vec(selectdim(logw, 1, 1)), reff_vec[1]; kwargs...)
# for arrays with named dimensions, this pattern ensures tail_lengths and tail_dists
# have the same names
tail_lengths = similar(logw_firstcol, Int)
tail_lengths[1] = r1.tail_length
tail_dists = similar(logw_firstcol, Tdist)
tail_dists[1] = r1.tail_dist
Threads.@threads for i in eachindex(tail_dists, reff_vec, tail_lengths, tail_dists)
ri = psis!(vec(selectdim(logw, 1, i)), reff_vec[i]; kwargs...)
tail_lengths[i] = ri.tail_length
tail_dists[i] = ri.tail_dist
return logw, k_hats
return PSISResult(logw, reff_vec, tail_lengths, map(identity, tail_dists))

function check_pareto_k(k)
if k 1
@warn "Pareto k=$(@sprintf("%.2g", k)) ≥ 1. Resulting importance sampling " *
pareto_shape(::Missing) = missing
pareto_shape(dist::Distributions.GeneralizedPareto) = Distributions.shape(dist)
pareto_shape(r::PSISResult) = pareto_shape(getfield(r, :tail_dist))
pareto_shape(dists) = map(pareto_shape, dists)

function check_pareto_shape(dist::Distributions.GeneralizedPareto)
ξ = pareto_shape(dist)
if ξ 1
@warn "Pareto shape=$(@sprintf("%.2g", ξ)) ≥ 1. Resulting importance sampling " *
"estimates are likely to be unstable and are unlikely to converge with " *
"additional samples."
elseif k 0.7
@warn "Pareto k=$(@sprintf("%.2g", k)) ≥ 0.7. Resulting importance sampling " *
elseif ξ 0.7
@warn "Pareto shape=$(@sprintf("%.2g", ξ)) ≥ 0.7. Resulting importance sampling " *
"estimates are likely to be unstable."
return nothing

tail_length(reff, S) = min(cld(S, 5), ceil(Int, 3 * sqrt(S / reff)))

function psis_tail!(logw, logu, M=length(logw), improved=false)
function psis_tail!(logw, logμ, M=length(logw), improved=false)
T = eltype(logw)
u = exp(logu)
w = (logw .= exp.(logw))
d_hat =, w; sorted=true, improved=improved)
d_hat = prior_adjust_shape(d_hat, M)
k_hat = Distributions.shape(d_hat)
if isfinite(k_hat)
logw_max = logw[M]
# to improve numerical stability, we first scale the log-weights to have a maximum of 1,
# equivalent to shifting the log-weights to have a maximum of 0.
μ_scaled = exp(logμ - logw_max)
w = (logw .= exp.(logw .- logw_max))
tail_dist_scaled =
GeneralizedParetoKnownMu(μ_scaled), w; sorted=true, improved=improved
tail_dist_adjusted = prior_adjust_shape(tail_dist_scaled, M)
# undo the scaling
ξ = Distributions.shape(tail_dist_adjusted)
if isfinite(ξ)
p = uniform_probabilities(T, M)
@inbounds for i in eachindex(logw, p)
logw[i] = min(log(_quantile(d_hat, p[i])), 0)
# undo scaling in the log-weights
logw[i] = min(log(_quantile(tail_dist_adjusted, p[i])), 0) + logw_max
return logw, k_hat
# undo scaling for the tail distribution
tail_dist = scale(tail_dist_adjusted, exp(logw_max))
return logw, tail_dist

4 changes: 4 additions & 0 deletions src/generalized_pareto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,7 @@ function prior_adjust_shape(d::Distributions.GeneralizedPareto, n, ξ_prior=1//2
ξ = (n * d.ξ + nobs * ξ_prior) / (n + nobs)
return Distributions.GeneralizedPareto(d.μ, d.σ, ξ)

function scale(d::Distributions.GeneralizedPareto, s)
return Distributions.GeneralizedPareto(d.μ * s, d.σ * s, d.ξ)

2 comments on commit 70ac431

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/49207

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 70ac431eeb4d0cd52dbd7ec825a230be9c553781
git push origin v0.2.0

Please sign in to comment.