Skip to content

Commit 72ef0e2

Browse files
authored
Add custom warning for k greater than 1 (#2)
* Add Printf as dependency * Customize warning for Pareto k greater than 1 * Increment version number
1 parent b77ad1c commit 72ef0e2

File tree

3 files changed

+45
-7
lines changed

3 files changed

+45
-7
lines changed

Project.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "PSIS"
22
uuid = "ce719bf2-d5d0-4fb9-925d-10a81b42ad04"
33
authors = ["Seth Axen <[email protected]> and contributors"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
89
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
910

1011
[compat]
@@ -14,8 +15,8 @@ julia = "1"
1415
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1516
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1617
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
17-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1818
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
19+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1920
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2021

2122
[targets]

src/PSIS.jl

+15-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module PSIS
22

33
using Statistics: mean
44
using LinearAlgebra: dot
5+
using Printf: @sprintf
56

67
export psis, psis!
78

@@ -88,17 +89,15 @@ function psis!(logw, r_eff=1.0; sorted=issorted(logw), normalize=false)
8889
@inbounds logw_tail = @views logw[perm[tail_range]]
8990
if logw_max - first(logw_tail) < eps(eltype(logw_tail)) / 100
9091
@warn "Cannot fit the generalized Pareto distribution because all tail " *
91-
"values are the same"
92+
"values are the same"
9293
else
9394
logw_tail .-= logw_max
9495
@inbounds logu = logw[perm[icut]] - logw_max
9596

9697
_, k_hat = psis_tail!(logw_tail, logu, M)
9798
logw_tail .+= logw_max
9899

99-
k_hat 0.7 &&
100-
@warn "Pareto k statistic exceeded 0.7. Resulting importance sampling estimates " *
101-
"are likely to be unstable."
100+
check_pareto_k(k_hat)
102101
end
103102
end
104103

@@ -109,6 +108,18 @@ function psis!(logw, r_eff=1.0; sorted=issorted(logw), normalize=false)
109108
return logw, k_hat
110109
end
111110

111+
function check_pareto_k(k)
112+
if k 1
113+
@warn "Pareto k=$(@sprintf("%.2g", k)) ≥ 1. Resulting importance sampling " *
114+
"estimates are likely to be unstable and are unlikely to converge with " *
115+
"additional samples."
116+
elseif k 0.7
117+
@warn "Pareto k=$(@sprintf("%.2g", k)) ≥ 0.7. Resulting importance sampling " *
118+
"estimates are likely to be unstable."
119+
end
120+
return nothing
121+
end
122+
112123
tail_length(r_eff, S) = min(cld(S, 5), ceil(Int, 3 * sqrt(S / r_eff)))
113124

114125
function psis_tail!(logw, logu, M=length(logw))

test/psis.jl

+27-1
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,35 @@ end
9595
@test k > 0.7
9696
msg = String(take!(io))
9797
@test occursin(
98-
"Warning: Pareto k statistic exceeded 0.7. Resulting importance sampling estimates are likely to be unstable",
98+
"Resulting importance sampling estimates are likely to be unstable", msg
99+
)
100+
101+
io = IOBuffer()
102+
with_logger(SimpleLogger(io)) do
103+
PSIS.check_pareto_k(1.1)
104+
end
105+
msg = String(take!(io))
106+
@test occursin(
107+
"Warning: Pareto k=1.1 ≥ 1. Resulting importance sampling estimates are likely to be unstable and are unlikely to converge with additional samples.",
99108
msg,
100109
)
110+
111+
io = IOBuffer()
112+
with_logger(SimpleLogger(io)) do
113+
PSIS.check_pareto_k(0.8)
114+
end
115+
msg = String(take!(io))
116+
@test occursin(
117+
"Warning: Pareto k=0.8 ≥ 0.7. Resulting importance sampling estimates are likely to be unstable.",
118+
msg,
119+
)
120+
121+
io = IOBuffer()
122+
with_logger(SimpleLogger(io)) do
123+
PSIS.check_pareto_k(0.69)
124+
end
125+
msg = String(take!(io))
126+
@test isempty(msg)
101127
end
102128

103129
has_loo() && @testset "consistent with loo" begin

0 commit comments

Comments
 (0)