Skip to content

Commit

Permalink
keep parameter order (#466)
Browse files Browse the repository at this point in the history
* keep parameter order

* Test order of parameters for get(c; section)

---------

Co-authored-by: Markus Hauru <[email protected]>
  • Loading branch information
a1ix2 and mhauru authored Dec 6, 2024
1 parent aa6161b commit 69bbd98
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ julia> get(chn, :param_1; flatten=true)
```
"""
function Base.get(c::Chains, vs::Vector{Symbol}; flatten=false)
pairs = Dict()
pairs = OrderedCollections.OrderedDict()
for v in vs
syms = namesingroup(c, v)
len = length(syms)
Expand Down Expand Up @@ -289,7 +289,7 @@ function Base.get(
section::Union{Symbol,AbstractVector{Symbol}},
flatten = false
)
names = Set(Symbol[])
names = OrderedCollections.OrderedSet(Symbol[])
regex = r"[^\[]*"
_section = section isa Symbol ? (section,) : section
for v in _section
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function cummean(x::AbstractVector)
return y
end

function _dict2namedtuple(d::Dict)
function _dict2namedtuple(d::AbstractDict)
t_keys = ntuple(x -> Symbol(collect(keys(d))[x]), length(keys(d)))
t_vals = ntuple(x -> collect(values(d))[x], length(values(d)))
return NamedTuple{t_keys}(t_vals)
Expand Down
7 changes: 5 additions & 2 deletions test/sections_tests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
using MCMCChains, Test

# https://github.com/TuringLang/AdvancedMH.jl/pull/63
# https://github.com/TuringLang/MCMCChains.jl/issues/443
@testset "order of parameters" begin
chains = Chains(rand(2, 10, 4), vcat(["μ[$i]" for i in 1:9], :lp), (internals=[:lp],))
@test names(chains, :parameters) == [Symbol("μ[$i]") for i in 1:9]
params = vcat(["μ[$i]" for i in 1:9], :p1, :p2, :p3)
chains = Chains(rand(2, length(params)+1, 4), vcat(params, :lp), (internals=[:lp],))
@test names(chains, :parameters) == map(Symbol, params)
@test collect(keys(get(chains; section=:parameters))) == [, :p1, :p2, :p3]
end

@testset "describe sections" begin
Expand Down

0 comments on commit 69bbd98

Please sign in to comment.