Skip to content

Commit

Permalink
Merge pull request #21 from TuringLang/filter
Browse files Browse the repository at this point in the history
Provide default progress loggers
  • Loading branch information
cpfiffer authored Mar 6, 2020
2 parents c867a65 + 930708c commit b29b38d
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 69 deletions.
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,28 @@ desc = "A lightweight interface for common MCMC methods."
version = "0.5.0"

[deps]
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"

[compat]
ConsoleProgressMonitor = "0.1"
LoggingExtras = "0.4"
ProgressLogging = "0.1"
StatsBase = "0.32"
TerminalLoggers = "0.1"
julia = "1"

[extras]
Atom = "c52e3926-4ff0-5f6e-af25-54175e0327b1"
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Statistics", "Test", "TerminalLoggers"]
test = ["Atom", "IJulia", "Statistics", "Test"]
115 changes: 67 additions & 48 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,62 @@
module AbstractMCMC

import ConsoleProgressMonitor
import LoggingExtras
import ProgressLogging
import StatsBase
using StatsBase: sample
import TerminalLoggers

import Distributed
import Logging
using Random: GLOBAL_RNG, AbstractRNG, seed!
import UUIDs

# avoid creating a progress bar with @withprogress if progress logging is disabled
# and add a custom progress logger if the current logger does not seem to be able to handle
# progress logs
macro ifwithprogresslogger(progress, exprs...)
return quote
if $progress
if $hasprogresslevel($Logging.current_logger())
$ProgressLogging.@withprogress $(exprs...)
else
$with_progresslogger($Logging.current_logger()) do
$ProgressLogging.@withprogress $(exprs...)
end
end
else
$(exprs[end])
end
end |> esc
end

# improved checks?
function hasprogresslevel(logger)
return Logging.min_enabled_level(logger) ProgressLogging.ProgressLevel
end

# filter better, e.g., according to group?
function with_progresslogger(f, logger)
_module = @__MODULE__
logger1 = LoggingExtras.EarlyFilteredLogger(progresslogger()) do log
log._module === _module && log.level == ProgressLogging.ProgressLevel
end
logger2 = LoggingExtras.EarlyFilteredLogger(logger) do log
log._module !== _module || log.level != ProgressLogging.ProgressLevel
end

Logging.with_logger(f, LoggingExtras.TeeLogger(logger1, logger2))
end

function progresslogger()
# detect if code is running under IJulia since TerminalLogger does not work with IJulia
# https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia
if isdefined(Main, :IJulia) && Main.IJulia.inited
return ConsoleProgressMonitor.ProgressLogger()
else
return TerminalLoggers.TerminalLogger()
end
end

