Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,35 @@ In its place, you should directly use the accumulator API to:
To do so, we now export a convenience function `get_raw_values(::AbstractVarInfo)` that will get the stored `VarNamedTuple` of raw values.
This is exactly analogous to how `getlogprior(::AbstractVarInfo)` extracts the log-prior from a `LogPriorAccumulator`.

### Pointwise logdensities

Calling `pointwise_logdensities(model, varinfo)` now returns a `VarNamedTuple` of log-densities rather than an `OrderedDict`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to expose the two optional arguments (i.e. ::Val{Prior} and ::Val{Likelihood})?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically they were always exposed. But actually now thinking about it, I kind of prefer not to expose them, so we would move these arguments to an internal function and make pointwise_logdensities call that with (true, true).


Models with implicitly broadcasted observations (e.g. `y ~ Normal()` where `y` is an observed `Array`) will now return an `Array` of log-densities, one per element of `y`.
To recover the previous behaviour, you can sum the log-densities.

The method `pointwise_logdensities(model, chain::MCMCChains.Chains)` no longer accepts a `Tout` argument to control the output type; it always returns a new `MCMCChains.Chains`.

The internal argument `whichlogprob` for `pointwise_logdensities` is removed.
As a replacement you should just directly use `pointwise_logdensities`, `pointwise_loglikelihoods`, or `pointwise_prior_logdensities` as appropriate.

### Function signature changes in tilde-pipeline

`tilde_assume!!` and `accumulate_assume!!` now take extra arguments.
`tilde_assume!!`, `tilde_observe!!`, `accumulate_assume!!`, and `accumulate_observe!!` now take extra arguments.

In particular

- `tilde_assume!!(ctx, dist, vn, vi)` is now `tilde_assume!!(ctx, dist, vn, template, vi)`, where `template` is the value for the top-level symbol in `vn`.
For example, if `vn` is `@varname(x[1])`, then `template` is the current value of `x` in the model.
The DynamicPPL compiler is responsible for generating and providing this argument.

- Likewise, `tilde_observe!!(ctx, dist, left, vn, vi)` is now `tilde_observe!!(ctx, dist, left, vn, template, vi)`, where `template` is the same as above.
- `accumulate_assume!!(acc::AbstractAccumulator, val, logjac, vn, dist)` is now `accumulate_assume!!(acc, val, tval, logjac, vn, dist, template)`.
`template` is the same as above.
`tval` is either the `AbstractTransformedValue` that `DynamicPPL.init` provided (for InitContext), or the `AbstractTransformedValue` found inside the VarInfo (for DefaultContext).
- `accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, dist)` is now `accumulate_assume!!(vi, val, tval, logjac, vn, dist, template)`.
- `accumulate_observe!!(acc::AbstractAccumulator, dist, left, vn)` is now `accumulate_observe!!(acc, dist, left, vn, template)`, where `template` is the same as above.
- `accumulate_observe!!(vi::AbstractVarInfo, dist, left, vn)` is now `accumulate_observe!!(vi, dist, left, vn, template)`.

### `DynamicPPL.DebugUtils`

Expand Down
4 changes: 3 additions & 1 deletion docs/src/accs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ function DynamicPPL.accumulate_assume!!(
return acc
end

function DynamicPPL.accumulate_observe!!(acc::VarNameLogpAccumulator, dist, val, vn)
function DynamicPPL.accumulate_observe!!(
acc::VarNameLogpAccumulator, dist, val, vn, template
)
acc.logps[vn] = (true, logpdf(dist, val))
return acc
end
Expand Down
6 changes: 4 additions & 2 deletions docs/src/tilde.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ begin

elseif is_conditioned(vn)
conditioned_x = get_conditioned_value(vn)
raw_x, __varinfo__ = tilde_observe!!(ctx, dist, conditioned_x, vn, __varinfo__)
raw_x, __varinfo__ = tilde_observe!!(
ctx, dist, conditioned_x, vn, template, __varinfo__
)

elseif is_model_argument(vn)
arg_x = x
raw_x, __varinfo__ = tilde_observe!!(ctx, dist, arg_x, vn, __varinfo__)
raw_x, __varinfo__ = tilde_observe!!(ctx, dist, arg_x, vn, template, __varinfo__)

else
raw_x, __varinfo__ = tilde_assume!!(ctx, dist, vn, template, __varinfo__)
Expand Down
90 changes: 34 additions & 56 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,23 +390,40 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
return map(first, reevaluate_with_chain(model, chain, (), nothing))
end

"""
Shared internal helper function.
"""
function _pointwise_logdensities_chain(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
::Val{Prior}=Val(true),
::Val{Likelihood}=Val(true),
) where {Prior,Likelihood}
acc = DynamicPPL.VNTAccumulator{DynamicPPL.POINTWISE_ACCNAME}(
DynamicPPL.PointwiseLogProb{Prior,Likelihood}()
)
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
# Reevaluating this gives us a VNT of log probs. We can densify and then wrap in
# ParamsWithStats so that we can easily convert back to a Chains object.
pointwise_logps = map(
reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)
) do (_, oavi)
logprobs = DynamicPPL.get_pointwise_logprobs(oavi)
dense_logprobs = DynamicPPL.densify!!(logprobs)
DynamicPPL.ParamsWithStats(dense_logprobs, (;))
end
return AbstractMCMC.from_samples(MCMCChains.Chains, pointwise_logps)
end

