diff --git a/src/chains.jl b/src/chains.jl index e57a2f5e..3079273f 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -107,9 +107,9 @@ end #################### Indexing #################### function Base.getindex(c::Chains, window, names, chains) - inds1 = window2inds(c, window) - inds2 = names2inds(c, names) - Chains(c.value[inds1, inds2, chains], + inds1 = window2inds(c, window) + inds2 = names2inds(c, names) + return Chains(c.value[inds1, inds2, chains], start = first(c) + (first(inds1) - 1) * step(c), thin = step(inds1) * step(c), names = c.names[inds2], chains = c.chains[chains]) diff --git a/src/plot.jl b/src/plot.jl index d0e46714..11552783 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -38,7 +38,12 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity) if st == :mixeddensity discrete = MCMCChain.indiscretesupport(c, barbounds) - st = (discrete[i] && colordim == :chain) ? :histogram : :density + st = if colordim == :chain + discrete[i] ? :histogram : :density + else + # NOTE: It might make sense to overlay histograms and density plots here. + :density + end seriestype := st end @@ -57,14 +62,14 @@ end @recipe function f(p::_DensityPlot) xaxis --> "Sample value" yaxis --> "Density" - [collect(skipmissing(p.val[:,k])) for k in axes(p.val, 2)] + [collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)] end @recipe function f(p::_HistogramPlot) xaxis --> "Sample value" yaxis --> "Frequency" fillalpha --> 0.7 - [collect(skipmissing(p.val[:,k])) for k in axes(p.val, 2)] + [collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)] end @recipe function f(p::_MeanPlot) @@ -100,7 +105,7 @@ end ptypes = ptypes isa AbstractVector || ptypes isa Tuple ? ptypes : (ptypes,) @assert all(map(ptype -> ptype ∈ supportedplots, ptypes)) - nrows, nvars, nchains = size(c.value) + nrows, nvars, nchains = size(c) ntypes = length(ptypes) N = colordim == :chain ? nvars : nchains layout := (N, ntypes) diff --git a/test/plot_test.jl b/test/plot_test.jl index d6401928..9bbfe0a8 100644 --- a/test/plot_test.jl +++ b/test/plot_test.jl @@ -11,32 +11,46 @@ val = hcat(val, rand(1:2, n_iter, 1, n_chain)) chn = Chains(val) -# plotting singe plotting types -ps_trace = traceplot(chn, 1) -@test isa(ps_trace, Plots.Plot) +@testset "Plotting tests" begin -ps_mean = meanplot(chn, 1) -@test isa(ps_mean, Plots.Plot) + # plotting singe plotting types + ps_trace = traceplot(chn, 1) + @test isa(ps_trace, Plots.Plot) -ps_density = density(chn, 1) -@test isa(ps_density, Plots.Plot) + ps_mean = meanplot(chn, 1) + @test isa(ps_mean, Plots.Plot) -ps_autocor = autocorplot(chn, 1) -@test isa(ps_autocor, Plots.Plot) + ps_density = density(chn, 1) + @test isa(ps_density, Plots.Plot) -#ps_contour = plot(chn, :contour) + ps_autocor = autocorplot(chn, 1) + @test isa(ps_autocor, Plots.Plot) -ps_hist = histogram(chn, 1) -@test isa(ps_hist, Plots.Plot) + #ps_contour = plot(chn, :contour) -ps_mixed = mixeddensity(chn, 1) -@test isa(ps_mixed, Plots.Plot) + ps_hist = histogram(chn, 1) + @test isa(ps_hist, Plots.Plot) -# plotting combinations -ps_trace_mean = plot(chn) -@test isa(ps_trace_mean, Plots.Plot) + ps_mixed = mixeddensity(chn, 1) + @test isa(ps_mixed, Plots.Plot) -savefig("demo-plot.png") + # plotting combinations + ps_trace_mean = plot(chn) + @test isa(ps_trace_mean, Plots.Plot) -ps_mixed_auto = plot(chn, seriestype = (:mixeddensity, :autocorplot)) -@test isa(ps_mixed_auto, Plots.Plot) + savefig("demo-plot.png") + + ps_mixed_auto = plot(chn, seriestype = (:mixeddensity, :autocorplot)) + @test isa(ps_mixed_auto, Plots.Plot) + + # Test plotting using colordim keyword + p_colordim = plot(chn, colordim = :parameter) + @test isa(p_colordim, Plots.Plot) + + # Test if plotting a sub-set work.s + p_subset = plot(chn[:,2,:]) + @test isa(p_subset, Plots.Plot) + + p_subset_colordim = plot(chn[:,2,:], colordim = :parameter) + @test isa(p_subset_colordim, Plots.Plot) +end