diff --git a/src/stats.jl b/src/stats.jl index ba744d64..7fd1e04d 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -299,23 +299,42 @@ function summarystats( _chains = Chains(chains, _clean_sections(chains, sections)) # Calculate MCSE and ESS/R-hat separately. - mcse_df = MCMCDiagnosticTools.mcse( - _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, - ) - ess_rhat_rank_df = MCMCDiagnosticTools.ess_rhat( - _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:rank - ) - ess_tail_df = MCMCDiagnosticTools.ess( - _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:tail - ) - nt_additional = ( - mcse=mcse_df.nt.mcse, - ess_bulk=ess_rhat_rank_df.nt.ess, - ess_tail=ess_tail_df.nt.ess, - rhat=ess_rhat_rank_df.nt.rhat, - ess_per_sec=ess_rhat_rank_df.nt.ess_per_sec, - ) - additional_df = ChainDataFrame("Additional", nt_additional) + nt_additional = NamedTuple() + try + mcse_df = MCMCDiagnosticTools.mcse( + _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, + ) + nt_additional = merge(nt_additional, (; mcse=mcse_df.nt.mcse)) + catch e + @warn "MCSE calculation failed: $e" + end + + try + ess_tail_df = MCMCDiagnosticTools.ess( + _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:tail + ) + nt_additional = merge(nt_additional, (ess_tail=ess_tail_df.nt.ess,)) + catch e + @warn "Tail ESS calculation failed: $e" + end + + try + ess_rhat_rank_df = MCMCDiagnosticTools.ess_rhat( + _chains; sections = nothing, autocov_method = autocov_method, maxlag = maxlag, kind=:rank + ) + nt_ess_rhat_rank = ( + ess_bulk=ess_rhat_rank_df.nt.ess, + rhat=ess_rhat_rank_df.nt.rhat, + ess_per_sec=ess_rhat_rank_df.nt.ess_per_sec + ) + nt_additional = merge(nt_additional, nt_ess_rhat_rank) + catch e + @warn "Bulk ESS/R-hat calculation failed: $e" + end + + # Possibly re-order the columns to stay backwards-compatible. + additional_keys = (:mcse, :ess_bulk, :ess_tail, :rhat, :ess_per_sec) + additional_df = ChainDataFrame("Additional", (; ((k, nt_additional[k]) for k in additional_keys if k ∈ keys(nt_additional))...)) # Summarize. summary_df = summarize(