Skip to content

Commit

Permalink
Fixes for corner (#415)
Browse files Browse the repository at this point in the history
* fix for corner plot

* also fixed issues with corner and labels

* version bump
  • Loading branch information
torfjelde authored Apr 17, 2023
1 parent ddac60f commit 738ba9f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "6.0.0"
version = "6.0.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
7 changes: 5 additions & 2 deletions src/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,13 @@ struct Corner
end

@recipe function f(corner::Corner)
label --> permutedims(corner.parameters)
# Convert labels to string because `Symbol` is not supported generally supported.
label --> permutedims(map(string, corner.parameters))
compact --> true
size --> (600, 600)
ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c))
# NOTE: Don't use the indices from `chains(chains)`.
# See https://github.com/TuringLang/MCMCChains.jl/issues/413.
ar = collect(Array(corner.c.value[:, corner.parameters, i]) for i in 1:length(chains(corner.c)))
RecipesBase.recipetype(:cornerplot, vcat(ar...))
end

Expand Down
7 changes: 6 additions & 1 deletion test/plot_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ unicodeplots()

n_iter = 500
n_name = 3
n_chain = 2
n_chain = 3

val = randn(n_iter, n_name, n_chain) .+ [1, 2, 3]'
val = hcat(val, rand(1:2, n_iter, 1, n_chain))
Expand Down Expand Up @@ -55,6 +55,11 @@ Logging.disable_logging(Logging.Warn)
println("\nmixeddensity")
display(mixeddensity(chn, 1))

println("corner")
display(corner(chn[:, 1:2, :]))
# https://github.com/TuringLang/MCMCChains.jl/issues/413
display(corner(chn[:, 1:2, 2:3]))

# plotting combinations
display(plot(chn))
display(plot(chn, append_chains=true))
Expand Down

0 comments on commit 738ba9f

Please sign in to comment.