Skip to content

InitFromVector with fallback to prior #1298

@charlesknipp

Description

@charlesknipp

I want to clarify, this is not a feature request per se, but more of a clarification of the init pipeline. With that being said, I want to best understand exactly what needs to be done along the way to enhance modularity with something like a partially evaluated model (useful in SMC).

Since #1252 suggests operating in terms of initialization strategies, I find it fitting to push PMCMC in this direction as well. In order to send a partially populated OnlyAccVarInfo through the pipeline, I think there exists an elegant solution to SMCContext which contains a child context for initializing a partially evaluated model.

As it stands, however, there seems to be a disconnect in terms of exactly what is done under the hood.

Implementation

For starters, consider the proposed context (see full replication for details)

struct PartialInitFromVector{
    T<:AbstractVector{<:Real},VT<:VarNamedTuple,ST<:AbstractTransformStrategy
} <: AbstractInitStrategy
    vect::T
    varname_ranges::VT
    set_indices::Vector{Bool}
    transform_strategy::ST
end

On the surface, this works just fine. We can check this by choosing a model, rng, and transform of your liking and matching the outputs vi1 and vi2:

# define a logdensity for the accumulator
logdensity = LogDensityFunction(
    model, DynamicPPL.getlogjoint_internal, DynamicPPL._default_vnt(model, transform)
)

# should sample entirely from the prior
vi = DynamicPPL.setacc!!(OnlyAccsVarInfo(), VectorParamAccumulator(logdensity))
strategy = PartialInitFromVector(vi.accs[:VectorParamAccumulator], transform)
_, vi1 = DynamicPPL.init!!(
    rng, reg_model, vi, strategy, strategy.transform_strategy
)

# should recalculate the same parameters
vi = DynamicPPL.setacc!!(OnlyAccsVarInfo(), VectorParamAccumulator(logdensity))
strategy = PartialInitFromVector(deepcopy(vi1.accs[:VectorParamAccumulator]), transform)
_, vi2 = DynamicPPL.init!!(
    rng, reg_model, vi, strategy, strategy.transform_strategy
)

Concerns

While the proposed interface seems to work exactly as intended, the expected return type of certain functions along the init pipeline may be ambiguous. This specifically fails in the case of running a TapedTask for PMCMC. In fact, the underlying error arises from traversing along the type inferred IR, which you can see in detail here:

Stack Trace
ERROR: TypeError: in typeassert, expected VectorValue{SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, DynamicPPL.UnwrapSingletonTransform{Tuple{Int64}}}, got a value of type UntransformedValue{Float64}
Stacktrace:
  [1] deref_phi
    @ /.../julia/packages/Libtask/AGx8L/src/copyable_task.jl:1230 [inlined]
  [2] tilde_assume!!
    @ /.../julia/packages/DynamicPPL/U6J6a/src/contexts/init.jl:400 [inlined]
  [3] (::Tuple{…})(_2::typeof(tilde_assume!!), _3::InitContext{…}, _4::Normal{…}, _5::VarName{…}, _6::DynamicPPL.VarNamedTuples.NoTemplate, _7::OnlyAccsVarInfo{…})
    @ Base.Experimental ./<missing>:0
  [4] (::MistyClosures.MistyClosure{…})(::Function, ::InitContext{…}, ::Normal{…}, ::VarName{…}, ::DynamicPPL.VarNamedTuples.NoTemplate, ::OnlyAccsVarInfo{…})
    @ MistyClosures /.../julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
  [5] (::Libtask.DynamicCallable{…})(::Function, ::InitContext{…}, ::Normal{…}, ::VarName{…}, ::DynamicPPL.VarNamedTuples.NoTemplate, ::OnlyAccsVarInfo{…})
    @ Libtask /.../julia/packages/Libtask/AGx8L/src/copyable_task.jl:1292
  [6] tilde_assume!!
    @ /.../this_file.jl:100 [inlined]
  [7] (::Tuple{…})(_2::typeof(tilde_assume!!), _3::SMCContext{…}, _4::Normal{…}, _5::VarName{…}, _6::DynamicPPL.VarNamedTuples.NoTemplate, _7::OnlyAccsVarInfo{…})
    @ Base.Experimental ./<missing>:0
  [8] (::MistyClosures.MistyClosure{…})(::Function, ::SMCContext{…}, ::Normal{…}, ::VarName{…}, ::DynamicPPL.VarNamedTuples.NoTemplate, ::OnlyAccsVarInfo{…})
    @ MistyClosures /.../julia/packages/MistyClosures/2vtLL/src/MistyClosures.jl:22
  [9] (::Libtask.LazyCallable{…})(::Function, ::SMCContext{…}, ::Normal{…}, ::VarName{…}, ::DynamicPPL.VarNamedTuples.NoTemplate, ::OnlyAccsVarInfo{…})
    @ Libtask /.../julia/packages/Libtask/AGx8L/src/copyable_task.jl:1263
 [10] linear_regression
    @ /.../this_file.jl:183 [inlined]
 [11] (::Tuple{…})(_2::typeof(linear_regression), _3::Model{…}, _4::OnlyAccsVarInfo{…}, _5::Vector{…}, _6::Vector{…})
    @ Base.Experimental ./<missing>:0
 [12] consume
    @ /.../julia/packages/Libtask/AGx8L/src/copyable_task.jl:378 [inlined]
 [13] consume(trace::TracedModel{TapedTask{…}, OnlyAccsVarInfo{…}})
    @ Main /.../this_file.jl:149
 [14] iterate (repeats 2 times)
    @ /.../this_file.jl:156 [inlined]
 [15] iterate
    @ ./iterators.jl:780 [inlined]
 [16] iterate
    @ ./iterators.jl:778 [inlined]
 [17] grow_to!(dest::Vector{Any}, itr::Base.Iterators.Take{TracedModel{TapedTask{…}, OnlyAccsVarInfo{…}}})
    @ Base ./array.jl:863
 [18] _collect
    @ ./array.jl:779 [inlined]
 [19] collect(itr::Base.Iterators.Take{TracedModel{TapedTask{…}, OnlyAccsVarInfo{…}}})
    @ Base ./array.jl:728
 [20] top-level scope
    @ /.../this_file.jl:199
