Skip to content

Commit

Permalink
Add compute time instead of wall time (#303)
Browse files Browse the repository at this point in the history
Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
cpfiffer and devmotion authored May 25, 2021
1 parent 65c9c56 commit 45ae381
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 64 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "4.11.0"
version = "4.12.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
170 changes: 109 additions & 61 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,17 +335,16 @@ Base.convert(::Type{Array}, chn::Chains) = convert(Array, chn.value)

# Convenience functions to handle different types of
# timestamps.
min_datetime(t::DateTime) = t
min_datetime(ts::Vector{DateTime}) = minimum(ts)
min_datetime(t::Float64) = unix2datetime(t)
min_datetime(ts::Vector{Float64}) = unix2datetime(minimum(ts))
min_datetime(ts) = missing_datetime(typeof(ts))

max_datetime(t::DateTime) = t
max_datetime(ts::Vector{DateTime}) = maximum(ts)
max_datetime(t::Float64) = unix2datetime(t)
max_datetime(ts::Vector{Float64}) = unix2datetime(maximum(ts))
max_datetime(ts) = missing_datetime(typeof(ts))
to_datetime(t::DateTime) = t
to_datetime(t::Float64) = unix2datetime(t)
to_datetime(t) = missing_datetime(typeof(t))
to_datetime_vec(t::Union{Float64, DateTime}) = [to_datetime(t)]
to_datetime_vec(t::DateTime) = [to_datetime(t)]
to_datetime_vec(ts::Vector) = map(to_datetime, ts)
to_datetime_vec(ts) = [missing]

min_datetime(ts) = minimum(to_datetime_vec(ts))
max_datetime(ts) = maximum(to_datetime_vec(ts))

# does not specialize on `typeof(T)`
function missing_datetime(T::Type)
Expand All @@ -361,15 +360,7 @@ Retrieve the minimum of the start times (as `DateTime`) from `chain.info`.
It is assumed that the start times are stored in `chain.info.start_time` as
`DateTime` or unix timestamps of type `Float64`.
"""
function min_start(c::Chains)
return if :start_time in keys(c.info)
# We've got some times, return the minimum.
min_datetime(c.info.start_time)
else
# Times not found -- spit out missing.
missing
end
end
min_start(c::Chains) = min_datetime(start_times(c))

"""
max_stop(c::Chains)
Expand All @@ -379,15 +370,23 @@ Retrieve the maximum of the stop times (as `DateTime`) from `chain.info`.
It is assumed that the start times are stored in `chain.info.stop_time` as
`DateTime` or unix timestamps of type `Float64`.
"""
function max_stop(c::Chains)
return if :stop_time in keys(c.info)
# We've got some times, return the minimum.
return max_datetime(c.info.stop_time)
else
# Times not found -- spit out missing.
return missing
end
end
max_stop(c::Chains) = max_datetime(stop_times(c))

"""
start_times(c::Chains)
Retrieve the contents of `c.info.start_time`, or `missing` if no
`start_time` is set.
"""
start_times(c::Chains) = to_datetime_vec(get(c.info, :start_time, missing))

"""
stop_times(c::Chains)
Retrieve the contents of `c.info.stop_time`, or `missing` if no
`stop_time` is set.
"""
stop_times(c::Chains) = to_datetime_vec(get(c.info, :stop_time, missing))

"""
wall_duration(c::Chains; start=min_start(c), stop=max_stop(c))
Expand All @@ -407,6 +406,34 @@ function wall_duration(c::Chains; start=min_start(c), stop=max_stop(c))
end
end

"""
compute_duration(c::Chains; start=start_times(c), stop=stop_times(c))
Calculate the compute time for all chains in seconds.
The duration is calculated as the sum of `start - stop` in seconds.
`compute_duration` is more useful in cases of parallel sampling, where `wall_duration`
may understate how much computation time was utilitzed.
"""
function compute_duration(
c::Chains;
start=start_times(c),
stop=stop_times(c)
)
# Calculate total time for each chain, then add it up.
if start === missing || stop === missing
return missing
else
calc = sum(stop - start)
if calc === missing
return missing
else
return Dates.value(calc) / 1000
end
end
end

#################### Auxilliary Functions ####################

"""
Expand Down Expand Up @@ -524,10 +551,9 @@ function header(c::Chains; section=missing)
"= $(join(map(string, arr), ", "))\n"
)

# Get the wall time
start = min_start(c)
stop = max_stop(c)
wall = wall_duration(c; start=start, stop=stop)
# Get the timing stats
wall = wall_duration(c)
compute = compute_duration(c)

# Set up string array.
section_strings = String[]
Expand All @@ -548,43 +574,44 @@ function header(c::Chains; section=missing)
# Return header.
return string(
ismissing(c.logevidence) ? "" : "Log evidence = $(c.logevidence)\n",
ismissing(start) ? "" : "Start time = $(start)\n",
ismissing(stop) ? "" : "Stop time = $(stop)\n",
ismissing(wall) ? "" : "Wall duration = $(round(wall, digits=2)) seconds\n",
"Iterations = $(first(c)):$(last(c))\n",
"Thinning interval = $(step(c))\n",
"Chains = $(join(map(string, chains(c)), ", "))\n",
"Number of chains = $(size(c, 3))\n",
"Samples per chain = $(length(range(c)))\n",
ismissing(wall) ? "" : "Wall duration = $(round(wall, digits=2)) seconds\n",
ismissing(compute) ? "" : "Compute duration = $(round(compute, digits=2)) seconds\n",
section_strings...
)
end

function indiscretesupport(c::Chains,
bounds::Tuple{Real, Real}=(0, Inf))
nrows, nvars, nchains = size(c.value)
result = Array{Bool}(undef, nvars * (nrows > 0))
for i in 1:nvars
result[i] = true
for j in 1:nrows, k in 1:nchains
x = c.value[j, i, k]
if !isinteger(x) || x < bounds[1] || x > bounds[2]
result[i] = false
break
end
function indiscretesupport(
c::Chains,
bounds::Tuple{Real, Real}=(0, Inf)
)
nrows, nvars, nchains = size(c.value)
result = Array{Bool}(undef, nvars * (nrows > 0))
for i in 1:nvars
result[i] = true
for j in 1:nrows, k in 1:nchains
x = c.value[j, i, k]
if !isinteger(x) || x < bounds[1] || x > bounds[2]
result[i] = false
break
end
end
end
end
result
return result
end

function link(c::Chains)
cc = copy(c.value.data)
for j in axes(cc, 2)
x = cc[:, j, :]
if minimum(x) > 0.0
cc[:, j, :] = maximum(x) < 1.0 ? StatsFuns.logit.(x) : log.(x)
cc = copy(c.value.data)
for j in axes(cc, 2)
x = cc[:, j, :]
if minimum(x) > 0.0
cc[:, j, :] = maximum(x) < 1.0 ? StatsFuns.logit.(x) : log.(x)
end
end
end
cc
return cc
end

### Chains specific functions ###
Expand Down Expand Up @@ -748,11 +775,32 @@ function _cat(::Val{3}, c1::Chains, args::Chains...)
all(c -> names(c) == nms, args) || throw(ArgumentError("chain names differ"))

# concatenate all chains
data = mapreduce(c -> c.value.data, (x, y) -> cat(x, y; dims = 3), args;
init = c1.value.data)
data = mapreduce(
c -> c.value.data,
(x, y) -> cat(x, y; dims = 3),
args;
init = c1.value.data
)
value = AxisArray(data; iter = rng, var = nms, chain = 1:size(data, 3))

return Chains(value, missing, c1.name_map, c1.info)
# Concatenate times, if available
starts = mapreduce(
c -> get(c.info, :start_time, nothing),
vcat,
args,
init = get(c1.info, :start_time, nothing)
)
stops = mapreduce(
c -> get(c.info, :stop_time, nothing),
vcat,
args,
init = get(c1.info, :stop_time, nothing)
)
nontime_props = filter(x -> !(x in [:start_time, :stop_time]), [propertynames(c1.info)...])
new_info = NamedTuple{tuple(nontime_props...)}(tuple([c1.info[n] for n in nontime_props]...))
new_info = merge(new_info, (start_time = starts, stop_time = stops))

return Chains(value, missing, c1.name_map, new_info)
end

function pool_chain(c::Chains)
Expand Down
8 changes: 6 additions & 2 deletions src/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,17 @@ function mean_autocov(k::Int, cache::BDAESSCache)
end

"""
ess(chains::Chains; kwargs...)
ess(chains::Chains; duration=compute_duration, kwargs...)
Estimate the effective sample size and the potential scale reduction.
ESS per second options include `duration=MCMCChains.compute_duration` (the default)
and `duration=MCMCChains.wall_duration`.
"""
function ess(
chains::Chains;
sections = _default_sections(chains),
duration = compute_duration,
kwargs...
)
# subset the chain
Expand All @@ -212,7 +216,7 @@ function ess(
ess, rhat = ess_rhat(_chains.value.data; kwargs...)

# Calculate ESS/minute if available
dur = wall_duration(chains)
dur = duration(chains)

# convert to namedtuple
nt = if dur === missing
Expand Down
16 changes: 16 additions & 0 deletions test/ess_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,22 @@ using Random
using Statistics
using Test

@testset "ESS per second" begin
c1 = Chains(randn(100,5, 1), info = (start_time=time(), stop_time = time()+1))
c2 = Chains(randn(100,5, 1), info = (start_time=time()+1, stop_time = time()+2))
c = chainscat(c1, c2)

wall = MCMCChains.wall_duration(c)
compute = MCMCChains.compute_duration(c)

@test round(wall, digits=1) round(c2.info.stop_time - c1.info.start_time, digits=1)
@test compute (MCMCChains.compute_duration(c1) + MCMCChains.compute_duration(c2))

s = ess(c)
@test length(s[:,:ess_per_sec]) == 5
@test all(map(!ismissing, s[:,:ess_per_sec]))
end

@testset "copy and split" begin
# check a matrix with even number of rows
x = rand(50, 20)
Expand Down

2 comments on commit 45ae381

@cpfiffer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/37430

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v4.12.0 -m "<description of version>" 45ae38163c8a74e6cecfd4d9a1759a0f7b76d6fc
git push origin v4.12.0

Please sign in to comment.