From 711a298833705df92821296d13de376abb07ea6a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 30 Jan 2023 15:22:00 +0100 Subject: [PATCH] Better support for `AbstractString` (#397) * Support `AbstractString` * Add tests * Bump version --- Project.toml | 2 +- src/chains.jl | 22 +++++++------- src/utils.jl | 10 +++--- test/diagnostic_tests.jl | 66 ++++++++++++++++++++++++++++++---------- 4 files changed, 67 insertions(+), 33 deletions(-) diff --git a/Project.toml b/Project.toml index ec6bf439..03707e8d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "5.6.1" +version = "5.7.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/chains.jl b/src/chains.jl index 7dae00b3..aa652592 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -80,7 +80,7 @@ function Chains( end """ - Chains(c::Chains, section::Union{Symbol,String}) + Chains(c::Chains, section::Union{Symbol,AbstractString}) Chains(c::Chains, sections) Return a new chain with only a specific `section` or multiple `sections` pulled out. @@ -101,7 +101,7 @@ julia> names(chn2) :a ``` """ -Chains(c::Chains, section::Union{Symbol,String}) = Chains(c, (section,)) +Chains(c::Chains, section::Union{Symbol,AbstractString}) = Chains(c, (section,)) function Chains(chn::Chains, sections) # Make sure the sections exist first. all(haskey(chn.name_map, Symbol(x)) for x in sections) || @@ -121,7 +121,7 @@ Chains(chain::Chains, ::Nothing) = chain # Groups of parameters """ - namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket) + namesingroup(chains::Chains, sym::Union{AbstractString,Symbol}; index_type::Symbol=:bracket) Return the parameters with the same name `sym`, but have a different index. Bracket indexing format in the form of `:sym[index]` is assumed by default. Use `index_type=:dot` for parameters with dot @@ -147,7 +147,7 @@ julia> namesingroup(chn, :A; index_type=:dot) Symbol("A.2") ``` """ -namesingroup(chains::Chains, sym::String; kwargs...) = namesingroup(chains, Symbol(sym); kwargs...) +namesingroup(chains::Chains, sym::AbstractString; kwargs...) = namesingroup(chains, Symbol(sym); kwargs...) function namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket) if index_type !== :bracket && index_type !== :dot error("index_type must be :bracket or :dot") @@ -161,14 +161,14 @@ function namesingroup(chains::Chains, sym::Symbol; index_type::Symbol=:bracket) end """ - group(chains::Chains, name::Union{String,Symbol}; index_type::Symbol=:bracket) + group(chains::Chains, name::Union{AbstractString,Symbol}; index_type::Symbol=:bracket) Return a subset of the chain containing parameters with the same `name`, but a different index. Bracket indexing format in the form of `:name[index]` is assumed by default. Use `index_type=:dot` for parameters with dot indexing, i.e. `:sym.index`. """ -function group(chains::Chains, name::Union{String,Symbol}; kwargs...) +function group(chains::Chains, name::Union{AbstractString,Symbol}; kwargs...) return chains[:, namesingroup(chains, name; kwargs...), :] end @@ -177,8 +177,8 @@ end Base.getindex(c::Chains, i::Integer) = c[i, :, :] Base.getindex(c::Chains, i::AbstractVector{<:Integer}) = c[i, :, :] -Base.getindex(c::Chains, v::String) = c[:, Symbol(v), :] -Base.getindex(c::Chains, v::AbstractVector{String}) = c[:, Symbol.(v), :] +Base.getindex(c::Chains, v::AbstractString) = c[:, Symbol(v), :] +Base.getindex(c::Chains, v::AbstractVector{<:AbstractString}) = c[:, Symbol.(v), :] Base.getindex(c::Chains, v::Symbol) = c[:, v, :] Base.getindex(c::Chains, v::AbstractVector{Symbol}) = c[:, v, :] @@ -199,7 +199,7 @@ _toindex(i, j, k::Integer) = (i, string2symbol(j), k:k) _toindex(i::Integer, j, k::Integer) = (i:i, string2symbol(j), k:k) # return an array or a number if a single parameter is specified -const SingleIndex = Union{Symbol,String,Integer} +const SingleIndex = Union{Symbol,AbstractString,Integer} _toindex(i, j::SingleIndex, k) = (i, string2symbol(j), k) _toindex(i::Integer, j::SingleIndex, k) = (i, string2symbol(j), k) _toindex(i, j::SingleIndex, k::Integer) = (i, string2symbol(j), k) @@ -542,7 +542,7 @@ Return multiple `Chains` objects, each containing only a single section. function get_sections(chains::Chains, sections = keys(chains.name_map)) return [Chains(chains, section) for section in sections] end -get_sections(chains::Chains, section::Union{Symbol, String}) = Chains(chains, section) +get_sections(chains::Chains, section::Union{Symbol, AbstractString}) = Chains(chains, section) """ sections(c::Chains) @@ -727,7 +727,7 @@ function _clean_sections(chains::Chains, sections) haskey(chains.name_map, Symbol(section)) end end -function _clean_sections(chains::Chains, section::Union{String,Symbol}) +function _clean_sections(chains::Chains, section::Union{AbstractString,Symbol}) return haskey(chains.name_map, Symbol(section)) ? section : () end _clean_sections(::Chains, ::Nothing) = nothing diff --git a/src/utils.jl b/src/utils.jl index 43809700..1eeebcdc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -27,13 +27,13 @@ end Convert strings to symbols. -If `x isa String`, the corresponding `Symbol` is returned. Likewise, if -`x isa AbstractVector{String}`, the corresponding vector of `Symbol`s is returned. In all -other cases, input `x` is returned. +If `x isa AbstractString`, the corresponding `Symbol` is returned. +Likewise, if `x isa AbstractVector{<:AbstractString}`, the corresponding vector of `Symbol`s is returned. +In all other cases, input `x` is returned. """ string2symbol(x) = x -string2symbol(x::String) = Symbol(x) -string2symbol(x::AbstractVector{String}) = Symbol.(x) +string2symbol(x::AbstractString) = Symbol(x) +string2symbol(x::AbstractVector{<:AbstractString}) = Symbol.(x) #################### Mathematical Operators #################### function cummean(x::AbstractArray) diff --git a/test/diagnostic_tests.jl b/test/diagnostic_tests.jl index 8ef8bb15..a7a4d443 100644 --- a/test/diagnostic_tests.jl +++ b/test/diagnostic_tests.jl @@ -101,13 +101,42 @@ end end @testset "indexing tests" begin - @test chn[:,1,:] isa AbstractMatrix - @test chn[200:300, "param_1", :] isa AbstractMatrix - @test chn[200:300, ["param_1", "param_3"], :] isa Chains - @test chn[200:300, "param_1", 1] isa AbstractVector - @test size(chn[:,1,:]) == (niter, nchains) - @test chn[:,1,1] == val[:,1,1] - @test chn[:,1,2] == val[:,1,2] + c = chn[:, 1, :] + @test c isa AbstractMatrix + @test size(c) == (niter, nchains) + @test c == val[:, 1, :] + + for i in 1:2 + c = chn[:, 1, i] + @test c isa AbstractVector + @test length(c) == niter + @test c == val[:, 1, i] + end + + for p in (:param_1, "param_1", SubString("param_1", 1)) + c = chn[200:300, p, :] + @test c isa AbstractMatrix + @test size(c) == (101, size(chn, 3)) + @test c == val[200:300, 1, :] + + c = chn[200:300, p, 1] + @test c isa AbstractVector + @test length(c) == 101 + @test c == val[200:300, 1, 1] + end + + for ps in ( + [:param_1, :param_3], + ["param_1", "param_3"], + [SubString("param_1", 1), "param_3"], + ["param_1", SubString("param_3", 1)], + [SubString("param_1", 1), SubString("param_3", 1)], + ) + c = chn[200:300, ps, :] + @test c isa Chains + @test size(c) == (101, 2, nchains) + @test c.value.data == val[200:300, [1, 3], :] + end end @testset "names and groups tests" begin @@ -116,18 +145,23 @@ end (@inferred replacenames(chn, Dict("param_2" => "param[2]", "param_3" => "param[3]"))).value @test names(chn2) == [:param_1, Symbol("param[2]"), Symbol("param[3]"), :param_4] - @test namesingroup(chn2, "param") == Symbol.(["param[2]", "param[3]"]) + for p in (:param, "param", SubString("param", 1)) + @test namesingroup(chn2, p) == Symbol.(["param[2]", "param[3]"]) + end - chn3 = group(chn2, "param") - @test names(chn3) == Symbol.(["param[2]", "param[3]"]) - @test chn3.value == chn[:, [:param_2, :param_3], :].value + for p in (:param, "param", SubString("param", 1)) + chn3 = group(chn2, p) + @test names(chn3) == Symbol.(["param[2]", "param[3]"]) + @test chn3.value == chn[:, [:param_2, :param_3], :].value + end stan_chn = Chains(rand(100, 3, 1), ["a.1", "a[2]", "b"]) - @test namesingroup(stan_chn, "a"; index_type=:dot) == [Symbol("a.1")] - @test namesingroup(stan_chn, :a; index_type=:dot) == [Symbol("a.1")] - @test names(group(stan_chn, :a; index_type=:dot)) == [Symbol("a.1")] - @test_throws Exception namesingroup(stan_chn, :a; index_type=:x) - @test_throws Exception group(stan_chn, :a; index_type=:x) + for p in (:a, "a", SubString("a", 1)) + @test namesingroup(stan_chn, p; index_type=:dot) == [Symbol("a.1")] + @test names(group(stan_chn, p; index_type=:dot)) == [Symbol("a.1")] + @test_throws Exception namesingroup(stan_chn, p; index_type=:x) + @test_throws Exception group(stan_chn, p; index_type=:x) + end end @testset "function tests" begin