"""
DynamicPPL.pointwise_logdensities(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
::Type{Tout}=MCMCChains.Chains
::Val{whichlogprob}=Val(:both),
)

Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where
the log-density of each variable at each sample is stored (rather than its value).

`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or
`:likelihood`.

You can pass `Tout=OrderedDict` to get the result as an `OrderedDict{VarName,
Matrix{Float64}}` instead.

See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref),
[`DynamicPPL.pointwise_prior_logdensities`](@ref).

Expand Down Expand Up @@ -463,56 +480,17 @@ julia> # The above is the same as:
-1.3822169643436162
-2.0986122886681096
```

julia> # Alternatively:
plds_dict = pointwise_logdensities(model, chain, OrderedDict)
OrderedDict{VarName, Matrix{Float64}} with 6 entries:
s => [-0.802775; -1.38222; -2.09861;;]
m => [-8.91894; -7.51551; -7.46824;;]
xs[1] => [-5.41894; -5.26551; -5.63491;;]
xs[2] => [-2.91894; -3.51551; -4.13491;;]
xs[3] => [-1.41894; -2.26551; -2.96824;;]
y => [-0.918939; -1.51551; -2.13491;;]
"""
function DynamicPPL.pointwise_logdensities(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
::Type{Tout}=MCMCChains.Chains,
::Val{whichlogprob}=Val(:both),
) where {whichlogprob,Tout}
acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}()
accname = DynamicPPL.accumulator_name(acc)
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
pointwise_logps =
map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi)
DynamicPPL.getacc(vi, Val(accname)).logps
end
# pointwise_logps is a matrix of OrderedDicts
all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
for d in pointwise_logps
union!(all_keys, DynamicPPL.OrderedCollections.OrderedSet(keys(d)))
end
# this is a 3D array: (iterations, variables, chains)
new_data = [
get(pointwise_logps[iter, chain], k, missing) for
iter in 1:size(pointwise_logps, 1), k in all_keys,
chain in 1:size(pointwise_logps, 2)
]

if Tout == MCMCChains.Chains
return MCMCChains.Chains(new_data, Symbol.(collect(all_keys)))
elseif Tout <: AbstractDict
return Tout{DynamicPPL.VarName,Matrix{Float64}}(
k => new_data[:, i, :] for (i, k) in enumerate(all_keys)
)
end
model::DynamicPPL.Model, chain::MCMCChains.Chains
)
return _pointwise_logdensities_chain(model, chain, Val(true), Val(true))
end

"""
DynamicPPL.pointwise_loglikelihoods(
model::DynamicPPL.Model,
chain::MCMCChains.Chains,
::Type{Tout}=MCMCChains.Chains
)

Compute the pointwise log-likelihoods of the model given the chain. This is the same as
Expand All @@ -521,9 +499,9 @@ Compute the pointwise log-likelihoods of the model given the chain. This is the
See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref).
"""
function DynamicPPL.pointwise_loglikelihoods(
model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains
) where {Tout}
return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:likelihood))
model::DynamicPPL.Model, chain::MCMCChains.Chains
)
return _pointwise_logdensities_chain(model, chain, Val(false), Val(true))
end

"""
Expand All @@ -538,9 +516,9 @@ Compute the pointwise log-prior-densities of the model given the chain. This is
See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_loglikelihoods`](@ref).
"""
function DynamicPPL.pointwise_prior_logdensities(
model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains
) where {Tout}
return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:prior))
model::DynamicPPL.Model, chain::MCMCChains.Chains
)
return _pointwise_logdensities_chain(model, chain, Val(true), Val(false))
end

"""
Expand Down
8 changes: 5 additions & 3 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,12 +326,14 @@ function accumulate_assume!!(vi::AbstractVarInfo, val, tval, logjac, vn, right,
end

"""
accumulate_observe!!(vi::AbstractVarInfo, right, left, vn)
accumulate_observe!!(vi::AbstractVarInfo, right, left, vn, template)

Update all the accumulators of `vi` by calling `accumulate_observe!!` on them.
"""
function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn)
return map_accumulators!!(acc -> accumulate_observe!!(acc, right, left, vn), vi)
function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn, template)
return map_accumulators!!(
acc -> accumulate_observe!!(acc, right, left, vn, template), vi
)
end

"""
Expand Down
4 changes: 2 additions & 2 deletions src/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ seen so far.

An accumulator type `T <: AbstractAccumulator` must implement the following methods:
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
- `accumulate_observe!!(acc::T, dist, val, vn)`
- `accumulate_observe!!(acc::T, dist, val, vn, template)`
- `accumulate_assume!!(acc::T, val, tval, logjac, vn, dist, template)`
- `reset(acc::T)`
- `Base.copy(acc::T)`
Expand Down Expand Up @@ -53,7 +53,7 @@ depends on the type of `acc`, not on its value.
accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc))

"""
accumulate_observe!!(acc::AbstractAccumulator, right, left, vn)
accumulate_observe!!(acc::AbstractAccumulator, right, left, vn, template)

Update `acc` in a `tilde_observe!!` call. Returns the updated `acc`.

Expand Down
2 changes: 1 addition & 1 deletion src/accumulators/bijector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function accumulate_assume!!(
return acc
end

accumulate_observe!!(acc::BijectorAccumulator, right, left, vn) = acc
accumulate_observe!!(acc::BijectorAccumulator, right, left, vn, template) = acc

"""
bijector(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior())
Expand Down
6 changes: 3 additions & 3 deletions src/accumulators/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ function accumulate_assume!!(
)
return acclogp(acc, logpdf(right, val))
end
accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc
accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn, template) = acc

"""
LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T}
Expand Down Expand Up @@ -142,7 +142,7 @@ function accumulate_assume!!(
)
return acclogp(acc, logjac)
end
accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc
accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn, template) = acc

"""
LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T}
Expand All @@ -166,7 +166,7 @@ function accumulate_assume!!(
)
return acc
end
function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn, template)
# Note that it's important to use the loglikelihood function here, not logpdf, because
# they handle vectors differently:
# https://github.com/JuliaStats/Distributions.jl/issues/1972
Expand Down
Loading