From 7ea6ef78f7d0d0d990fd6a7faaf94e282a066b61 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Wed, 26 Jul 2023 10:36:05 +0100 Subject: [PATCH 1/6] save CSV --- src/TuringCallbacks.jl | 5 ++++- src/callbacks/save.jl | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 src/callbacks/save.jl diff --git a/src/TuringCallbacks.jl b/src/TuringCallbacks.jl index 758a2e4..2268f3e 100644 --- a/src/TuringCallbacks.jl +++ b/src/TuringCallbacks.jl @@ -5,6 +5,8 @@ using Reexport using LinearAlgebra using Logging using DocStringExtensions +using DynamicPPL: Model, Sampler, AbstractVarInfo, invlink!! +using CSV: write @reexport using OnlineStats # used to compute different statistics on-the-fly @@ -21,8 +23,9 @@ export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback include("stats.jl") include("tensorboardlogger.jl") -include("callbacks/tensorboard.jl") include("callbacks/multicallback.jl") +include("callbacks/save.jl") +include("callbacks/tensorboard.jl") @static if !isdefined(Base, :get_extension) function __init__() diff --git a/src/callbacks/save.jl b/src/callbacks/save.jl new file mode 100644 index 0000000..c183be7 --- /dev/null +++ b/src/callbacks/save.jl @@ -0,0 +1,36 @@ +############################### +### Saves samples on the go ### +############################### + +""" + SaveCSV + +A callback saves samples to .csv file during sampling +""" +function SaveCSV(rng::AbstractRNG, + model::Model, + sampler::Sampler, + transition, + state, + iteration::Int64; + kwargs... +) + SaveCSV(model, sampler, transition, state.vi, iteration; kwargs...) +end + +function SaveCSV(rng::AbstractRNG, + model::Model, + sampler::Sampler, + transition, + vi::AbstractVarInfo, + iteration::Int64; + kwargs... +) + vii = deepcopy(vi) + invlink!!(vii, model) + θ = vii[sampler] + # it would be good to have the param names as in the chain + chain_name = get(kwargs, :chain_name, "chain") + write(string(chain_name,".csv"), Dict("params"=>[θ]); + append=true, delim=";") +end \ No newline at end of file From 5268bc5e261b0dd2bba8f541bf2145bcc4cdccfa Mon Sep 17 00:00:00 2001 From: jaimerz Date: Wed, 26 Jul 2023 10:44:23 +0100 Subject: [PATCH 2/6] formatting --- Project.toml | 3 +++ docs/make.jl | 24 +++++++++----------- ext/TuringCallbacksTuringExt.jl | 8 +++++-- src/TuringCallbacks.jl | 7 ++++-- src/callbacks/multicallback.jl | 3 ++- src/callbacks/save.jl | 15 +++++++------ src/callbacks/tensorboard.jl | 39 ++++++++++++++++++++++++--------- src/stats.jl | 29 ++++++++++-------------- src/tensorboardlogger.jl | 19 ++++++++-------- test/runtests.jl | 6 ++--- 10 files changed, 88 insertions(+), 65 deletions(-) diff --git a/Project.toml b/Project.toml index 177f8a5..b93fb37 100644 --- a/Project.toml +++ b/Project.toml @@ -4,12 +4,15 @@ authors = ["Tor Erlend Fjelde and contributors"] version = "0.3.1" [deps] +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" diff --git a/docs/make.jl b/docs/make.jl index d62bb21..82080ed 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,20 +2,16 @@ using TuringCallbacks using Documenter makedocs(; - modules=[TuringCallbacks], - authors="Tor", - repo="https://github.com/TuringLang/TuringCallbacks.jl/blob/{commit}{path}#L{line}", - sitename="TuringCallbacks.jl", - format=Documenter.HTML(; - prettyurls=get(ENV, "CI", "false") == "true", - canonical="https://turinglang.github.io/TuringCallbacks.jl", - assets=String[], + modules = [TuringCallbacks], + authors = "Tor", + repo = "https://github.com/TuringLang/TuringCallbacks.jl/blob/{commit}{path}#L{line}", + sitename = "TuringCallbacks.jl", + format = Documenter.HTML(; + prettyurls = get(ENV, "CI", "false") == "true", + canonical = "https://turinglang.github.io/TuringCallbacks.jl", + assets = String[], ), - pages=[ - "Home" => "index.md", - ], + pages = ["Home" => "index.md"], ) -deploydocs(; - repo="github.com/TuringLang/TuringCallbacks.jl", -) +deploydocs(; repo = "github.com/TuringLang/TuringCallbacks.jl") diff --git a/ext/TuringCallbacksTuringExt.jl b/ext/TuringCallbacksTuringExt.jl index b18fad7..ab4f32d 100644 --- a/ext/TuringCallbacksTuringExt.jl +++ b/ext/TuringCallbacksTuringExt.jl @@ -12,13 +12,17 @@ end const TuringTransition = Union{Turing.Inference.Transition,Turing.Inference.HMCTransition} function TuringCallbacks.params_and_values(transition::TuringTransition; kwargs...) - return Iterators.map(zip(Turing.Inference._params_to_array([transition])...)) do (ksym, val) + return Iterators.map( + zip(Turing.Inference._params_to_array([transition])...), + ) do (ksym, val) return string(ksym), val end end function TuringCallbacks.extras(transition::TuringTransition; kwargs...) - return Iterators.map(zip(Turing.Inference.get_transition_extras([transition])...)) do (ksym, val) + return Iterators.map( + zip(Turing.Inference.get_transition_extras([transition])...), + ) do (ksym, val) return string(ksym), val end end diff --git a/src/TuringCallbacks.jl b/src/TuringCallbacks.jl index 2268f3e..e7ad06e 100644 --- a/src/TuringCallbacks.jl +++ b/src/TuringCallbacks.jl @@ -6,7 +6,8 @@ using LinearAlgebra using Logging using DocStringExtensions using DynamicPPL: Model, Sampler, AbstractVarInfo, invlink!! -using CSV: write +using CSV: write +using Random: AbstractRNG @reexport using OnlineStats # used to compute different statistics on-the-fly @@ -29,7 +30,9 @@ include("callbacks/tensorboard.jl") @static if !isdefined(Base, :get_extension) function __init__() - @require Turing="fce5fe82-541a-59a6-adf8-730c64b5f9a0" include("../ext/TuringCallbacksTuringExt.jl") + @require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" include( + "../ext/TuringCallbacksTuringExt.jl", + ) end end diff --git a/src/callbacks/multicallback.jl b/src/callbacks/multicallback.jl index 2e270cd..2277281 100644 --- a/src/callbacks/multicallback.jl +++ b/src/callbacks/multicallback.jl @@ -20,4 +20,5 @@ MultiCallback(callbacks...) = MultiCallback(callbacks) Add a callback to the list of callbacks, mutating if possible. """ push!!(c::MultiCallback{<:Tuple}, callback) = MultiCallback((c.callbacks..., callback)) -push!!(c::MultiCallback{<:AbstractArray}, callback) = (push!(c.callbacks, callback); return c) +push!!(c::MultiCallback{<:AbstractArray}, callback) = + (push!(c.callbacks, callback); return c) diff --git a/src/callbacks/save.jl b/src/callbacks/save.jl index c183be7..9478b00 100644 --- a/src/callbacks/save.jl +++ b/src/callbacks/save.jl @@ -7,30 +7,31 @@ A callback saves samples to .csv file during sampling """ -function SaveCSV(rng::AbstractRNG, +function SaveCSV( + rng::AbstractRNG, model::Model, sampler::Sampler, transition, state, iteration::Int64; - kwargs... + kwargs..., ) SaveCSV(model, sampler, transition, state.vi, iteration; kwargs...) end -function SaveCSV(rng::AbstractRNG, +function SaveCSV( + rng::AbstractRNG, model::Model, sampler::Sampler, transition, vi::AbstractVarInfo, iteration::Int64; - kwargs... + kwargs..., ) vii = deepcopy(vi) invlink!!(vii, model) θ = vii[sampler] # it would be good to have the param names as in the chain chain_name = get(kwargs, :chain_name, "chain") - write(string(chain_name,".csv"), Dict("params"=>[θ]); - append=true, delim=";") -end \ No newline at end of file + write(string(chain_name, ".csv"), Dict("params" => [θ]); append = true, delim = ";") +end diff --git a/src/callbacks/tensorboard.jl b/src/callbacks/tensorboard.jl index 9d28f9a..3904156 100644 --- a/src/callbacks/tensorboard.jl +++ b/src/callbacks/tensorboard.jl @@ -72,7 +72,7 @@ function TensorBoardCallback(args...; comment = "", directory = nothing, kwargs. end # Set up the logger - lg = TBLogger(log_dir, min_level=Logging.Info; step_increment=0) + lg = TBLogger(log_dir, min_level = Logging.Info; step_increment = 0) return TensorBoardCallback(lg, args...; kwargs...) end @@ -87,14 +87,15 @@ function TensorBoardCallback( filter = nothing, param_prefix::String = "", extras_prefix::String = "extras/", - kwargs... + kwargs..., ) # Lookups: create default ones if not given stats_lookup = if stats isa OnlineStat # Warn the user if they've provided a non-empty `OnlineStat` - OnlineStats.nobs(stats) > 0 && @warn("using statistic with observations as a base: $(stats)") + OnlineStats.nobs(stats) > 0 && + @warn("using statistic with observations as a base: $(stats)") let o = stats - DefaultDict{String, typeof(o)}(() -> deepcopy(o)) + DefaultDict{String,typeof(o)}(() -> deepcopy(o)) end elseif !isnothing(stats) # If it's not an `OnlineStat` nor `nothing`, assume user knows what they're doing @@ -102,12 +103,19 @@ function TensorBoardCallback( else # This is default let o = OnlineStats.Series(Mean(), Variance(), KHist(num_bins)) - DefaultDict{String, typeof(o)}(() -> deepcopy(o)) + DefaultDict{String,typeof(o)}(() -> deepcopy(o)) end end return TensorBoardCallback( - lg, stats_lookup, filter, include, exclude, include_extras, param_prefix, extras_prefix + lg, + stats_lookup, + filter, + include, + exclude, + include_extras, + param_prefix, + extras_prefix, ) end @@ -133,7 +141,8 @@ function filter_param_and_value(cb::TensorBoardCallback, param, value) # Otherwise we return `true` by default. return true end -filter_param_and_value(cb::TensorBoardCallback, param_and_value::Tuple) = filter_param_and_value(cb, param_and_value...) +filter_param_and_value(cb::TensorBoardCallback, param_and_value::Tuple) = + filter_param_and_value(cb, param_and_value...) """ default_param_names_for_values(x) @@ -150,7 +159,8 @@ default_param_names_for_values(x) = ("θ[$i]" for i = 1:length(x)) Return an iterator over parameter names and values from a `transition`. """ params_and_values(transition, state; kwargs...) = params_and_values(transition; kwargs...) -params_and_values(model, sampler, transition, state; kwargs...) = params_and_values(transition, state; kwargs...) +params_and_values(model, sampler, transition, state; kwargs...) = + params_and_values(transition, state; kwargs...) """ extras(transition[, state]; kwargs...) @@ -167,14 +177,23 @@ extras(model, sampler, transition, state; kwargs...) = extras(transition, state; increment_step!(lg::TensorBoardLogger.TBLogger, Δ_Step) = TensorBoardLogger.increment_step!(lg, Δ_Step) -function (cb::TensorBoardCallback)(rng, model, sampler, transition, state, iteration; kwargs...) +function (cb::TensorBoardCallback)( + rng, + model, + sampler, + transition, + state, + iteration; + kwargs..., +) stats = cb.stats lg = cb.logger filterf = Base.Fix1(filter_param_and_value, cb) # TODO: Should we use the explicit interface for TensorBoardLogger? with_logger(lg) do - for (k, val) in Iterators.filter(filterf, params_and_values(transition, state; kwargs...)) + for (k, val) in + Iterators.filter(filterf, params_and_values(transition, state; kwargs...)) stat = stats[k] # Log the raw value diff --git a/src/stats.jl b/src/stats.jl index 7115386..e0d40e8 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -10,7 +10,7 @@ $(TYPEDEF) Skips the first `b` observations before passing them on to `stat`. """ -mutable struct Skip{T, O<:OnlineStat{T}} <: OnlineStat{T} +mutable struct Skip{T,O<:OnlineStat{T}} <: OnlineStat{T} b::Int current_index::Int stat::O @@ -29,10 +29,8 @@ function OnlineStats._fit!(o::Skip, x::Real) return o end -Base.show(io::IO, o::Skip) = print( - io, - "Skip ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`" -) +Base.show(io::IO, o::Skip) = + print(io, "Skip ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`") """ $(TYPEDEF) @@ -43,7 +41,7 @@ $(TYPEDEF) Thins `stat` with an interval `b`, i.e. only passes every b-th observation to `stat`. """ -mutable struct Thin{T, O<:OnlineStat{T}} <: OnlineStat{T} +mutable struct Thin{T,O<:OnlineStat{T}} <: OnlineStat{T} b::Int current_index::Int stat::O @@ -62,10 +60,8 @@ function OnlineStats._fit!(o::Thin, x::Real) return o end -Base.show(io::IO, o::Thin) = print( - io, - "Thin ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`" -) +Base.show(io::IO, o::Thin) = + print(io, "Thin ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`") """ $(TYPEDEF) @@ -80,27 +76,26 @@ $(TYPEDEF) `stat`, which is *only* fitted on the batched data contained in the `MovingWindow`. """ -struct WindowStat{T, O} <: OnlineStat{T} +struct WindowStat{T,O} <: OnlineStat{T} window::MovingWindow{T} stat::O end -WindowStat(b::Int, T::Type, o) = WindowStat{T, typeof(o)}(MovingWindow(b, T), o) -WindowStat(b::Int, o::OnlineStat{T}) where {T} = WindowStat{T, typeof(o)}( - MovingWindow(b, T), o -) +WindowStat(b::Int, T::Type, o) = WindowStat{T,typeof(o)}(MovingWindow(b, T), o) +WindowStat(b::Int, o::OnlineStat{T}) where {T} = + WindowStat{T,typeof(o)}(MovingWindow(b, T), o) # Proxy methods to the window OnlineStats.nobs(o::WindowStat) = OnlineStats.nobs(o.window) OnlineStats._fit!(o::WindowStat, x) = OnlineStats._fit!(o.window, x) -function OnlineStats.value(o::WindowStat{<:Any, <:OnlineStat}) +function OnlineStats.value(o::WindowStat{<:Any,<:OnlineStat}) stat_new = deepcopy(o.stat) fit!(stat_new, OnlineStats.value(o.window)) return stat_new end -function OnlineStats.value(o::WindowStat{<:Any, <:Function}) +function OnlineStats.value(o::WindowStat{<:Any,<:Function}) stat_new = o.stat() fit!(stat_new, OnlineStats.value(o.window)) return stat_new diff --git a/src/tensorboardlogger.jl b/src/tensorboardlogger.jl index 16a74fe..9727bc2 100644 --- a/src/tensorboardlogger.jl +++ b/src/tensorboardlogger.jl @@ -41,10 +41,10 @@ end function TBL.preprocess(name, stat::AutoCov, data) autocor = OnlineStats.autocor(stat) - for b = 1:(stat.lag.b - 1) + for b = 1:(stat.lag.b-1) # `autocor[i]` corresponds to the lag of size `i - 1` and `autocor[1] = 1.0` bname = tb_name(stat, b) - TBL.preprocess(tb_name(name, bname), autocor[b + 1], data) + TBL.preprocess(tb_name(name, bname), autocor[b+1], data) end end @@ -60,22 +60,23 @@ function TBL.preprocess(name, hist::KHist, data) # Creates a NORMALIZED histogram edges = OnlineStats.edges(hist) cnts = OnlineStats.counts(hist) - TBL.preprocess( - name, (edges, cnts ./ sum(cnts)), data - ) + TBL.preprocess(name, (edges, cnts ./ sum(cnts)), data) end end # Unlike the `preprocess` overload, this allows us to specify if we want to normalize function TBL.log_histogram( - logger::AbstractLogger, name::AbstractString, hist::OnlineStats.HistogramStat; - step=nothing, normalize=false + logger::AbstractLogger, + name::AbstractString, + hist::OnlineStats.HistogramStat; + step = nothing, + normalize = false, ) edges = edges(hist) cnts = Float64.(OnlineStats.counts(hist)) if normalize - return TBL.log_histogram(logger, name, (edges, cnts ./ sum(cnts)); step=step) + return TBL.log_histogram(logger, name, (edges, cnts ./ sum(cnts)); step = step) else - return TBL.log_histogram(logger, name, (edges, cnts); step=step) + return TBL.log_histogram(logger, name, (edges, cnts); step = step) end end diff --git a/test/runtests.jl b/test/runtests.jl index 78d8225..8037f25 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using TuringCallbacks using TensorBoardLogger, ValueHistories Base.@kwdef struct CountingCallback - count::Ref{Int}=Ref(0) + count::Ref{Int} = Ref(0) end (c::CountingCallback)(args...; kwargs...) = c.count[] += 1 @@ -31,7 +31,7 @@ end @testset "MultiCallback" begin callback = MultiCallback(CountingCallback(), CountingCallback()) - chain = sample(model, alg, num_samples, callback=callback) + chain = sample(model, alg, num_samples, callback = callback) # Both should have been trigger an equal number of times. counts = map(c -> c.count[], callback.callbacks) @@ -49,7 +49,7 @@ end callback = TensorBoardCallback(mktempdir()) # Sample - chain = sample(model, alg, num_samples; callback=callback) + chain = sample(model, alg, num_samples; callback = callback) # Extract the values. hist = convert(MVHistory, callback.logger) From e839e18d3579aca13a8d3a6d8fb76c1e10adaddd Mon Sep 17 00:00:00 2001 From: jaimerz Date: Wed, 26 Jul 2023 11:26:59 +0100 Subject: [PATCH 3/6] tests --- src/TuringCallbacks.jl | 2 +- src/callbacks/save.jl | 2 +- test/Project.toml | 2 ++ test/runtests.jl | 12 +++++++++++- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/TuringCallbacks.jl b/src/TuringCallbacks.jl index e7ad06e..f8761ce 100644 --- a/src/TuringCallbacks.jl +++ b/src/TuringCallbacks.jl @@ -20,7 +20,7 @@ using DataStructures: DefaultDict using Requires end -export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback +export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback, SaveCSV include("stats.jl") include("tensorboardlogger.jl") diff --git a/src/callbacks/save.jl b/src/callbacks/save.jl index 9478b00..638f78f 100644 --- a/src/callbacks/save.jl +++ b/src/callbacks/save.jl @@ -16,7 +16,7 @@ function SaveCSV( iteration::Int64; kwargs..., ) - SaveCSV(model, sampler, transition, state.vi, iteration; kwargs...) + SaveCSV(rng, model, sampler, transition, state.vi, iteration; kwargs...) end function SaveCSV( diff --git a/test/Project.toml b/test/Project.toml index 95733df..dee7ff1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,6 @@ [deps] +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/test/runtests.jl b/test/runtests.jl index 8037f25..6eda623 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,8 @@ using Test using Turing using TuringCallbacks using TensorBoardLogger, ValueHistories +using CSV +using DataFrames Base.@kwdef struct CountingCallback count::Ref{Int} = Ref(0) @@ -28,7 +30,7 @@ end # Sampling algorithm to use alg = NUTS(num_adapts, 0.65) - + @testset "MultiCallback" begin callback = MultiCallback(CountingCallback(), CountingCallback()) chain = sample(model, alg, num_samples, callback = callback) @@ -61,4 +63,12 @@ end @test m_mean ≈ mean(chain[:m]) @test s_mean ≈ mean(chain[:s]) end + + @testset "SaveCallback" begin + # Sample + sample(model, alg, num_samples; callback = SaveCSV, chain_name="chain_1") + chain = Matrix(CSV.read("chain_1.csv", DataFrame, header=false)) + @test size(chain) == (num_samples, 2) + rm("chain_1.csv") + end end From eae1d72b9788bc2560d4d841a35e611675026c4a Mon Sep 17 00:00:00 2001 From: jaimerz Date: Thu, 12 Oct 2023 17:14:29 +0100 Subject: [PATCH 4/6] match master --- Project.toml | 9 +- docs/make.jl | 24 ++-- ext/TuringCallbacksTuringExt.jl | 69 +++++++--- src/TuringCallbacks.jl | 13 +- src/callbacks/multicallback.jl | 3 +- src/callbacks/tensorboard.jl | 215 +++++++++++++++++++++++--------- src/stats.jl | 29 +++-- src/tensorboardlogger.jl | 19 ++- src/utils.jl | 10 ++ test/Project.toml | 2 - test/multicallback.jl | 21 ++++ test/runtests.jl | 73 +++-------- test/save.jl | 7 ++ test/tensorboardcallback.jl | 163 ++++++++++++++++++++++++ 14 files changed, 475 insertions(+), 182 deletions(-) create mode 100644 src/utils.jl create mode 100644 test/multicallback.jl create mode 100644 test/save.jl create mode 100644 test/tensorboardcallback.jl diff --git a/Project.toml b/Project.toml index b93fb37..e863cba 100644 --- a/Project.toml +++ b/Project.toml @@ -1,18 +1,15 @@ name = "TuringCallbacks" uuid = "ea0860ee-d0ef-45ef-82e6-cc37d6be2f9c" authors = ["Tor Erlend Fjelde and contributors"] -version = "0.3.1" +version = "0.4.0" [deps] -CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" @@ -29,8 +26,8 @@ DocStringExtensions = "0.8, 0.9" OnlineStats = "1.5" Reexport = "0.2, 1.0" Requires = "1" -TensorBoardLogger = "0.1" -Turing = "0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22" +TensorBoardLogger = "0.1.22" +Turing = "0.29" julia = "1" [extras] diff --git a/docs/make.jl b/docs/make.jl index 82080ed..d62bb21 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,16 +2,20 @@ using TuringCallbacks using Documenter makedocs(; - modules = [TuringCallbacks], - authors = "Tor", - repo = "https://github.com/TuringLang/TuringCallbacks.jl/blob/{commit}{path}#L{line}", - sitename = "TuringCallbacks.jl", - format = Documenter.HTML(; - prettyurls = get(ENV, "CI", "false") == "true", - canonical = "https://turinglang.github.io/TuringCallbacks.jl", - assets = String[], + modules=[TuringCallbacks], + authors="Tor", + repo="https://github.com/TuringLang/TuringCallbacks.jl/blob/{commit}{path}#L{line}", + sitename="TuringCallbacks.jl", + format=Documenter.HTML(; + prettyurls=get(ENV, "CI", "false") == "true", + canonical="https://turinglang.github.io/TuringCallbacks.jl", + assets=String[], ), - pages = ["Home" => "index.md"], + pages=[ + "Home" => "index.md", + ], ) -deploydocs(; repo = "github.com/TuringLang/TuringCallbacks.jl") +deploydocs(; + repo="github.com/TuringLang/TuringCallbacks.jl", +) diff --git a/ext/TuringCallbacksTuringExt.jl b/ext/TuringCallbacksTuringExt.jl index ab4f32d..9eb9d16 100644 --- a/ext/TuringCallbacksTuringExt.jl +++ b/ext/TuringCallbacksTuringExt.jl @@ -1,30 +1,69 @@ module TuringCallbacksTuringExt if isdefined(Base, :get_extension) - using Turing: Turing + using Turing: Turing, DynamicPPL using TuringCallbacks: TuringCallbacks else # Requires compatible. - using ..Turing: Turing + using ..Turing: Turing, DynamicPPL using ..TuringCallbacks: TuringCallbacks end -const TuringTransition = Union{Turing.Inference.Transition,Turing.Inference.HMCTransition} +const TuringTransition = Union{ + Turing.Inference.Transition, + Turing.Inference.SMCTransition, + Turing.Inference.PGTransition +} -function TuringCallbacks.params_and_values(transition::TuringTransition; kwargs...) - return Iterators.map( - zip(Turing.Inference._params_to_array([transition])...), - ) do (ksym, val) - return string(ksym), val - end +function TuringCallbacks.params_and_values( + model::DynamicPPL.Model, + transition::TuringTransition; + kwargs... +) + vns, vals = Turing.Inference._params_to_array(model, [transition]) + return zip(Iterators.map(string, vns), vals) end -function TuringCallbacks.extras(transition::TuringTransition; kwargs...) - return Iterators.map( - zip(Turing.Inference.get_transition_extras([transition])...), - ) do (ksym, val) - return string(ksym), val - end +function TuringCallbacks.extras( + model::DynamicPPL.Model, transition::TuringTransition; + kwargs... +) + names, vals = Turing.Inference.get_transition_extras([transition]) + return zip(string.(names), vec(vals)) +end + +default_hyperparams(sampler::DynamicPPL.Sampler) = default_hyperparams(sampler.alg) +default_hyperparams(alg::Turing.Inference.InferenceAlgorithm) = ( + string(f) => getfield(alg, f) for f in fieldnames(typeof(alg)) +) + +const AlgsWithDefaultHyperparams = Union{ + Turing.Inference.HMC, + Turing.Inference.HMCDA, + Turing.Inference.NUTS, + Turing.Inference.SGHMC, + +} + +function TuringCallbacks.hyperparams( + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:AlgsWithDefaultHyperparams}; + kwargs... +) + return default_hyperparams(sampler) +end + +function TuringCallbacks.hyperparam_metrics( + model, + sampler::Turing.Sampler{<:Turing.Inference.NUTS} +) + return [ + "extras/acceptance_rate/stat/Mean", + "extras/max_hamiltonian_energy_error/stat/Mean", + "extras/lp/stat/Mean", + "extras/n_steps/stat/Mean", + "extras/tree_depth/stat/Mean" + ] end end diff --git a/src/TuringCallbacks.jl b/src/TuringCallbacks.jl index f8761ce..1d183c3 100644 --- a/src/TuringCallbacks.jl +++ b/src/TuringCallbacks.jl @@ -5,9 +5,6 @@ using Reexport using LinearAlgebra using Logging using DocStringExtensions -using DynamicPPL: Model, Sampler, AbstractVarInfo, invlink!! -using CSV: write -using Random: AbstractRNG @reexport using OnlineStats # used to compute different statistics on-the-fly @@ -20,19 +17,17 @@ using DataStructures: DefaultDict using Requires end -export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback, SaveCSV +export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback +include("utils.jl") include("stats.jl") include("tensorboardlogger.jl") -include("callbacks/multicallback.jl") -include("callbacks/save.jl") include("callbacks/tensorboard.jl") +include("callbacks/multicallback.jl") @static if !isdefined(Base, :get_extension) function __init__() - @require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" include( - "../ext/TuringCallbacksTuringExt.jl", - ) + @require Turing="fce5fe82-541a-59a6-adf8-730c64b5f9a0" include("../ext/TuringCallbacksTuringExt.jl") end end diff --git a/src/callbacks/multicallback.jl b/src/callbacks/multicallback.jl index 2277281..2e270cd 100644 --- a/src/callbacks/multicallback.jl +++ b/src/callbacks/multicallback.jl @@ -20,5 +20,4 @@ MultiCallback(callbacks...) = MultiCallback(callbacks) Add a callback to the list of callbacks, mutating if possible. """ push!!(c::MultiCallback{<:Tuple}, callback) = MultiCallback((c.callbacks..., callback)) -push!!(c::MultiCallback{<:AbstractArray}, callback) = - (push!(c.callbacks, callback); return c) +push!!(c::MultiCallback{<:AbstractArray}, callback) = (push!(c.callbacks, callback); return c) diff --git a/src/callbacks/tensorboard.jl b/src/callbacks/tensorboard.jl index 3904156..0f1dc1d 100644 --- a/src/callbacks/tensorboard.jl +++ b/src/callbacks/tensorboard.jl @@ -30,9 +30,22 @@ provided instead of `lg`. particular variable and value; expected signature is `filter(varname, value)`. If `isnothing` a default-filter constructed from `exclude` and `include` will be used. -- `exclude = nothing`: If non-empty, these variables will not be logged. -- `include = nothing`: If non-empty, only these variables will be logged. +- `exclude = String[]`: If non-empty, these variables will not be logged. +- `include = String[]`: If non-empty, only these variables will be logged. - `include_extras::Bool = true`: Include extra statistics from transitions. +- `extras_include = String[]`: If non-empty, only these extra statistics will be logged. +- `extras_exclude = String[]`: If non-empty, these extra statistics will not be logged. +- `extras_filter = nothing`: Filter determining whether or not we should log + extra statistics; expected signature is `filter(extra, value)`. + If `isnothing` a default-filter constructed from `extras_exclude` and + `extras_include` will be used. +- `include_hyperparams::Bool = true`: Include hyperparameters. +- `hyperparam_include = String[]`: If non-empty, only these hyperparameters will be logged. +- `hyperparam_exclude = String[]`: If non-empty, these hyperparameters will not be logged. +- `hyperparam_filter = nothing`: Filter determining whether or not we should log + hyperparameters; expected signature is `filter(hyperparam, value)`. + If `isnothing` a default-filter constructed from `hyperparam_exclude` and + `hyperparam_include` will be used. - `directory::String = nothing`: if specified, will together with `comment` be used to define the logging directory. - `comment::String = nothing`: if specified, will together with `directory` be used to @@ -41,19 +54,21 @@ provided instead of `lg`. # Fields $(TYPEDFIELDS) """ -struct TensorBoardCallback{L,F,VI,VE} +struct TensorBoardCallback{L,F1,F2,F3} "Underlying logger." logger::AbstractLogger "Lookup for variable name to statistic estimate." stats::L - "Filter determining whether or not we should log stats for a particular variable." - filter::F - "Variables to include in the logging." - include::VI - "Variables to exclude from the logging." - exclude::VE + "Filter determining whether to include stats for a particular variable." + variable_filter::F1 "Include extra statistics from transitions." include_extras::Bool + "Filter determining whether to include a particular extra statistic." + extras_filter::F2 + "Include hyperparameters." + include_hyperparams::Bool + "Filter determining whether to include a particular hyperparameter." + hyperparam_filter::F3 "Prefix used for logging realizations/parameters" param_prefix::String "Prefix used for logging extra statistics" @@ -72,30 +87,48 @@ function TensorBoardCallback(args...; comment = "", directory = nothing, kwargs. end # Set up the logger - lg = TBLogger(log_dir, min_level = Logging.Info; step_increment = 0) + lg = TBLogger(log_dir, min_level=Logging.Info; step_increment=0) return TensorBoardCallback(lg, args...; kwargs...) end +maybe_filter(f; kwargs...) = f +maybe_filter(::Nothing; exclude=nothing, include=nothing) = NameFilter(; exclude, include) + function TensorBoardCallback( lg::AbstractLogger, stats = nothing; num_bins::Int = 100, exclude = nothing, include = nothing, - include_extras::Bool = true, filter = nothing, + include_extras::Bool = true, + extras_include = nothing, + extras_exclude = nothing, + extras_filter = nothing, + include_hyperparams::Bool = false, + hyperparams_include = nothing, + hyperparams_exclude = nothing, + hyperparams_filter = nothing, param_prefix::String = "", extras_prefix::String = "extras/", - kwargs..., + kwargs... ) + # Create the filters. + variable_filter_f = maybe_filter(filter; include=include, exclude=exclude) + extras_filter_f = maybe_filter( + extras_filter; include=extras_include, exclude=extras_exclude + ) + hyperparams_filter_f = maybe_filter( + hyperparams_filter; include=hyperparams_include, exclude=hyperparams_exclude + ) + # Lookups: create default ones if not given stats_lookup = if stats isa OnlineStat # Warn the user if they've provided a non-empty `OnlineStat` - OnlineStats.nobs(stats) > 0 && - @warn("using statistic with observations as a base: $(stats)") + OnlineStats.nobs(stats) > 0 && @warn("using statistic with observations as a base: $(stats)") let o = stats - DefaultDict{String,typeof(o)}(() -> deepcopy(o)) + DefaultDict{String, typeof(o)}(() -> deepcopy(o)) end elseif !isnothing(stats) # If it's not an `OnlineStat` nor `nothing`, assume user knows what they're doing @@ -103,19 +136,20 @@ function TensorBoardCallback( else # This is default let o = OnlineStats.Series(Mean(), Variance(), KHist(num_bins)) - DefaultDict{String,typeof(o)}(() -> deepcopy(o)) + DefaultDict{String, typeof(o)}(() -> deepcopy(o)) end end return TensorBoardCallback( lg, stats_lookup, - filter, - include, - exclude, + variable_filter_f, include_extras, + extras_filter_f, + include_hyperparams, + hyperparams_filter_f, param_prefix, - extras_prefix, + extras_prefix ) end @@ -125,24 +159,11 @@ end Filter parameters and values from a `transition` based on the `filter` of `cb`. """ function filter_param_and_value(cb::TensorBoardCallback, param, value) - if !isnothing(cb.filter) - return cb.filter(param, value) - end - - # Otherwise we construct from `include` and `exclude`. - if !isnothing(cb.include) - # If only `include` is given, we only return the variables in `include`. - return param ∈ cb.include - elseif !isnothing(cb.exclude) - # If only `exclude` is given, we return all variables except those in `exclude`. - return !(param ∈ cb.exclude) - end - - # Otherwise we return `true` by default. - return true + return cb.variable_filter(param, value) end -filter_param_and_value(cb::TensorBoardCallback, param_and_value::Tuple) = +function filter_param_and_value(cb::TensorBoardCallback, param_and_value::Tuple) filter_param_and_value(cb, param_and_value...) +end """ default_param_names_for_values(x) @@ -153,47 +174,117 @@ default_param_names_for_values(x) = ("θ[$i]" for i = 1:length(x)) """ - params_and_values(transition[, state]; kwargs...) + params_and_values(model, transition[, state]; kwargs...) params_and_values(model, sampler, transition, state; kwargs...) Return an iterator over parameter names and values from a `transition`. """ -params_and_values(transition, state; kwargs...) = params_and_values(transition; kwargs...) -params_and_values(model, sampler, transition, state; kwargs...) = - params_and_values(transition, state; kwargs...) +function params_and_values(model, transition, state; kwargs...) + return params_and_values(model, transition; kwargs...) +end +function params_and_values(model, sampler, transition, state; kwargs...) + return params_and_values(model, transition, state; kwargs...) +end """ - extras(transition[, state]; kwargs...) + extras(model, transition[, state]; kwargs...) extras(model, sampler, transition, state; kwargs...) Return an iterator with elements of the form `(name, value)` for additional statistics in `transition`. Default implementation returns an empty iterator. """ -extras(transition; kwargs...) = () -extras(transition, state; kwargs...) = extras(transition; kwargs...) -extras(model, sampler, transition, state; kwargs...) = extras(transition, state; kwargs...) +extras(model, transition; kwargs...) = () +extras(model, transition, state; kwargs...) = extras(model, transition; kwargs...) +function extras(model, sampler, transition, state; kwargs...) + return extras(model, transition, state; kwargs...) +end + +""" + filter_extras_and_value(cb::TensorBoardCallback, name, value) + +Filter extras and values from a `transition` based on the `filter` of `cb`. +""" +function filter_extras_and_value(cb::TensorBoardCallback, name, value) + return cb.extras_filter(name, value) +end +function filter_extras_and_value(cb::TensorBoardCallback, name_and_value::Tuple) + return filter_extras_and_value(cb, name_and_value...) +end + +""" + hyperparams(model, sampler[, transition, state]; kwargs...) + +Return an iterator with elements of the form `(name, value)` for hyperparameters in `model`. +""" +function hyperparams(model, sampler; kwargs...) + @warn "`hyperparams(model, sampler; kwargs...)` is not implemented for $(typeof(model)) and $(typeof(sampler)). If you want to record hyperparameters, please implement this method." + return Pair{String, Any}[] +end +function hyperparams(model, sampler, transition, state; kwargs...) + return hyperparams(model, sampler; kwargs...) +end + +""" + filter_hyperparams_and_value(cb::TensorBoardCallback, name, value) + +Filter hyperparameters and values from a `transition` based on the `filter` of `cb`. +""" +function filter_hyperparams_and_value(cb::TensorBoardCallback, name, value) + return cb.hyperparam_filter(name, value) +end +function filter_hyperparams_and_value( + cb::TensorBoardCallback, + name_and_value::Union{Pair,Tuple} +) + return filter_hyperparams_and_value(cb, name_and_value...) +end + +""" + hyperparam_metrics(model, sampler[, transition, state]; kwargs...) + +Return a `Vector{String}` of metrics for hyperparameters in `model`. +""" +function hyperparam_metrics(model, sampler; kwargs...) + @warn "`hyperparam_metrics(model, sampler; kwargs...)` is not implemented for $(typeof(model)) and $(typeof(sampler)). If you want to use some of the other recorded values as hyperparameters metrics, please implement this method." + return String[] +end +function hyperparam_metrics(model, sampler, transition, state; kwargs...) + return hyperparam_metrics(model, sampler; kwargs...) +end increment_step!(lg::TensorBoardLogger.TBLogger, Δ_Step) = TensorBoardLogger.increment_step!(lg, Δ_Step) -function (cb::TensorBoardCallback)( - rng, - model, - sampler, - transition, - state, - iteration; - kwargs..., -) +function (cb::TensorBoardCallback)(rng, model, sampler, transition, state, iteration; kwargs...) stats = cb.stats lg = cb.logger - filterf = Base.Fix1(filter_param_and_value, cb) + variable_filter = Base.Fix1(filter_param_and_value, cb) + extras_filter = Base.Fix1(filter_extras_and_value, cb) + hyperparams_filter = Base.Fix1(filter_hyperparams_and_value, cb) + + if iteration == 1 && cb.include_hyperparams + # If it's the first iteration, we write the hyperparameters. + hparams = Dict(Iterators.filter( + hyperparams_filter, + hyperparams(model, sampler, transition, state; kwargs...) + )) + if !isempty(hparams) + TensorBoardLogger.write_hparams!( + lg, + hparams, + hyperparam_metrics(model, sampler) + ) + end + end + # TODO: Should we use the explicit interface for TensorBoardLogger? with_logger(lg) do - for (k, val) in - Iterators.filter(filterf, params_and_values(transition, state; kwargs...)) + for (k, val) in Iterators.filter( + variable_filter, + params_and_values(model, sampler, transition, state; kwargs...) + ) stat = stats[k] # Log the raw value @@ -208,8 +299,18 @@ function (cb::TensorBoardCallback)( # Transition statstics if cb.include_extras - for (name, val) in extras(transition, state; kwargs...) + for (name, val) in Iterators.filter( + extras_filter, + extras(model, sampler, transition, state; kwargs...) + ) @info "$(cb.extras_prefix)$(name)" val + + # TODO: Make this customizable. + if val isa Real + stat = stats["$(cb.extras_prefix)$(name)"] + fit!(stat, float(val)) + @info ("$(cb.extras_prefix)$(name)") stat + end end end # Increment the step for the logger. diff --git a/src/stats.jl b/src/stats.jl index e0d40e8..7115386 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -10,7 +10,7 @@ $(TYPEDEF) Skips the first `b` observations before passing them on to `stat`. """ -mutable struct Skip{T,O<:OnlineStat{T}} <: OnlineStat{T} +mutable struct Skip{T, O<:OnlineStat{T}} <: OnlineStat{T} b::Int current_index::Int stat::O @@ -29,8 +29,10 @@ function OnlineStats._fit!(o::Skip, x::Real) return o end -Base.show(io::IO, o::Skip) = - print(io, "Skip ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`") +Base.show(io::IO, o::Skip) = print( + io, + "Skip ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`" +) """ $(TYPEDEF) @@ -41,7 +43,7 @@ $(TYPEDEF) Thins `stat` with an interval `b`, i.e. only passes every b-th observation to `stat`. """ -mutable struct Thin{T,O<:OnlineStat{T}} <: OnlineStat{T} +mutable struct Thin{T, O<:OnlineStat{T}} <: OnlineStat{T} b::Int current_index::Int stat::O @@ -60,8 +62,10 @@ function OnlineStats._fit!(o::Thin, x::Real) return o end -Base.show(io::IO, o::Thin) = - print(io, "Thin ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`") +Base.show(io::IO, o::Thin) = print( + io, + "Thin ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`" +) """ $(TYPEDEF) @@ -76,26 +80,27 @@ $(TYPEDEF) `stat`, which is *only* fitted on the batched data contained in the `MovingWindow`. """ -struct WindowStat{T,O} <: OnlineStat{T} +struct WindowStat{T, O} <: OnlineStat{T} window::MovingWindow{T} stat::O end -WindowStat(b::Int, T::Type, o) = WindowStat{T,typeof(o)}(MovingWindow(b, T), o) -WindowStat(b::Int, o::OnlineStat{T}) where {T} = - WindowStat{T,typeof(o)}(MovingWindow(b, T), o) +WindowStat(b::Int, T::Type, o) = WindowStat{T, typeof(o)}(MovingWindow(b, T), o) +WindowStat(b::Int, o::OnlineStat{T}) where {T} = WindowStat{T, typeof(o)}( + MovingWindow(b, T), o +) # Proxy methods to the window OnlineStats.nobs(o::WindowStat) = OnlineStats.nobs(o.window) OnlineStats._fit!(o::WindowStat, x) = OnlineStats._fit!(o.window, x) -function OnlineStats.value(o::WindowStat{<:Any,<:OnlineStat}) +function OnlineStats.value(o::WindowStat{<:Any, <:OnlineStat}) stat_new = deepcopy(o.stat) fit!(stat_new, OnlineStats.value(o.window)) return stat_new end -function OnlineStats.value(o::WindowStat{<:Any,<:Function}) +function OnlineStats.value(o::WindowStat{<:Any, <:Function}) stat_new = o.stat() fit!(stat_new, OnlineStats.value(o.window)) return stat_new diff --git a/src/tensorboardlogger.jl b/src/tensorboardlogger.jl index 9727bc2..16a74fe 100644 --- a/src/tensorboardlogger.jl +++ b/src/tensorboardlogger.jl @@ -41,10 +41,10 @@ end function TBL.preprocess(name, stat::AutoCov, data) autocor = OnlineStats.autocor(stat) - for b = 1:(stat.lag.b-1) + for b = 1:(stat.lag.b - 1) # `autocor[i]` corresponds to the lag of size `i - 1` and `autocor[1] = 1.0` bname = tb_name(stat, b) - TBL.preprocess(tb_name(name, bname), autocor[b+1], data) + TBL.preprocess(tb_name(name, bname), autocor[b + 1], data) end end @@ -60,23 +60,22 @@ function TBL.preprocess(name, hist::KHist, data) # Creates a NORMALIZED histogram edges = OnlineStats.edges(hist) cnts = OnlineStats.counts(hist) - TBL.preprocess(name, (edges, cnts ./ sum(cnts)), data) + TBL.preprocess( + name, (edges, cnts ./ sum(cnts)), data + ) end end # Unlike the `preprocess` overload, this allows us to specify if we want to normalize function TBL.log_histogram( - logger::AbstractLogger, - name::AbstractString, - hist::OnlineStats.HistogramStat; - step = nothing, - normalize = false, + logger::AbstractLogger, name::AbstractString, hist::OnlineStats.HistogramStat; + step=nothing, normalize=false ) edges = edges(hist) cnts = Float64.(OnlineStats.counts(hist)) if normalize - return TBL.log_histogram(logger, name, (edges, cnts ./ sum(cnts)); step = step) + return TBL.log_histogram(logger, name, (edges, cnts ./ sum(cnts)); step=step) else - return TBL.log_histogram(logger, name, (edges, cnts); step = step) + return TBL.log_histogram(logger, name, (edges, cnts); step=step) end end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..2220659 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,10 @@ +Base.@kwdef struct NameFilter{A,B} + include::A=nothing + exclude::B=nothing +end + +(f::NameFilter)(name, value) = f(name) +function (f::NameFilter)(name) + include, exclude = f.include, f.exclude + (exclude === nothing || name ∉ exclude) && (include === nothing || name ∈ include) +end diff --git a/test/Project.toml b/test/Project.toml index dee7ff1..95733df 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,4 @@ [deps] -CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/test/multicallback.jl b/test/multicallback.jl new file mode 100644 index 0000000..69ba91f --- /dev/null +++ b/test/multicallback.jl @@ -0,0 +1,21 @@ +@testset "MultiCallback" begin + # Number of MCMC samples/steps + num_samples = 100 + num_adapts = 50 + + # Sampling algorithm to use + alg = NUTS(num_adapts, 0.65) + + callback = MultiCallback(CountingCallback(), CountingCallback()) + chain = sample(demo_model, alg, num_samples, callback=callback) + + # Both should have been trigger an equal number of times. + counts = map(c -> c.count[], callback.callbacks) + @test counts[1] == counts[2] + @test counts[1] == num_samples + + # Add a new one and make sure it's not like the others. + callback = TuringCallbacks.push!!(callback, CountingCallback()) + counts = map(c -> c.count[], callback.callbacks) + @test counts[1] == counts[2] != counts[3] +end diff --git a/test/runtests.jl b/test/runtests.jl index 6eda623..cdb8aa6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,73 +2,28 @@ using Test using Turing using TuringCallbacks using TensorBoardLogger, ValueHistories -using CSV -using DataFrames Base.@kwdef struct CountingCallback - count::Ref{Int} = Ref(0) + count::Ref{Int}=Ref(0) end (c::CountingCallback)(args...; kwargs...) = c.count[] += 1 -@testset "TuringCallbacks.jl" begin - # TODO: Improve. - @model function demo(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, √s) - for i in eachindex(x) - x[i] ~ Normal(m, √s) - end - end - - xs = randn(100) .+ 1 - model = demo(xs) - - # Number of MCMC samples/steps - num_samples = 1_000 - num_adapts = 500 - - # Sampling algorithm to use - alg = NUTS(num_adapts, 0.65) - - @testset "MultiCallback" begin - callback = MultiCallback(CountingCallback(), CountingCallback()) - chain = sample(model, alg, num_samples, callback = callback) - - # Both should have been trigger an equal number of times. - counts = map(c -> c.count[], callback.callbacks) - @test counts[1] == counts[2] - @test counts[1] == num_samples - - # Add a new one and make sure it's not like the others. - callback = TuringCallbacks.push!!(callback, CountingCallback()) - counts = map(c -> c.count[], callback.callbacks) - @test counts[1] == counts[2] != counts[3] +@model function demo(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in eachindex(x) + x[i] ~ Normal(m, √s) end +end - @testset "TensorBoardCallback" begin - # Create the callback - callback = TensorBoardCallback(mktempdir()) - - # Sample - chain = sample(model, alg, num_samples; callback = callback) - - # Extract the values. - hist = convert(MVHistory, callback.logger) +function DynamicPPL.TestUtils.varnames(::DynamicPPL.Model{typeof(demo)}) + return [@varname(s), @varname(m)] +end - # Compare the recorded values to the chain. - m_mean = last(last(hist["m/stat/Mean"])) - s_mean = last(last(hist["s/stat/Mean"])) +const demo_model = demo(randn(100) .+ 1) - @test m_mean ≈ mean(chain[:m]) - @test s_mean ≈ mean(chain[:s]) - end - - @testset "SaveCallback" begin - # Sample - sample(model, alg, num_samples; callback = SaveCSV, chain_name="chain_1") - chain = Matrix(CSV.read("chain_1.csv", DataFrame, header=false)) - @test size(chain) == (num_samples, 2) - rm("chain_1.csv") - end +@testset "TuringCallbacks.jl" begin + include("multicallback.jl") + include("tensorboardcallback.jl") end diff --git a/test/save.jl b/test/save.jl new file mode 100644 index 0000000..c29dc95 --- /dev/null +++ b/test/save.jl @@ -0,0 +1,7 @@ +@testset "SaveCallback" begin + # Sample + sample(model, alg, num_samples; callback = SaveCSV, chain_name="chain_1") + chain = Matrix(CSV.read("chain_1.csv", DataFrame, header=false)) + @test size(chain) == (num_samples, 2) + rm("chain_1.csv") +end \ No newline at end of file diff --git a/test/tensorboardcallback.jl b/test/tensorboardcallback.jl new file mode 100644 index 0000000..a8043d7 --- /dev/null +++ b/test/tensorboardcallback.jl @@ -0,0 +1,163 @@ +@testset "TensorBoardCallback" begin + tmpdir = mktempdir() + mkpath(tmpdir) + + vns = DynamicPPL.TestUtils.varnames(demo_model) + + # Number of MCMC samples/steps + num_samples = 100 + num_adapts = 50 + + # Sampling algorithm to use + alg = NUTS(num_adapts, 0.65) + + @testset "Correctness of values" begin + # Create the callback + callback = TensorBoardCallback(joinpath(tmpdir, "runs")) + + # Sample + chain = sample(demo_model, alg, num_samples; callback=callback) + + # Extract the values. + hist = convert(MVHistory, callback.logger) + + # Compare the recorded values to the chain. + m_mean = last(last(hist["m/stat/Mean"])) + s_mean = last(last(hist["s/stat/Mean"])) + + @test m_mean ≈ mean(chain[:m]) + @test s_mean ≈ mean(chain[:s]) + end + + @testset "Default" begin + # Create the callback + callback = TensorBoardCallback( + joinpath(tmpdir, "runs"); + ) + + # Sample + chain = sample(demo_model, alg, num_samples; callback=callback) + + # Read the logging info. + hist = convert(MVHistory, callback.logger) + + # Check the variables. + @testset "$vn" for vn in vns + # Should have the `val` field. + @test haskey(hist, Symbol(vn, "/val")) + # Should have the `Mean` and `Variance` stat. + @test haskey(hist, Symbol(vn, "/stat/Mean")) + @test haskey(hist, Symbol(vn, "/stat/Variance")) + end + + # Check the extra statistics. + @testset "extras" begin + @test haskey(hist, Symbol("extras/lp/val")) + @test haskey(hist, Symbol("extras/acceptance_rate/val")) + end + end + + @testset "Exclude variable" begin + # Create the callback + callback = TensorBoardCallback( + joinpath(tmpdir, "runs"); + exclude=["s"] + ) + + # Sample + chain = sample(demo_model, alg, num_samples; callback=callback) + + # Read the logging info. + hist = convert(MVHistory, callback.logger) + + # Check the variables. + @testset "$vn" for vn in vns + if vn == @varname(s) + @test !haskey(hist, Symbol(vn, "/val")) + @test !haskey(hist, Symbol(vn, "/stat/Mean")) + @test !haskey(hist, Symbol(vn, "/stat/Variance")) + else + @test haskey(hist, Symbol(vn, "/val")) + @test haskey(hist, Symbol(vn, "/stat/Mean")) + @test haskey(hist, Symbol(vn, "/stat/Variance")) + end + end + + # Check the extra statistics. + @testset "extras" begin + @test haskey(hist, Symbol("extras/lp/val")) + @test haskey(hist, Symbol("extras/acceptance_rate/val")) + end + end + + @testset "Exclude extras" begin + # Create the callback + callback = TensorBoardCallback( + joinpath(tmpdir, "runs"); + include_extras=false + ) + + # Sample + chain = sample(demo_model, alg, num_samples; callback=callback) + + # Read the logging info. + hist = convert(MVHistory, callback.logger) + + # Check the variables. + @testset "$vn" for vn in vns + @test haskey(hist, Symbol(vn, "/val")) + @test haskey(hist, Symbol(vn, "/stat/Mean")) + @test haskey(hist, Symbol(vn, "/stat/Variance")) + end + + # Check the extra statistics. + @testset "extras" begin + @test !haskey(hist, Symbol("extras/lp/val")) + @test !haskey(hist, Symbol("extras/acceptance_rate/val")) + end + end + + @testset "With hyperparams" begin + @testset "$alg (has hyperparam: $hashyp)" for (alg, hashyp) in [ + (HMC(0.05, 10), true), + (HMCDA(num_adapts, 0.65, 1.0), true), + (NUTS(num_adapts, 0.65), true), + (MH(), false), + ] + + # Create the callback + callback = TensorBoardCallback( + joinpath(tmpdir, "runs"); + include_hyperparams=true, + ) + + # Sample + chain = sample(demo_model, alg, num_samples; callback=callback) + + # HACK: This touches internals so might just break at some point. + # If it some point does, let's just remove this test. + # Inspiration: https://github.com/JuliaLogging/TensorBoardLogger.jl/blob/3d9c1a554a08179785459ad7b83bce0177b90275/src/Deserialization/deserialization.jl#L244-L258 + iter = TensorBoardLogger.TBEventFileCollectionIterator( + callback.logger.logdir, purge=true + ) + + found_one = false + for event_file in iter + for event in event_file + event.what === nothing && continue + !(event.what.value isa TensorBoardLogger.Summary) && continue + + for (tag, _) in event.what.value + if tag == "_hparams_/experiment" + found_one = true + break + end + end + end + + found_one && break + end + @test (hashyp && found_one) || (!hashyp && !found_one) + end + end +end From 9e9fb54ce7b0e332348eed7ecedd0fff359b17ce Mon Sep 17 00:00:00 2001 From: Jaime RZ Date: Thu, 12 Oct 2023 17:36:39 +0100 Subject: [PATCH 5/6] Bump Turing.jl compat + allow recording of hyperparameters (#46) (#47) * added more tests * added support for filtering extras, computing stats for extras, and for adding hyperparameters * bump versions of Turing and TensorBoardLogger now that it supports hyperparams * added automatic recording of hyperparms for some Turing samplers, and improved testing of this * added useful comment on where i got that weird test from * fixed tests + added test for MH, which does not currently have hyperparams * bump minor version since this is breaking * updated docstring for TensorBoardLogger Co-authored-by: Tor Erlend Fjelde From 584ed6206d2ae48c5e982f866630d2c22e5002a6 Mon Sep 17 00:00:00 2001 From: Jaime RZ Date: Thu, 12 Oct 2023 17:38:52 +0100 Subject: [PATCH 6/6] Bump Turing.jl compat + allow recording of hyperparameters (#46) (#48) * added more tests * added support for filtering extras, computing stats for extras, and for adding hyperparameters * bump versions of Turing and TensorBoardLogger now that it supports hyperparams * added automatic recording of hyperparms for some Turing samplers, and improved testing of this * added useful comment on where i got that weird test from * fixed tests + added test for MH, which does not currently have hyperparams * bump minor version since this is breaking * updated docstring for TensorBoardLogger Co-authored-by: Tor Erlend Fjelde