From a7cf400e7d58695b727d5c8dcb7dd525254a10bf Mon Sep 17 00:00:00 2001 From: jaimerz Date: Thu, 12 Oct 2023 17:45:52 +0100 Subject: [PATCH 1/4] clean start --- src/TuringCallbacks.jl | 1 + src/callbacks/save.jl | 37 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + test/save.jl | 7 +++++++ 4 files changed, 46 insertions(+) create mode 100644 src/callbacks/save.jl create mode 100644 test/save.jl diff --git a/src/TuringCallbacks.jl b/src/TuringCallbacks.jl index 1d183c3..2ebfcd8 100644 --- a/src/TuringCallbacks.jl +++ b/src/TuringCallbacks.jl @@ -24,6 +24,7 @@ include("stats.jl") include("tensorboardlogger.jl") include("callbacks/tensorboard.jl") include("callbacks/multicallback.jl") +include("callbacks/save.jl") @static if !isdefined(Base, :get_extension) function __init__() diff --git a/src/callbacks/save.jl b/src/callbacks/save.jl new file mode 100644 index 0000000..638f78f --- /dev/null +++ b/src/callbacks/save.jl @@ -0,0 +1,37 @@ +############################### +### Saves samples on the go ### +############################### + +""" + SaveCSV + +A callback saves samples to .csv file during sampling +""" +function SaveCSV( + rng::AbstractRNG, + model::Model, + sampler::Sampler, + transition, + state, + iteration::Int64; + kwargs..., +) + SaveCSV(rng, model, sampler, transition, state.vi, iteration; kwargs...) +end + +function SaveCSV( + rng::AbstractRNG, + model::Model, + sampler::Sampler, + transition, + vi::AbstractVarInfo, + iteration::Int64; + kwargs..., +) + vii = deepcopy(vi) + invlink!!(vii, model) + θ = vii[sampler] + # it would be good to have the param names as in the chain + chain_name = get(kwargs, :chain_name, "chain") + write(string(chain_name, ".csv"), Dict("params" => [θ]); append = true, delim = ";") +end diff --git a/test/runtests.jl b/test/runtests.jl index cdb8aa6..fb67b70 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,4 +26,5 @@ const demo_model = demo(randn(100) .+ 1) @testset "TuringCallbacks.jl" begin include("multicallback.jl") include("tensorboardcallback.jl") + include("save.jl") end diff --git a/test/save.jl b/test/save.jl new file mode 100644 index 0000000..c29dc95 --- /dev/null +++ b/test/save.jl @@ -0,0 +1,7 @@ +@testset "SaveCallback" begin + # Sample + sample(model, alg, num_samples; callback = SaveCSV, chain_name="chain_1") + chain = Matrix(CSV.read("chain_1.csv", DataFrame, header=false)) + @test size(chain) == (num_samples, 2) + rm("chain_1.csv") +end \ No newline at end of file From 01ffc090e48f7918ffa7f4aebbd8f6b832f3f5f6 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Thu, 12 Oct 2023 17:47:44 +0100 Subject: [PATCH 2/4] bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e863cba..bedde1f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TuringCallbacks" uuid = "ea0860ee-d0ef-45ef-82e6-cc37d6be2f9c" authors = ["Tor Erlend Fjelde and contributors"] -version = "0.4.0" +version = "0.4.1" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" From ceba035d41a20ae0ebbdb652ab3bcbdcfc48b4bc Mon Sep 17 00:00:00 2001 From: jaimerz Date: Mon, 16 Oct 2023 11:09:33 +0100 Subject: [PATCH 3/4] working --- Project.toml | 4 ++++ src/TuringCallbacks.jl | 7 +++++-- src/callbacks/save.jl | 4 ++-- test/Project.toml | 2 ++ test/runtests.jl | 2 ++ test/save.jl | 11 +++++++++-- 6 files changed, 24 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index bedde1f..4413737 100644 --- a/Project.toml +++ b/Project.toml @@ -4,15 +4,19 @@ authors = ["Tor Erlend Fjelde and contributors"] version = "0.4.1" [deps] +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [weakdeps] Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/src/TuringCallbacks.jl b/src/TuringCallbacks.jl index 2ebfcd8..b758bb8 100644 --- a/src/TuringCallbacks.jl +++ b/src/TuringCallbacks.jl @@ -2,9 +2,12 @@ module TuringCallbacks using Reexport -using LinearAlgebra +using CSV +using Random using Logging +using LinearAlgebra using DocStringExtensions +import DynamicPPL: AbstractVarInfo, Model, Sampler @reexport using OnlineStats # used to compute different statistics on-the-fly @@ -17,7 +20,7 @@ using DataStructures: DefaultDict using Requires end -export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback +export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback, SaveCSV include("utils.jl") include("stats.jl") diff --git a/src/callbacks/save.jl b/src/callbacks/save.jl index 638f78f..a50b19e 100644 --- a/src/callbacks/save.jl +++ b/src/callbacks/save.jl @@ -29,9 +29,9 @@ function SaveCSV( kwargs..., ) vii = deepcopy(vi) - invlink!!(vii, model) + #invlink!!(vii, model) θ = vii[sampler] # it would be good to have the param names as in the chain chain_name = get(kwargs, :chain_name, "chain") - write(string(chain_name, ".csv"), Dict("params" => [θ]); append = true, delim = ";") + CSV.write(string(chain_name, ".csv"), Dict("params" => [θ]), append = true, delim = ";") end diff --git a/test/Project.toml b/test/Project.toml index 95733df..dee7ff1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,6 @@ [deps] +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/test/runtests.jl b/test/runtests.jl index fb67b70..f16d50e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,6 @@ using Test +using CSV +using DataFrames using Turing using TuringCallbacks using TensorBoardLogger, ValueHistories diff --git a/test/save.jl b/test/save.jl index c29dc95..a4410b4 100644 --- a/test/save.jl +++ b/test/save.jl @@ -1,7 +1,14 @@ @testset "SaveCallback" begin - # Sample - sample(model, alg, num_samples; callback = SaveCSV, chain_name="chain_1") + # Number of MCMC samples/steps + num_samples = 100 + num_adapts = 50 + + # Sampling algorithm to use + alg = NUTS(num_adapts, 0.65) + + sample(demo_model, alg, num_samples; callback = SaveCSV, chain_name="chain_1") chain = Matrix(CSV.read("chain_1.csv", DataFrame, header=false)) + println(chain) @test size(chain) == (num_samples, 2) rm("chain_1.csv") end \ No newline at end of file From af4dfcf4fc060da02a53f204929f2a4232012302 Mon Sep 17 00:00:00 2001 From: jaimerz Date: Mon, 16 Oct 2023 11:10:26 +0100 Subject: [PATCH 4/4] no print --- test/save.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/save.jl b/test/save.jl index a4410b4..0b6213e 100644 --- a/test/save.jl +++ b/test/save.jl @@ -8,7 +8,6 @@ sample(demo_model, alg, num_samples; callback = SaveCSV, chain_name="chain_1") chain = Matrix(CSV.read("chain_1.csv", DataFrame, header=false)) - println(chain) - @test size(chain) == (num_samples, 2) + @test size(Matrix(chain)) == (num_samples, 2) rm("chain_1.csv") end \ No newline at end of file