Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reparameterization attached to Distribution #4

Open
torfjelde opened this issue Sep 24, 2019 · 14 comments
Open

Reparameterization attached to Distribution #4

torfjelde opened this issue Sep 24, 2019 · 14 comments

Comments

@torfjelde
Copy link
Member

torfjelde commented Sep 24, 2019

Overview

Since a distributions has to be re-implemented here and the focus is on AD, I was wondering if it would be of any interested to add reparameterization to Distribution. In AD-context you usually want to work in ℝ (unconstrained) rather than constrained space, e.g. optimizing parameters for a Distribution.

A simple example is Normal(μ, σ). One might want to perform an maximum likelihood estimate (MLE) of μ and σ by gradient descent (GD). This requires differentiating the logpdf wrt. μ, σ and then updating the parameters of the Normal accordingly. But for the distribution to be valid we simultaneously need to ensure that σ > 0. Usually we accomplish this by instead differentiating the function

(μ, logσ) -> logpdf(Normal(μ, exp(logσ)), x)
# instead of
(μ, σ) -> logpdf(Normal(μ, σ), x)

The proposal is to also allow something like

reparam(μ, σ) = μ, exp(σ)
Normal(μ, logσ, reparam)

which in the MLE case allows us to differentiate

(μ, σ) -> logpdf(Normal(μ, σ, reparam), x)

Why?

As you can see, in the case of a univariate Normal this doesn't offer much advantage of the current approach. But the current approach is a subclass of what we can then support (by letting reparam equal identity) and I believe there certainly are cases where this is very useful:

  • Distributions with parameters consisting of Array can be updated in-place rather than by reconstruction of the distribution.
  • Specialized implementations of different parameterizations can be implemented for possible performance gain
  • Abstracting away optimization from the user becomes significantly easier. Take the MLE estimate using GD again; if we were to wrap this entire process in some mle function we'd require the user to also provide the transformation of σ as an argument. If there are a lot of functions depending on this parameterization, it quickly becomes tedious and a bit difficult (speaking from experience) to remember to pass and perform the transformation in each such function. Alternatively you pass around the unconstrained parameters as additional parameters, but again, tedious and you still need to ensure you perform the transformation in each method. For an example, see the impl of ADVI in Turing.jl: https://github.com/TuringLang/Turing.jl/blob/bc7e5b643abad9529b99c24caac6dbce6a562ad2/src/variational/advi.jl#L74-L77, https://github.com/TuringLang/Turing.jl/blob/bc7e5b643abad9529b99c24caac6dbce6a562ad2/src/variational/advi.jl#L88, https://github.com/TuringLang/Turing.jl/blob/bc7e5b643abad9529b99c24caac6dbce6a562ad2/src/variational/advi.jl#L119.
  • IMO, this works much better with the Tracker.jl and Flux.jl "framework"/approach. At the moment one cannot do something like Normal(param(μ), param(σ)), Tracker.back! through a computation depending on μ, σ, and then update parameters. If one did this naively, it's possible to step into σ < 0 region. For very complex cases I think it's easier to attach the parameters to the structs which depend on them, rather than putting everything into a huge array and then packing/unpacking everywhere. We can then use Flux.@treelike to further simplify our lives. The below example show an example which arises in things like auto-encoders:
W, b = param(W_init),  param(b_init)
μ, logσ = param(μ_init), param(logσ_init)

for i = 1:num_steps
    nn = Dense(W, b)

    d = MvNormal(μ, exp(σ)) # Diagonal MvNormal

    x = rand(d)
    y = nn(x)

    # Do computation using `y`
    # ...

    Tracker.back!(...)

    update!(W, b)
    update!(μ, logσ)
end

# VS.

Flux.@treelike MvNormal

nn = Dense(param(W_init), param(b_init)) 
d = MvNormal(param(μ_init), param(σ_init), (μ, σ) -> (μ, exp.(σ)))

for i = 1:num_steps
    x = rand(d)
    y = nn(x)

    # Do computation using `y`
    # ...

    Tracker.back!(...)

    update!(Flux.params(nn))  # more general
    update!(Flux.params(d))
end

Example implementation

using Distributions, StatsFuns, Random

abstract type ParameterizedDistribution{F, S, P} <: Distribution{F, S} where {P} end

# maybe?
transformation(::ParameterizedDistribution{F, S, P}) where {F, S, P} = P

