Skip to content

Commit

Permalink
Support MCMCDiagnosticTools v0.2 (#392)
Browse files Browse the repository at this point in the history
* Create permutedims utility function

* Bump compat

* Permute dims

* Don't worry about building y

* Remove unnecessary version checks

* Remove Compat dependency for eachcols

* Bump minimum tested version

* Bump minor version

* Test against permuted dims

* Add undocumented sections keyword arg

* Bump test min Julia version to v1.3

* Set Julia lower bound to 1.6

* Update minimum tested version
  • Loading branch information
sethaxen authored Dec 13, 2022
1 parent 5f59160 commit e5dbc07
Show file tree
Hide file tree
Showing 11 changed files with 36 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
strategy:
matrix:
version:
- '1.0'
- '1.6'
- '1'
- nightly
os:
Expand Down
8 changes: 3 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "5.5.0"
version = "5.6.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
Expand All @@ -32,12 +31,11 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
[compat]
AbstractMCMC = "0.4, 0.5, 1.0, 2.0, 3.0, 4"
AxisArrays = "0.4.4"
Compat = "2.2, 3, 4"
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
Formatting = "0.4"
IteratorInterfaceExtensions = "0.1.1, 1"
KernelDensity = "0.6.2"
MCMCDiagnosticTools = "0.1"
MCMCDiagnosticTools = "0.2"
MLJModelInterface = "0.3.5, 0.4, 1.0"
NaturalSort = "1"
OrderedCollections = "1.4"
Expand All @@ -47,4 +45,4 @@ StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9, 1"
TableTraits = "0.4, 1"
Tables = "1"
julia = "1"
julia = "1.6"
30 changes: 8 additions & 22 deletions src/MCMCChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using AxisArrays
const axes = Base.axes
import AbstractMCMC
import AbstractMCMC: chainscat
using Compat
using Distributions
using RecipesBase
using Formatting
Expand Down Expand Up @@ -91,26 +90,13 @@ include("rstar.jl")
# so we use the following hack
const _read = Base.read
const _write = Base.write
@static if VERSION < v"1.1"
Base.@deprecate _read(
f::AbstractString,
::Type{T}
) where {T<:Chains} open(Serialization.deserialize, f, "r") false
Base.@deprecate _write(
f::AbstractString,
c::Chains
) open(f, "w") do io
Serialization.serialize(io, c)
end false
else
Base.@deprecate _read(
f::AbstractString,
::Type{T}
) where {T<:Chains} Serialization.deserialize(f) false
Base.@deprecate _write(
f::AbstractString,
c::Chains
) Serialization.serialize(f, c) false
end
Base.@deprecate _read(
f::AbstractString,
::Type{T}
) where {T<:Chains} Serialization.deserialize(f) false
Base.@deprecate _write(
f::AbstractString,
c::Chains
) Serialization.serialize(f, c) false

end # module
2 changes: 1 addition & 1 deletion src/discretediag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function MCMCDiagnosticTools.discretediag(

# Compute statistics.
between_chain_vals, within_chain_vals = MCMCDiagnosticTools.discretediag(
_chains.value.data; kwargs...
_permutedims_diagnostics(_chains.value.data); kwargs...
)

# Create dataframes
Expand Down
5 changes: 4 additions & 1 deletion src/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ function MCMCDiagnosticTools.ess_rhat(
_chains = Chains(chains, _clean_sections(chains, sections))

# Estimate the effective sample size and rhat
ess, rhat = MCMCDiagnosticTools.ess_rhat(_chains.value.data; kwargs...)
ess, rhat = MCMCDiagnosticTools.ess_rhat(
_permutedims_diagnostics(_chains.value.data);
kwargs...,
)

# Calculate ESS/minute if available
dur = duration(chains)
Expand Down
7 changes: 5 additions & 2 deletions src/gelmandiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function MCMCDiagnosticTools.gelmandiag(

# Compute the potential scale reduction factor.
psi = transform ? link(_chains) : _chains.value.data
results = MCMCDiagnosticTools.gelmandiag(psi; kwargs...)
results = MCMCDiagnosticTools.gelmandiag(_permutedims_diagnostics(psi); kwargs...)

# Create a data frame with the results.
df = ChainDataFrame(
Expand All @@ -31,7 +31,10 @@ function MCMCDiagnosticTools.gelmandiag_multivariate(

# Compute the potential scale reduction factor.
psi = transform ? link(_chains) : _chains.value.data
results = MCMCDiagnosticTools.gelmandiag_multivariate(psi; kwargs...)
results = MCMCDiagnosticTools.gelmandiag_multivariate(
_permutedims_diagnostics(psi);
kwargs...,
)

# Create a data frame with the results.
df = ChainDataFrame(
Expand Down
13 changes: 8 additions & 5 deletions src/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,17 @@ function MCMCDiagnosticTools.rstar(
end

function MCMCDiagnosticTools.rstar(
rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, chn::Chains; kwargs...
rng::Random.AbstractRNG,
classif::MLJModelInterface.Supervised,
chn::Chains;
sections = _default_sections(chn),
kwargs...
)
nchains = size(chn, 3)
nchains <= 1 && throw(DimensionMismatch())

# collect data
x = Array(chn)
y = repeat(chains(chn); inner = size(chn,1))
_chn = Chains(chn, _clean_sections(chn, sections))
x = _permutedims_diagnostics(_chn.value.data)

return MCMCDiagnosticTools.rstar(rng, classif, x, y; kwargs...)
return MCMCDiagnosticTools.rstar(rng, classif, x; kwargs...)
end
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,6 @@ function _isstrictlyincreasing_nonempty(x::AbstractVector{Int})
end
return true
end

# permute dims to match the dimension order of MCMCDiagnosticsTools
_permutedims_diagnostics(x) = PermutedDimsArray(x, (1, 3, 2))
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ StatsPlots = "0.14.17, 0.15"
TableTraits = "1"
Tables = "1.3.1"
UnicodePlots = "2, 3"
julia = "1"
julia = "1.6"
4 changes: 2 additions & 2 deletions test/ess_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ end
ess_df = ess_rhat(chain; method = method)

# analyze array
ess_array, rhat_array = ess_rhat(x; method = method)
ess_array, rhat_array = ess_rhat(permutedims(x, (1, 3, 2)); method = method)

@test ess_df[:,2] == ess_array
@test ess_df[:,3] == rhat_array
Expand All @@ -46,7 +46,7 @@ end
ess_df = ess_rhat(chain; method = method)

# analyze array
ess_array, rhat_array = ess_rhat(val; method = method)
ess_array, rhat_array = ess_rhat(permutedims(val, (1, 3, 2)); method = method)

@test ismissing(ess_df[:,2][1]) # since min(maxlag, niter - 1) = 0
@test ismissing(ess_df[:,3][1])
Expand Down
8 changes: 0 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
# Activate test environment on older Julia versions
if VERSION < v"1.2"
using Pkg
Pkg.activate(@__DIR__)
Pkg.develop(PackageSpec(path=dirname(@__DIR__)))
Pkg.instantiate()
end

using MCMCChains
using Documenter

Expand Down

2 comments on commit e5dbc07

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/74057

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v5.6.0 -m "<description of version>" e5dbc07e2a770618a89a2874b0554c2e64c36b15
git push origin v5.6.0

Please sign in to comment.