From 738ba9fe7e616b0b753b3807490d26aeda51b059 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 17 Apr 2023 23:28:50 +0100 Subject: [PATCH] Fixes for `corner` (#415) * fix for corner plot * also fixed issues with corner and labels * version bump --- Project.toml | 2 +- src/plot.jl | 7 +++++-- test/plot_test.jl | 7 ++++++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 5fc4b7ce..b756946f 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "6.0.0" +version = "6.0.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/plot.jl b/src/plot.jl index 9be73bd1..1575846c 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -181,10 +181,13 @@ struct Corner end @recipe function f(corner::Corner) - label --> permutedims(corner.parameters) + # Convert labels to string because `Symbol` is not supported generally supported. + label --> permutedims(map(string, corner.parameters)) compact --> true size --> (600, 600) - ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c)) + # NOTE: Don't use the indices from `chains(chains)`. + # See https://github.com/TuringLang/MCMCChains.jl/issues/413. + ar = collect(Array(corner.c.value[:, corner.parameters, i]) for i in 1:length(chains(corner.c))) RecipesBase.recipetype(:cornerplot, vcat(ar...)) end diff --git a/test/plot_test.jl b/test/plot_test.jl index 860f383f..a7c6e999 100644 --- a/test/plot_test.jl +++ b/test/plot_test.jl @@ -8,7 +8,7 @@ unicodeplots() n_iter = 500 n_name = 3 -n_chain = 2 +n_chain = 3 val = randn(n_iter, n_name, n_chain) .+ [1, 2, 3]' val = hcat(val, rand(1:2, n_iter, 1, n_chain)) @@ -55,6 +55,11 @@ Logging.disable_logging(Logging.Warn) println("\nmixeddensity") display(mixeddensity(chn, 1)) + println("corner") + display(corner(chn[:, 1:2, :])) + # https://github.com/TuringLang/MCMCChains.jl/issues/413 + display(corner(chn[:, 1:2, 2:3])) + # plotting combinations display(plot(chn)) display(plot(chn, append_chains=true))