struct NormalAD{T<:Real, P} <: ParameterizedDistribution{Univariate, Continuous, P}
    μ::T
    σ::T
end

NormalAD::T, σ::T) where {T<:Real} = NormalAD{T, identity}(μ, σ)
NormalAD::T, σ::T, f::Function) where {T<:Real} = NormalAD{T, f}(μ, σ)

# convenience; probably don't want to do this in an actual implementation
Base.identity(args...) = identity.(args)

function Distributions.logpdf(d::NormalAD{<:Real, P}, x::Real) where {P}
    μ, σ = P(d.μ, d.σ)
    z = (x - μ) / σ
    return -(z^2 + log2π) / 2 - log(σ)
end

function Distributions.rand(rng::AbstractRNG, d::NormalAD{T, P}) where {T, P}
    μ, σ = P(d.μ, d.σ)
    return μ + σ * randn(rng)
end
julia> # Standard: μ ∈ ℝ, σ ∈ ℝ⁺
       d1 = NormalAD(0.0, 1.0)
NormalAD{Float64,identity}=0.0, σ=1.0)

julia> d2 = Normal(0.0, 1.0)
Normal{Float64}=0.0, σ=1.0)

julia> x = randn()
-0.028232023381049923

julia> logpdf(d1, x) == logpdf(d2, x)
true

julia> # Real-valued: μ ∈ ℝ, σ ∈ ℝ using `exp`
       d3 = NormalAD(0.0, 0.0, (μ, σ) -> (μ, exp(σ)))
NormalAD{Float64,getfield(Main, Symbol("##3#4"))()}=0.0, σ=0.0)

julia> logpdf(d3, x) == logpdf(d2, x)
true

julia> #  Real-valued: μ ∈ ℝ, σ ∈ ℝ using `softplus`
       d4 = NormalAD(0.0, invsoftplus(1.0), (μ, σ) -> (μ, softplus(σ)))
NormalAD{Float64,getfield(Main, Symbol("##9#10"))()}=0.0, σ=0.541324854612918)

julia> logpdf(d4, x) == logpdf(d2, x)
true

Together with Tracker.jl

julia> using Tracker

julia> μ = param(0.0)
0.0 (tracked)

julia> σ = param(0.0)
0.0 (tracked)

julia> d_tracked = NormalAD(μ, σ, (μ, σ) -> (μ, exp(σ)))
NormalAD{Tracker.TrackedReal{Float64},getfield(Main, Symbol("##5#6"))()}=0.0 (tracked), σ=0.0 (tracked))

julia> lp = logpdf(d_tracked, x)
-0.9193370567767668 (tracked)

julia> Tracker.back!(lp)

julia> Tracker.grad.((d_tracked.μ, d_tracked.σ))
(-0.028232023381049923, -0.9992029528558118)

julia> x = rand(d_tracked)
-1.6719800201542028 (tracked)

julia> Tracker.back!(x)

julia> Tracker.grad.((d_tracked.μ, d_tracked.σ))
(0.9717679766189501, -2.6711829730100147)

Alternative approach: wrap Distribution

An alternative approach is to do something similar to TransformedDistribution in Bijectors.jl where you simply wrap a distribution in the instance. Then you could require the user to provide a reparam method which takes what's returned from Distributions.params(d::Distribution) and applies the reparameterization correctly.

This requires signfinicantly less work, but isn't as nice nor as easy to extend/work with IMO.

@willtebbutt
Copy link
Member

I don't really understand why this has to be tied to a distributions library. Wouldn't it be more straightforward / useful to have this as an orthogonal thing that just plays nicely with distributions? I had imagined something along the lines of an interface like

a_positive, a_unconstrained = positive(inv_link_or_link_or_whatever, a_positive_init)

Then we're just talking about the generic parameter handling / transformation problem, rather than anything inherently probabilistic.

Also, could we please try to think about how this plays with Zygote, rather than Tracker, as Tracker's day are numbered?

@willtebbutt
Copy link
Member

Oops, didn't mean to close

@willtebbutt willtebbutt reopened this Sep 24, 2019
@torfjelde
Copy link
Member Author

I don't really understand why this has to be tied to a distributions library. Wouldn't it be more straightforward / useful to have this as an orthogonal thing that just plays nicely with distributions?

