Skip to content

Commit

Permalink
Add warnings/errors if relative efficiency is invalid (#52)
Browse files Browse the repository at this point in the history
* Add informative warnings for relative efficiency shape mismatch

* Increment patch number

* Warn if reff value is invalid
  • Loading branch information
sethaxen authored Jul 1, 2023
1 parent e0f8756 commit a6eff24
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PSIS"
uuid = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
authors = ["Seth Axen <[email protected]> and contributors"]
version = "0.9.0"
version = "0.9.1"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
30 changes: 28 additions & 2 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ end

function psis!(logw::AbstractVecOrMat, reff=1; normalize::Bool=true, warn::Bool=true)
T = typeof(float(one(eltype(logw))))
if length(reff) != 1
throw(DimensionMismatch("`reff` has length $(length(reff)) but must have length 1"))
end
warn && check_reff(reff)
S = length(logw)
reff_val = first(reff)
M = tail_length(reff_val, S)
Expand Down Expand Up @@ -247,7 +251,7 @@ function psis!(logw::AbstractVecOrMat, reff=1; normalize::Bool=true, warn::Bool=
return PSISResult(logw, reff_val, M, tail_dist, normalize)
end
function psis!(logw::AbstractMatrix, reff=1; kwargs...)
result = psis!(vec(logw), only(reff); kwargs...)
result = psis!(vec(logw), reff; kwargs...)
# unflatten log_weights
return PSISResult(
logw, result.reff, result.tail_length, result.tail_dist, result.normalized
Expand All @@ -257,6 +261,15 @@ function psis!(logw::AbstractArray, reff=1; normalize::Bool=true, warn::Bool=tru
T = typeof(float(one(eltype(logw))))
# if an array defines custom indices (e.g. AbstractDimArray), we preserve them
param_axes = _param_axes(logw)
param_shape = map(length, param_axes)
if !(length(reff) == 1 || size(reff) == param_shape)
throw(
DimensionMismatch(
"`reff` has shape $(size(reff)) but must have same shape as the parameter axes $(param_shape)",
),
)
end
check_reff(reff)

# allocate containers
reffs = similar(logw, eltype(reff), param_axes)
Expand Down Expand Up @@ -284,6 +297,14 @@ pareto_shape(dist::GeneralizedPareto) = dist.k
pareto_shape(r::PSISResult) = pareto_shape(getfield(r, :tail_dist))
pareto_shape(dists) = map(pareto_shape, dists)

function check_reff(reff)
isvalid = all(reff) do r
return isfinite(r) && r > 0
end
isvalid || @warn "All values of `reff` should be finite, but some are not."
return nothing
end

check_pareto_shape(result::PSISResult) = check_pareto_shape(result.tail_dist)
function check_pareto_shape(dist::GeneralizedPareto)
k = pareto_shape(dist)
Expand All @@ -310,7 +331,12 @@ function check_pareto_shape(dists::AbstractArray{<:GeneralizedPareto})
return nothing
end

tail_length(reff, S) = min(cld(S, 5), ceil(Int, 3 * sqrt(S / reff)))
function tail_length(reff, S)
max_length = cld(S, 5)
(isfinite(reff) && reff > 0) || return max_length
min_length = ceil(Int, 3 * sqrt(S / reff))
return min(max_length, min_length)
end

function psis_tail!(logw, logμ)
T = eltype(logw)
Expand Down
39 changes: 39 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,46 @@ end
end
end

@testset "reff combinations" begin
reffs_uniform = [rand(), fill(rand()), [rand()]]
x = randn(1000)
for r in reffs_uniform
psis(x, r)
end
@test_throws DimensionMismatch psis(x, rand(2))

x = randn(1000, 4)
for r in reffs_uniform
psis(x, r)
end
@test_throws DimensionMismatch psis(x, rand(2))

x = randn(1000, 4, 2)
for r in reffs_uniform
psis(x, r)
end
psis(x, rand(2))
@test_throws DimensionMismatch psis(x, rand(3))

x = randn(1000, 4, 2, 3)
for r in reffs_uniform
psis(x, r)
end
psis(x, rand(2, 3))
@test_throws DimensionMismatch psis(x, rand(3))
end

@testset "warnings" begin
io = IOBuffer()
@testset for sz in (100, (100, 4, 3)), rbad in (-1, 0, NaN)
logr = randn(sz)
result = with_logger(SimpleLogger(io)) do
psis(logr, rbad)
end
msg = String(take!(io))
@test occursin("All values of `reff` should be finite, but some are not.", msg)
end

io = IOBuffer()
logr = randn(5)
result = with_logger(SimpleLogger(io)) do
Expand Down

2 comments on commit a6eff24

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/86686

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.1 -m "<description of version>" a6eff24f8001f4bff817e2da85954a6858168ca3
git push origin v0.9.1

Please sign in to comment.