Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save CSV #44

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TuringCallbacks"
uuid = "ea0860ee-d0ef-45ef-82e6-cc37d6be2f9c"
authors = ["Tor Erlend Fjelde <[email protected]> and contributors"]
version = "0.3.1"
version = "0.4.0"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -26,8 +26,8 @@ DocStringExtensions = "0.8, 0.9"
OnlineStats = "1.5"
Reexport = "0.2, 1.0"
Requires = "1"
TensorBoardLogger = "0.1"
Turing = "0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22"
TensorBoardLogger = "0.1.22"
Turing = "0.29"
julia = "1"

[extras]
Expand Down
65 changes: 54 additions & 11 deletions ext/TuringCallbacksTuringExt.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,69 @@
module TuringCallbacksTuringExt

if isdefined(Base, :get_extension)
using Turing: Turing
using Turing: Turing, DynamicPPL
using TuringCallbacks: TuringCallbacks
else
# Requires compatible.
using ..Turing: Turing
using ..Turing: Turing, DynamicPPL
using ..TuringCallbacks: TuringCallbacks
end

const TuringTransition = Union{Turing.Inference.Transition,Turing.Inference.HMCTransition}
const TuringTransition = Union{
Turing.Inference.Transition,
Turing.Inference.SMCTransition,
Turing.Inference.PGTransition
}

function TuringCallbacks.params_and_values(transition::TuringTransition; kwargs...)
return Iterators.map(zip(Turing.Inference._params_to_array([transition])...)) do (ksym, val)
return string(ksym), val
end
function TuringCallbacks.params_and_values(
model::DynamicPPL.Model,
transition::TuringTransition;
kwargs...
)
vns, vals = Turing.Inference._params_to_array(model, [transition])
return zip(Iterators.map(string, vns), vals)
end

function TuringCallbacks.extras(transition::TuringTransition; kwargs...)
return Iterators.map(zip(Turing.Inference.get_transition_extras([transition])...)) do (ksym, val)
return string(ksym), val
end
function TuringCallbacks.extras(
model::DynamicPPL.Model, transition::TuringTransition;
kwargs...
)
names, vals = Turing.Inference.get_transition_extras([transition])
return zip(string.(names), vec(vals))
end

default_hyperparams(sampler::DynamicPPL.Sampler) = default_hyperparams(sampler.alg)
default_hyperparams(alg::Turing.Inference.InferenceAlgorithm) = (
string(f) => getfield(alg, f) for f in fieldnames(typeof(alg))
)

const AlgsWithDefaultHyperparams = Union{
Turing.Inference.HMC,
Turing.Inference.HMCDA,
Turing.Inference.NUTS,
Turing.Inference.SGHMC,

}

function TuringCallbacks.hyperparams(
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:AlgsWithDefaultHyperparams};
kwargs...
)
return default_hyperparams(sampler)
end

function TuringCallbacks.hyperparam_metrics(
model,
sampler::Turing.Sampler{<:Turing.Inference.NUTS}
)
return [
"extras/acceptance_rate/stat/Mean",
"extras/max_hamiltonian_energy_error/stat/Mean",
"extras/lp/stat/Mean",
"extras/n_steps/stat/Mean",
"extras/tree_depth/stat/Mean"
]
end

end
1 change: 1 addition & 0 deletions src/TuringCallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ end

export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback

include("utils.jl")
include("stats.jl")
include("tensorboardlogger.jl")
include("callbacks/tensorboard.jl")
Expand Down
37 changes: 37 additions & 0 deletions src/callbacks/save.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
###############################
### Saves samples on the go ###
###############################

"""
SaveCSV

A callback saves samples to .csv file during sampling
"""
function SaveCSV(

Check warning on line 10 in src/callbacks/save.jl

View check run for this annotation

Codecov / codecov/patch

src/callbacks/save.jl#L10

Added line #L10 was not covered by tests
rng::AbstractRNG,
model::Model,
sampler::Sampler,
transition,
state,
iteration::Int64;
kwargs...,
)
SaveCSV(rng, model, sampler, transition, state.vi, iteration; kwargs...)

Check warning on line 19 in src/callbacks/save.jl

View check run for this annotation

Codecov / codecov/patch

src/callbacks/save.jl#L19

Added line #L19 was not covered by tests
end

function SaveCSV(

Check warning on line 22 in src/callbacks/save.jl

View check run for this annotation

Codecov / codecov/patch

src/callbacks/save.jl#L22

Added line #L22 was not covered by tests
rng::AbstractRNG,
model::Model,
sampler::Sampler,
transition,
vi::AbstractVarInfo,
iteration::Int64;
kwargs...,
)
vii = deepcopy(vi)
invlink!!(vii, model)
θ = vii[sampler]

Check warning on line 33 in src/callbacks/save.jl

View check run for this annotation

Codecov / codecov/patch

src/callbacks/save.jl#L31-L33

Added lines #L31 - L33 were not covered by tests
# it would be good to have the param names as in the chain
chain_name = get(kwargs, :chain_name, "chain")
write(string(chain_name, ".csv"), Dict("params" => [θ]); append = true, delim = ";")

Check warning on line 36 in src/callbacks/save.jl

View check run for this annotation

Codecov / codecov/patch

src/callbacks/save.jl#L35-L36

Added lines #L35 - L36 were not covered by tests
end
Loading
Loading