I see what you're saying, but I think it's just too closely related. And I think it's not far-fetched to say that "reparameterization of a Distribution is related to Distributions.jl"? Also in some cases it can simplify certain computations, e.g. entropy for a DiagMvNormal using exp to enforce positive-constraint on variance. And my main motivation is that you end up performing the transformations "behind the scenes" rather than the user having to do this in every method that needs it. You do it right once in the implementation of the Distribution and then no more. And the standard case is simply an instance of the more general reparametrizable Distribution, so the user who doesn't care doesn't have to care. Other than more work, I think the only downside is that it's more difficult to perform checks as to whether or not the parameters are valid.

Also, could we please try to think about how this plays with Zygote, rather than Tracker, as Tracker's day are numbered?

But I think Zygote also intends to support AD wrt. parameters of a struct, right? I can't find the issue right now, but I swear I saw @MikeInnes discussing something like this somewhere. If so, I think my argument using Tracker.jl still holds?

@MikeInnes
Copy link

I haven't followed this issue carefully but (1) yes, Zygote supports structs well and (2) it'd be nice not to have to load DistributionsAD on top of Distributions to get AD to work (not sure if that's the plan). Happy to look at support directly in Zygote, maybe via requires, if that's an option.

@mohamed82008
Copy link
Member

A few comments I have.

  1. Doing constrained optimization by transforming the constrained variables is just one way of doing constrained optimization. There are optimization algorithms that can efficiently handle box constraints, semidefinite constraints, linear constraints, etc.
  2. I think doing the re-parameterization of the constrained parameters at the optimization/differentiation layer, not the distribution layer, is the better approach in many cases at no loss of efficiency, e.g. x -> logpdf(Normal(1.0, exp(x)), 1.0) is pretty efficient.
  3. However, I also see the need for being able to construct a distribution using different parameters, e.g. precision vs covariance matrix, or directly using a triangular matrix which could be the Cholesky of the covariance or precision. I think these should be possible with multiple dispatch. Providing things like MvNormal(mu, Covariance(A)) or MvNormal(mu, Precision(A)). If A is a Cholesky we can also construct the PDMat directly. With these more efficient constructors, we get the triangular re-parameterization for free, e.g. L -> logpdf(MvNormal(mu, Covariance(Cholesky(L, 'L', 0))), x). I believe the distribution (re-)construction in this case should not allocate since we are not factorizing A.

Since we are discussing changes to Distributions, pinging @matbesancon.

@torfjelde
Copy link
Member Author

  1. I think doing the re-parameterization of the constrained parameters at the optimization/differentiation layer, not the distribution layer, is the better approach in many cases at no loss of efficiency, e.g. x -> logpdf(Normal(1.0, exp(x)), 1.0) is pretty efficient.

That's true, but in multivariate cases you still cannot do inplace updates the parameters (though to allow this you'd have to take a slightly different approach to certain distributions than what Distributions.jl is currently doing, e.g. MvNormal assumes the covariance matrix is constant so the Cholesky decomp will be performed once upon construction).

It also doesn't solve the issue of "interoperability" with the parts of the ecosystem in which Distributions.jl is often used, e.g. with Tracker/Zygote. It of course works, but for larger models it can be quite a hassle compared to tying the parameters to the Distribution instance rather than keeping track of it through variables outside of the Distribution.

  1. However, I also see the need for being able to construct a distribution using different parameters, e.g. precision vs covariance matrix, or directly using a triangular matrix which could be the Cholesky of the covariance or precision. I think these should be possible with multiple dispatch. Providing things like MvNormal(mu, Covariance(A)) or MvNormal(mu, Precision(A)). If A is a Cholesky we can also construct the PDMat directly. With these more efficient constructors, we get the triangular re-parameterization for free, e.g. L -> logpdf(MvNormal(mu, Covariance(Cholesky(L, 'L', 0))), x). I believe the distribution (re-)construction in this case should not allocate since we are not factorizing A.

I think MvNormal already does this, no? But what is the difference between this and the more general approach of allowing "lazy" transformations like what this issue is proposing? It seems, uhmm, maybe a bit arbitrary to allow reparameterizations, but only for Cholesky and Precision? I understand you could do this for more reparameterizations, e.g. define Normal(μ, Exp(σ)) and so on, but this will require even more work and be less flexible than what this issue is proposing, right?

@willtebbutt
Copy link
Member

willtebbutt commented Sep 25, 2019

