-
Notifications
You must be signed in to change notification settings - Fork 37
Description
I think unflatten is the single thing that I've complained about the most and I would dearly like to remove all need for this function ever again.
I've thought about this a fair amount and there are a few blockers to this, specifically, Gibbs (TuringLang/Turing.jl#2764). However, a very positive intermediate step would be to make unflatten!! return a new AbstractInitStrategy instead of a new VarInfo.
If you look at the difference between old slow LDF and new fast LDF, it mostly boils down to:
- slowLDF used
unflatten+evaluate!!on a VarInfo; - fastLDF constructs an
InitFromParams{<:VectorWithRanges}, that is semantically equal to a VarInfo whose contents have been replaced with the vector, but avoids actually using a VarInfo
I think this idea can be extended more broadly to also work with generic VarInfo. Essentially, unflatten!! would probably return exactly the same InitFromParams{<:VectorWithRanges} object that the LDF constructs. We could probably even share some code.
If you REALLY want a new VarInfo with those values, simple! Just do this
new_init_strategy = unflatten!!(varinfo, vec)
_, new_varinfo = init!!(model, varinfo, new_init_strategy)The second line will overwrite everything inside the old varinfo. It will also cause the transforms to be updated, so there is no risk of having out of date transforms, which is my current number one gripe with unflatten!!.
Essentially, the above would replace the current workflow of
new_varinfo = unflatten!!(varinfo, vec)
# If you forget to do this (e.g. because you want to avoid the cost
# of reevaluating) and just use new_varinfo directly...... maybe in
# a subsequent MCMC step or in a different function ....... Oops!!!!
_, new_varinfo = evaluate!!(model, new_varinfo)However, returning an AbstractInitStrategy is also more versatile. You can use it with an OnlyAccsVarInfo. You can use it with a LogDensityFunction (once I've reimplemented #1232, which is not difficult). It doesn't tie you into using a full VarInfo.
In fact, unflatten!! doesn't need a full VarInfo either -- it only needs a VectorValueAccumulator. That would be the next logical step after the above. What this means is that, any function who uses unflatten!! would be responsible upfront for making sure that there is a VectorValueAccumulator somewhere that it can use.
new_init_strategy = unflatten!!(vector_value_acc, vec)