From db48f7b73050c0e1795ed46f0013c3c5ba0074d2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 13 Feb 2026 16:20:42 +0000 Subject: [PATCH 1/5] Use templates in observe pipeline, make pointwise_logprobs return VNT --- docs/src/accs/overview.md | 4 +- docs/src/tilde.md | 6 +- ext/DynamicPPLMCMCChainsExt.jl | 78 ++++++------------ src/abstract_varinfo.jl | 8 +- src/accumulators.jl | 4 +- src/accumulators/bijector.jl | 2 +- src/accumulators/default.jl | 6 +- src/accumulators/pointwise_logdensities.jl | 96 +++++++++------------- src/accumulators/vector_params.jl | 2 +- src/accumulators/vnt.jl | 2 +- src/compiler.jl | 7 ++ src/contexts.jl | 10 ++- src/contexts/default.jl | 4 +- src/contexts/init.jl | 3 +- src/contexts/prefix.jl | 4 +- src/debug_utils.jl | 2 +- src/submodel.jl | 3 +- src/varnamedtuple/partial_array.jl | 19 +++-- test/accumulators.jl | 16 ++-- test/logdensityfunction.jl | 2 +- test/varinfo.jl | 3 - 21 files changed, 130 insertions(+), 151 deletions(-) diff --git a/docs/src/accs/overview.md b/docs/src/accs/overview.md index 5a3bd6ed9..14afcfc87 100644 --- a/docs/src/accs/overview.md +++ b/docs/src/accs/overview.md @@ -91,7 +91,9 @@ function DynamicPPL.accumulate_assume!!( return acc end -function DynamicPPL.accumulate_observe!!(acc::VarNameLogpAccumulator, dist, val, vn) +function DynamicPPL.accumulate_observe!!( + acc::VarNameLogpAccumulator, dist, val, vn, template +) acc.logps[vn] = (true, logpdf(dist, val)) return acc end diff --git a/docs/src/tilde.md b/docs/src/tilde.md index 97197c784..cd2ac46e7 100644 --- a/docs/src/tilde.md +++ b/docs/src/tilde.md @@ -26,11 +26,13 @@ begin elseif is_conditioned(vn) conditioned_x = get_conditioned_value(vn) - raw_x, __varinfo__ = tilde_observe!!(ctx, dist, conditioned_x, vn, __varinfo__) + raw_x, __varinfo__ = tilde_observe!!( + ctx, dist, conditioned_x, vn, template, __varinfo__ + ) elseif is_model_argument(vn) arg_x = x - raw_x, __varinfo__ = tilde_observe!!(ctx, dist, arg_x, vn, __varinfo__) + raw_x, __varinfo__ = tilde_observe!!(ctx, dist, arg_x, vn, template, __varinfo__) else raw_x, __varinfo__ = tilde_assume!!(ctx, dist, vn, template, __varinfo__) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 2b227d40c..bf8803728 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -394,18 +394,17 @@ end DynamicPPL.pointwise_logdensities( model::DynamicPPL.Model, chain::MCMCChains.Chains, - ::Type{Tout}=MCMCChains.Chains - ::Val{whichlogprob}=Val(:both), + ::Val{Prior}=Val(true), + ::Val{Likelihood}=Val(true) ) Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where the log-density of each variable at each sample is stored (rather than its value). -`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or -`:likelihood`. - -You can pass `Tout=OrderedDict` to get the result as an `OrderedDict{VarName, -Matrix{Float64}}` instead. +The `Val`s passed as the last two arguments control which log-probabilities are included in +the result. If both are `true`, then the log-probabilities include both the prior and +likelihood terms. If only one of them is `true`, then only the corresponding +log-probabilities are included. See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref). @@ -463,56 +462,33 @@ julia> # The above is the same as: -1.3822169643436162 -2.0986122886681096 ``` - -julia> # Alternatively: - plds_dict = pointwise_logdensities(model, chain, OrderedDict) -OrderedDict{VarName, Matrix{Float64}} with 6 entries: - s => [-0.802775; -1.38222; -2.09861;;] - m => [-8.91894; -7.51551; -7.46824;;] - xs[1] => [-5.41894; -5.26551; -5.63491;;] - xs[2] => [-2.91894; -3.51551; -4.13491;;] - xs[3] => [-1.41894; -2.26551; -2.96824;;] - y => [-0.918939; -1.51551; -2.13491;;] """ function DynamicPPL.pointwise_logdensities( model::DynamicPPL.Model, chain::MCMCChains.Chains, - ::Type{Tout}=MCMCChains.Chains, - ::Val{whichlogprob}=Val(:both), -) where {whichlogprob,Tout} - acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}() - accname = DynamicPPL.accumulator_name(acc) + ::Val{Prior}=Val(true), + ::Val{Likelihood}=Val(true), +) where {Prior,Likelihood} + acc = DynamicPPL.VNTAccumulator{DynamicPPL.POINTWISE_ACCNAME}( + DynamicPPL.PointwiseLogProb{Prior,Likelihood}() + ) parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - pointwise_logps = - map(reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing)) do (_, vi) - DynamicPPL.getacc(vi, Val(accname)).logps - end - # pointwise_logps is a matrix of OrderedDicts - all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() - for d in pointwise_logps - union!(all_keys, DynamicPPL.OrderedCollections.OrderedSet(keys(d))) - end - # this is a 3D array: (iterations, variables, chains) - new_data = [ - get(pointwise_logps[iter, chain], k, missing) for - iter in 1:size(pointwise_logps, 1), k in all_keys, - chain in 1:size(pointwise_logps, 2) - ] - - if Tout == MCMCChains.Chains - return MCMCChains.Chains(new_data, Symbol.(collect(all_keys))) - elseif Tout <: AbstractDict - return Tout{DynamicPPL.VarName,Matrix{Float64}}( - k => new_data[:, i, :] for (i, k) in enumerate(all_keys) - ) + # Reevaluating this gives us a VNT of log probs. We can densify and then wrap in + # ParamsWithStats so that we can easily convert back to a Chains object. + pointwise_logps = map( + reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing) + ) do (_, oavi) + logprobs = DynamicPPL.get_pointwise_logprobs(oavi) + dense_logprobs = DynamicPPL.densify!!(logprobs) + DynamicPPL.ParamsWithStats(dense_logprobs, (;)) end + return AbstractMCMC.from_samples(MCMCChains.Chains, pointwise_logps) end """ DynamicPPL.pointwise_loglikelihoods( model::DynamicPPL.Model, chain::MCMCChains.Chains, - ::Type{Tout}=MCMCChains.Chains ) Compute the pointwise log-likelihoods of the model given the chain. This is the same as @@ -521,9 +497,9 @@ Compute the pointwise log-likelihoods of the model given the chain. This is the See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref). """ function DynamicPPL.pointwise_loglikelihoods( - model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains -) where {Tout} - return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:likelihood)) + model::DynamicPPL.Model, chain::MCMCChains.Chains +) + return DynamicPPL.pointwise_logdensities(model, chain, Val(false), Val(true)) end """ @@ -538,9 +514,9 @@ Compute the pointwise log-prior-densities of the model given the chain. This is See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_loglikelihoods`](@ref). """ function DynamicPPL.pointwise_prior_logdensities( - model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains -) where {Tout} - return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:prior)) + model::DynamicPPL.Model, chain::MCMCChains.Chains +) + return DynamicPPL.pointwise_logdensities(model, chain, Val(true), Val(false)) end """ diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 122413e9a..ba30404ec 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -326,12 +326,14 @@ function accumulate_assume!!(vi::AbstractVarInfo, val, tval, logjac, vn, right, end """ - accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) + accumulate_observe!!(vi::AbstractVarInfo, right, left, vn, template) Update all the accumulators of `vi` by calling `accumulate_observe!!` on them. """ -function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) - return map_accumulators!!(acc -> accumulate_observe!!(acc, right, left, vn), vi) +function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn, template) + return map_accumulators!!( + acc -> accumulate_observe!!(acc, right, left, vn, template), vi + ) end """ diff --git a/src/accumulators.jl b/src/accumulators.jl index 861522823..84364fab5 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -11,7 +11,7 @@ seen so far. An accumulator type `T <: AbstractAccumulator` must implement the following methods: - `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` -- `accumulate_observe!!(acc::T, dist, val, vn)` +- `accumulate_observe!!(acc::T, dist, val, vn, template)` - `accumulate_assume!!(acc::T, val, tval, logjac, vn, dist, template)` - `reset(acc::T)` - `Base.copy(acc::T)` @@ -53,7 +53,7 @@ depends on the type of `acc`, not on its value. accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc)) """ - accumulate_observe!!(acc::AbstractAccumulator, right, left, vn) + accumulate_observe!!(acc::AbstractAccumulator, right, left, vn, template) Update `acc` in a `tilde_observe!!` call. Returns the updated `acc`. diff --git a/src/accumulators/bijector.jl b/src/accumulators/bijector.jl index c2876324c..473fa5062 100644 --- a/src/accumulators/bijector.jl +++ b/src/accumulators/bijector.jl @@ -37,7 +37,7 @@ function accumulate_assume!!( return acc end -accumulate_observe!!(acc::BijectorAccumulator, right, left, vn) = acc +accumulate_observe!!(acc::BijectorAccumulator, right, left, vn, template) = acc """ bijector(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) diff --git a/src/accumulators/default.jl b/src/accumulators/default.jl index 8d561f34e..bd5b98cc7 100644 --- a/src/accumulators/default.jl +++ b/src/accumulators/default.jl @@ -96,7 +96,7 @@ function accumulate_assume!!( ) return acclogp(acc, logpdf(right, val)) end -accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc +accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn, template) = acc """ LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T} @@ -142,7 +142,7 @@ function accumulate_assume!!( ) return acclogp(acc, logjac) end -accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc +accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn, template) = acc """ LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T} @@ -166,7 +166,7 @@ function accumulate_assume!!( ) return acc end -function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) +function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn, template) # Note that it's important to use the loglikelihood function here, not logpdf, because # they handle vectors differently: # https://github.com/JuliaStats/Distributions.jl/issues/1972 diff --git a/src/accumulators/pointwise_logdensities.jl b/src/accumulators/pointwise_logdensities.jl index 9e4be7c6a..dcbce2ce5 100644 --- a/src/accumulators/pointwise_logdensities.jl +++ b/src/accumulators/pointwise_logdensities.jl @@ -9,81 +9,61 @@ the variable names and the values are log-probabilities. `whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies which log-probabilities to store in the accumulator. """ -struct PointwiseLogProbAccumulator{whichlogprob} <: AbstractAccumulator - logps::OrderedDict{VarName,LogProbType} - function PointwiseLogProbAccumulator{whichlogprob}( - d::OrderedDict{VarName,LogProbType}=OrderedDict{VarName,LogProbType}() - ) where {whichlogprob} - return new{whichlogprob}(d) +struct PointwiseLogProb{Prior,Likelihood} end +function (plp::PointwiseLogProb{Prior,Likelihood})( + val, tval, logjac, vn, dist +) where {Prior,Likelihood} + if Prior + return logpdf(dist, val) + else + return DoNotAccumulate() end end +const POINTWISE_ACCNAME = :PointwiseLogProbAccumulator -function Base.:(==)( - acc1::PointwiseLogProbAccumulator{wlp1}, acc2::PointwiseLogProbAccumulator{wlp2} -) where {wlp1,wlp2} - return (wlp1 == wlp2 && acc1.logps == acc2.logps) -end - -function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps)) -end - -function accumulator_name( - ::Type{<:PointwiseLogProbAccumulator{whichlogprob}} -) where {whichlogprob} - return Symbol("PointwiseLogProbAccumulator{$whichlogprob}") -end - -function _zero(::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob}() -end -reset(acc::PointwiseLogProbAccumulator) = _zero(acc) -split(acc::PointwiseLogProbAccumulator) = _zero(acc) -function combine( - acc::PointwiseLogProbAccumulator{whichlogprob}, - acc2::PointwiseLogProbAccumulator{whichlogprob}, -) where {whichlogprob} - return PointwiseLogProbAccumulator{whichlogprob}(mergewith(+, acc.logps, acc2.logps)) -end - -function accumulate_assume!!( - acc::PointwiseLogProbAccumulator{whichlogprob}, val, tval, logjac, vn, right, template -) where {whichlogprob} - if whichlogprob == :both || whichlogprob == :prior - acc.logps[vn] = logpdf(right, val) - end - return acc +# Not exported +function get_pointwise_logprobs(varinfo::AbstractVarInfo) + return getacc(varinfo, Val(POINTWISE_ACCNAME)).values end +# Have to overload accumulate_assume!! since VNTAccumulator by default does not track +# observe statements. function accumulate_observe!!( - acc::PointwiseLogProbAccumulator{whichlogprob}, right, left, vn -) where {whichlogprob} - # If `vn` is nothing the LHS of ~ is a literal and we don't have a name to attach this - # acc to, and thus do nothing. - if vn === nothing - return acc - end - if whichlogprob == :both || whichlogprob == :likelihood - acc.logps[vn] = loglikelihood(right, left) + acc::VNTAccumulator{POINTWISE_ACCNAME,PointwiseLogProb{Prior,Likelihood}}, + right, + left, + vn, + template, +) where {Prior,Likelihood} + # vn could be `nothing`, in which case we can't store it in a VNT. + return if Likelihood && vn isa VarName + logp = logpdf(right, left) + new_values = DynamicPPL.templated_setindex!!(acc.values, logp, vn, template) + return VNTAccumulator{POINTWISE_ACCNAME}(acc.f, new_values) + else + # No need to accumulate likelihoods. + acc end - return acc end function pointwise_logdensities( - model::Model, varinfo::AbstractVarInfo, ::Val{whichlogprob}=Val(:both) -) where {whichlogprob} - AccType = PointwiseLogProbAccumulator{whichlogprob} - oavi = OnlyAccsVarInfo((AccType(),)) + model::Model, + varinfo::AbstractVarInfo, + ::Val{Prior}=Val(true), + ::Val{Likelihood}=Val(true), +) where {Prior,Likelihood} + acc = VNTAccumulator{POINTWISE_ACCNAME}(PointwiseLogProb{Prior,Likelihood}()) + oavi = OnlyAccsVarInfo(acc) init_strategy = InitFromParams(varinfo.values, nothing) oavi = last(init!!(model, oavi, init_strategy, UnlinkAll())) - return getacc(oavi, Val(accumulator_name(AccType))).logps + return get_pointwise_logprobs(oavi) end function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - return pointwise_logdensities(model, varinfo, Val(:likelihood)) + return pointwise_logdensities(model, varinfo, Val(false), Val(true)) end function pointwise_prior_logdensities(model::Model, varinfo::AbstractVarInfo) - return pointwise_logdensities(model, varinfo, Val(:prior)) + return pointwise_logdensities(model, varinfo, Val(true), Val(false)) end diff --git a/src/accumulators/vector_params.jl b/src/accumulators/vector_params.jl index 1f266d48d..0f7b9dd7d 100644 --- a/src/accumulators/vector_params.jl +++ b/src/accumulators/vector_params.jl @@ -29,7 +29,7 @@ const VECTOR_ACC_NAME = :VectorParamAccumulator DynamicPPL.accumulator_name(::Type{<:VectorParamAccumulator}) = VECTOR_ACC_NAME function DynamicPPL.accumulate_observe!!( - acc::VectorParamAccumulator, ::Distribution, val, ::Union{VarName,Nothing} + acc::VectorParamAccumulator, ::Distribution, val, ::Union{VarName,Nothing}, ::Any ) return acc end diff --git a/src/accumulators/vnt.jl b/src/accumulators/vnt.jl index c6db0e841..b85da5d49 100644 --- a/src/accumulators/vnt.jl +++ b/src/accumulators/vnt.jl @@ -68,4 +68,4 @@ function accumulate_assume!!( VNTAccumulator{AccName}(acc.f, new_values) end end -accumulate_observe!!(acc::VNTAccumulator, right, left, vn) = acc +accumulate_observe!!(acc::VNTAccumulator, right, left, vn, template) = acc diff --git a/src/compiler.jl b/src/compiler.jl index 5e759702a..7a67125ae 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -436,6 +436,7 @@ function generate_tilde_literal(left, right) $(DynamicPPL.check_tilde_rhs)($right), $left, nothing, + $(NoTemplate()), __varinfo__, ) $value @@ -464,6 +465,11 @@ variables. """ function generate_tilde(left, right) isliteral(left) && return generate_tilde_literal(left, right) + template = if left isa Symbol # i.e. identity optic + :($(NoTemplate)()) + else + get_top_level_symbol(left) + end # Otherwise it is determined by the model or its value, # if the LHS represents an observation @@ -509,6 +515,7 @@ function generate_tilde(left, right) $(DynamicPPL.check_tilde_rhs)($dist), $supplied_val, $vn, + $template, __varinfo__, ) $(assign_or_set!!(left, value, vn)) diff --git a/src/contexts.jl b/src/contexts.jl index 80fae8411..7a14b6970 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -131,6 +131,7 @@ end right::Distribution, left, vn::Union{VarName, Nothing}, + template::Any, vi::AbstractVarInfo )::Tuple{Any,AbstractVarInfo} @@ -145,9 +146,9 @@ the VarInfo object `vi` (except for fixed variables, which do not contribute to log-probability). `left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the -left-hand side, or `nothing` if the left-hand side is a literal value. - -Observations of submodels are not yet supported in DynamicPPL. +left-hand side, or `nothing` if the left-hand side is a literal value. `template` is the +value of the top-level symbol in `vn`; if `vn` is `nothing`, then `template` will be +`NoTemplate()`. This function should return a tuple `(left, vi)`, where `left` is the same as the input, and `vi` is the updated VarInfo. @@ -157,9 +158,10 @@ function tilde_observe!!( right::Distribution, left, vn::Union{VarName,Nothing}, + template::Any, vi::AbstractVarInfo, ) - return tilde_observe!!(childcontext(context), right, left, vn, vi) + return tilde_observe!!(childcontext(context), right, left, vn, template, vi) end function tilde_observe!!( context::AbstractContext, diff --git a/src/contexts/default.jl b/src/contexts/default.jl index 338e165bc..dca451eab 100644 --- a/src/contexts/default.jl +++ b/src/contexts/default.jl @@ -59,6 +59,7 @@ end right::Distribution, left, vn::Union{VarName,Nothing}, + template::Any, vi::AbstractVarInfo, ) @@ -69,8 +70,9 @@ function tilde_observe!!( right::Distribution, left, vn::Union{VarName,Nothing}, + template::Any, vi::AbstractVarInfo, ) - vi = accumulate_observe!!(vi, right, left, vn) + vi = accumulate_observe!!(vi, right, left, vn, template) return left, vi end diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 092e50e48..894d8e6e5 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -405,7 +405,8 @@ function tilde_observe!!( right::Distribution, left, vn::Union{VarName,Nothing}, + template::Any, vi::AbstractVarInfo, ) - return tilde_observe!!(DefaultContext(), right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, template, vi) end diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl index 7bab58d62..2fdd9633b 100644 --- a/src/contexts/prefix.jl +++ b/src/contexts/prefix.jl @@ -128,6 +128,7 @@ function tilde_observe!!( right::Distribution, left, vn::Union{VarName,Nothing}, + template::Any, vi::AbstractVarInfo, ) # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal @@ -138,7 +139,8 @@ function tilde_observe!!( else vn, childcontext(context) end - return tilde_observe!!(new_context, right, left, new_vn, vi) + n = optic_skip_length(AbstractPPL.getoptic(context.vn_prefix)) + 1 + return tilde_observe!!(new_context, right, left, new_vn, SkipTemplate{n}(template), vi) end function store_coloneq_value!!( diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 87ff369ff..0a24c7dcc 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -79,7 +79,7 @@ function DynamicPPL.accumulate_assume!!( end function DynamicPPL.accumulate_observe!!( - acc::DebugAccumulator, right::Distribution, val, vn::Union{VarName,Nothing} + acc::DebugAccumulator, right::Distribution, val, vn::Union{VarName,Nothing}, template ) if _has_partial_missings(val, right) msg = if vn === nothing diff --git a/src/submodel.jl b/src/submodel.jl index c7e836d8e..31b6712ac 100644 --- a/src/submodel.jl +++ b/src/submodel.jl @@ -210,6 +210,7 @@ function tilde_observe!!( right::DynamicPPL.Submodel, left::Any, vn::VarName, + template::Any, vi::AbstractVarInfo, ) # TODO(penelopeysm) This is VERY BAD. See @@ -252,7 +253,7 @@ function tilde_observe!!( return _evaluate!!(right, vi, context, vn) end function tilde_observe!!( - ::AbstractContext, ::DynamicPPL.Submodel, left, ::Nothing, ::AbstractVarInfo + ::AbstractContext, ::DynamicPPL.Submodel, left, ::Nothing, template, ::AbstractVarInfo ) throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is a literal")) end diff --git a/src/varnamedtuple/partial_array.jl b/src/varnamedtuple/partial_array.jl index 55d47d1a9..ad60524a2 100644 --- a/src/varnamedtuple/partial_array.jl +++ b/src/varnamedtuple/partial_array.jl @@ -571,15 +571,16 @@ function BangBang.setindex!!(pa::PartialArray, value, inds::Vararg{Any}; kw...) if _needs_arraylikeblock(new_data, value, inds...; kw...) # Check that we're trying to set a block that has the right size. idx_sz = size(@view new_data[inds..., kw...]) - vnt_sz = vnt_size(value) - if vnt_sz != idx_sz - throw( - DimensionMismatch( - "Assigned value has size $(vnt_sz), which does not match " * - "the size implied by the indices $(idx_sz).", - ), - ) - end + + # vnt_sz = vnt_size(value) + # if vnt_sz != idx_sz + # throw( + # DimensionMismatch( + # "Assigned value has size $(vnt_sz), which does not match " * + # "the size implied by the indices $(idx_sz).", + # ), + # ) + # end alb = ArrayLikeBlock(value, inds, NamedTuple(kw), idx_sz) new_data = setindex!!(new_data, fill(alb, idx_sz...), inds...; kw...) fill!(view(new_mask, inds...; kw...), true) diff --git a/test/accumulators.jl b/test/accumulators.jl index 7d3848c15..c23449f58 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -91,12 +91,16 @@ using DynamicPPL: right = Normal() left = 2.0 vn = @varname(x) - @test accumulate_observe!!(LogPriorAccumulator(1.0), right, left, vn) == - LogPriorAccumulator(1.0) - @test accumulate_observe!!(LogJacobianAccumulator(1.0), right, left, vn) == - LogJacobianAccumulator(1.0) - @test accumulate_observe!!(LogLikelihoodAccumulator(1.0), right, left, vn) == - LogLikelihoodAccumulator(1.0 + logpdf(right, left)) + template = nothing + @test accumulate_observe!!( + LogPriorAccumulator(1.0), right, left, vn, template + ) == LogPriorAccumulator(1.0) + @test accumulate_observe!!( + LogJacobianAccumulator(1.0), right, left, vn, template + ) == LogJacobianAccumulator(1.0) + @test accumulate_observe!!( + LogLikelihoodAccumulator(1.0), right, left, vn, template + ) == LogLikelihoodAccumulator(1.0 + logpdf(right, left)) end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index bf71ca4c6..1831513c6 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -123,7 +123,7 @@ end ::ErrorAccumulator, ::Any, ::Any, ::Any, ::VarName, ::Distribution, ::Any ) = throw(ErrorAccumulatorException()) DynamicPPL.accumulate_observe!!( - ::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing} + ::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing}, ::Any ) = throw(ErrorAccumulatorException()) DynamicPPL.reset(ea::ErrorAccumulator) = ea Base.copy(ea::ErrorAccumulator) = ea diff --git a/test/varinfo.jl b/test/varinfo.jl index 75e17faa2..c52ebe7d2 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -195,9 +195,6 @@ end vi_orig = DynamicPPL.setacc!!(vi_orig, DynamicPPL.DebugUtils.DebugAccumulator(true)) vi_orig = DynamicPPL.setacc!!(vi_orig, DynamicPPL.RawValueAccumulator(true)) vi_orig = DynamicPPL.setacc!!(vi_orig, DynamicPPL.PriorDistributionAccumulator()) - vi_orig = DynamicPPL.setacc!!( - vi_orig, DynamicPPL.PointwiseLogProbAccumulator{:both}() - ) # And evaluate the model once so that they are populated. _, vi_orig = DynamicPPL.evaluate_nowarn!!(model, vi_orig) From c787174807eb776f59d1b9e4acbc85748b525101 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 13 Feb 2026 16:24:25 +0000 Subject: [PATCH 2/5] Changelog --- HISTORY.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index f96cb63b1..90c3334d1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -110,9 +110,15 @@ In its place, you should directly use the accumulator API to: To do so, we now export a convenience function `get_raw_values(::AbstractVarInfo)` that will get the stored `VarNamedTuple` of raw values. This is exactly analogous to how `getlogprior(::AbstractVarInfo)` extracts the log-prior from a `LogPriorAccumulator`. +### Pointwise logdensities + +Calling `pointwise_logdensities(model, varinfo)` now returns a `VarNamedTuple` of log-densities rather than an `OrderedDict`. + +The method `pointwise_logdensities(model, chain::MCMCChains.Chains)` no longer accepts a `Tout` argument to control the output type; it always returns a new `MCMCChains.Chains`. + ### Function signature changes in tilde-pipeline -`tilde_assume!!` and `accumulate_assume!!` now take extra arguments. +`tilde_assume!!`, `tilde_observe!!`, `accumulate_assume!!`, and `accumulate_observe!!` now take extra arguments. In particular @@ -120,10 +126,13 @@ In particular For example, if `vn` is `@varname(x[1])`, then `template` is the current value of `x` in the model. The DynamicPPL compiler is responsible for generating and providing this argument. + - Likewise, `tilde_observe!!(ctx, dist, left, vn, vi)` is now `tilde_observe!!(ctx, dist, left, vn, template, vi)`, where `template` is the same as above. - `accumulate_assume!!(acc::AbstractAccumulator, val, logjac, vn, dist)` is now `accumulate_assume!!(acc, val, tval, logjac, vn, dist, template)`. `template` is the same as above. `tval` is either the `AbstractTransformedValue` that `DynamicPPL.init` provided (for InitContext), or the `AbstractTransformedValue` found inside the VarInfo (for DefaultContext). - `accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, dist)` is now `accumulate_assume!!(vi, val, tval, logjac, vn, dist, template)`. + - `accumulate_observe!!(acc::AbstractAccumulator, dist, left, vn)` is now `accumulate_observe!!(acc, dist, left, vn, template)`, where `template` is the same as above. + - `accumulate_observe!!(vi::AbstractVarInfo, dist, left, vn)` is now `accumulate_observe!!(vi, dist, left, vn, template)`. ### `DynamicPPL.DebugUtils` From 0db3bd2d2d4459dc35671e6c7d42d83fa0be7a2b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 21 Feb 2026 01:44:50 +0000 Subject: [PATCH 3/5] Improve PointwiseLogProbs docstring --- src/accumulators/pointwise_logdensities.jl | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/accumulators/pointwise_logdensities.jl b/src/accumulators/pointwise_logdensities.jl index dcbce2ce5..fbd1786e5 100644 --- a/src/accumulators/pointwise_logdensities.jl +++ b/src/accumulators/pointwise_logdensities.jl @@ -1,15 +1,17 @@ """ - PointwiseLogProbAccumulator{whichlogprob} <: AbstractAccumulator + PointwiseLogProb{Prior,Likelihood} -An accumulator that stores the log-probabilities of each variable in a model. +A callable struct that computes the log probability of a given value under a distribution. +The `Prior` and `Likelihood` type parameters are used to control whether the log probability +is computed for prior or likelihood terms, respectively. -Internally this accumulator stores the log-probabilities in a dictionary, where the keys are -the variable names and the values are log-probabilities. +This struct is used in conjunction with `VNTAccumulator`, via -`whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies -which log-probabilities to store in the accumulator. -""" + acc = VNTAccumulator{POINTWISE_ACCNAME}(PointwiseLogProb{Prior,Likelihood}()) +where `Prior` and `Likelihood` are the boolean type parameters. This accumulator will then +store the log-probabilities for all tilde-statements in the model. +""" struct PointwiseLogProb{Prior,Likelihood} end function (plp::PointwiseLogProb{Prior,Likelihood})( val, tval, logjac, vn, dist From d2dc47aed8c403bc46d2d87554a4c35cf478a733 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 21 Feb 2026 01:47:56 +0000 Subject: [PATCH 4/5] Hide `Prior` and `Likelihood` internal arguments --- ext/DynamicPPLMCMCChainsExt.jl | 58 +++++++++++----------- src/accumulators/pointwise_logdensities.jl | 21 ++++++-- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index bf8803728..0fc8920c2 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -390,22 +390,40 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha return map(first, reevaluate_with_chain(model, chain, (), nothing)) end +""" +Shared internal helper function. +""" +function _pointwise_logdensities_chain( + model::DynamicPPL.Model, + chain::MCMCChains.Chains, + ::Val{Prior}=Val(true), + ::Val{Likelihood}=Val(true), +) where {Prior,Likelihood} + acc = DynamicPPL.VNTAccumulator{DynamicPPL.POINTWISE_ACCNAME}( + DynamicPPL.PointwiseLogProb{Prior,Likelihood}() + ) + parameter_only_chain = MCMCChains.get_sections(chain, :parameters) + # Reevaluating this gives us a VNT of log probs. We can densify and then wrap in + # ParamsWithStats so that we can easily convert back to a Chains object. + pointwise_logps = map( + reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing) + ) do (_, oavi) + logprobs = DynamicPPL.get_pointwise_logprobs(oavi) + dense_logprobs = DynamicPPL.densify!!(logprobs) + DynamicPPL.ParamsWithStats(dense_logprobs, (;)) + end + return AbstractMCMC.from_samples(MCMCChains.Chains, pointwise_logps) +end + """ DynamicPPL.pointwise_logdensities( model::DynamicPPL.Model, chain::MCMCChains.Chains, - ::Val{Prior}=Val(true), - ::Val{Likelihood}=Val(true) ) Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where the log-density of each variable at each sample is stored (rather than its value). -The `Val`s passed as the last two arguments control which log-probabilities are included in -the result. If both are `true`, then the log-probabilities include both the prior and -likelihood terms. If only one of them is `true`, then only the corresponding -log-probabilities are included. - See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref). @@ -464,25 +482,9 @@ julia> # The above is the same as: ``` """ function DynamicPPL.pointwise_logdensities( - model::DynamicPPL.Model, - chain::MCMCChains.Chains, - ::Val{Prior}=Val(true), - ::Val{Likelihood}=Val(true), -) where {Prior,Likelihood} - acc = DynamicPPL.VNTAccumulator{DynamicPPL.POINTWISE_ACCNAME}( - DynamicPPL.PointwiseLogProb{Prior,Likelihood}() - ) - parameter_only_chain = MCMCChains.get_sections(chain, :parameters) - # Reevaluating this gives us a VNT of log probs. We can densify and then wrap in - # ParamsWithStats so that we can easily convert back to a Chains object. - pointwise_logps = map( - reevaluate_with_chain(model, parameter_only_chain, (acc,), nothing) - ) do (_, oavi) - logprobs = DynamicPPL.get_pointwise_logprobs(oavi) - dense_logprobs = DynamicPPL.densify!!(logprobs) - DynamicPPL.ParamsWithStats(dense_logprobs, (;)) - end - return AbstractMCMC.from_samples(MCMCChains.Chains, pointwise_logps) + model::DynamicPPL.Model, chain::MCMCChains.Chains +) + return _pointwise_logdensities_chain(model, chain, Val(true), Val(true)) end """ @@ -499,7 +501,7 @@ See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_pr function DynamicPPL.pointwise_loglikelihoods( model::DynamicPPL.Model, chain::MCMCChains.Chains ) - return DynamicPPL.pointwise_logdensities(model, chain, Val(false), Val(true)) + return _pointwise_logdensities_chain(model, chain, Val(false), Val(true)) end """ @@ -516,7 +518,7 @@ See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_lo function DynamicPPL.pointwise_prior_logdensities( model::DynamicPPL.Model, chain::MCMCChains.Chains ) - return DynamicPPL.pointwise_logdensities(model, chain, Val(true), Val(false)) + return _pointwise_logdensities_chain(model, chain, Val(true), Val(false)) end """ diff --git a/src/accumulators/pointwise_logdensities.jl b/src/accumulators/pointwise_logdensities.jl index fbd1786e5..335abb88e 100644 --- a/src/accumulators/pointwise_logdensities.jl +++ b/src/accumulators/pointwise_logdensities.jl @@ -49,7 +49,18 @@ function accumulate_observe!!( end end -function pointwise_logdensities( +""" + _pointwise_logdensities( + model::Model, + varinfo::AbstractVarInfo, + ::Val{Prior}=Val(true), + ::Val{Likelihood}=Val(true), + ) where {Prior,Likelihood} + +Shared internal function that computes pointwise log-densities (either priors, likelihoods, +or both). +""" +function _pointwise_logdensities( model::Model, varinfo::AbstractVarInfo, ::Val{Prior}=Val(true), @@ -62,10 +73,14 @@ function pointwise_logdensities( return get_pointwise_logprobs(oavi) end +function pointwise_logdensities(model::Model, varinfo::AbstractVarInfo) + return _pointwise_logdensities(model, varinfo, Val(true), Val(true)) +end + function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - return pointwise_logdensities(model, varinfo, Val(false), Val(true)) + return _pointwise_logdensities(model, varinfo, Val(false), Val(true)) end function pointwise_prior_logdensities(model::Model, varinfo::AbstractVarInfo) - return pointwise_logdensities(model, varinfo, Val(true), Val(false)) + return _pointwise_logdensities(model, varinfo, Val(true), Val(false)) end From e5ff7003d6950e1e2fda8219633a4fae12e9db60 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 21 Feb 2026 02:17:09 +0000 Subject: [PATCH 5/5] Add changelog note about logpdf/loglikelihood --- HISTORY.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 90c3334d1..6f8949000 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -114,8 +114,14 @@ In its place, you should directly use the accumulator API to: Calling `pointwise_logdensities(model, varinfo)` now returns a `VarNamedTuple` of log-densities rather than an `OrderedDict`. +Models with implicitly broadcasted observations (e.g. `y ~ Normal()` where `y` is an observed `Array`) will now return an `Array` of log-densities, one per element of `y`. +To recover the previous behaviour, you can sum the log-densities. + The method `pointwise_logdensities(model, chain::MCMCChains.Chains)` no longer accepts a `Tout` argument to control the output type; it always returns a new `MCMCChains.Chains`. +The internal argument `whichlogprob` for `pointwise_logdensities` is removed. +As a replacement you should just directly use `pointwise_logdensities`, `pointwise_loglikelihoods`, or `pointwise_prior_logdensities` as appropriate. + ### Function signature changes in tilde-pipeline `tilde_assume!!`, `tilde_observe!!`, `accumulate_assume!!`, and `accumulate_observe!!` now take extra arguments.