That's true, but in multivariate cases you still cannot do inplace updates the parameters (though to allow this you'd have to take a slightly different approach to certain distributions than what Distributions.jl is currently doing, e.g. MvNormal assumes the covariance matrix is constant so the Cholesky decomp will be performed once upon construction)

I think this is one of the key aspects of this discussion. I'm personally more of a fan of the functional approach, but I appreciate that there are merits to both approaches. I'm not really sure which way the community is leaning here, perhaps @MikeInnes or @oxinabox can comment? If I remember correctly, Zygote's recommended mode of operation is now the functional style?

@torfjelde
Copy link
Member Author

I started out preferring the more functional style, but have recently grown quite fond of the Flux approach. Granted, I've recently been using more neural networks where I think this approach is particularly useful.

Also, it's worth noting that with what's proposed here you can do both (which is why I like it!:) )

@oxinabox
Copy link

On an earlier point:

I haven't followed this issue carefully but (1) yes, Zygote supports structs well and (2) it'd be nice not to have to load DistributionsAD on top of Distributions to get AD to work (not sure if that's the plan). Happy to look at support directly in Zygote, maybe via requires, if that's an option.

Had this discussion with @matbesancon
In context of ChainRules.
My recollection is that
While he happy about AD for derivatives, he absolutely does not what it in Distribution.jl
ChainRules.jl (not ChainRulesCore) is just adding @requires for these cases.
(Rewriting this for ChainRules is still a little way off, meet to continue improving struct support for that, I think)

@mohamed82008
Copy link
Member

mohamed82008 commented Sep 26, 2019

I think MvNormal already does this, no? But what is the difference between this and the more general approach of allowing "lazy" transformations like what this issue is proposing? It seems, uhmm, maybe a bit arbitrary to allow reparameterizations, but only for Cholesky and Precision? I understand you could do this for more reparameterizations, e.g. define Normal(μ, Exp(σ)) and so on, but this will require even more work and be less flexible than what this issue is proposing, right?

@torfjelde You can still do lazy transformations by multiple dispatch, like you said using Normal(μ, Exp(σ)) for example. For MvNormal, we can also do MvNormal(μ, Exp(Σ)) which internally also stores lazy wrapper of Σ and dispatches to efficient v' Exp(Σ)^-1 v and logdet(Exp(Σ)) where possible. For example, logdet(Exp(Σ)) = tr(Σ).

Dispatching on reparam in your proposal for efficient tricks like this is only possible if reparam itself uses the lazy Exp internally and we dispatch on Exp for logdet. So if we can avoid making our own AD types using the lazy wrapper approach directly, that would be better.

If we are talking modifying the distribution in-place (no AD), we can do that using the lazy function wrapper. Note that we always have to define the fields of the distribution according to its struct definition. So we have one of two scenarios:

  1. We tap into the inner most constructors for PDMat and MvNormal for example to define our distribution dist once while keeping the handle to Σ that we can modify in-place outside affecting the next logpdf(dist, x) result.
  2. We call an outer constructor that does copying, linear algebra, or call other functions that render our handles to Σ independent from the distribution struct returned.

This is fundamentally a constructor definition problem. It is a question of how we can construct the distribution while enabling in-place modification of the inputs. Lazy function wrappers take us some of the way. Note that at the end of the day we still need to satisfy the field type signature of the distribution struct, so we may need to modify the type parameters of the distribution struct to accept more generic matrix types like a lazy matrix-valued function which sub-types AbstractMatrix. Learning to live within those boundaries and pushing them where it makes sense to enable dispatch-based laziness seems like a more Julian approach to me than making 2 versions of the same struct, one persistent and one lazy.

So in summary, the anonymous function and dispatch-based laziness approach enables us to:

  1. Think about ways to make various functions more efficient, e.g. logdet,
  2. Avoid the need for an AD version of every distribution,
  3. Keep handles to the inputs passed to the outer constructor if we get laziness right, which enables in-place modification.

Note that at this point, it is not a question of whether we need arbitrary re-parameterization, just the API choice. I am leaning towards not having a struct for every distribution for AD purposes only, using anonymous functions and dispatch-based laziness to gain any efficiency and/or flexibility benefits. Ironically, we already implement an AD distribution for MvNormal here to workaround some Distributions-PDMats complexity. But for a long-term solution we should try to live within the boundaries of Distributions.jl and PDMats.jl.

Pinging @ChrisRackauckas in case he has opinions on this.

@torfjelde
Copy link
Member Author

@torfjelde You can still do lazy transformations by multiple dispatch, like you said using Normal(μ, Exp(σ)) for example. For MvNormal, we can also do MvNormal(μ, Exp(Σ)) which internally also stores lazy wrapper of Σ and dispatches to efficient v' Exp(Σ)^-1 v and logdet(Exp(Σ)) where possible. For example, logdet(Exp(Σ)) = tr(Σ).

Yeah, I understood that but it would still require always building an explicit type Exp which could do this, in constrast to the user just passing in the exp function and we wrap every use of σ in this (this approach wouldn't just work for any case, but in univariate case it would be "one impl works for alll transformations").

But after reading your comment I realize we can just make a Lazy{exp}(σ) wrapper of σ and do the same thing as I wanted to do:) (You might have already realized this!) This is basically a "you know what you're doing"-type. Well, it's going to be rather annoying to have to specify different behavior on all combinations of the different parameters, .e.g. you want to apply log to μ and exp to σ you have to implement Normal{Log, Exp}, Normal{<:Real, Exp} and Normal{Log, <:Real} in addition to existing implementation. Granted, the same issue is a problem in what I'm proposing if you require a separate transform for each parameter and you want to do specific behavior for exp on σ.

I think I'm coming around to your suggestion!:) It still seems like making this compatible with current Distribution is going to be, uhmm, slightly challenging.

Dispatching on reparam in your proposal for efficient tricks like this is only possible if reparam itself uses the lazy Exp internally and we dispatch on Exp for logdet. So if we can avoid making our own AD types using the lazy wrapper approach directly, that would be better.

You could still do this when P is, say, the actual function exp though, right? But maybe this has some issues I'm not fully aware of.

Learning to live within those boundaries and pushing them where it makes sense to enable dispatch-based laziness seems like a more Julian approach to me than making 2 versions of the same struct, one persistent and one lazy.

"to where it makes sense" -> "to where we can" seems like a more accurate statement 🙃

@mohamed82008
Copy link
Member

You could still do this when P is, say, the actual function exp though, right? But maybe this has some issues I'm not fully aware of.

Well in your proposal IIUC, P is acting on all the arguments together not each one individually. So we don't really know that it is using exp on the covariance inside from its type only to do any magical specialization on P. This means we still need to rely on Exp for the dispatch-based lazy specialization of logdet for example.

"to where it makes sense" -> "to where we can" seems like a more accurate statement 🙃

True, but if we hit a wall, we can decide to temporarily branch off until the obstacle is removed. This is what we do now for MvNormal and arguably with this whole package.

But after reading your comment I realize we can just make a Lazy{exp}(σ) wrapper of σ and do the same thing as I wanted to do:)

