-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reparameterization attached to Distribution
#4
Comments
I don't really understand why this has to be tied to a distributions library. Wouldn't it be more straightforward / useful to have this as an orthogonal thing that just plays nicely with distributions? I had imagined something along the lines of an interface like a_positive, a_unconstrained = positive(inv_link_or_link_or_whatever, a_positive_init) Then we're just talking about the generic parameter handling / transformation problem, rather than anything inherently probabilistic. Also, could we please try to think about how this plays with Zygote, rather than Tracker, as Tracker's day are numbered? |
Oops, didn't mean to close |
I see what you're saying, but I think it's just too closely related. And I think it's not far-fetched to say that "reparameterization of a
But I think Zygote also intends to support AD wrt. parameters of a struct, right? I can't find the issue right now, but I swear I saw @MikeInnes discussing something like this somewhere. If so, I think my argument using Tracker.jl still holds? |
I haven't followed this issue carefully but (1) yes, Zygote supports structs well and (2) it'd be nice not to have to load |
A few comments I have.
Since we are discussing changes to Distributions, pinging @matbesancon. |
That's true, but in multivariate cases you still cannot do inplace updates the parameters (though to allow this you'd have to take a slightly different approach to certain distributions than what Distributions.jl is currently doing, e.g. It also doesn't solve the issue of "interoperability" with the parts of the ecosystem in which Distributions.jl is often used, e.g. with Tracker/Zygote. It of course works, but for larger models it can be quite a hassle compared to tying the parameters to the
I think |
I think this is one of the key aspects of this discussion. I'm personally more of a fan of the functional approach, but I appreciate that there are merits to both approaches. I'm not really sure which way the community is leaning here, perhaps @MikeInnes or @oxinabox can comment? If I remember correctly, Zygote's recommended mode of operation is now the functional style? |
I started out preferring the more functional style, but have recently grown quite fond of the Flux approach. Granted, I've recently been using more neural networks where I think this approach is particularly useful. Also, it's worth noting that with what's proposed here you can do both (which is why I like it!:) ) |
On an earlier point:
Had this discussion with @matbesancon |
@torfjelde You can still do lazy transformations by multiple dispatch, like you said using Dispatching on If we are talking modifying the distribution in-place (no AD), we can do that using the lazy function wrapper. Note that we always have to define the fields of the distribution according to its struct definition. So we have one of two scenarios:
This is fundamentally a constructor definition problem. It is a question of how we can construct the distribution while enabling in-place modification of the inputs. Lazy function wrappers take us some of the way. Note that at the end of the day we still need to satisfy the field type signature of the distribution struct, so we may need to modify the type parameters of the distribution struct to accept more generic matrix types like a lazy matrix-valued function which sub-types So in summary, the anonymous function and dispatch-based laziness approach enables us to:
Note that at this point, it is not a question of whether we need arbitrary re-parameterization, just the API choice. I am leaning towards not having a struct for every distribution for AD purposes only, using anonymous functions and dispatch-based laziness to gain any efficiency and/or flexibility benefits. Ironically, we already implement an AD distribution for Pinging @ChrisRackauckas in case he has opinions on this. |
Yeah, I understood that but it would still require always building an explicit type But after reading your comment I realize we can just make a I think I'm coming around to your suggestion!:) It still seems like making this compatible with current
You could still do this when
"to where it makes sense" -> "to where we can" seems like a more accurate statement 🙃 |
Well in your proposal IIUC,
True, but if we hit a wall, we can decide to temporarily branch off until the obstacle is removed. This is what we do now for
Yes this is a nice generic way of defining lazy wrappers. |
I completely agree with the last comment:) One thing though: this "wrapping"-approach means that if we want type-stability we'd have to allow all the parameters of a distribution to take on different types, e.g. It still seems like the best approach, but worth noting that this might be a big hurdle to overcome as we'd basically need to re-define most distributions to accomodate something like this. |
And I think something like the following works okay as a "default" where we just allow the type itself to specify how to handle the unconstrained-to-constrained transformation: abstract type Constrained end
struct Unconstrained{T} <: Constrained
val::T
end
value(c::Unconstrained) = c.val
Normal(μ::Unconstrained{T}, σ::Unconstrained{T}) where {T} = Normal(value(μ), exp(value(σ))) Could also do something like value(c::Unconstrained{T, F}) where {T, F} = F(c.val)
# when `F = identity` we have a default treatment
Normal(μ::Unconstrained{T, identity}, σ::Unconstrained{T, identity}) where {T} = Normal(value(μ), exp(value(σ)))
# in this case we have to assume that `value` takes care of the transformation
Normal(μ::Unconstrained{T}, σ::Unconstrained{T}) where {T} = Normal(value(μ), value(σ)) Need to think about this further, but doesn't seem like a horrible approach. |
Overview
Since a distributions has to be re-implemented here and the focus is on AD, I was wondering if it would be of any interested to add reparameterization to
Distribution
. In AD-context you usually want to work in ℝ (unconstrained) rather than constrained space, e.g. optimizing parameters for aDistribution
.A simple example is
Normal(μ, σ)
. One might want to perform an maximum likelihood estimate (MLE) ofμ
andσ
by gradient descent (GD). This requires differentiating thelogpdf
wrt.μ, σ
and then updating the parameters of theNormal
accordingly. But for the distribution to be valid we simultaneously need to ensure thatσ > 0
. Usually we accomplish this by instead differentiating the functionThe proposal is to also allow something like
which in the MLE case allows us to differentiate
Why?
As you can see, in the case of a univariate
Normal
this doesn't offer much advantage of the current approach. But the current approach is a subclass of what we can then support (by lettingreparam
equalidentity
) and I believe there certainly are cases where this is very useful:Array
can be updated in-place rather than by reconstruction of the distribution.mle
function we'd require the user to also provide the transformation ofσ
as an argument. If there are a lot of functions depending on this parameterization, it quickly becomes tedious and a bit difficult (speaking from experience) to remember to pass and perform the transformation in each such function. Alternatively you pass around the unconstrained parameters as additional parameters, but again, tedious and you still need to ensure you perform the transformation in each method. For an example, see the impl of ADVI in Turing.jl: https://github.com/TuringLang/Turing.jl/blob/bc7e5b643abad9529b99c24caac6dbce6a562ad2/src/variational/advi.jl#L74-L77, https://github.com/TuringLang/Turing.jl/blob/bc7e5b643abad9529b99c24caac6dbce6a562ad2/src/variational/advi.jl#L88, https://github.com/TuringLang/Turing.jl/blob/bc7e5b643abad9529b99c24caac6dbce6a562ad2/src/variational/advi.jl#L119.Normal(param(μ), param(σ))
,Tracker.back!
through a computation depending onμ, σ
, and then update parameters. If one did this naively, it's possible to step intoσ < 0
region. For very complex cases I think it's easier to attach the parameters to the structs which depend on them, rather than putting everything into a huge array and then packing/unpacking everywhere. We can then useFlux.@treelike
to further simplify our lives. The below example show an example which arises in things like auto-encoders:Example implementation
Together with Tracker.jl
Alternative approach: wrap
Distribution
An alternative approach is to do something similar to
TransformedDistribution
in Bijectors.jl where you simply wrap a distribution in the instance. Then you could require the user to provide areparam
method which takes what's returned fromDistributions.params(d::Distribution)
and applies the reparameterization correctly.This requires signfinicantly less work, but isn't as nice nor as easy to extend/work with IMO.
The text was updated successfully, but these errors were encountered: