Skip to content

Commit

Permalink
Merge pull request #74 from TuringLang/exportfunctions
Browse files Browse the repository at this point in the history
Updated to README and docs. Simplified sampling_tests.jl by using new…
  • Loading branch information
cpfiffer authored Mar 27, 2019
2 parents c3f9380 + 50b1370 commit a9b69b4
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 55 deletions.
84 changes: 71 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ chn = Chains(val,
Dict(:internals => ["d", "e"]))
```

Or through the `set_section` function, which returns a new `Chains` object. `Chains` objects cannot be modified in place due to section map immutability:
Or through the `set_section` function, which returns a new `Chains` object (as `Chains` objects cannot be modified in place due to section map immutability):

```julia
chn2 = set_section(chn, Dict(:internals => ["d", "e"]))
Expand Down Expand Up @@ -152,18 +152,6 @@ You can access each of the `P[. . .]` variables by indexing, using `x.P[1]`, `x.

Note that `x.P` is a tuple which has to be indexed by the relevant index, while `x.D` is just a vector.

### Saving and Loading Chains

Chains objects can be serialized and deserialized using `read` and `write`.

```julia
# Save a chain.
write("chain-file.jls", chn)

# Read a chain.
chn2 = read("chain-file.jls", Chains)
```

### Convergence Diagnostics functions
#### Discrete Diagnostic
Options for method are `[:weiss, :hangartner, :DARBOOT, MCBOOT, :billinsgley, :billingsleyBOOT]`
Expand Down Expand Up @@ -221,5 +209,75 @@ autocorplot(c::AbstractChains)
corner(c::AbstractChains, [:A, :B])
```

### Saving and Loading Chains

Chains objects can be serialized and deserialized using `read` and `write`.

```julia
# Save a chain.
write("chain-file.jls", chn)

# Read a chain.
chn2 = read("chain-file.jls", Chains)
```

### Exporting Chains

A few utility export functions have been provided to convers `Chains` objects to either an Array or a DataFrame:

```julia
# Several examples of creating an Array object:
Array(chns)
Array(chns[:s])
Array(chns, [:parameters])
Array(chns, [:parameters, :internals])

# By default chains are appended. This can be disabled
# using the append_chains keyword argument:
Array(chns, append_chains=false)

# This will return an `Array{Array, 1}` object containing
# an Array for each chain.

# A final option is:
Array(chns, remove_missing_union=false)

# This will not convert the Array columns from a
`Union{Missing, Real}` to a `Vector{Real}`.
```

Similarly, for DataFrames:

```julia
DataFrame(chns)
DataFrame(chns[:s])
DataFrame(chns, [:parameters])
DataFrame(chns, [:parameters, :internals])
DataFrame(chns, append_chains=false)
DataFrame(chns, remove_missing_union=false)
```

See also ?MCMCChains.DataFrame and ?MCMCChains.Array for more help.

### Sampling Chains

MCMCChains overloads several `sample()` methods as defined in StatsBase:

```julia
# Sampling `n` samples from the chain `a`. Optionally
# weighting the samples using `wv`.
sample([rng], a, [wv::AbstractWeights], n::Integer)

# E.g. creating 10000 weighted samples:
c = kde(Array(chn[:s]))
chn_weighted_sample = sample(c.x, Weights(c.density), 100000)

# As above, but supports replacing and ordering.
sample([rng], a, [wv::AbstractWeights], n::Integer; replace=true,
ordered=false)
```

See also ?MCMCChains.sample for additional help.

## License Notice
Note that this package heavily uses and adapts code from the Mamba.jl package licensed under MIT License, see License.md.
2 changes: 0 additions & 2 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,6 @@ function Array(chn::MCMCChains.AbstractChains,
b
end

Base.convert(::Type{Array}, chn::MCMCChains.Chains) = convert(Array, chn.value)

"""
# DataFrame
Expand Down
36 changes: 0 additions & 36 deletions test/df_chainsummary_tests.jl

This file was deleted.

8 changes: 4 additions & 4 deletions test/sampling_tests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Turing, MCMCChains, KernelDensity, StatsBase, Test, Statistics

@testset "sampling api" begin
#@testset "sampling api" begin

@model gdemo(x) = begin
m ~ Normal(1, 0.01)
Expand All @@ -14,9 +14,9 @@ using Turing, MCMCChains, KernelDensity, StatsBase, Test, Statistics
chn_sample = sample(chn, 5)
@test range(chn_sample) == 1:1:5

c = kde(reshape(convert(Array{Float64}, chn[:s].value), 500))
c = kde(Array(chn[:s]))
chn_weighted_sample = sample(c.x, Weights(c.density), 100000)

@test mean(convert(Array{Float64}, chn[:s].value)) 5.0 atol=0.1
@test mean(Array(chn[:s])) 5.0 atol=0.1

end
#end

0 comments on commit a9b69b4

Please sign in to comment.