Some type information was truncated. Use `show(err)` to see complete types.

In theory this should be the same as when we set a conditional in the tilde_assume stack to determine what context we should be using. However, this seems to suggest I could be doing something wrong, and the init pipeline is allowing it.

Replication

This is based on my experimental implementation of SMC in Turing, but is entirely self contained in the following script.

Replication
using BenchmarkTools
using Distributions
using DynamicPPL
using Libtask
using Random
using Random123

## ACCUMULATORS ############################################################################

struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.LogProbAccumulator{T}
    logp::T
end

DynamicPPL.accumulator_name(::Type{<:ProduceLogLikelihoodAccumulator}) = :LogLikelihood
DynamicPPL.logp(acc::ProduceLogLikelihoodAccumulator) = acc.logp

# we send logscore to task local storage to get produced later on in the stack
function DynamicPPL.acclogp(acc::ProduceLogLikelihoodAccumulator, val)
    task_local_storage(:logscore, val)
    newacc = ProduceLogLikelihoodAccumulator(DynamicPPL.logp(acc) + val)
    return newacc
end

function DynamicPPL.accumulate_assume!!(
    acc::ProduceLogLikelihoodAccumulator, val, tval, logjac, vn, dist, template
)
    return acc
end

function DynamicPPL.accumulate_observe!!(
    acc::ProduceLogLikelihoodAccumulator, dist, val, vn, template
)
    return DynamicPPL.acclogp(acc, Distributions.loglikelihood(dist, val))
end

## CONTEXTS ################################################################################

struct PartialInitFromVector{
    T<:AbstractVector{<:Real},VT<:VarNamedTuple,ST<:AbstractTransformStrategy
} <: AbstractInitStrategy
    vect::T
    varname_ranges::VT
    set_indices::Vector{Bool}
    transform_strategy::ST
end

function PartialInitFromVector(
    acc::VectorParamAccumulator, transform::AbstractTransformStrategy
)
    return PartialInitFromVector(acc.vals, acc.vn_ranges, acc.set_indices, transform)
end

function _get_range_and_linked(pifv::PartialInitFromVector, vn::VarName)
    return pifv.varname_ranges[vn]::DynamicPPL.RangeAndLinked
end

function DynamicPPL.init(
    rng::AbstractRNG, vn::VarName, dist::Distribution, pifv::PartialInitFromVector
)
    # fallback to simulate from the prior when necessary
    range_and_linked = _get_range_and_linked(pifv, vn)
    if ~all(pifv.set_indices[range_and_linked.range])
        return DynamicPPL.init(rng, vn, dist, InitFromPrior())
    end

    # I feel like this could be handled by dispatch as opposed to a conditional switch
    vect = DynamicPPL.maybe_view_ad(pifv.vect, range_and_linked.range)
    return if pifv.transform_strategy isa LinkAll
        LinkedVectorValue(vect, DynamicPPL.from_linked_vec_transform(dist))
    elseif pifv.transform_strategy isa UnlinkAll
        VectorValue(vect, DynamicPPL.from_vec_transform(dist))
    elseif range_and_linked.is_linked
        LinkedVectorValue(vect, DynamicPPL.from_linked_vec_transform(dist))
    else
        VectorValue(vect, DynamicPPL.from_vec_transform(dist))
    end
end

function DynamicPPL.get_param_eltype(strategy::PartialInitFromVector)
    return eltype(strategy.vect)
end

