Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save #49

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open

Save #49

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
name = "TuringCallbacks"
uuid = "ea0860ee-d0ef-45ef-82e6-cc37d6be2f9c"
authors = ["Tor Erlend Fjelde <[email protected]> and contributors"]
version = "0.4.0"
version = "0.4.1"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Copy link
Member

@yebai yebai Oct 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JaimeRZP I think @torfjelde wanted to avoid explicit dependence on Turing, CSV and DynamicPPL. Let's consider converting this PR into a package extension to make these dependencies optional.


[weakdeps]
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Expand Down
8 changes: 6 additions & 2 deletions src/TuringCallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ module TuringCallbacks

using Reexport

using LinearAlgebra
using CSV
using Random
using Logging
using LinearAlgebra
using DocStringExtensions
import DynamicPPL: AbstractVarInfo, Model, Sampler

@reexport using OnlineStats # used to compute different statistics on-the-fly

Expand All @@ -17,13 +20,14 @@ using DataStructures: DefaultDict
using Requires
end

export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback
export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback, SaveCSV

include("utils.jl")
include("stats.jl")
include("tensorboardlogger.jl")
include("callbacks/tensorboard.jl")
include("callbacks/multicallback.jl")
include("callbacks/save.jl")

@static if !isdefined(Base, :get_extension)
function __init__()
Expand Down
37 changes: 37 additions & 0 deletions src/callbacks/save.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
###############################
### Saves samples on the go ###
###############################

"""
SaveCSV

A callback saves samples to .csv file during sampling
"""
function SaveCSV(
rng::AbstractRNG,
model::Model,
sampler::Sampler,
transition,
state,
iteration::Int64;
kwargs...,
)
SaveCSV(rng, model, sampler, transition, state.vi, iteration; kwargs...)
end

function SaveCSV(
rng::AbstractRNG,
model::Model,
sampler::Sampler,
transition,
vi::AbstractVarInfo,
iteration::Int64;
kwargs...,
)
vii = deepcopy(vi)
#invlink!!(vii, model)
θ = vii[sampler]
# it would be good to have the param names as in the chain
chain_name = get(kwargs, :chain_name, "chain")
CSV.write(string(chain_name, ".csv"), Dict("params" => [θ]), append = true, delim = ";")
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using Test
using CSV
using DataFrames
using Turing
using TuringCallbacks
using TensorBoardLogger, ValueHistories
Expand Down Expand Up @@ -26,4 +28,5 @@ const demo_model = demo(randn(100) .+ 1)
@testset "TuringCallbacks.jl" begin
include("multicallback.jl")
include("tensorboardcallback.jl")
include("save.jl")
end
13 changes: 13 additions & 0 deletions test/save.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
@testset "SaveCallback" begin
# Number of MCMC samples/steps
num_samples = 100
num_adapts = 50

# Sampling algorithm to use
alg = NUTS(num_adapts, 0.65)

sample(demo_model, alg, num_samples; callback = SaveCSV, chain_name="chain_1")
chain = Matrix(CSV.read("chain_1.csv", DataFrame, header=false))
@test size(Matrix(chain)) == (num_samples, 2)
rm("chain_1.csv")
end
Loading