Skip to content

Commit

Permalink
Allow show to function properly even if some statistic computations f…
Browse files Browse the repository at this point in the history
…ail (#412)

* allow show to function properly even if some statistic computations fail

* added patch version

* fixed ordering of stats in case some are missing

* added missing )

---------

Co-authored-by: Cameron Pfiffer <[email protected]>
  • Loading branch information
torfjelde and cpfiffer authored Apr 18, 2023
1 parent 738ba9f commit e1041f5
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions src/stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -299,23 +299,42 @@ function summarystats(
_chains = Chains(chains, _clean_sections(chains, sections))

# Calculate MCSE and ESS/R-hat separately.
mcse_df = MCMCDiagnosticTools.mcse(
_chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag,
)
ess_rhat_rank_df = MCMCDiagnosticTools.ess_rhat(
_chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:rank
)
ess_tail_df = MCMCDiagnosticTools.ess(
_chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:tail
)
nt_additional = (
mcse=mcse_df.nt.mcse,
ess_bulk=ess_rhat_rank_df.nt.ess,
ess_tail=ess_tail_df.nt.ess,
rhat=ess_rhat_rank_df.nt.rhat,
ess_per_sec=ess_rhat_rank_df.nt.ess_per_sec,
)
additional_df = ChainDataFrame("Additional", nt_additional)
nt_additional = NamedTuple()
try
mcse_df = MCMCDiagnosticTools.mcse(
_chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag,
)
nt_additional = merge(nt_additional, (; mcse=mcse_df.nt.mcse))
catch e
@warn "MCSE calculation failed: $e"
end

try
ess_tail_df = MCMCDiagnosticTools.ess(
_chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:tail
)
nt_additional = merge(nt_additional, (ess_tail=ess_tail_df.nt.ess,))
catch e
@warn "Tail ESS calculation failed: $e"
end

try
ess_rhat_rank_df = MCMCDiagnosticTools.ess_rhat(
_chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:rank
)
nt_ess_rhat_rank = (
ess_bulk=ess_rhat_rank_df.nt.ess,
rhat=ess_rhat_rank_df.nt.rhat,
ess_per_sec=ess_rhat_rank_df.nt.ess_per_sec
)
nt_additional = merge(nt_additional, nt_ess_rhat_rank)
catch e
@warn "Bulk ESS/R-hat calculation failed: $e"
end

# Possibly re-order the columns to stay backwards-compatible.
additional_keys = (:mcse, :ess_bulk, :ess_tail, :rhat, :ess_per_sec)
additional_df = ChainDataFrame("Additional", (; ((k, nt_additional[k]) for k in additional_keys if k keys(nt_additional))...))

# Summarize.
summary_df = summarize(
Expand Down

2 comments on commit e1041f5

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/81799

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v6.0.1 -m "<description of version>" e1041f552542eb9442f96f368023eb8fa287fb37
git push origin v6.0.1

Please sign in to comment.