diff --git a/Project.toml b/Project.toml index a18c8807..2efb10e4 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "4.11.0" +version = "4.12.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/chains.jl b/src/chains.jl index bebdc573..d216a04e 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -335,17 +335,16 @@ Base.convert(::Type{Array}, chn::Chains) = convert(Array, chn.value) # Convenience functions to handle different types of # timestamps. -min_datetime(t::DateTime) = t -min_datetime(ts::Vector{DateTime}) = minimum(ts) -min_datetime(t::Float64) = unix2datetime(t) -min_datetime(ts::Vector{Float64}) = unix2datetime(minimum(ts)) -min_datetime(ts) = missing_datetime(typeof(ts)) - -max_datetime(t::DateTime) = t -max_datetime(ts::Vector{DateTime}) = maximum(ts) -max_datetime(t::Float64) = unix2datetime(t) -max_datetime(ts::Vector{Float64}) = unix2datetime(maximum(ts)) -max_datetime(ts) = missing_datetime(typeof(ts)) +to_datetime(t::DateTime) = t +to_datetime(t::Float64) = unix2datetime(t) +to_datetime(t) = missing_datetime(typeof(t)) +to_datetime_vec(t::Union{Float64, DateTime}) = [to_datetime(t)] +to_datetime_vec(t::DateTime) = [to_datetime(t)] +to_datetime_vec(ts::Vector) = map(to_datetime, ts) +to_datetime_vec(ts) = [missing] + +min_datetime(ts) = minimum(to_datetime_vec(ts)) +max_datetime(ts) = maximum(to_datetime_vec(ts)) # does not specialize on `typeof(T)` function missing_datetime(T::Type) @@ -361,15 +360,7 @@ Retrieve the minimum of the start times (as `DateTime`) from `chain.info`. It is assumed that the start times are stored in `chain.info.start_time` as `DateTime` or unix timestamps of type `Float64`. """ -function min_start(c::Chains) - return if :start_time in keys(c.info) - # We've got some times, return the minimum. - min_datetime(c.info.start_time) - else - # Times not found -- spit out missing. - missing - end -end +min_start(c::Chains) = min_datetime(start_times(c)) """ max_stop(c::Chains) @@ -379,15 +370,23 @@ Retrieve the maximum of the stop times (as `DateTime`) from `chain.info`. It is assumed that the start times are stored in `chain.info.stop_time` as `DateTime` or unix timestamps of type `Float64`. """ -function max_stop(c::Chains) - return if :stop_time in keys(c.info) - # We've got some times, return the minimum. - return max_datetime(c.info.stop_time) - else - # Times not found -- spit out missing. - return missing - end -end +max_stop(c::Chains) = max_datetime(stop_times(c)) + +""" + start_times(c::Chains) + +Retrieve the contents of `c.info.start_time`, or `missing` if no +`start_time` is set. +""" +start_times(c::Chains) = to_datetime_vec(get(c.info, :start_time, missing)) + +""" + stop_times(c::Chains) + +Retrieve the contents of `c.info.stop_time`, or `missing` if no +`stop_time` is set. +""" +stop_times(c::Chains) = to_datetime_vec(get(c.info, :stop_time, missing)) """ wall_duration(c::Chains; start=min_start(c), stop=max_stop(c)) @@ -407,6 +406,34 @@ function wall_duration(c::Chains; start=min_start(c), stop=max_stop(c)) end end +""" + compute_duration(c::Chains; start=start_times(c), stop=stop_times(c)) + +Calculate the compute time for all chains in seconds. + +The duration is calculated as the sum of `start - stop` in seconds. + +`compute_duration` is more useful in cases of parallel sampling, where `wall_duration` +may understate how much computation time was utilitzed. +""" +function compute_duration( + c::Chains; + start=start_times(c), + stop=stop_times(c) +) + # Calculate total time for each chain, then add it up. + if start === missing || stop === missing + return missing + else + calc = sum(stop - start) + if calc === missing + return missing + else + return Dates.value(calc) / 1000 + end + end +end + #################### Auxilliary Functions #################### """ @@ -524,10 +551,9 @@ function header(c::Chains; section=missing) "= $(join(map(string, arr), ", "))\n" ) - # Get the wall time - start = min_start(c) - stop = max_stop(c) - wall = wall_duration(c; start=start, stop=stop) + # Get the timing stats + wall = wall_duration(c) + compute = compute_duration(c) # Set up string array. section_strings = String[] @@ -548,43 +574,44 @@ function header(c::Chains; section=missing) # Return header. return string( ismissing(c.logevidence) ? "" : "Log evidence = $(c.logevidence)\n", - ismissing(start) ? "" : "Start time = $(start)\n", - ismissing(stop) ? "" : "Stop time = $(stop)\n", - ismissing(wall) ? "" : "Wall duration = $(round(wall, digits=2)) seconds\n", "Iterations = $(first(c)):$(last(c))\n", "Thinning interval = $(step(c))\n", - "Chains = $(join(map(string, chains(c)), ", "))\n", + "Number of chains = $(size(c, 3))\n", "Samples per chain = $(length(range(c)))\n", + ismissing(wall) ? "" : "Wall duration = $(round(wall, digits=2)) seconds\n", + ismissing(compute) ? "" : "Compute duration = $(round(compute, digits=2)) seconds\n", section_strings... ) end -function indiscretesupport(c::Chains, - bounds::Tuple{Real, Real}=(0, Inf)) - nrows, nvars, nchains = size(c.value) - result = Array{Bool}(undef, nvars * (nrows > 0)) - for i in 1:nvars - result[i] = true - for j in 1:nrows, k in 1:nchains - x = c.value[j, i, k] - if !isinteger(x) || x < bounds[1] || x > bounds[2] - result[i] = false - break - end +function indiscretesupport( + c::Chains, + bounds::Tuple{Real, Real}=(0, Inf) +) + nrows, nvars, nchains = size(c.value) + result = Array{Bool}(undef, nvars * (nrows > 0)) + for i in 1:nvars + result[i] = true + for j in 1:nrows, k in 1:nchains + x = c.value[j, i, k] + if !isinteger(x) || x < bounds[1] || x > bounds[2] + result[i] = false + break + end + end end - end - result + return result end function link(c::Chains) - cc = copy(c.value.data) - for j in axes(cc, 2) - x = cc[:, j, :] - if minimum(x) > 0.0 - cc[:, j, :] = maximum(x) < 1.0 ? StatsFuns.logit.(x) : log.(x) + cc = copy(c.value.data) + for j in axes(cc, 2) + x = cc[:, j, :] + if minimum(x) > 0.0 + cc[:, j, :] = maximum(x) < 1.0 ? StatsFuns.logit.(x) : log.(x) + end end - end - cc + return cc end ### Chains specific functions ### @@ -748,11 +775,32 @@ function _cat(::Val{3}, c1::Chains, args::Chains...) all(c -> names(c) == nms, args) || throw(ArgumentError("chain names differ")) # concatenate all chains - data = mapreduce(c -> c.value.data, (x, y) -> cat(x, y; dims = 3), args; - init = c1.value.data) + data = mapreduce( + c -> c.value.data, + (x, y) -> cat(x, y; dims = 3), + args; + init = c1.value.data + ) value = AxisArray(data; iter = rng, var = nms, chain = 1:size(data, 3)) - return Chains(value, missing, c1.name_map, c1.info) + # Concatenate times, if available + starts = mapreduce( + c -> get(c.info, :start_time, nothing), + vcat, + args, + init = get(c1.info, :start_time, nothing) + ) + stops = mapreduce( + c -> get(c.info, :stop_time, nothing), + vcat, + args, + init = get(c1.info, :stop_time, nothing) + ) + nontime_props = filter(x -> !(x in [:start_time, :stop_time]), [propertynames(c1.info)...]) + new_info = NamedTuple{tuple(nontime_props...)}(tuple([c1.info[n] for n in nontime_props]...)) + new_info = merge(new_info, (start_time = starts, stop_time = stops)) + + return Chains(value, missing, c1.name_map, new_info) end function pool_chain(c::Chains) diff --git a/src/ess.jl b/src/ess.jl index d5cc1e21..e178e90f 100644 --- a/src/ess.jl +++ b/src/ess.jl @@ -196,13 +196,17 @@ function mean_autocov(k::Int, cache::BDAESSCache) end """ - ess(chains::Chains; kwargs...) + ess(chains::Chains; duration=compute_duration, kwargs...) Estimate the effective sample size and the potential scale reduction. + +ESS per second options include `duration=MCMCChains.compute_duration` (the default) +and `duration=MCMCChains.wall_duration`. """ function ess( chains::Chains; sections = _default_sections(chains), + duration = compute_duration, kwargs... ) # subset the chain @@ -212,7 +216,7 @@ function ess( ess, rhat = ess_rhat(_chains.value.data; kwargs...) # Calculate ESS/minute if available - dur = wall_duration(chains) + dur = duration(chains) # convert to namedtuple nt = if dur === missing diff --git a/test/ess_tests.jl b/test/ess_tests.jl index 780b123b..8690643d 100644 --- a/test/ess_tests.jl +++ b/test/ess_tests.jl @@ -5,6 +5,22 @@ using Random using Statistics using Test +@testset "ESS per second" begin + c1 = Chains(randn(100,5, 1), info = (start_time=time(), stop_time = time()+1)) + c2 = Chains(randn(100,5, 1), info = (start_time=time()+1, stop_time = time()+2)) + c = chainscat(c1, c2) + + wall = MCMCChains.wall_duration(c) + compute = MCMCChains.compute_duration(c) + + @test round(wall, digits=1) ≈ round(c2.info.stop_time - c1.info.start_time, digits=1) + @test compute ≈ (MCMCChains.compute_duration(c1) + MCMCChains.compute_duration(c2)) + + s = ess(c) + @test length(s[:,:ess_per_sec]) == 5 + @test all(map(!ismissing, s[:,:ess_per_sec])) +end + @testset "copy and split" begin # check a matrix with even number of rows x = rand(50, 20)