function init_context(
    rng::AbstractRNG, vi::OnlyAccsVarInfo, ::VarName, transform::AbstractTransformStrategy
)
    acc = vi.accs[:VectorParamAccumulator]
    return InitContext(rng, PartialInitFromVector(acc, transform), transform)
end

struct SMCContext{ST<:AbstractTransformStrategy} <: DynamicPPL.AbstractContext
    transform_strategy::ST
end

# should be removed once we restructure the context
DynamicPPL.get_param_eltype(::OnlyAccsVarInfo, ::SMCContext) = Any

function DynamicPPL.tilde_assume!!(
    ctx::SMCContext, dist::Distribution, vn::VarName, template::Any, vi::AbstractVarInfo
)
    rng = Libtask.get_taped_globals(AbstractRNG)
    dispatch_ctx = init_context(rng, vi, vn, ctx.transform_strategy)
    val, vi = DynamicPPL.tilde_assume!!(dispatch_ctx, dist, vn, template, vi)
    return val, vi
end

function DynamicPPL.tilde_observe!!(
    ::SMCContext,
    dist::Distribution,
    val,
    vn::Union{VarName,Nothing},
    template::Any,
    vi::AbstractVarInfo
)
    val, vi = DynamicPPL.tilde_observe!!(DefaultContext(), dist, val, vn, template, vi)
    task_local_storage(:varinfo, vi)
    Libtask.produce(task_local_storage(:logscore))
    return val, vi
end

Libtask.@might_produce(DynamicPPL.tilde_observe!!)
Libtask.@might_produce(DynamicPPL.tilde_assume!!)
Libtask.@might_produce(DynamicPPL.evaluate!!)
Libtask.@might_produce(DynamicPPL.init!!)
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true

## TAPED TASK ##############################################################################

# this is only necessary for testing multiple iterations of the same object
abstract type AbstractTracedModel end

function Libtask.TapedTask(rng::AbstractRNG, model::Model, vi::AbstractVarInfo)
    args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, vi)
    return TapedTask(rng, model.f, args...; kwargs...)
end

function construct_task(
    rng::AbstractRNG,
    model::Model,
    vi::AbstractVarInfo,
    transform::AbstractTransformStrategy
)
    inner_rng = Random.seed!(Random123.Philox2x(), rand(rng, Random.Sampler(rng, UInt64)))
    inner_model = DynamicPPL.setleafcontext(model, SMCContext(transform))
    return TapedTask(inner_rng, inner_model, vi)
end

# overload consume to store the local varinfo
function Libtask.consume(trace::AbstractTracedModel)
    score = Libtask.consume(trace.task)
    set_varinfo!(trace, score)
    return score
end

# apply the same iteration utilities as a TapedTask
function Base.iterate(trace::AbstractTracedModel, ::Nothing=nothing)
    v = Libtask.consume(trace)
    return v === nothing ? nothing : (v, nothing)
end

Base.IteratorSize(::Type{<:AbstractTracedModel}) = Base.SizeUnknown()
Base.IteratorEltype(::Type{<:AbstractTracedModel}) = Base.EltypeUnknown()

struct TracedModel{TT<:TapedTask,VT<:AbstractVarInfo} <: AbstractTracedModel
    task::TT
    varinfo::Base.RefValue{VT}
end

function TracedModel(rng::AbstractRNG, model::Model, transform::AbstractTransformStrategy)
    logdensity = LogDensityFunction(
        model, DynamicPPL.getlogjoint_internal, DynamicPPL._default_vnt(model, transform)
    )
    accs = DynamicPPL.setacc!!(OnlyAccsVarInfo(), VectorParamAccumulator(logdensity))
    accs = DynamicPPL.setacc!!(accs, ProduceLogLikelihoodAccumulator())
    return TracedModel(construct_task(rng, model, accs, transform), Ref(accs))
end

# if score is nothing, the varinfo is caught up and there's no need to update
set_varinfo!(::TracedModel, ::Nothing) = nothing
set_varinfo!(trace::TracedModel, ::Real) = (trace.varinfo[] = task_local_storage(:varinfo); )

## LINEAR REGRESSION MODEL #################################################################

@model function linear_regression(x, y)
    β ~ Normal(0, 1)
    σ ~ truncated(Cauchy(0, 3); lower=0)
    for t in eachindex(x)
        y[t] ~ Normal* x[t], σ)
    end
end

# condition the model
rng = MersenneTwister(1234);
x, y = rand(rng, 100), rand(rng, 100);
reg_model = linear_regression(x, y);

## TESTING #################################################################################

trace = TracedModel(deepcopy(rng), reg_model, UnlinkAll());
collect(Iterators.take(trace, 4))

Final Remarks

This is rather pedantic, but I also find the use of conditionals to be quite excessive in both the init and transformed_value calls. It seems like something that can be handled at dispatch. Regardless, I cannot offer an immediate solution, so clearly this is just a nit-pick.

If you have any questions feel free to let me know. I plan on looking into this a little more in the afternoon.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions