diff --git a/Project.toml b/Project.toml index 9464bc08..6a6a51ab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MCMCDiagnosticTools" uuid = "be115224-59cd-429b-ad48-344e309966f0" authors = ["David Widmann"] -version = "0.3.4" +version = "0.3.5" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/README.md b/README.md index d94fe42c..9561bee1 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ [![Build Status](https://github.com/TuringLang/MCMCDiagnosticTools.jl/workflows/CI/badge.svg?branch=main)](https://github.com/TuringLang/MCMCDiagnosticTools.jl/actions?query=workflow%3ACI+branch%3Amain) [![Coverage](https://codecov.io/gh/TuringLang/MCMCDiagnosticTools.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/TuringLang/MCMCDiagnosticTools.jl) [![Coverage](https://coveralls.io/repos/github/TuringLang/MCMCDiagnosticTools.jl/badge.svg?branch=main)](https://coveralls.io/github/TuringLang/MCMCDiagnosticTools.jl?branch=main) +[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) diff --git a/src/rstar.jl b/src/rstar.jl index 41de1233..9b16b299 100644 --- a/src/rstar.jl +++ b/src/rstar.jl @@ -211,10 +211,7 @@ julia> round(value; digits=2) Lambert, B., & Vehtari, A. (2020). ``R^*``: A robust MCMC convergence diagnostic with uncertainty using decision tree classifiers. """ function rstar(rng::Random.AbstractRNG, classifier, x::AbstractArray; kwargs...) - return rstar(rng, classifier, _params_array(x); kwargs...) -end -function rstar(rng::Random.AbstractRNG, classifier, x::AbstractArray{<:Any,3}; kwargs...) - samples = reshape(x, :, size(x, 3)) + samples = reshape(x, size(x, 1) * size(x, 2), :) chain_inds = repeat(axes(x, 2); inner=size(x, 1)) return rstar(rng, classifier, samples, chain_inds; kwargs...) end @@ -222,6 +219,12 @@ end function rstar(classifier, x, y::AbstractVector{Int}; kwargs...) return rstar(Random.default_rng(), classifier, x, y; kwargs...) end +# Fix method ambiguity issue +function rstar(rng::Random.AbstractRNG, classifier, x::AbstractVector{Int}; kwargs...) + samples = reshape(x, length(x), :) + chain_inds = ones(Int, length(x)) + return rstar(rng, classifier, samples, chain_inds; kwargs...) +end function rstar(classifier, x::AbstractArray; kwargs...) return rstar(Random.default_rng(), classifier, x; kwargs...) diff --git a/test/Project.toml b/test/Project.toml index e1909c9c..553ae630 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" @@ -19,6 +20,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +Aqua = "0.6.5" Distributions = "0.25" DynamicHMC = "3" EvoTrees = "0.14.7, 0.15" diff --git a/test/aqua.jl b/test/aqua.jl new file mode 100644 index 00000000..e6226e77 --- /dev/null +++ b/test/aqua.jl @@ -0,0 +1,16 @@ +using MCMCDiagnosticTools +using Aqua +using Test + +@testset "Aqua" begin + # Test ambiguities separately without Base and Core + # Ref: https://github.com/JuliaTesting/Aqua.jl/issues/77 + # Only test Project.toml formatting on Julia > 1.6 when running Github action + # Ref: https://github.com/JuliaTesting/Aqua.jl/issues/105 + Aqua.test_all( + MCMCDiagnosticTools; + ambiguities=false, + project_toml_formatting=VERSION >= v"1.7" || !haskey(ENV, "GITHUB_ACTIONS"), + ) + Aqua.test_ambiguities([MCMCDiagnosticTools]) +end diff --git a/test/rstar.jl b/test/rstar.jl index 848fb221..1550c38b 100644 --- a/test/rstar.jl +++ b/test/rstar.jl @@ -188,4 +188,14 @@ end ) @test_throws ArgumentError MCMCDiagnosticTools._rstar(1.0, rand(2), rand(2)) end + + @testset "single chain: method ambiguity issue" begin + samples = rand(1:5, N) + rng = MersenneTwister(42) + dist = rstar(rng, DecisionTreeClassifier(), samples) + @test mean(dist) ≈ 1 atol = 0.15 + Random.seed!(rng, 42) + dist2 = rstar(rng, DecisionTreeClassifier(), samples, ones(Int, N)) + @test dist2 == dist + end end diff --git a/test/runtests.jl b/test/runtests.jl index b756fbe2..07f52402 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,10 @@ Random.seed!(1) @testset "MCMCDiagnosticTools.jl" begin include("helpers.jl") + @testset "Aqua" begin + include("aqua.jl") + end + @testset "utils" begin include("utils.jl") end