"""
AbstractChains
Expand Down Expand Up @@ -44,7 +93,7 @@ abstract type AbstractModel end
Return `N` samples from the MCMC `sampler` for the provided `model`.
If a callback function `f` with type signature
If a callback function `f` with type signature
```julia
f(rng::AbstractRNG, model::AbstractModel, sampler::AbstractSampler, N::Integer,
iteration::Integer, transition; kwargs...)
Expand Down Expand Up @@ -77,15 +126,7 @@ function StatsBase.sample(
# Perform any necessary setup.
sample_init!(rng, model, sampler, N; kwargs...)

# Create a progress bar.
if progress
progressid = UUIDs.uuid4()
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=NaN,
_id=progressid)
end

local transitions
try
@ifwithprogresslogger progress name=progressname begin
# Obtain the initial transition.
transition = step!(rng, model, sampler, N; iteration=1, kwargs...)

Expand All @@ -97,10 +138,7 @@ function StatsBase.sample(
transitions_save!(transitions, 1, transition, model, sampler, N; kwargs...)

# Update the progress bar.
if progress
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=1/N,
_id=progressid)
end
progress && ProgressLogging.@logprogress 1/N

# Step through the sampler.
for i in 2:N
Expand All @@ -114,16 +152,7 @@ function StatsBase.sample(
transitions_save!(transitions, i, transition, model, sampler, N; kwargs...)

# Update the progress bar.
if progress
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=i/N,
_id=progressid)
end
end
finally
# Close the progress bar.
if progress
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress="done",
_id=progressid)
progress && ProgressLogging.@logprogress i/N
end
end

Expand Down Expand Up @@ -178,12 +207,12 @@ function sample_end!(
end

function bundle_samples(
::AbstractRNG,
::AbstractModel,
::AbstractSampler,
::Integer,
::AbstractRNG,
::AbstractModel,
::AbstractSampler,
::Integer,
transitions,
::Type{Any};
::Type{Any};
kwargs...
)
return transitions
Expand Down Expand Up @@ -259,7 +288,7 @@ end
Sample `nchains` chains using the available threads, and combine them into a single chain.
By default, the random number generator, the model and the samplers are deep copied for each
thread to prevent contamination between threads.
thread to prevent contamination between threads.
"""
function psample(
model::AbstractModel,
Expand Down Expand Up @@ -292,24 +321,20 @@ function psample(
# Set up a chains vector.
chains = Vector{Any}(undef, nchains)

# Create a progress bar and a channel for progress logging.
if progress
progressid = UUIDs.uuid4()
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname, progress=NaN,
_id=progressid)
channel = Distributed.RemoteChannel(() -> Channel{Bool}(nchains), 1)
end
@ifwithprogresslogger progress name=progressname begin
# Create a channel for progress logging.
if progress
channel = Distributed.RemoteChannel(() -> Channel{Bool}(nchains), 1)
end

try
Distributed.@sync begin
if progress
Distributed.@async begin
# Update the progress bar.
progresschains = 0
while take!(channel)
progresschains += 1
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname,
progress=progresschains/nchains, _id=progressid)
ProgressLogging.@logprogress progresschains/nchains
end
end
end
Expand All @@ -322,7 +347,7 @@ function psample(
# Seed the thread-specific random number generator with the pre-made seed.
subrng = rngs[id]
seed!(subrng, seeds[i])

# Sample a chain and save it to the vector.
chains[i] = sample(subrng, models[id], samplers[id], N;
progress = false, kwargs...)
Expand All @@ -335,12 +360,6 @@ function psample(
progress && put!(channel, false)
end
end
finally
# Close the progress bar.
if progress
Logging.@logmsg(ProgressLogging.ProgressLevel, progressname,
progress="done", _id=progressid)
end
end

# Concatenate the chains together.
Expand Down
4 changes: 3 additions & 1 deletion test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ function AbstractMCMC.step!(
N::Integer,
transition::Union{Nothing,MyTransition};
sleepy = false,
loggers = false,
kwargs...
)
a = rand(rng)
b = randn(rng)

loggers && push!(LOGGERS, Logging.current_logger())
sleepy && sleep(0.001)

return MyTransition(a, b)
Expand All @@ -50,4 +52,4 @@ function AbstractMCMC.bundle_samples(
return MyChain(as, bs)
end

AbstractMCMC.chainscat(chains::Union{MyChain,Vector{<:MyChain}}...) = vcat(chains...)
AbstractMCMC.chainscat(chains::Union{MyChain,Vector{<:MyChain}}...) = vcat(chains...)
105 changes: 88 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,104 @@
using AbstractMCMC
using AbstractMCMC: sample, psample, steps!
import TerminalLoggers
using Atom.Progress: JunoProgressLogger
using ConsoleProgressMonitor: ProgressLogger
using IJulia
using LoggingExtras: TeeLogger, EarlyFilteredLogger
using TerminalLoggers: TerminalLogger

import Logging
using Random
using Statistics
using Test
using Test: collect_test_logs

# install progress logger
Logging.global_logger(TerminalLoggers.TerminalLogger(right_justify=120))
const LOGGERS = Set()
const CURRENT_LOGGER = Logging.current_logger()

include("interface.jl")

@testset "AbstractMCMC" begin
@testset "Basic sampling" begin
Random.seed!(1234)
N = 1_000
chain = sample(MyModel(), MySampler(), N; sleepy = true)

# test output type and size
@test chain isa Vector{MyTransition}
@test length(chain) == N

# test some statistical properties
@test mean(x.a for x in chain) 0.5 atol=6e-2
@test var(x.a for x in chain) 1 / 12 atol=5e-3
@test mean(x.b for x in chain) 0.0 atol=5e-2
@test var(x.b for x in chain) 1 atol=6e-2
@testset "REPL" begin
empty!(LOGGERS)

Random.seed!(1234)
N = 1_000
chain = sample(MyModel(), MySampler(), N; sleepy = true, loggers = true)

@test length(LOGGERS) == 1
logger = first(LOGGERS)
@test logger isa TeeLogger
@test logger.loggers[1].logger isa TerminalLogger
@test logger.loggers[2].logger === CURRENT_LOGGER
@test Logging.current_logger() === CURRENT_LOGGER

# test output type and size
@test chain isa Vector{MyTransition}
@test length(chain) == N

# test some statistical properties
@test mean(x.a for x in chain) 0.5 atol=6e-2
@test var(x.a for x in chain) 1 / 12 atol=5e-3
@test mean(x.b for x in chain) 0.0 atol=5e-2
@test var(x.b for x in chain) 1 atol=6e-2
end

@testset "Juno" begin
empty!(LOGGERS)

Random.seed!(1234)
N = 10

logger = JunoProgressLogger()
Logging.with_logger(logger) do
sample(MyModel(), MySampler(), N; sleepy = true, loggers = true)
end

@test length(LOGGERS) == 1
@test first(LOGGERS) === logger
@test Logging.current_logger() === CURRENT_LOGGER
end

@testset "IJulia" begin
# emulate running IJulia kernel
@eval IJulia begin
inited = true
end

empty!(LOGGERS)

Random.seed!(1234)
N = 10
sample(MyModel(), MySampler(), N; sleepy = true, loggers = true)

@test length(LOGGERS) == 1
logger = first(LOGGERS)
@test logger isa TeeLogger
@test logger.loggers[1].logger isa ProgressLogger
@test logger.loggers[2].logger === CURRENT_LOGGER
@test Logging.current_logger() === CURRENT_LOGGER

@eval IJulia begin
inited = false
end
end

@testset "Custom logger" begin
empty!(LOGGERS)

Random.seed!(1234)
N = 10

logger = Logging.ConsoleLogger(stderr, Logging.LogLevel(-1))
Logging.with_logger(logger) do
sample(MyModel(), MySampler(), N; sleepy = true, loggers = true)
end

@test length(LOGGERS) == 1
@test first(LOGGERS) === logger
@test Logging.current_logger() === CURRENT_LOGGER
end
end

if VERSION v"1.3"
Expand Down Expand Up @@ -104,4 +175,4 @@ include("interface.jl")
@test Base.IteratorSize(iter) == Base.IsInfinite()
@test Base.IteratorEltype(iter) == Base.EltypeUnknown()
end
end
end

0 comments on commit b29b38d

Please sign in to comment.