Yes this is a nice generic way of defining lazy wrappers. Exp can be alias for Lazy{exp}.

@torfjelde
Copy link
Member Author

I completely agree with the last comment:)

One thing though: this "wrapping"-approach means that if we want type-stability we'd have to allow all the parameters of a distribution to take on different types, e.g. Beta(a::T, b::T) can't be used since you might want to do Beta(a::Lazy{T, f1}, b::Lazy{T, f2}) where f1 and f2 are two different functions.

It still seems like the best approach, but worth noting that this might be a big hurdle to overcome as we'd basically need to re-define most distributions to accomodate something like this.

@torfjelde
Copy link
Member Author

And I think something like the following works okay as a "default" where we just allow the type itself to specify how to handle the unconstrained-to-constrained transformation:

abstract type Constrained end

struct Unconstrained{T} <: Constrained
    val::T
end
value(c::Unconstrained) = c.val

Normal::Unconstrained{T}, σ::Unconstrained{T}) where {T} = Normal(value(μ), exp(value(σ)))

Could also do something like Unconstrained{T, F} where F is a callable. Then we can use

value(c::Unconstrained{T, F}) where {T, F} = F(c.val)

# when `F = identity` we have a default treatment
Normal::Unconstrained{T, identity}, σ::Unconstrained{T, identity}) where {T} = Normal(value(μ), exp(value(σ)))

# in this case we have to assume that `value` takes care of the transformation
Normal::Unconstrained{T}, σ::Unconstrained{T}) where {T} = Normal(value(μ), value(σ))

Need to think about this further, but doesn't seem like a horrible approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants