diff --git a/Project.toml b/Project.toml index e863cba..4413737 100644 --- a/Project.toml +++ b/Project.toml @@ -1,18 +1,22 @@ name = "TuringCallbacks" uuid = "ea0860ee-d0ef-45ef-82e6-cc37d6be2f9c" authors = ["Tor Erlend Fjelde and contributors"] -version = "0.4.0" +version = "0.4.1" [deps] +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [weakdeps] Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/src/TuringCallbacks.jl b/src/TuringCallbacks.jl index 1d183c3..b758bb8 100644 --- a/src/TuringCallbacks.jl +++ b/src/TuringCallbacks.jl @@ -2,9 +2,12 @@ module TuringCallbacks using Reexport -using LinearAlgebra +using CSV +using Random using Logging +using LinearAlgebra using DocStringExtensions +import DynamicPPL: AbstractVarInfo, Model, Sampler @reexport using OnlineStats # used to compute different statistics on-the-fly @@ -17,13 +20,14 @@ using DataStructures: DefaultDict using Requires end -export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback +export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback, SaveCSV include("utils.jl") include("stats.jl") include("tensorboardlogger.jl") include("callbacks/tensorboard.jl") include("callbacks/multicallback.jl") +include("callbacks/save.jl") @static if !isdefined(Base, :get_extension) function __init__() diff --git a/src/callbacks/save.jl b/src/callbacks/save.jl new file mode 100644 index 0000000..a50b19e --- /dev/null +++ b/src/callbacks/save.jl @@ -0,0 +1,37 @@ +############################### +### Saves samples on the go ### +############################### + +""" + SaveCSV + +A callback saves samples to .csv file during sampling +""" +function SaveCSV( + rng::AbstractRNG, + model::Model, + sampler::Sampler, + transition, + state, + iteration::Int64; + kwargs..., +) + SaveCSV(rng, model, sampler, transition, state.vi, iteration; kwargs...) +end + +function SaveCSV( + rng::AbstractRNG, + model::Model, + sampler::Sampler, + transition, + vi::AbstractVarInfo, + iteration::Int64; + kwargs..., +) + vii = deepcopy(vi) + #invlink!!(vii, model) + θ = vii[sampler] + # it would be good to have the param names as in the chain + chain_name = get(kwargs, :chain_name, "chain") + CSV.write(string(chain_name, ".csv"), Dict("params" => [θ]), append = true, delim = ";") +end diff --git a/test/Project.toml b/test/Project.toml index 95733df..dee7ff1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,6 @@ [deps] +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/test/runtests.jl b/test/runtests.jl index cdb8aa6..f16d50e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,6 @@ using Test +using CSV +using DataFrames using Turing using TuringCallbacks using TensorBoardLogger, ValueHistories @@ -26,4 +28,5 @@ const demo_model = demo(randn(100) .+ 1) @testset "TuringCallbacks.jl" begin include("multicallback.jl") include("tensorboardcallback.jl") + include("save.jl") end diff --git a/test/save.jl b/test/save.jl new file mode 100644 index 0000000..0b6213e --- /dev/null +++ b/test/save.jl @@ -0,0 +1,13 @@ +@testset "SaveCallback" begin + # Number of MCMC samples/steps + num_samples = 100 + num_adapts = 50 + + # Sampling algorithm to use + alg = NUTS(num_adapts, 0.65) + + sample(demo_model, alg, num_samples; callback = SaveCSV, chain_name="chain_1") + chain = Matrix(CSV.read("chain_1.csv", DataFrame, header=false)) + @test size(Matrix(chain)) == (num_samples, 2) + rm("chain_1.csv") +end \ No newline at end of file