diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 00000000..ba942b23 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1,2 @@ +style="blue" +format_markdown=true \ No newline at end of file diff --git a/.github/workflows/Format.yml b/.github/workflows/Format.yml new file mode 100644 index 00000000..6a6df765 --- /dev/null +++ b/.github/workflows/Format.yml @@ -0,0 +1,38 @@ +name: Format + +on: + push: + branches: + - master + pull_request: + branches: + - master + merge_group: + types: [checks_requested] + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: 1 + - name: Format code + run: | + using Pkg + Pkg.add(; name="JuliaFormatter", uuid="98e50ef6-434e-11e9-1051-2b60c6c9e899") + using JuliaFormatter + format("."; verbose=true) + shell: julia --color=yes {0} + - uses: reviewdog/action-suggester@v1 + if: github.event_name == 'pull_request' + with: + tool_name: JuliaFormatter + fail_on_error: true \ No newline at end of file diff --git a/docs/make.jl b/docs/make.jl index e1138577..5f3322a8 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -4,13 +4,18 @@ using Bijectors # Doctest setup DocMeta.setdocmeta!(Bijectors, :DocTestSetup, :(using Bijectors); recursive=true) -makedocs( - sitename = "Bijectors", - format = Documenter.HTML(), - modules = [Bijectors], - pages = ["Home" => "index.md", "Transforms" => "transforms.md", "Distributions.jl integration" => "distributions.md", "Examples" => "examples.md"], +makedocs(; + sitename="Bijectors", + format=Documenter.HTML(), + modules=[Bijectors], + pages=[ + "Home" => "index.md", + "Transforms" => "transforms.md", + "Distributions.jl integration" => "distributions.md", + "Examples" => "examples.md", + ], strict=false, checkdocs=:exports, ) -deploydocs(repo = "github.com/TuringLang/Bijectors.jl.git", push_preview=true) +deploydocs(; repo="github.com/TuringLang/Bijectors.jl.git", push_preview=true) diff --git a/docs/src/distributions.md b/docs/src/distributions.md index 514fc4bf..a1135516 100644 --- a/docs/src/distributions.md +++ b/docs/src/distributions.md @@ -1,10 +1,13 @@ ## Basic usage + Other than the `logpdf_with_trans` methods, the package also provides a more composable interface through the `Bijector` types. Consider for example the one from above with `Beta(2, 2)`. ```julia -julia> using Random; Random.seed!(42); +julia> using Random; + Random.seed!(42); -julia> using Bijectors; using Bijectors: Logit +julia> using Bijectors; + using Bijectors: Logit; julia> dist = Beta(2, 2) Beta{Float64}(α=2.0, β=2.0) diff --git a/docs/src/examples.md b/docs/src/examples.md index fff55f28..a29e6875 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -3,9 +3,11 @@ using Bijectors ``` ## Univariate ADVI example + But the real utility of `TransformedDistribution` becomes more apparent when using `transformed(dist, b)` for any bijector `b`. To get the transformed distribution corresponding to the `Beta(2, 2)`, we called `transformed(dist)` before. This is simply an alias for `transformed(dist, bijector(dist))`. Remember `bijector(dist)` returns the constrained-to-constrained bijector for that particular `Distribution`. But we can of course construct a `TransformedDistribution` using different bijectors with the same `dist`. This is particularly useful in something called _Automatic Differentiation Variational Inference (ADVI)_.[2] An important part of ADVI is to approximate a constrained distribution, e.g. `Beta`, as follows: -1. Sample `x` from a `Normal` with parameters `μ` and `σ`, i.e. `x ~ Normal(μ, σ)`. -2. Transform `x` to `y` s.t. `y ∈ support(Beta)`, with the transform being a differentiable bijection with a differentiable inverse (a "bijector") + + 1. Sample `x` from a `Normal` with parameters `μ` and `σ`, i.e. `x ~ Normal(μ, σ)`. + 2. Transform `x` to `y` s.t. `y ∈ support(Beta)`, with the transform being a differentiable bijection with a differentiable inverse (a "bijector") This then defines a probability density with same _support_ as `Beta`! Of course, it's unlikely that it will be the same density, but it's an _approximation_. Creating such a distribution becomes trivial with `Bijector` and `TransformedDistribution`: @@ -16,7 +18,7 @@ dist = Beta(2, 2) b = bijector(dist) # (0, 1) → ℝ b⁻¹ = inverse(b) # ℝ → (0, 1) td = transformed(Normal(), b⁻¹) # x ∼ 𝓝(0, 1) then b(x) ∈ (0, 1) - x = rand(rng, td) # ∈ (0, 1) +x = rand(rng, td) # ∈ (0, 1) ``` It's worth noting that `support(Beta)` is the _closed_ interval `[0, 1]`, while the constrained-to-unconstrained bijection, `Logit` in this case, is only well-defined as a map `(0, 1) → ℝ` for the _open_ interval `(0, 1)`. This is of course not an implementation detail. `ℝ` is itself open, thus no continuous bijection exists from a _closed_ interval to `ℝ`. But since the boundaries of a closed interval has what's known as measure zero, this doesn't end up affecting the resulting density with support on the entire real line. In practice, this means that @@ -29,25 +31,22 @@ inverse(td.transform)(rand(rng, td)) will never result in `0` or `1` though any sample arbitrarily close to either `0` or `1` is possible. _Disclaimer: numerical accuracy is limited, so you might still see `0` and `1` if you're lucky._ ## Multivariate ADVI example + We can also do _multivariate_ ADVI using the `Stacked` bijector. `Stacked` gives us a way to combine univariate and/or multivariate bijectors into a singe multivariate bijector. Say you have a vector `x` of length 2 and you want to transform the first entry using `Exp` and the second entry using `Log`. `Stacked` gives you an easy and efficient way of representing such a bijector. ```@repl advi using Bijectors: SimplexBijector # Original distributions -dists = ( - Beta(), - InverseGamma(), - Dirichlet(2, 3) -); +dists = (Beta(), InverseGamma(), Dirichlet(2, 3)); # Construct the corresponding ranges ranges = []; idx = 1; -for i = 1:length(dists) +for i in 1:length(dists) d = dists[i] - push!(ranges, idx:idx + length(d) - 1) + push!(ranges, idx:(idx + length(d) - 1)) global idx idx += length(d) @@ -74,6 +73,7 @@ sum(y[3:4]) ≈ 1.0 ``` ## Normalizing flows + A very interesting application is that of _normalizing flows_.[1] Usually this is done by sampling from a multivariate normal distribution, and then transforming this to a target distribution using invertible neural networks. Currently there are two such transforms available in Bijectors.jl: `PlanarLayer` and `RadialLayer`. Let's create a flow with a single `PlanarLayer`: ```@setup normalizing-flows @@ -144,7 +144,7 @@ f = NLLObjective(reconstruct, MvNormal(2, 1), xs); # Train using gradient descent. ε = 1e-3; -for i = 1:100 +for i in 1:100 ∇s = Zygote.gradient(f, θs...) θs = map(θs, ∇s) do θ, ∇ θ - ε .* ∇ diff --git a/docs/src/index.md b/docs/src/index.md index 7f7ed0a4..a347da07 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,31 +1,32 @@ # Bijectors.jl This package implements a set of functions for transforming constrained random variables (e.g. simplexes, intervals) to Euclidean space. The 3 main functions implemented in this package are the `link`, `invlink` and `logpdf_with_trans` for a number of distributions. The distributions supported are: -1. `RealDistribution`: `Union{Cauchy, Gumbel, Laplace, Logistic, NoncentralT, Normal, NormalCanon, TDist}`, -2. `PositiveDistribution`: `Union{BetaPrime, Chi, Chisq, Erlang, Exponential, FDist, Frechet, Gamma, InverseGamma, InverseGaussian, Kolmogorov, LogNormal, NoncentralChisq, NoncentralF, Rayleigh, Weibull}`, -3. `UnitDistribution`: `Union{Beta, KSOneSided, NoncentralBeta}`, -4. `SimplexDistribution`: `Union{Dirichlet}`, -5. `PDMatDistribution`: `Union{InverseWishart, Wishart}`, and -6. `TransformDistribution`: `Union{T, Truncated{T}} where T<:ContinuousUnivariateDistribution`. + + 1. `RealDistribution`: `Union{Cauchy, Gumbel, Laplace, Logistic, NoncentralT, Normal, NormalCanon, TDist}`, + 2. `PositiveDistribution`: `Union{BetaPrime, Chi, Chisq, Erlang, Exponential, FDist, Frechet, Gamma, InverseGamma, InverseGaussian, Kolmogorov, LogNormal, NoncentralChisq, NoncentralF, Rayleigh, Weibull}`, + 3. `UnitDistribution`: `Union{Beta, KSOneSided, NoncentralBeta}`, + 4. `SimplexDistribution`: `Union{Dirichlet}`, + 5. `PDMatDistribution`: `Union{InverseWishart, Wishart}`, and + 6. `TransformDistribution`: `Union{T, Truncated{T}} where T<:ContinuousUnivariateDistribution`. All exported names from the [Distributions.jl](https://github.com/TuringLang/Bijectors.jl) package are reexported from `Bijectors`. Bijectors.jl also provides a nice interface for working with these maps: composition, inversion, etc. The following table lists mathematical operations for a bijector and the corresponding code in Bijectors.jl. -| Operation | Method | Automatic | -|:------------------------------------:|:-----------------:|:-----------:| -| `b ↦ b⁻¹` | `inverse(b)` | ✓ | -| `(b₁, b₂) ↦ (b₁ ∘ b₂)` | `b₁ ∘ b₂` | ✓ | -| `(b₁, b₂) ↦ [b₁, b₂]` | `stack(b₁, b₂)` | ✓ | -| `x ↦ b(x)` | `b(x)` | × | -| `y ↦ b⁻¹(y)` | `inverse(b)(y)` | × | -| `x ↦ log|det J(b, x)|` | `logabsdetjac(b, x)` | AD | -| `x ↦ b(x), log|det J(b, x)|` | `with_logabsdet_jacobian(b, x)` | ✓ | -| `p ↦ q := b_* p` | `q = transformed(p, b)` | ✓ | -| `y ∼ q` | `y = rand(q)` | ✓ | -| `p ↦ b` such that `support(b_* p) = ℝᵈ` | `bijector(p)` | ✓ | -| `(x ∼ p, b(x), log|det J(b, x)|, log q(y))` | `forward(q)` | ✓ | +| Operation | Method | Automatic | +|:-------------------------------------------:|:-------------------------------:|:---------:| +| `b ↦ b⁻¹` | `inverse(b)` | ✓ | +| `(b₁, b₂) ↦ (b₁ ∘ b₂)` | `b₁ ∘ b₂` | ✓ | +| `(b₁, b₂) ↦ [b₁, b₂]` | `stack(b₁, b₂)` | ✓ | +| `x ↦ b(x)` | `b(x)` | × | +| `y ↦ b⁻¹(y)` | `inverse(b)(y)` | × | +| `x ↦ log|det J(b, x)|` | `logabsdetjac(b, x)` | AD | +| `x ↦ b(x), log|det J(b, x)|` | `with_logabsdet_jacobian(b, x)` | ✓ | +| `p ↦ q := b_* p` | `q = transformed(p, b)` | ✓ | +| `y ∼ q` | `y = rand(q)` | ✓ | +| `p ↦ b` such that `support(b_* p) = ℝᵈ` | `bijector(p)` | ✓ | +| `(x ∼ p, b(x), log|det J(b, x)|, log q(y))` | `forward(q)` | ✓ | In this table, `b` denotes a `Bijector`, `J(b, x)` denotes the Jacobian of `b` evaluated at `x`, `b_*` denotes the [push-forward](https://www.wikiwand.com/en/Pushforward_measure) of `p` by `b`, and `x ∼ p` denotes `x` sampled from the distribution with density `p`. diff --git a/docs/src/transforms.md b/docs/src/transforms.md index cf9223aa..8617a828 100644 --- a/docs/src/transforms.md +++ b/docs/src/transforms.md @@ -1,8 +1,9 @@ ## Usage A very simple example of a "bijector"/diffeomorphism, i.e. a differentiable transformation with a differentiable inverse, is the `exp` function: -- The inverse of `exp` is `log`. -- The derivative of `exp` at an input `x` is simply `exp(x)`, hence `logabsdetjac` is simply `x`. + + - The inverse of `exp` is `log`. + - The derivative of `exp` at an input `x` is simply `exp(x)`, hence `logabsdetjac` is simply `x`. ```@repl usage using Bijectors @@ -100,4 +101,3 @@ Bijectors.OrderedBijector Bijectors.NamedTransform Bijectors.NamedCoupling ``` - diff --git a/src/Bijectors.jl b/src/Bijectors.jl index d934381f..92e00b92 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -40,41 +40,41 @@ using InverseFunctions: InverseFunctions import ChangesOfVariables: ChangesOfVariables, with_logabsdet_jacobian import InverseFunctions: inverse -import ChainRulesCore -import Functors -import IrrationalConstants -import LogExpFunctions -import Roots - -export TransformDistribution, - PositiveDistribution, - UnitDistribution, - SimplexDistribution, - PDMatDistribution, - link, - invlink, - logpdf_with_trans, - isclosedform, - transform, - transform!, - with_logabsdet_jacobian, - with_logabsdet_jacobian!, - inverse, - logabsdetjac, - logabsdetjac!, - logabsdetjacinv, - Bijector, - Inverse, - Stacked, - bijector, - transformed, - UnivariateTransformed, - MultivariateTransformed, - PlanarLayer, - RadialLayer, - Coupling, - InvertibleBatchNorm, - elementwise +using ChainRulesCore: ChainRulesCore +using Functors: Functors +using IrrationalConstants: IrrationalConstants +using LogExpFunctions: LogExpFunctions +using Roots: Roots + +export TransformDistribution, + PositiveDistribution, + UnitDistribution, + SimplexDistribution, + PDMatDistribution, + link, + invlink, + logpdf_with_trans, + isclosedform, + transform, + transform!, + with_logabsdet_jacobian, + with_logabsdet_jacobian!, + inverse, + logabsdetjac, + logabsdetjac!, + logabsdetjacinv, + Bijector, + Inverse, + Stacked, + bijector, + transformed, + UnivariateTransformed, + MultivariateTransformed, + PlanarLayer, + RadialLayer, + Coupling, + InvertibleBatchNorm, + elementwise if VERSION < v"1.1" using Compat: eachcol @@ -98,27 +98,27 @@ end function mapvcat(f, args...) out = map(f, args...) init = vcat(out[1]) - return reduce(vcat, drop(out, 1); init = init) + return reduce(vcat, drop(out, 1); init=init) end function maphcat(f, args...) out = map(f, args...) init = reshape(out[1], :, 1) - return reduce(hcat, drop(out, 1); init = init) + return reduce(hcat, drop(out, 1); init=init) end function eachcolmaphcat(f, x1, x2) - out = [f(x1[:,i], x2[i]) for i in 1:size(x1, 2)] + out = [f(x1[:, i], x2[i]) for i in 1:size(x1, 2)] init = reshape(out[1], :, 1) - return reduce(hcat, drop(out, 1); init = init) + return reduce(hcat, drop(out, 1); init=init) end function eachcolmaphcat(f, x) out = map(f, eachcol(x)) init = reshape(out[1], :, 1) - return reduce(hcat, drop(out, 1); init = init) + return reduce(hcat, drop(out, 1); init=init) end function sumeachcol(f, x1, x2) # Using a view below for x1 breaks Tracker - return sum(f(x1[:,i], x2[i]) for i in 1:size(x1, 2)) + return sum(f(x1[:, i], x2[i]) for i in 1:size(x1, 2)) end # Distributions @@ -129,14 +129,21 @@ invlink(d::Distribution, y) = inverse(bijector(d))(y) # To still allow `logpdf_with_trans` to work with "batches" in a similar way # as `logpdf` can. _logabsdetjac_dist(d::UnivariateDistribution, x::Real) = logabsdetjac(bijector(d), x) -_logabsdetjac_dist(d::UnivariateDistribution, x::AbstractArray) = logabsdetjac.((bijector(d),), x) +function _logabsdetjac_dist(d::UnivariateDistribution, x::AbstractArray) + return logabsdetjac.((bijector(d),), x) +end -_logabsdetjac_dist(d::MultivariateDistribution, x::AbstractVector) = logabsdetjac(bijector(d), x) -_logabsdetjac_dist(d::MultivariateDistribution, x::AbstractMatrix) = logabsdetjac.((bijector(d),), eachcol(x)) +function _logabsdetjac_dist(d::MultivariateDistribution, x::AbstractVector) + return logabsdetjac(bijector(d), x) +end +function _logabsdetjac_dist(d::MultivariateDistribution, x::AbstractMatrix) + return logabsdetjac.((bijector(d),), eachcol(x)) +end _logabsdetjac_dist(d::MatrixDistribution, x::AbstractMatrix) = logabsdetjac(bijector(d), x) -_logabsdetjac_dist(d::MatrixDistribution, x::AbstractVector{<:AbstractMatrix}) = logabsdetjac.((bijector(d),), x) - +function _logabsdetjac_dist(d::MatrixDistribution, x::AbstractVector{<:AbstractMatrix}) + return logabsdetjac.((bijector(d),), x) +end function logpdf_with_trans(d::Distribution, x, transform::Bool) if ispd(d) @@ -155,15 +162,27 @@ end ## Univariate -const TransformDistribution = Union{ - T, - Truncated{T}, -} where T <: ContinuousUnivariateDistribution +const TransformDistribution = + Union{T,Truncated{T}} where {T<:ContinuousUnivariateDistribution} const PositiveDistribution = Union{ - BetaPrime, Chi, Chisq, Erlang, Exponential, FDist, Frechet, Gamma, InverseGamma, - InverseGaussian, Kolmogorov, LogNormal, NoncentralChisq, NoncentralF, Rayleigh, Weibull, + BetaPrime, + Chi, + Chisq, + Erlang, + Exponential, + FDist, + Frechet, + Gamma, + InverseGamma, + InverseGaussian, + Kolmogorov, + LogNormal, + NoncentralChisq, + NoncentralF, + Rayleigh, + Weibull, } -const UnitDistribution = Union{Beta, KSOneSided, NoncentralBeta} +const UnitDistribution = Union{Beta,KSOneSided,NoncentralBeta} function logpdf_with_trans(d::UnivariateDistribution, x, transform::Bool) if transform @@ -183,33 +202,23 @@ isdirichlet(::Distribution) = false # ∑xᵢ = 1 # ########### -function link( - d::Dirichlet, - x::AbstractVecOrMat{<:Real}, - ::Val{proj}=Val(true), -) where {proj} +function link(d::Dirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)) where {proj} return SimplexBijector{proj}()(x) end function link_jacobian( - d::Dirichlet, - x::AbstractVector{<:Real}, - ::Val{proj}=Val(true), + d::Dirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true) ) where {proj} return jacobian(SimplexBijector{proj}(), x) end function invlink( - d::Dirichlet, - y::AbstractVecOrMat{<:Real}, - ::Val{proj}=Val(true), + d::Dirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true) ) where {proj} return inverse(SimplexBijector{proj}())(y) end function invlink_jacobian( - d::Dirichlet, - y::AbstractVector{<:Real}, - ::Val{proj}=Val(true), + d::Dirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true) ) where {proj} return jacobian(inverse(SimplexBijector{proj}()), y) end @@ -220,14 +229,12 @@ end # Positive definite # ##################### -const PDMatDistribution = Union{MatrixBeta, InverseWishart, Wishart} +const PDMatDistribution = Union{MatrixBeta,InverseWishart,Wishart} ispd(::Distribution) = false ispd(::PDMatDistribution) = true function logpdf_with_trans( - d::MatrixDistribution, - X::AbstractArray{<:AbstractMatrix{<:Real}}, - transform::Bool, + d::MatrixDistribution, X::AbstractArray{<:AbstractMatrix{<:Real}}, transform::Bool ) return map(X) do x logpdf_with_trans(d, x, transform) @@ -235,7 +242,7 @@ function logpdf_with_trans( end function pd_logpdf_with_trans(d, X::AbstractMatrix{<:Real}, transform::Bool) T = eltype(X) - Xcf = cholesky(X, check = false) + Xcf = cholesky(X; check=false) if !issuccess(Xcf) Xcf = cholesky(X + max(eps(T), eps(T) * norm(X)) * I) end @@ -265,7 +272,7 @@ include("interface.jl") include("chainrules.jl") # Broadcasting here breaks Tracker for some reason -maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) +maporbroadcast(f, x::AbstractArray{<:Any,N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) # optional dependencies @@ -281,11 +288,17 @@ function __init__() return copy(f.(x1, x2, x3, x...)) end end - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("compat/forwarddiff.jl") - @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("compat/tracker.jl") - @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" include("compat/zygote.jl") - @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("compat/reversediff.jl") - @require DistributionsAD="ced4e74d-a319-5a8a-b0ac-84af2272839c" include("compat/distributionsad.jl") + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include( + "compat/forwarddiff.jl" + ) + @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("compat/tracker.jl") + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("compat/zygote.jl") + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include( + "compat/reversediff.jl" + ) + @require DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" include( + "compat/distributionsad.jl" + ) end end # module diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 252ecc68..995708dd 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -67,8 +67,8 @@ with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) w = cholesky(x).U # keep LowerTriangular until here can avoid some computation - r = _link_chol_lkj(w) - return r + zero(x) + r = _link_chol_lkj(w) + return r + zero(x) # This dense format itself is required by a test, though I can't get the point. # https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67 end @@ -80,15 +80,17 @@ end function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) K = LinearAlgebra.checksquare(y) - + result = float(zero(eltype(y))) for j in 2:K, i in 1:(j - 1) @inbounds abs_y_i_j = abs(y[i, j]) - result += (K - i + 1) * ( - IrrationalConstants.logtwo - (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j)) - ) + result += + (K - i + 1) * ( + IrrationalConstants.logtwo - + (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j)) + ) end - + return result end function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) @@ -98,27 +100,27 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) `logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})` if possible. =# - return -logabsdetjac(inverse(b), (b(X))) + return -logabsdetjac(inverse(b), (b(X))) end function _inv_link_chol_lkj(y) K = LinearAlgebra.checksquare(y) w = similar(y) - + @inbounds for j in 1:K w[1, j] = 1 for i in 2:j - z = tanh(y[i-1, j]) - tmp = w[i-1, j] - w[i-1, j] = z * tmp + z = tanh(y[i - 1, j]) + tmp = w[i - 1, j] + w[i - 1, j] = z * tmp w[i, j] = tmp * sqrt(1 - z^2) end - for i in (j+1):K + for i in (j + 1):K w[i, j] = 0 end end - + return w end @@ -163,7 +165,7 @@ function _link_chol_lkj(w) # This block can't be integrated with loop below, because w[1,1] != 0. @inbounds z[1, 1] = 0 - @inbounds for j=2:K + @inbounds for j in 2:K z[1, j] = atanh(w[1, j]) tmp = sqrt(1 - w[1, j]^2) for i in 2:(j - 1) @@ -173,6 +175,6 @@ function _link_chol_lkj(w) end z[j, j] = 0 end - + return z end diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 9aaaf829..b907583d 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -48,13 +48,15 @@ PartitionMask{Float32,SparseArrays.SparseMatrixCSC{Float32,Int64}}( [3, 1] = 1.0) ``` """ -struct PartitionMask{T, A} +struct PartitionMask{T,A} A_1::A A_2::A A_3::A # Only make it possible to construct using matrices - PartitionMask(A_1::A, A_2::A, A_3::A) where {T<:Real, A <: AbstractMatrix{T}} = new{T, A}(A_1, A_2, A_3) + function PartitionMask(A_1::A, A_2::A, A_3::A) where {T<:Real,A<:AbstractMatrix{T}} + return new{T,A}(A_1, A_2, A_3) + end end PartitionMask(args...) = PartitionMask{Bool}(args...) @@ -63,7 +65,7 @@ function PartitionMask{T}( n::Int, indices_1::AbstractVector{Int}, indices_2::AbstractVector{Int}, - indices_3::AbstractVector{Int} + indices_3::AbstractVector{Int}, ) where {T<:Real} A_1 = sparse(indices_1, 1:length(indices_1), one(T), n, length(indices_1)) A_2 = sparse(indices_2, 1:length(indices_2), one(T), n, length(indices_2)) @@ -72,25 +74,29 @@ function PartitionMask{T}( return PartitionMask(A_1, A_2, A_3) end -PartitionMask{T}( - n::Int, - indices_1::AbstractVector{Int}, - indices_2::AbstractVector{Int}; -) where {T} = PartitionMask{T}(n, indices_1, indices_2, nothing) +function PartitionMask{T}( + n::Int, indices_1::AbstractVector{Int}, indices_2::AbstractVector{Int}; +) where {T} + return PartitionMask{T}(n, indices_1, indices_2, nothing) +end -PartitionMask{T}( +function PartitionMask{T}( n::Int, indices_1::AbstractVector{Int}, indices_2::AbstractVector{Int}, indices_3::Nothing, -) where {T} = PartitionMask{T}(n, indices_1, indices_2, setdiff(1:n, indices_1, indices_2)) +) where {T} + return PartitionMask{T}(n, indices_1, indices_2, setdiff(1:n, indices_1, indices_2)) +end -PartitionMask{T}( +function PartitionMask{T}( n::Int, indices_1::AbstractVector{Int}, indices_2::Nothing, indices_3::AbstractVector{Int}, -) where {T} = PartitionMask{T}(n, indices_1, setdiff(1:n, indices_1, indices_3), indices_3) +) where {T} + return PartitionMask{T}(n, indices_1, setdiff(1:n, indices_1, indices_3), indices_3) +end """ PartitionMask(n::Int, indices) @@ -123,8 +129,9 @@ Combines `x_1`, `x_2`, and `x_3` into a single vector. Partitions `x` into 3 disjoint subvectors. """ -@inline partition(m::PartitionMask, x) = (transpose(m.A_1) * x, transpose(m.A_2) * x, transpose(m.A_3) * x) - +@inline function partition(m::PartitionMask, x) + return (transpose(m.A_1) * x, transpose(m.A_2) * x, transpose(m.A_3) * x) +end # Coupling @@ -168,7 +175,7 @@ julia> with_logabsdet_jacobian(cl, x) # References [1] Kobyzev, I., Prince, S., & Brubaker, M. A., Normalizing flows: introduction and ideas, CoRR, (), (2019). """ -struct Coupling{F, M} <: Bijector where {F, M <: PartitionMask} +struct Coupling{F,M} <: Bijector where {F,M<:PartitionMask} θ::F mask::M end @@ -233,7 +240,7 @@ end function transform(icl::Inverse{<:Coupling}, y::AbstractVector) cl = icl.orig - + y_1, y_2, y_3 = partition(cl.mask, y) b = cl.θ(y_2) diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index 7236a74e..8af505b3 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -1,4 +1,6 @@ -transform!(b::Union{Elementwise{typeof(log)}, Elementwise{typeof(exp)}}, x, y) = broadcast!(b.x, y, x) +function transform!(b::Union{Elementwise{typeof(log)},Elementwise{typeof(exp)}}, x, y) + return broadcast!(b.x, y, x) +end logabsdetjac(b::typeof(exp), x::Real) = x logabsdetjac(b::Elementwise{typeof(exp)}, x) = sum(x) diff --git a/src/bijectors/logit.jl b/src/bijectors/logit.jl index 1df73514..946228bf 100644 --- a/src/bijectors/logit.jl +++ b/src/bijectors/logit.jl @@ -25,4 +25,6 @@ logit_logabsdetjac(x, a, b) = -log((x - a) * (b - x) / (b - a)) logabsdetjac(b::Logit, x) = sum(logit_logabsdetjac.(x, b.a, b.b)) # `with_logabsdet_jacobian` -with_logabsdet_jacobian(b::Logit, x) = _logit.(x, b.a, b.b), sum(logit_logabsdetjac.(x, b.a, b.b)) +function with_logabsdet_jacobian(b::Logit, x) + return _logit.(x, b.a, b.b), sum(logit_logabsdetjac.(x, b.a, b.b)) +end diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index d147afc6..85e53e9b 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -24,16 +24,16 @@ julia> (a = 2 * x.a, b = exp(x.b), c = x.c) (a = 2.0, b = 1.0, c = 42.0) ``` """ -struct NamedTransform{names, Bs<:NamedTuple{names}} <: AbstractNamedTransform +struct NamedTransform{names,Bs<:NamedTuple{names}} <: AbstractNamedTransform bs::Bs end # fields contain nested numerical parameters -function Functors.functor(::Type{<:NamedTransform{names}}, x) where names +function Functors.functor(::Type{<:NamedTransform{names}}, x) where {names} function reconstruct_namedbijector(xs) return NamedTransform{names,typeof(xs.bs)}(xs.bs) end - return (bs = x.bs,), reconstruct_namedbijector + return (bs=x.bs,), reconstruct_namedbijector end # TODO: Use recursion instead of `@generated`? @@ -43,9 +43,8 @@ inverse(t::NamedTransform) = NamedTransform(map(inverse, t.bs)) isinvertible(t::NamedTransform) = all(isinvertible, t.bs) @generated function transform( - b::NamedTransform{names1}, - x::NamedTuple{names2} -) where {names1, names2} + b::NamedTransform{names1}, x::NamedTuple{names2} +) where {names1,names2} exprs = [] for n in names2 if n in names1 @@ -56,7 +55,7 @@ isinvertible(t::NamedTransform) = all(isinvertible, t.bs) push!(exprs, :($n = x.$n)) end end - return :($(exprs...), ) + return :($(exprs...),) end @generated function logabsdetjac(b::NamedTransform{names}, x::NamedTuple) where {names} @@ -65,18 +64,20 @@ end end @generated function with_logabsdet_jacobian( - b::NamedTransform{names1}, - x::NamedTuple{names2} -) where {names1, names2} + b::NamedTransform{names1}, x::NamedTuple{names2} +) where {names1,names2} body_exprs = [] logjac_expr = Expr(:call, :+) - val_expr = Expr(:tuple, ) + val_expr = Expr(:tuple) for n in names2 if n in names1 val_sym = Symbol("y_$n") logjac_sym = Symbol("logjac_$n") - push!(body_exprs, :(($val_sym, $logjac_sym) = with_logabsdet_jacobian(b.bs.$n, x.$n))) + push!( + body_exprs, + :(($val_sym, $logjac_sym) = with_logabsdet_jacobian(b.bs.$n, x.$n)), + ) push!(logjac_expr.args, logjac_sym) push!(val_expr.args, :($n = $val_sym)) else @@ -115,13 +116,13 @@ julia> (a = x.a, b = (x.a + x.c) * x.b, c = x.c) (a = 1.0, b = 8.0, c = 3.0) ``` """ -struct NamedCoupling{target, deps, F} <: AbstractNamedBijector where {F, target} +struct NamedCoupling{target,deps,F} <: AbstractNamedBijector where {F,target} f::F end -NamedCoupling(target, deps, f::F) where {F} = NamedCoupling{target, deps, F}(f) -function NamedCoupling(::Val{target}, ::Val{deps}, f::F) where {target, deps, F} - return NamedCoupling{target, deps, F}(f) +NamedCoupling(target, deps, f::F) where {F} = NamedCoupling{target,deps,F}(f) +function NamedCoupling(::Val{target}, ::Val{deps}, f::F) where {target,deps,F} + return NamedCoupling{target,deps,F}(f) end isinvertible(::NamedCoupling) = true @@ -130,20 +131,24 @@ coupling(b::NamedCoupling) = b.f # For some reason trying to use the parameteric types doesn't always work # so we have to do this weird approach of extracting type and then index `parameters`. target(b::NamedCoupling{Target}) where {Target} = Target -deps(b::NamedCoupling{<:Any, Deps}) where {Deps} = Deps +deps(b::NamedCoupling{<:Any,Deps}) where {Deps} = Deps -@generated function with_logabsdet_jacobian(nc::NamedCoupling{target, deps, F}, x::NamedTuple) where {target, deps, F} +@generated function with_logabsdet_jacobian( + nc::NamedCoupling{target,deps,F}, x::NamedTuple +) where {target,deps,F} return quote b = nc.f($([:(x.$d) for d in deps]...)) x_target, logjac = with_logabsdet_jacobian(b, x.$target) - return merge(x, ($target = x_target, )), logjac + return merge(x, ($target=x_target,)), logjac end end -@generated function with_logabsdet_jacobian(ni::Inverse{<:NamedCoupling{target, deps, F}}, x::NamedTuple) where {target, deps, F} +@generated function with_logabsdet_jacobian( + ni::Inverse{<:NamedCoupling{target,deps,F}}, x::NamedTuple +) where {target,deps,F} return quote ib = inverse(ni.orig.f($([:(x.$d) for d in deps]...))) x_target, logjac = with_logabsdet_jacobian(ib, x.$target) - return merge(x, ($target = x_target, )), logjac + return merge(x, ($target=x_target,)), logjac end end diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index 6b131d27..5ff91f6d 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -7,32 +7,30 @@ using Statistics: mean istraining() = false mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector - b :: T1 # bias - logs :: T1 # log-scale - m :: T2 # moving mean - v :: T2 # moving variance - eps :: T3 - mtm :: T3 # momentum + b::T1 # bias + logs::T1 # log-scale + m::T2 # moving mean + v::T2 # moving variance + eps::T3 + mtm::T3 # momentum end function Base.:(==)(b1::InvertibleBatchNorm, b2::InvertibleBatchNorm) - return b1.b == b2.b && - b1.logs == b2.logs && - b1.m == b2.m && - b1.v == b2.v && - b1.eps == b2.eps && - b1.mtm == b2.mtm + return b1.b == b2.b && + b1.logs == b2.logs && + b1.m == b2.m && + b1.v == b2.v && + b1.eps == b2.eps && + b1.mtm == b2.mtm end function InvertibleBatchNorm( - chs::Int; - eps::T=1f-5, - mtm::T=1f-1, + chs::Int; eps::T=1.0f-5, mtm::T=1.0f-1 ) where {T<:AbstractFloat} return InvertibleBatchNorm( zeros(T, chs), zeros(T, chs), # logs = 0 means s = 1 zeros(T, chs), - ones(T, chs), + ones(T, chs), eps, mtm, ) @@ -42,8 +40,9 @@ Functors.@functor InvertibleBatchNorm (b, logs) function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) dims = ndims(x) - size(x, dims - 1) == length(bn.b) || - error("InvertibleBatchNorm expected $(length(bn.b)) channels, got $(size(x, dims - 1))") + size(x, dims - 1) == length(bn.b) || error( + "InvertibleBatchNorm expected $(length(bn.b)) channels, got $(size(x, dims - 1))", + ) channels = size(x, dims - 1) as = ntuple(i -> i == ndims(x) - 1 ? size(x, i) : 1, dims) logs = reshape(bn.logs, as...) @@ -51,9 +50,9 @@ function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) b = reshape(bn.b, as...) if istraining() n = div(prod(size(x)), channels) - axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) - m = mean(x, dims = axes) - v = sum((x .- m) .^ 2, dims = axes) ./ n + axes = [1:(dims - 2); dims] # axes to reduce along (all but channels axis) + m = mean(x; dims=axes) + v = sum((x .- m) .^ 2; dims=axes) ./ n # Update moving mean and variance mtm = bn.mtm T = eltype(bn.m) @@ -65,9 +64,7 @@ function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) end result = s .* (x .- m) ./ sqrt.(v .+ bn.eps) .+ b - logabsdetjac = ( - fill(sum(logs - log.(v .+ bn.eps) / 2), size(x, dims)) - ) + logabsdetjac = (fill(sum(logs - log.(v .+ bn.eps) / 2), size(x, dims))) return (result, logabsdetjac) end @@ -91,5 +88,5 @@ end transform(bn::Inverse{<:InvertibleBatchNorm}, y) = first(with_logabsdet_jacobian(bn, y)) function Base.show(io::IO, l::InvertibleBatchNorm) - print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))") + return print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))") end diff --git a/src/bijectors/ordered.jl b/src/bijectors/ordered.jl index 9bc172b2..d52aa34d 100644 --- a/src/bijectors/ordered.jl +++ b/src/bijectors/ordered.jl @@ -18,7 +18,11 @@ This transformation is currently only supported for otherwise unconstrained dist """ function ordered(d::ContinuousMultivariateDistribution) if bijector(d) !== identity - throw(ArgumentError("ordered transform is currently only supported for unconstrained distributions.")) + throw( + ArgumentError( + "ordered transform is currently only supported for unconstrained distributions.", + ), + ) end return transformed(d, OrderedBijector()) end @@ -32,7 +36,7 @@ function _transform_ordered(y::AbstractVector) @assert !isempty(y) @inbounds x[1] = y[1] - @inbounds for i = 2:length(x) + @inbounds for i in 2:length(x) x[i] = x[i - 1] + exp(y[i]) end @@ -43,7 +47,7 @@ function _transform_ordered(y::AbstractMatrix) x = similar(y) @assert !isempty(y) - @inbounds for j = 1:size(x, 2), i = 1:size(x, 1) + @inbounds for j in 1:size(x, 2), i in 1:size(x, 1) if i == 1 x[i, j] = y[i, j] else @@ -60,7 +64,7 @@ function _transform_inverse_ordered(x::AbstractVector) @assert !isempty(y) @inbounds y[1] = x[1] - @inbounds for i = 2:length(y) + @inbounds for i in 2:length(y) y[i] = log(x[i] - x[i - 1]) end @@ -71,7 +75,7 @@ function _transform_inverse_ordered(x::AbstractMatrix) y = similar(x) @assert !isempty(y) - @inbounds for j = 1:size(y, 2), i = 1:size(y, 1) + @inbounds for j in 1:size(y, 2), i in 1:size(y, 1) if i == 1 y[i, j] = x[i, j] else @@ -83,4 +87,4 @@ function _transform_inverse_ordered(x::AbstractMatrix) end logabsdetjac(b::OrderedBijector, x::AbstractVector) = sum(@view(x[2:end])) -logabsdetjac(b::OrderedBijector, x::AbstractMatrix) = vec(sum(@view(x[2:end, :]); dims = 1)) +logabsdetjac(b::OrderedBijector, x::AbstractMatrix) = vec(sum(@view(x[2:end, :]); dims=1)) diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index bed6ee9a..4a68bdd8 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -9,7 +9,7 @@ function replace_diag(f, X) end transform(b::PDBijector, X::AbstractMatrix{<:Real}) = pd_link(X) function pd_link(X) - Y = lower(parent(cholesky(X; check = true).L)) + Y = lower(parent(cholesky(X; check=true).L)) return replace_diag(log, Y) end lower(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) @@ -22,7 +22,7 @@ getpd(X) = LowerTriangular(X) * LowerTriangular(X)' function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real}) T = eltype(X) - Xcf = cholesky(X, check = false) + Xcf = cholesky(X; check=false) if !issuccess(Xcf) Xcf = cholesky(X + max(eps(T), eps(T) * norm(X)) * I) end @@ -35,7 +35,7 @@ function logabsdetjac_pdbijector_chol(Xcf::Cholesky) UL = Xcf.UL d = size(UL, 1) z = sum(((d + 1):(-1):2) .* log.(diag(UL))) - return - (z + d * oftype(z, IrrationalConstants.logtwo)) + return -(z + d * oftype(z, IrrationalConstants.logtwo)) end # TODO: Implement explicitly. diff --git a/src/bijectors/permute.jl b/src/bijectors/permute.jl index a2b49aa7..ddfacaf8 100644 --- a/src/bijectors/permute.jl +++ b/src/bijectors/permute.jl @@ -99,7 +99,7 @@ function Permute(indices::AbstractVector{Int}) return Permute(A) end -function Permute(n::Int, indices::Pair{Int, Int}...) +function Permute(n::Int, indices::Pair{Int,Int}...) A = spdiagm(0 => ones(n)) dests = Set{Int}() @@ -111,7 +111,7 @@ function Permute(n::Int, indices::Pair{Int, Int}...) push!(dests, dst) push!(sources, src) - + A[dst, src] = 1.0 A[src, src] = 0.0 # <= remove `src => src` end @@ -122,7 +122,7 @@ function Permute(n::Int, indices::Pair{Int, Int}...) return Permute(A) end -function Permute(n::Int, indices::Pair{Vector{Int}, Vector{Int}}...) +function Permute(n::Int, indices::Pair{Vector{Int},Vector{Int}}...) A = spdiagm(0 => ones(n)) dests = Set{Int}() @@ -130,14 +130,14 @@ function Permute(n::Int, indices::Pair{Vector{Int}, Vector{Int}}...) for (srcs, dsts) in indices @argcheck length(srcs) == length(dsts) "$srcs => $dsts is not bijective" - + for (src, dst) in zip(srcs, dsts) @argcheck dst ∉ dests "$dst used more than once" @argcheck src ∉ sources "$src used more than once" push!(dests, dst) push!(sources, src) - + A[dst, src] = 1.0 A[src, src] = 0.0 # <= remove `src => src` end @@ -149,7 +149,6 @@ function Permute(n::Int, indices::Pair{Vector{Int}, Vector{Int}}...) return Permute(A) end - transform(b::Permute, x::AbstractVecOrMat) = b.A * x inverse(b::Permute) = Permute(transpose(b.A)) diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index be46de3a..4267dd61 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -10,7 +10,8 @@ # TODO: add docstring -struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Union{Real, AbstractVector{<:Real}}} <: Bijector +struct PlanarLayer{T1<:AbstractVector{<:Real},T2<:Union{Real,AbstractVector{<:Real}}} <: + Bijector w::T1 u::T1 b::T2 @@ -29,7 +30,9 @@ end # all fields are numerical parameters Functors.@functor PlanarLayer -Base.show(io::IO, b::PlanarLayer) = print(io, "PlanarLayer(w = $(b.w), u = $(b.u), b = $(b.b))") +function Base.show(io::IO, b::PlanarLayer) + return print(io, "PlanarLayer(w = $(b.w), u = $(b.u), b = $(b.b))") +end """ get_u_hat(u::AbstractVector{<:Real}, w::AbstractVector{<:Real}) @@ -73,7 +76,7 @@ function _transform(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) û, wT_û = get_u_hat(flow.u, w) wT_z = aT_b(w, z) transformed = z .+ û .* tanh.(wT_z .+ b) - return (transformed = transformed, wT_û = wT_û, wT_z = wT_z) + return (transformed=transformed, wT_û=wT_û, wT_z=wT_z) end transform(b::PlanarLayer, z) = _transform(b, z).transformed @@ -103,7 +106,7 @@ function with_logabsdet_jacobian(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) b = first(flow.b) log_det_jacobian = log1p.(wT_û .* abs2.(sech.(_vec(wT_z) .+ b))) - return (result = transformed, logabsdetjac = log_det_jacobian) + return (result=transformed, logabsdetjac=log_det_jacobian) end function transform(ib::Inverse{<:PlanarLayer}, y::AbstractVecOrMat{<:Real}) diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index d4156f01..d2ed127f 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -8,7 +8,9 @@ # RadialLayer # ############### -mutable struct RadialLayer{T1<:Union{Real, AbstractVector{<:Real}}, T2<:AbstractVector{<:Real}} <: Bijector +mutable struct RadialLayer{ + T1<:Union{Real,AbstractVector{<:Real}},T2<:AbstractVector{<:Real} +} <: Bijector α_::T1 β::T1 z_0::T2 @@ -27,7 +29,9 @@ end # all fields are numerical parameters Functors.@functor RadialLayer -Base.show(io::IO, b::RadialLayer) = print(io, "RadialLayer(α_ = $(b.α_), β = $(b.β), z_0 = $(b.z_0))") +function Base.show(io::IO, b::RadialLayer) + return print(io, "RadialLayer(α_ = $(b.α_), β = $(b.β), z_0 = $(b.z_0))") +end h(α, r) = 1 ./ (α .+ r) # for radial flow from eq(14) #dh(α, r) = .- (1 ./ (α .+ r)) .^ 2 # for radial flow; derivative of h() @@ -42,10 +46,10 @@ function _radial_transform(α_, β, z_0, z) if z isa AbstractVector r = norm(z .- z_0) else - r = vec(sqrt.(sum(abs2, z .- z_0; dims = 1))) + r = vec(sqrt.(sum(abs2, z .- z_0; dims=1))) end transformed = z .+ β_hat ./ (α .+ r') .* (z .- z_0) # from eq(14) - return (transformed = transformed, α = α, β_hat = β_hat, r = r) + return (transformed=transformed, α=α, β_hat=β_hat, r=r) end transform(b::RadialLayer, z::AbstractVector{<:Real}) = vec(_transform(b, z).transformed) @@ -62,10 +66,9 @@ function with_logabsdet_jacobian(flow::RadialLayer, z::AbstractVecOrMat) T = typeof(vec(transformed)) end log_det_jacobian::T = @. ( - (d - 1) * log(1 + β_hat * h_) - + log(1 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) + (d - 1) * log(1 + β_hat * h_) + log(1 + β_hat * h_ + β_hat * (-h_^2) * r) ) # from eq(14) - return (result = transformed, logabsdetjac = log_det_jacobian) + return (result=transformed, logabsdetjac=log_det_jacobian) end function transform(ib::Inverse{<:RadialLayer}, y::AbstractVector{<:Real}) @@ -125,4 +128,6 @@ function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) return r end -logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = last(with_logabsdet_jacobian(flow, x)) +function logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) + return last(with_logabsdet_jacobian(flow, x)) +end diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index 6c0dd601..03e10a8a 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -78,21 +78,17 @@ struct RationalQuadraticSpline{T} <: Bijector derivatives::T # K derivatives, with endpoints being ones function RationalQuadraticSpline( - widths::T, - heights::T, - derivatives::T + widths::T, heights::T, derivatives::T ) where {T<:AbstractVector} # TODO: add a `NoArgCheck` type and argument so we can circumvent if we want @assert length(widths) == length(heights) == length(derivatives) @assert all(derivatives .> 0) "derivatives need to be positive" - + return new{T}(widths, heights, derivatives) end function RationalQuadraticSpline( - widths::T, - heights::T, - derivatives::T + widths::T, heights::T, derivatives::T ) where {T<:AbstractMatrix} @assert size(widths, 2) == size(heights, 2) == size(derivatives, 2) @assert all(derivatives .> 0) "derivatives need to be positive" @@ -101,32 +97,28 @@ struct RationalQuadraticSpline{T} <: Bijector end function RationalQuadraticSpline( - widths::A, - heights::A, - derivatives::A, - B::T2 -) where {T1, T2, A <: AbstractVector{T1}} + widths::A, heights::A, derivatives::A, B::T2 +) where {T1,T2,A<:AbstractVector{T1}} return RationalQuadraticSpline( (cumsum(vcat([zero(T1)], LogExpFunctions.softmax(widths))) .- 0.5) * 2 * B, (cumsum(vcat([zero(T1)], LogExpFunctions.softmax(heights))) .- 0.5) * 2 * B, - vcat([one(T1)], LogExpFunctions.log1pexp.(derivatives), [one(T1)]) + vcat([one(T1)], LogExpFunctions.log1pexp.(derivatives), [one(T1)]), ) end function RationalQuadraticSpline( - widths::A, - heights::A, - derivatives::A, - B::T2 -) where {T1, T2, A <: AbstractMatrix{T1}} - ws = hcat(zeros(T1, size(widths, 1)), LogExpFunctions.softmax(widths; dims = 2)) - hs = hcat(zeros(T1, size(widths, 1)), LogExpFunctions.softmax(heights; dims = 2)) - ds = hcat(ones(T1, size(widths, 1)), LogExpFunctions.log1pexp.(derivatives), ones(T1, size(widths, 1))) + widths::A, heights::A, derivatives::A, B::T2 +) where {T1,T2,A<:AbstractMatrix{T1}} + ws = hcat(zeros(T1, size(widths, 1)), LogExpFunctions.softmax(widths; dims=2)) + hs = hcat(zeros(T1, size(widths, 1)), LogExpFunctions.softmax(heights; dims=2)) + ds = hcat( + ones(T1, size(widths, 1)), + LogExpFunctions.log1pexp.(derivatives), + ones(T1, size(widths, 1)), + ) return RationalQuadraticSpline( - (2 * B) .* (cumsum(ws; dims = 2) .- 0.5), - (2 * B) .* (cumsum(hs; dims = 2) .- 0.5), - ds + (2 * B) .* (cumsum(ws; dims=2) .- 0.5), (2 * B) .* (cumsum(hs; dims=2) .- 0.5), ds ) end @@ -171,7 +163,6 @@ function rqs_univariate(widths, heights, derivatives, x::Real) return g end - # univariate function transform(b::RationalQuadraticSpline{<:AbstractVector}, x::Real) return rqs_univariate(b.widths, b.heights, b.derivatives, x) @@ -180,7 +171,10 @@ end # multivariate # TODO: Improve. function transform(b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector) - return [rqs_univariate(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], x[i]) for i = 1:length(x)] + return [ + rqs_univariate(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], x[i]) for + i in 1:length(x) + ] end ########################## @@ -215,10 +209,10 @@ function rqs_univariate_inverse(widths, heights, derivatives, y::Real) # Eq. (26) a2 = Δy * d_k - (y - h_k) * ds # Eq. (27) - a3 = - s * (y - h_k) + a3 = -s * (y - h_k) # Eq. (24). There's a mistake in the paper; says `x` but should be `ξ` - numerator = - 2 * a3 + numerator = -2 * a3 denominator = (a2 + sqrt(a2^2 - 4 * a1 * a3)) ξ = numerator / denominator @@ -232,7 +226,10 @@ end # TODO: Improve. function transform(ib::Inverse{<:RationalQuadraticSpline}, y::AbstractVector) b = ib.orig - return [rqs_univariate_inverse(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], y[i]) for i = 1:length(y)] + return [ + rqs_univariate_inverse(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], y[i]) + for i in 1:length(y) + ] end ###################### @@ -241,7 +238,7 @@ end function rqs_logabsdetjac(widths, heights, derivatives, x::Real) T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(y)) K = length(widths) - 1 - + # Find which bin `x` is in k = searchsortedfirst(widths, x) - 1 @@ -259,19 +256,15 @@ function rqs_logabsdetjac(widths, heights, derivatives, x::Real) s = Δy / w ξ = (x - widths[k]) / w - numerator = s^2 * (derivatives[k + 1] * ξ^2 - + 2 * s * ξ * (1 - ξ) - + derivatives[k] * (1 - ξ)^2) + numerator = + s^2 * (derivatives[k + 1] * ξ^2 + 2 * s * ξ * (1 - ξ) + derivatives[k] * (1 - ξ)^2) denominator = s + (derivatives[k + 1] + derivatives[k] - 2 * s) * ξ * (1 - ξ) return log(numerator) - 2 * log(denominator) end function rqs_logabsdetjac( - widths::AbstractVector, - heights::AbstractVector, - derivatives::AbstractVector, - x::Real + widths::AbstractVector, heights::AbstractVector, derivatives::AbstractVector, x::Real ) T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(x)) @@ -310,8 +303,8 @@ end # TODO: Improve. function logabsdetjac(b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector) return sum([ - rqs_logabsdetjac(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], x[i]) - for i = 1:length(x) + rqs_logabsdetjac(b.widths[i, :], b.heights[i, :], b.derivatives[i, :], x[i]) for + i in 1:length(x) ]) end @@ -322,10 +315,7 @@ end # TODO: implement this for `x::AbstractVector` and similarily for 1-dimensional `b`, # and possibly inverses too? function rqs_forward( - widths::AbstractVector, - heights::AbstractVector, - derivatives::AbstractVector, - x::Real + widths::AbstractVector, heights::AbstractVector, derivatives::AbstractVector, x::Real ) T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(x)) @@ -370,6 +360,8 @@ function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector}, x return rqs_forward(b.widths, b.heights, b.derivatives, x) end -function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector) +function with_logabsdet_jacobian( + b::RationalQuadraticSpline{<:AbstractMatrix}, x::AbstractVector +) return transform(b, x), logabsdetjac(b, x) end diff --git a/src/bijectors/scale.jl b/src/bijectors/scale.jl index bff549e5..1a277e21 100644 --- a/src/bijectors/scale.jl +++ b/src/bijectors/scale.jl @@ -18,7 +18,7 @@ transform(ib::Inverse{<:Scale{<:AbstractMatrix}}, y::AbstractVecOrMat) = ib.orig # We're going to implement custom adjoint for this logabsdetjac(b::Scale, x::Real) = _logabsdetjac_scale(b.a, x, Val(0)) -function logabsdetjac(b::Scale, x::AbstractArray{<:Real, N}) where {N} +function logabsdetjac(b::Scale, x::AbstractArray{<:Real,N}) where {N} return _logabsdetjac_scale(b.a, x, Val(N)) end diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index 908815a6..c02620ef 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -14,7 +14,7 @@ inverse(b::Shift) = Shift(-b.a) transform(b::Shift, x) = b.a .+ x # FIXME: implement custom adjoint to ensure we don't get tracking -function logabsdetjac(b::Shift, x::Union{Real, AbstractArray{<:Real}}) +function logabsdetjac(b::Shift, x::Union{Real,AbstractArray{<:Real}}) return _logabsdetjac_shift(b.a, x) end diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index 1fbc28d2..153d09c7 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -9,7 +9,9 @@ with_logabsdet_jacobian(b::SimplexBijector, x) = transform(b, x), logabsdetjac(b transform(b::SimplexBijector, x) = _simplex_bijector(x, b) transform!(b::SimplexBijector, y, x) = _simplex_bijector!(y, x, b) -_simplex_bijector(x::AbstractArray, b::SimplexBijector) = _simplex_bijector!(similar(x), x, b) +function _simplex_bijector(x::AbstractArray, b::SimplexBijector) + return _simplex_bijector!(similar(x), x, b) +end # Vector implementation. function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where {proj} @@ -24,7 +26,7 @@ function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where sum_tmp += x[k - 1] # z ∈ [ϵ, 1-ϵ] # x[k] = 0 && sum_tmp = 1 -> z ≈ 1 - z = (x[k] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp) + z = (x[k] + ϵ) * (one(T) - 2ϵ) / ((one(T) + ϵ) - sum_tmp) y[k] = LogExpFunctions.logit(z) + log(T(K - k)) end @inbounds sum_tmp += x[K - 1] @@ -49,10 +51,10 @@ function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where Y[1, n] = LogExpFunctions.logit(z) + log(T(K - 1)) for k in 2:(K - 1) sum_tmp += X[k - 1, n] - z = (X[k, n] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp) + z = (X[k, n] + ϵ) * (one(T) - 2ϵ) / ((one(T) + ϵ) - sum_tmp) Y[k, n] = LogExpFunctions.logit(z) + log(T(K - k)) end - sum_tmp += X[K-1, n] + sum_tmp += X[K - 1, n] if proj Y[K, n] = zero(T) else @@ -64,11 +66,11 @@ function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where end # Inverse. -transform(ib::Inverse{<:SimplexBijector}, y::AbstractArray) = _simplex_inv_bijector(y, ib.orig) +function transform(ib::Inverse{<:SimplexBijector}, y::AbstractArray) + return _simplex_inv_bijector(y, ib.orig) +end function transform!( - ib::Inverse{<:SimplexBijector}, - x::AbstractArray{T}, - y::AbstractArray{T}, + ib::Inverse{<:SimplexBijector}, x::AbstractArray{T}, y::AbstractArray{T} ) where {T} return _simplex_inv_bijector!(x, y, ib.orig) end @@ -83,9 +85,9 @@ function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) @inbounds z = LogExpFunctions.logistic(y[1] - log(T(K - 1))) @inbounds x[1] = _clamp((z - ϵ) / (one(T) - 2ϵ), 0, 1) sum_tmp = zero(T) - @inbounds @simd for k = 2:(K - 1) + @inbounds @simd for k in 2:(K - 1) z = LogExpFunctions.logistic(y[k] - log(T(K - k))) - sum_tmp += x[k-1] + sum_tmp += x[k - 1] x[k] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1) end @inbounds sum_tmp += x[K - 1] @@ -94,7 +96,7 @@ function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) else x[K] = _clamp(one(T) - sum_tmp - y[K], 0, 1) end - + return x end @@ -125,14 +127,14 @@ end function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where {T} ϵ = _eps(T) lp = zero(T) - + K = length(x) sum_tmp = zero(eltype(x)) @inbounds z = x[1] lp += log(max(z, ϵ)) + log(max(one(T) - z, ϵ)) @inbounds @simd for k in 2:(K - 1) - sum_tmp += x[k-1] + sum_tmp += x[k - 1] z = x[k] / max(one(T) - sum_tmp, ϵ) lp += log(max(z, ϵ)) + log(max(one(T) - z, ϵ)) + log(max(one(T) - sum_tmp, ϵ)) end @@ -141,7 +143,7 @@ function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where {T} end function simplex_logabsdetjac_gradient(x::AbstractVector) T = eltype(x) - ϵ = _eps(T) + ϵ = _eps(T) K = length(x) g = similar(x) g .= 0 @@ -151,9 +153,9 @@ function simplex_logabsdetjac_gradient(x::AbstractVector) c1 = z >= ϵ zc = one(T) - z c2 = zc >= ϵ - g[1] = ifelse(c1 & c2, -1/z + 1/zc, ifelse(c1, -1/z, 1/zc)) + g[1] = ifelse(c1 & c2, -1 / z + 1 / zc, ifelse(c1, -1 / z, 1 / zc)) @inbounds @simd for k in 2:(K - 1) - sum_tmp += x[k-1] + sum_tmp += x[k - 1] temp = 1 / (1 - sum_tmp) c0 = temp >= ϵ z = ifelse(c0, x[k] * temp, x[k] / ϵ) @@ -162,14 +164,14 @@ function simplex_logabsdetjac_gradient(x::AbstractVector) c1 = z >= ϵ zc = one(T) - z c2 = zc >= ϵ - dldz = ifelse(c1 & c2, 1/z - 1/zc, ifelse(c1, 1/z, -1/zc)) + dldz = ifelse(c1 & c2, 1 / z - 1 / zc, ifelse(c1, 1 / z, -1 / zc)) dldx = dldz * dzdx - g[k] -= dldx - for i in 1:k-1 - dzdxp = ifelse(c0, x[k] * dzdx^2, zero(T)) - dldxp = dldz * dzdxp - ifelse(c0, temp, zero(T)) - g[i] -= dldxp - end + g[k] -= dldx + for i in 1:(k - 1) + dzdxp = ifelse(c0, x[k] * dzdx^2, zero(T)) + dldxp = dldz * dzdxp - ifelse(c0, temp, zero(T)) + g[i] -= dldxp + end end return g end @@ -182,29 +184,29 @@ function simplex_logabsdetjac_gradient(x::AbstractMatrix) g .= 0 @inbounds @simd for col in 1:size(x, 2) sum_tmp = zero(eltype(x)) - z = x[1,col] + z = x[1, col] #lp += log(z + ϵ) + log((one(T) + ϵ) - z) c1 = z >= ϵ zc = one(T) - z c2 = zc >= ϵ - g[1,col] = ifelse(c1 & c2, -1/z + 1/zc, ifelse(c1, -1/z, 1/zc)) + g[1, col] = ifelse(c1 & c2, -1 / z + 1 / zc, ifelse(c1, -1 / z, 1 / zc)) for k in 2:(K - 1) - sum_tmp += x[k-1,col] + sum_tmp += x[k - 1, col] temp = 1 / (1 - sum_tmp) c0 = temp >= ϵ - z = ifelse(c0, x[k,col] * temp, x[k,col] / ϵ) + z = ifelse(c0, x[k, col] * temp, x[k, col] / ϵ) #lp += log(z + ϵ) + log((one(T) + ϵ) - z) + log(temp) dzdx = ifelse(c0, temp, one(T)) c1 = z >= ϵ zc = one(T) - z c2 = zc >= ϵ - dldz = ifelse(c1 & c2, 1/z - 1/zc, ifelse(c1, 1/z, -1/zc)) + dldz = ifelse(c1 & c2, 1 / z - 1 / zc, ifelse(c1, 1 / z, -1 / zc)) dldx = dldz * dzdx - g[k,col] -= dldx - for i in 1:k-1 - dzdxp = ifelse(c0, x[k,col] * dzdx^2, zero(T)) + g[k, col] -= dldx + for i in 1:(k - 1) + dzdxp = ifelse(c0, x[k, col] * dzdx^2, zero(T)) dldxp = dldz * dzdxp - ifelse(c0, temp, zero(T)) - g[i,col] -= dldxp + g[i, col] -= dldxp end end end @@ -212,9 +214,8 @@ function simplex_logabsdetjac_gradient(x::AbstractMatrix) end function simplex_link_jacobian( - x::AbstractVector{T}, - ::Val{proj}=Val(true), -) where {T<:Real, proj} + x::AbstractVector{T}, ::Val{proj}=Val(true) +) where {T<:Real,proj} K = length(x) @assert K > 1 "x needs to be of length greater than 1" dydxt = similar(x, length(x), length(x)) @@ -223,26 +224,28 @@ function simplex_link_jacobian( sum_tmp = zero(T) @inbounds z = x[1] * (one(T) - 2ϵ) + ϵ # z ∈ [ϵ, 1-ϵ] - @inbounds dydxt[1,1] = (1/z + 1/(1-z)) * (one(T) - 2ϵ) + @inbounds dydxt[1, 1] = (1 / z + 1 / (1 - z)) * (one(T) - 2ϵ) @inbounds @simd for k in 2:(K - 1) sum_tmp += x[k - 1] # z ∈ [ϵ, 1-ϵ] # x[k] = 0 && sum_tmp = 1 -> z ≈ 1 - z = (x[k] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp) - dydxt[k,k] = (1/z + 1/(1-z)) * (one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp) - for i in 1:k-1 - dydxt[i,k] = (1/z + 1/(1-z)) * (x[k] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp)^2 + z = (x[k] + ϵ) * (one(T) - 2ϵ) / ((one(T) + ϵ) - sum_tmp) + dydxt[k, k] = (1 / z + 1 / (1 - z)) * (one(T) - 2ϵ) / ((one(T) + ϵ) - sum_tmp) + for i in 1:(k - 1) + dydxt[i, k] = + (1 / z + 1 / (1 - z)) * (x[k] + ϵ) * (one(T) - 2ϵ) / + ((one(T) + ϵ) - sum_tmp)^2 end end @inbounds sum_tmp += x[K - 1] @inbounds if !proj @simd for i in 1:K - dydxt[i,K] = -1 + dydxt[i, K] = -1 end end return UpperTriangular(dydxt)' end -function jacobian(b::SimplexBijector{proj}, x::AbstractVector{T}) where {proj, T} +function jacobian(b::SimplexBijector{proj}, x::AbstractVector{T}) where {proj,T} return simplex_link_jacobian(x, Val(proj)) end @@ -313,9 +316,8 @@ end =# function simplex_invlink_jacobian( - y::AbstractVector{T}, - ::Val{proj}=Val(true), -) where {T<:Real, proj} + y::AbstractVector{T}, ::Val{proj}=Val(true) +) where {T<:Real,proj} K = length(y) @assert K > 1 "x needs to be of length greater than 1" dxdy = similar(y, length(y), length(y)) @@ -326,45 +328,45 @@ function simplex_invlink_jacobian( unclamped_x = (z - ϵ) / (one(T) - 2ϵ) clamped_x = _clamp(unclamped_x, 0, 1) @inbounds if unclamped_x == clamped_x - dxdy[1,1] = z * (1 - z) / (one(T) - 2ϵ) + dxdy[1, 1] = z * (1 - z) / (one(T) - 2ϵ) end sum_tmp = zero(T) - @inbounds for k = 2:(K - 1) + @inbounds for k in 2:(K - 1) z = LogExpFunctions.logistic(y[k] - log(T(K - k))) sum_tmp += clamped_x unclamped_x = ((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ clamped_x = _clamp(unclamped_x, 0, 1) if unclamped_x == clamped_x - dxdy[k,k] = z * (1 - z) * ((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) - for i in 1:k-1 - for j in i:k-1 - dxdy[k,i] += -dxdy[j,i] * z / (one(T) - 2ϵ) + dxdy[k, k] = z * (1 - z) * ((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) + for i in 1:(k - 1) + for j in i:(k - 1) + dxdy[k, i] += -dxdy[j, i] * z / (one(T) - 2ϵ) end end end end @inbounds sum_tmp += clamped_x @inbounds if proj - unclamped_x = one(T) - sum_tmp + unclamped_x = one(T) - sum_tmp clamped_x = _clamp(unclamped_x, 0, 1) else - unclamped_x = one(T) - sum_tmp - y[K] + unclamped_x = one(T) - sum_tmp - y[K] clamped_x = _clamp(unclamped_x, 0, 1) if unclamped_x == clamped_x - dxdy[K,K] = -1 + dxdy[K, K] = -1 end end @inbounds if unclamped_x == clamped_x - for i in 1:K-1 - @simd for j in i:K-1 - dxdy[K,i] += -dxdy[j,i] + for i in 1:(K - 1) + @simd for j in i:(K - 1) + dxdy[K, i] += -dxdy[j, i] end end end return LowerTriangular(dxdy) end # jacobian -function jacobian(ib::Inverse{<:SimplexBijector{proj}}, y::AbstractVector{T}) where {proj, T} +function jacobian(ib::Inverse{<:SimplexBijector{proj}}, y::AbstractVector{T}) where {proj,T} return simplex_invlink_jacobian(y, Val(proj)) end diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 9f0e73f4..4f0596cb 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -21,7 +21,7 @@ b = stack(b1, b2) b([0.0, 1.0]) == [b1(0.0), 1.0] # => true ``` """ -struct Stacked{Bs, Rs<:Union{Tuple,AbstractArray}} <: Transform +struct Stacked{Bs,Rs<:Union{Tuple,AbstractArray}} <: Transform bs::Bs ranges::Rs end @@ -48,27 +48,26 @@ isclosedform(b::Stacked) = all(isclosedform, b.bs) isinvertible(b::Stacked) = all(isinvertible, b.bs) - # For some reason `inverse.(sb.bs)` was unstable... This works though. inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) # map is not type stable for many stacked bijectors as a large tuple # hence the generated function -@generated function inverse(sb::Stacked{A}) where {A <: Tuple} +@generated function inverse(sb::Stacked{A}) where {A<:Tuple} exprs = [] - for i = 1:length(A.parameters) + for i in 1:length(A.parameters) push!(exprs, :(inverse(sb.bs[$i]))) end - :(Stacked(($(exprs...), ), sb.ranges)) + return :(Stacked(($(exprs...),), sb.ranges)) end -@generated function _transform(x, rs::NTuple{N, UnitRange{Int}}, bs...) where N +@generated function _transform(x, rs::NTuple{N,UnitRange{Int}}, bs...) where {N} exprs = [] - for i = 1:N + for i in 1:N push!(exprs, :(bs[$i](x[rs[$i]]))) end return :(vcat($(exprs...))) end -function _transform(x, rs::NTuple{1, UnitRange{Int}}, b) +function _transform(x, rs::NTuple{1,UnitRange{Int}}, b) @assert rs[1] == 1:length(x) return b(x) end @@ -89,10 +88,7 @@ function transform(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real}) return y end -function logabsdetjac( - b::Stacked, - x::AbstractVector{<:Real} -) +function logabsdetjac(b::Stacked, x::AbstractVector{<:Real}) N = length(b.bs) init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]])) @@ -106,8 +102,7 @@ function logabsdetjac( end function logabsdetjac( - b::Stacked{<:NTuple{N, Any}, <:NTuple{N, Any}}, - x::AbstractVector{<:Real} + b::Stacked{<:NTuple{N,Any},<:NTuple{N,Any}}, x::AbstractVector{<:Real} ) where {N} init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]])) @@ -129,7 +124,9 @@ end # logjac += sum(_logjac) # return (vcat(y_1, y_2), logjac) # end -@generated function with_logabsdet_jacobian(b::Stacked{<:NTuple{N, Any}, <:NTuple{N, Any}}, x::AbstractVector) where {N} +@generated function with_logabsdet_jacobian( + b::Stacked{<:NTuple{N,Any},<:NTuple{N,Any}}, x::AbstractVector +) where {N} expr = Expr(:block) y_names = [] @@ -137,9 +134,12 @@ end # TODO: drop the `sum` when we have dimensionality push!(expr.args, :(logjac = sum(_logjac))) push!(y_names, :y_1) - for i = 2:N + for i in 2:N y_name = Symbol("y_$i") - push!(expr.args, :(($y_name, _logjac) = with_logabsdet_jacobian(b.bs[$i], x[b.ranges[$i]]))) + push!( + expr.args, + :(($y_name, _logjac) = with_logabsdet_jacobian(b.bs[$i], x[b.ranges[$i]])), + ) # TODO: drop the `sum` when we have dimensionality push!(expr.args, :(logjac += sum(_logjac))) diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index 52c09f4d..14ea0a2e 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -1,7 +1,7 @@ ####################################################### # Constrained to unconstrained distribution bijectors # ####################################################### -struct TruncatedBijector{T1, T2} <: Bijector +struct TruncatedBijector{T1,T2} <: Bijector lb::T1 ub::T2 end @@ -15,7 +15,7 @@ end function transform(b::TruncatedBijector, x) a, b = b.lb, b.ub return truncated_link.(_clamp.(x, a, b), a, b) -end +end function truncated_link(x::Real, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) @@ -56,11 +56,11 @@ end function truncated_logabsdetjac(x, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded - return - log((x - a) * (b - x) / (b - a)) + return -log((x - a) * (b - x) / (b - a)) elseif lowerbounded - return - log(x - a) + return -log(x - a) elseif upperbounded - return - log(b - x) + return -log(b - x) else return zero(x) end diff --git a/src/chainrules.jl b/src/chainrules.jl index 45cacf82..26826e5a 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -1,10 +1,8 @@ # differentation rule for the iterative algorithm in the inverse of `PlanarLayer` ChainRulesCore.@scalar_rule( find_alpha(wt_y::Real, wt_u_hat::Real, b::Real), - @setup( - x = inv(1 + wt_u_hat * sech(Ω + b)^2), - ), - (x, - tanh(Ω + b) * x, x - 1), + @setup(x = inv(1 + wt_u_hat * sech(Ω + b)^2),), + (x, -tanh(Ω + b) * x, x - 1), ) function ChainRulesCore.rrule(::typeof(combine), m::PartitionMask, x_1, x_2, x_3) @@ -15,7 +13,9 @@ function ChainRulesCore.rrule(::typeof(combine), m::PartitionMask, x_1, x_2, x_3 function combine_pullback(ΔΩ) Δ = ChainRulesCore.unthunk(ΔΩ) dx_1, dx_2, dx_3 = partition(m, Δ) - return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), proj_x_1(dx_1), proj_x_2(dx_2), proj_x_3(dx_3) + return ChainRulesCore.NoTangent(), + ChainRulesCore.NoTangent(), proj_x_1(dx_1), proj_x_2(dx_2), + proj_x_3(dx_3) end return combine(m, x_1, x_2, x_3), combine_pullback @@ -81,7 +81,7 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractV project_x = ChainRulesCore.ProjectTo(x) r = similar(x) - @inbounds for i = 1:length(r) + @inbounds for i in 1:length(r) if i == 1 r[i] = 1 else @@ -95,7 +95,7 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractV @assert length(Δ_new) == length(Δ) n = length(Δ_new) - @inbounds for j = 1:n - 1 + @inbounds for j in 1:(n - 1) Δ_new[j] = (Δ[j] / r[j]) - (Δ[j + 1] / r[j + 1]) end @inbounds Δ_new[n] = Δ[n] / r[n] @@ -105,7 +105,7 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractV y = similar(x) @inbounds y[1] = x[1] - @inbounds for i = 2:length(x) + @inbounds for i in 2:length(x) y[i] = log(r[i]) end @@ -117,7 +117,7 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM project_x = ChainRulesCore.ProjectTo(x) r = similar(x) - @inbounds for j = 1:size(x, 2), i = 1:size(x, 1) + @inbounds for j in 1:size(x, 2), i in 1:size(x, 1) if i == 1 r[i, j] = 1 else @@ -131,11 +131,11 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM n = size(Δ, 1) @assert size(Δ) == size(Δ_new) - @inbounds for j = 1:size(Δ_new, 2), i = 1:n - 1 + @inbounds for j in 1:size(Δ_new, 2), i in 1:(n - 1) Δ_new[i, j] = (Δ[i, j] / r[i, j]) - (Δ[i + 1, j] / r[i + 1, j]) end - @inbounds for j = 1:size(Δ_new, 2) + @inbounds for j in 1:size(Δ_new, 2) Δ_new[n, j] = Δ[n, j] / r[n, j] end @@ -145,7 +145,7 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM # Compute primal here so we can make use of the already # computed `r`. y = similar(x) - @inbounds for j = 1:size(x, 2), i = 1:size(x, 1) + @inbounds for j in 1:size(x, 2), i in 1:size(x, 1) if i == 1 y[i, j] = x[i, j] else @@ -157,4 +157,4 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM end # Fixes Zygote's issues with `@debug` -ChainRulesCore.@non_differentiable _debug(::Any) \ No newline at end of file +ChainRulesCore.@non_differentiable _debug(::Any) diff --git a/src/compat/distributionsad.jl b/src/compat/distributionsad.jl index 99dc81d5..af5397f9 100644 --- a/src/compat/distributionsad.jl +++ b/src/compat/distributionsad.jl @@ -1,7 +1,15 @@ -using .DistributionsAD: TuringDirichlet, TuringWishart, TuringInverseWishart, - FillVectorOfUnivariate, FillMatrixOfUnivariate, - MatrixOfUnivariate, FillVectorOfMultivariate, VectorOfMultivariate, - TuringScalMvNormal, TuringDiagMvNormal, TuringDenseMvNormal +using .DistributionsAD: + TuringDirichlet, + TuringWishart, + TuringInverseWishart, + FillVectorOfUnivariate, + FillMatrixOfUnivariate, + MatrixOfUnivariate, + FillVectorOfMultivariate, + VectorOfMultivariate, + TuringScalMvNormal, + TuringDiagMvNormal, + TuringDenseMvNormal using Distributions: AbstractMvLogNormal # Bijectors @@ -20,49 +28,41 @@ bijector(d::MatrixOfUnivariate{Continuous}) = TruncatedBijector(_minmax(d.dists) bijector(d::VectorOfMultivariate{Discrete}) = identity for T in (:VectorOfMultivariate, :FillVectorOfMultivariate) @eval begin - bijector(d::$T{Continuous, <:MvNormal}) = identity - bijector(d::$T{Continuous, <:TuringScalMvNormal}) = identity - bijector(d::$T{Continuous, <:TuringDiagMvNormal}) = identity - bijector(d::$T{Continuous, <:TuringDenseMvNormal}) = identity - bijector(d::$T{Continuous, <:MvNormalCanon}) = identity - bijector(d::$T{Continuous, <:AbstractMvLogNormal}) = Log() - bijector(d::$T{Continuous, <:SimplexDistribution}) = SimplexBijector() - bijector(d::$T{Continuous, <:TuringDirichlet}) = SimplexBijector() + bijector(d::$T{Continuous,<:MvNormal}) = identity + bijector(d::$T{Continuous,<:TuringScalMvNormal}) = identity + bijector(d::$T{Continuous,<:TuringDiagMvNormal}) = identity + bijector(d::$T{Continuous,<:TuringDenseMvNormal}) = identity + bijector(d::$T{Continuous,<:MvNormalCanon}) = identity + bijector(d::$T{Continuous,<:AbstractMvLogNormal}) = Log() + bijector(d::$T{Continuous,<:SimplexDistribution}) = SimplexBijector() + bijector(d::$T{Continuous,<:TuringDirichlet}) = SimplexBijector() end end bijector(d::FillVectorOfMultivariate{Continuous}) = columnwise(bijector(d.dists.value)) -isdirichlet(::VectorOfMultivariate{Continuous, <:Dirichlet}) = true -isdirichlet(::VectorOfMultivariate{Continuous, <:TuringDirichlet}) = true +isdirichlet(::VectorOfMultivariate{Continuous,<:Dirichlet}) = true +isdirichlet(::VectorOfMultivariate{Continuous,<:TuringDirichlet}) = true isdirichlet(::TuringDirichlet) = true function link( - d::TuringDirichlet, - x::AbstractVecOrMat{<:Real}, - ::Val{proj}=Val(true), + d::TuringDirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true) ) where {proj} return SimplexBijector{proj}()(x) end function link_jacobian( - d::TuringDirichlet, - x::AbstractVector{<:Real}, - ::Val{proj}=Val(true), + d::TuringDirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true) ) where {proj} return jacobian(SimplexBijector{proj}(), x) end function invlink( - d::TuringDirichlet, - y::AbstractVecOrMat{<:Real}, - ::Val{proj}=Val(true), + d::TuringDirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true) ) where {proj} return inverse(SimplexBijector{proj}())(y) end function invlink_jacobian( - d::TuringDirichlet, - y::AbstractVector{<:Real}, - ::Val{proj}=Val(true), + d::TuringDirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true) ) where {proj} return jacobian(inverse(SimplexBijector{proj}()), y) end diff --git a/src/compat/forwarddiff.jl b/src/compat/forwarddiff.jl index a947cf18..1b51bb0b 100644 --- a/src/compat/forwarddiff.jl +++ b/src/compat/forwarddiff.jl @@ -1,7 +1,7 @@ import .ForwardDiff -_eps(::Type{<:ForwardDiff.Dual{<:Any, Real}}) = _eps(Real) -_eps(::Type{<:ForwardDiff.Dual{<:Any, <:Integer}}) = _eps(Real) +_eps(::Type{<:ForwardDiff.Dual{<:Any,Real}}) = _eps(Real) +_eps(::Type{<:ForwardDiff.Dual{<:Any,<:Integer}}) = _eps(Real) # Define forward-mode rule for ForwardDiff and don't trust support for ForwardDiff in Roots # https://github.com/JuliaMath/Roots.jl/issues/314 diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index c498205a..47c58054 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -1,17 +1,34 @@ module ReverseDiffCompat -using ..ReverseDiff: ReverseDiff, @grad, value, track, TrackedReal, TrackedVector, - TrackedMatrix +using ..ReverseDiff: + ReverseDiff, @grad, value, track, TrackedReal, TrackedVector, TrackedMatrix using Requires, LinearAlgebra -using ..Bijectors: Elementwise, SimplexBijector, maphcat, simplex_link_jacobian, - simplex_invlink_jacobian, simplex_logabsdetjac_gradient, Inverse -import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector, - _simplex_inv_bijector, replace_diag, jacobian, getpd, lower, - _inv_link_chol_lkj, _link_chol_lkj, _transform_ordered, _transform_inverse_ordered, +using ..Bijectors: + Elementwise, + SimplexBijector, + maphcat, + simplex_link_jacobian, + simplex_invlink_jacobian, + simplex_logabsdetjac_gradient, + Inverse +import ..Bijectors: + _eps, + logabsdetjac, + _logabsdetjac_scale, + _simplex_bijector, + _simplex_inv_bijector, + replace_diag, + jacobian, + getpd, + lower, + _inv_link_chol_lkj, + _link_chol_lkj, + _transform_ordered, + _transform_inverse_ordered, find_alpha -import ChainRulesCore +using ChainRulesCore: ChainRulesCore using Compat: eachcol using Distributions: LocationScale @@ -34,7 +51,9 @@ function Base.maximum(d::LocationScale{<:TrackedReal}) end end -logabsdetjac(b::Elementwise{typeof(log)}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) +function logabsdetjac(b::Elementwise{typeof(log)}, x::Union{TrackedVector,TrackedMatrix}) + return track(logabsdetjac, b, x) +end @grad function logabsdetjac(b::Elementwise{typeof(log)}, x::AbstractVector) return -sum(log, value(x)), Δ -> (nothing, -Δ ./ value(x)) end @@ -42,7 +61,8 @@ function _logabsdetjac_scale(a::TrackedReal, x::Real, ::Val{0}) return track(_logabsdetjac_scale, a, value(x), Val(0)) end @grad function _logabsdetjac_scale(a::Real, x::Real, v::Val{0}) - return _logabsdetjac_scale(value(a), value(x), Val(0)), Δ -> (inv(value(a)) .* Δ, nothing, nothing) + return _logabsdetjac_scale(value(a), value(x), Val(0)), + Δ -> (inv(value(a)) .* Δ, nothing, nothing) end # Need to treat `AbstractVector` and `AbstractMatrix` separately due to ambiguity errors function _logabsdetjac_scale(a::TrackedReal, x::AbstractVector, ::Val{0}) @@ -51,7 +71,8 @@ end @grad function _logabsdetjac_scale(a::Real, x::AbstractVector, v::Val{0}) da = value(a) J = fill(inv.(da), length(x)) - return _logabsdetjac_scale(da, value(x), Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing) + return _logabsdetjac_scale(da, value(x), Val(0)), + Δ -> (transpose(J) * Δ, nothing, nothing) end function _logabsdetjac_scale(a::TrackedReal, x::AbstractMatrix, ::Val{0}) return track(_logabsdetjac_scale, a, value(x), Val(0)) @@ -59,7 +80,8 @@ end @grad function _logabsdetjac_scale(a::Real, x::AbstractMatrix, v::Val{0}) da = value(a) J = fill(size(x, 1) / da, size(x, 2)) - return _logabsdetjac_scale(da, value(x), Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing) + return _logabsdetjac_scale(da, value(x), Val(0)), + Δ -> (transpose(J) * Δ, nothing, nothing) end # adjoints for 1-dim and 2-dim `Scale` using `AbstractVector` function _logabsdetjac_scale(a::TrackedVector, x::AbstractVector, ::Val{1}) @@ -82,7 +104,7 @@ end Jᵀ = repeat(inv.(da), 1, size(x, 2)) return _logabsdetjac_scale(da, value(x), Val(1)), Δ -> (Jᵀ * Δ, nothing, nothing) end -function _simplex_bijector(X::Union{TrackedVector, TrackedMatrix}, b::SimplexBijector) +function _simplex_bijector(X::Union{TrackedVector,TrackedMatrix}, b::SimplexBijector) return track(_simplex_bijector, X, b) end @grad function _simplex_bijector(Y::AbstractVector, b::SimplexBijector) @@ -90,7 +112,7 @@ end return _simplex_bijector(Yd, b), Δ -> (simplex_link_jacobian(Yd)' * Δ, nothing) end -function _simplex_inv_bijector(X::Union{TrackedVector, TrackedMatrix}, b::SimplexBijector) +function _simplex_inv_bijector(X::Union{TrackedVector,TrackedMatrix}, b::SimplexBijector) return track(_simplex_inv_bijector, X, b) end @grad function _simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector) @@ -99,10 +121,12 @@ end end @grad function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector) Yd = value(Y) - return _simplex_inv_bijector(Yd, b), Δ -> begin + return _simplex_inv_bijector(Yd, b), + Δ -> begin maphcat(eachcol(Yd), eachcol(Δ)) do c1, c2 simplex_invlink_jacobian(c1)' * c2 - end, nothing + end, + nothing end end @@ -112,7 +136,7 @@ replace_diag(::typeof(log), X::TrackedMatrix) = track(replace_diag, log, X) f(i, j) = i == j ? log(Xd[i, j]) : Xd[i, j] out = f.(1:size(Xd, 1), (1:size(Xd, 2))') out, ∇ -> begin - g(i, j) = i == j ? ∇[i, j]/Xd[i, j] : ∇[i, j] + g(i, j) = i == j ? ∇[i, j] / Xd[i, j] : ∇[i, j] return (nothing, g.(1:size(Xd, 1), (1:size(Xd, 2))')) end end @@ -123,12 +147,14 @@ replace_diag(::typeof(exp), X::TrackedMatrix) = track(replace_diag, exp, X) f(i, j) = ifelse(i == j, exp(Xd[i, j]), Xd[i, j]) out = f.(1:size(Xd, 1), (1:size(Xd, 2))') out, ∇ -> begin - g(i, j) = ifelse(i == j, ∇[i, j]*exp(Xd[i, j]), ∇[i, j]) + g(i, j) = ifelse(i == j, ∇[i, j] * exp(Xd[i, j]), ∇[i, j]) return (nothing, g.(1:size(Xd, 1), (1:size(Xd, 2))')) end end -logabsdetjac(b::SimplexBijector, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) +function logabsdetjac(b::SimplexBijector, x::Union{TrackedVector,TrackedMatrix}) + return track(logabsdetjac, b, x) +end @grad function logabsdetjac(b::SimplexBijector, x::AbstractVector) xd = value(x) return logabsdetjac(b, xd), Δ -> begin @@ -139,7 +165,8 @@ end getpd(X::TrackedMatrix) = track(getpd, X) @grad function getpd(X::AbstractMatrix) Xd = value(X) - return LowerTriangular(Xd) * LowerTriangular(Xd)', Δ -> begin + return LowerTriangular(Xd) * LowerTriangular(Xd)', + Δ -> begin Xl = LowerTriangular(Xd) return (LowerTriangular(Δ' * Xl + Δ * Xl),) end @@ -157,7 +184,7 @@ end α = find_alpha(value(wt_y), value(wt_u_hat), value(b)) ∂wt_y = inv(1 + wt_u_hat * sech(α + b)^2) - ∂wt_u_hat = - tanh(α + b) * ∂wt_y + ∂wt_u_hat = -tanh(α + b) * ∂wt_y ∂b = ∂wt_y - 1 find_alpha_pullback(Δ::Real) = (Δ * ∂wt_y, Δ * ∂wt_u_hat, Δ * ∂b) @@ -165,7 +192,7 @@ end end # `OrderedBijector` -function _transform_ordered(y::Union{TrackedVector, TrackedMatrix}) +function _transform_ordered(y::Union{TrackedVector,TrackedMatrix}) return track(_transform_ordered, y) end @grad function _transform_ordered(y::AbstractVecOrMat) @@ -173,7 +200,7 @@ end return x, (wrap_chainrules_output ∘ Base.tail ∘ dx) end -function _transform_inverse_ordered(x::Union{TrackedVector, TrackedMatrix}) +function _transform_inverse_ordered(x::Union{TrackedVector,TrackedMatrix}) return track(_transform_inverse_ordered, x) end @grad function _transform_inverse_ordered(x::AbstractVecOrMat) diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index 1166a29e..efa45650 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -1,21 +1,22 @@ module TrackerCompat -using ..Tracker: Tracker, - TrackedReal, - TrackedVector, - TrackedMatrix, - TrackedArray, - TrackedVecOrMat, - @grad, - track, - data, - param +using ..Tracker: + Tracker, + TrackedReal, + TrackedVector, + TrackedMatrix, + TrackedArray, + TrackedVecOrMat, + @grad, + track, + data, + param import ..Bijectors using ..Bijectors: Elementwise, SimplexBijector, Inverse, Stacked -import ChainRulesCore -import LogExpFunctions +using ChainRulesCore: ChainRulesCore +using LogExpFunctions: LogExpFunctions using Compat: eachcol using LinearAlgebra @@ -23,11 +24,9 @@ using Distributions: LocationScale Bijectors.maporbroadcast(f, x::TrackedArray...) = f.(x...) function Bijectors.maporbroadcast( - f, - x1::TrackedArray{T, N}, - x::AbstractArray{<:TrackedReal}..., -) where {T, N} - return f.(convert(Array{TrackedReal{T}, N}, x1), x...) + f, x1::TrackedArray{T,N}, x::AbstractArray{<:TrackedReal}... +) where {T,N} + return f.(convert(Array{TrackedReal{T},N}, x1), x...) end Bijectors._eps(::Type{<:TrackedReal{T}}) where {T} = Bijectors._eps(T) @@ -56,16 +55,12 @@ function Bijectors._logabsdetjac_shift(a::TrackedReal, x::AbstractVector{<:Real} return tracker_shift_logabsdetjac(a, x, Val(0)) end function Bijectors._logabsdetjac_shift( - a::Union{TrackedReal, TrackedVector{<:Real}}, - x::AbstractVector{<:Real}, - ::Val{1} + a::Union{TrackedReal,TrackedVector{<:Real}}, x::AbstractVector{<:Real}, ::Val{1} ) return tracker_shift_logabsdetjac(a, x, Val(1)) end function Bijectors._logabsdetjac_shift( - a::Union{TrackedReal, TrackedVector{<:Real}}, - x::AbstractMatrix{<:Real}, - ::Val{1} + a::Union{TrackedReal,TrackedVector{<:Real}}, x::AbstractMatrix{<:Real}, ::Val{1} ) return tracker_shift_logabsdetjac(a, x, Val(1)) end @@ -88,7 +83,8 @@ function Bijectors._logabsdetjac_scale(a::TrackedReal, x::Real, ::Val{0}) return track(Bijectors._logabsdetjac_scale, a, data(x), Val(0)) end @grad function Bijectors._logabsdetjac_scale(a::Real, x::Real, ::Val{0}) - return Bijectors._logabsdetjac_scale(data(a), data(x), Val(0)), Δ -> (inv(data(a)) .* Δ, nothing, nothing) + return Bijectors._logabsdetjac_scale(data(a), data(x), Val(0)), + Δ -> (inv(data(a)) .* Δ, nothing, nothing) end # Need to treat `AbstractVector` and `AbstractMatrix` separately due to ambiguity errors function Bijectors._logabsdetjac_scale(a::TrackedReal, x::AbstractVector, ::Val{0}) @@ -97,7 +93,8 @@ end @grad function Bijectors._logabsdetjac_scale(a::Real, x::AbstractVector, ::Val{0}) da = data(a) J = fill(inv.(da), length(x)) - return Bijectors._logabsdetjac_scale(da, data(x), Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing) + return Bijectors._logabsdetjac_scale(da, data(x), Val(0)), + Δ -> (transpose(J) * Δ, nothing, nothing) end function Bijectors._logabsdetjac_scale(a::TrackedReal, x::AbstractMatrix, ::Val{0}) return track(Bijectors._logabsdetjac_scale, a, data(x), Val(0)) @@ -105,7 +102,8 @@ end @grad function Bijectors._logabsdetjac_scale(a::Real, x::AbstractMatrix, ::Val{0}) da = data(a) J = fill(size(x, 1) / da, size(x, 2)) - return Bijectors._logabsdetjac_scale(da, data(x), Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing) + return Bijectors._logabsdetjac_scale(da, data(x), Val(0)), + Δ -> (transpose(J) * Δ, nothing, nothing) end # adjoints for 1-dim and 2-dim `Scale` using `AbstractVector` function Bijectors._logabsdetjac_scale(a::TrackedVector, x::AbstractVector, ::Val{1}) @@ -118,7 +116,8 @@ end # = (1 / aᵢ) da = data(a) J = inv.(da) - return Bijectors._logabsdetjac_scale(da, data(x), Val(1)), Δ -> (J .* Δ, nothing, nothing) + return Bijectors._logabsdetjac_scale(da, data(x), Val(1)), + Δ -> (J .* Δ, nothing, nothing) end function Bijectors._logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Val{1}) return track(Bijectors._logabsdetjac_scale, a, data(x), Val(1)) @@ -126,7 +125,8 @@ end @grad function Bijectors._logabsdetjac_scale(a::TrackedVector, x::AbstractMatrix, ::Val{1}) da = data(a) Jᵀ = repeat(inv.(da), 1, size(x, 2)) - return Bijectors._logabsdetjac_scale(da, data(x), Val(1)), Δ -> (Jᵀ * Δ, nothing, nothing) + return Bijectors._logabsdetjac_scale(da, data(x), Val(1)), + Δ -> (Jᵀ * Δ, nothing, nothing) end # TODO: implement analytical gradient for scaling a vector using a matrix # function _logabsdetjac_scale(a::TrackedMatrix, x::AbstractVector, ::Val{1}) @@ -145,39 +145,48 @@ function Bijectors._simplex_inv_bijector(Y::TrackedVecOrMat, b::SimplexBijector) end @grad function Bijectors._simplex_bijector(X::AbstractVector, b::SimplexBijector) Xd = data(X) - return Bijectors._simplex_bijector(Xd, b), Δ -> (Bijectors.simplex_link_jacobian(Xd)' * Δ, nothing) + return Bijectors._simplex_bijector(Xd, b), + Δ -> (Bijectors.simplex_link_jacobian(Xd)' * Δ, nothing) end @grad function Bijectors._simplex_inv_bijector(Y::AbstractVector, b::SimplexBijector) Yd = data(Y) - return Bijectors._simplex_inv_bijector(Yd, b), Δ -> (Bijectors.simplex_invlink_jacobian(Yd)' * Δ, nothing) + return Bijectors._simplex_inv_bijector(Yd, b), + Δ -> (Bijectors.simplex_invlink_jacobian(Yd)' * Δ, nothing) end -Bijectors.replace_diag(::typeof(log), X::TrackedMatrix) = track(Bijectors.replace_diag, log, X) +function Bijectors.replace_diag(::typeof(log), X::TrackedMatrix) + return track(Bijectors.replace_diag, log, X) +end @grad function Bijectors.replace_diag(::typeof(log), X) Xd = data(X) f(i, j) = i == j ? log(Xd[i, j]) : Xd[i, j] out = f.(1:size(Xd, 1), (1:size(Xd, 2))') out, ∇ -> begin - g(i, j) = i == j ? ∇[i, j]/Xd[i, j] : ∇[i, j] + g(i, j) = i == j ? ∇[i, j] / Xd[i, j] : ∇[i, j] return (nothing, g.(1:size(Xd, 1), (1:size(Xd, 2))')) end end -Bijectors.replace_diag(::typeof(exp), X::TrackedMatrix) = track(Bijectors.replace_diag, exp, X) +function Bijectors.replace_diag(::typeof(exp), X::TrackedMatrix) + return track(Bijectors.replace_diag, exp, X) +end @grad function Bijectors.replace_diag(::typeof(exp), X) Xd = data(X) f(i, j) = ifelse(i == j, exp(Xd[i, j]), Xd[i, j]) out = f.(1:size(Xd, 1), (1:size(Xd, 2))') out, ∇ -> begin - g(i, j) = ifelse(i == j, ∇[i, j]*exp(Xd[i, j]), ∇[i, j]) + g(i, j) = ifelse(i == j, ∇[i, j] * exp(Xd[i, j]), ∇[i, j]) return (nothing, g.(1:size(Xd, 1), (1:size(Xd, 2))')) end end -Bijectors.logabsdetjac(b::SimplexBijector, x::TrackedVecOrMat) = track(Bijectors.logabsdetjac, b, x) +function Bijectors.logabsdetjac(b::SimplexBijector, x::TrackedVecOrMat) + return track(Bijectors.logabsdetjac, b, x) +end @grad function Bijectors.logabsdetjac(b::SimplexBijector, x::AbstractVector) xd = data(x) - return Bijectors.logabsdetjac(b, xd), Δ -> begin + return Bijectors.logabsdetjac(b, xd), + Δ -> begin (nothing, Bijectors.simplex_logabsdetjac_gradient(xd) * Δ) end end @@ -212,7 +221,7 @@ for header in [ Tr = promote_type(eltype(z), eltype(z_0)) r::Tr = norm((z .- z_0)::TV) transformed::T = z .+ β_hat ./ (α .+ r') .* (z .- z_0) # from eq(14) - return (transformed = transformed, α = α, β_hat = β_hat, r = r) + return (transformed=transformed, α=α, β_hat=β_hat, r=r) end end end @@ -247,7 +256,7 @@ for header in [ end r::TV = eachcolnorm(z .- z_0) transformed::T = z .+ β_hat ./ (α .+ r') .* (z .- z_0) # from eq(14) - return (transformed = transformed, α = α, β_hat = β_hat, r = r) + return (transformed=transformed, α=α, β_hat=β_hat, r=r) end end end @@ -256,7 +265,7 @@ eachcolnorm(X::TrackedMatrix) = track(eachcolnorm, X) @grad function eachcolnorm(X) Xd = data(X) y = map(norm, eachcol(Xd)) - y, Δ -> begin + return y, Δ -> begin (Xd .* (Δ ./ y)',) end end @@ -284,7 +293,8 @@ end Bijectors.getpd(X::TrackedMatrix) = track(Bijectors.getpd, X) @grad function Bijectors.getpd(X::AbstractMatrix) Xd = data(X) - return Bijectors.LowerTriangular(Xd) * Bijectors.LowerTriangular(Xd)', Δ -> begin + return Bijectors.LowerTriangular(Xd) * Bijectors.LowerTriangular(Xd)', + Δ -> begin Xl = Bijectors.LowerTriangular(Xd) return (Bijectors.LowerTriangular(Δ' * Xl + Δ * Xl),) end @@ -306,20 +316,20 @@ Bijectors._inv_link_chol_lkj(y::TrackedMatrix) = track(Bijectors._inv_link_chol_ z_mat = similar(y) # cache for adjoint tmp_mat = similar(y) - + @inbounds for j in 1:K w[1, j] = 1 for i in 2:j - z = tanh(y[i-1, j]) - tmp = w[i-1, j] + z = tanh(y[i - 1, j]) + tmp = w[i - 1, j] z_mat[i, j] = z tmp_mat[i, j] = tmp - w[i-1, j] = z * tmp + w[i - 1, j] = z * tmp w[i, j] = tmp * sqrt(1 - z^2) end - for i in (j+1):K + for i in (j + 1):K w[i, j] = 0 end end @@ -330,14 +340,16 @@ Bijectors._inv_link_chol_lkj(y::TrackedMatrix) = track(Bijectors._inv_link_chol_ Δy = zero(y) @inbounds for j in 1:K - Δtmp = Δw[j,j] + Δtmp = Δw[j, j] for i in j:-1:2 - Δz = Δw[i-1, j] * tmp_mat[i, j] - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j] - Δy[i-1, j] = Δz / cosh(y[i-1, j])^2 - Δtmp = Δw[i-1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2) + Δz = + Δw[i - 1, j] * tmp_mat[i, j] - + Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j] + Δy[i - 1, j] = Δz / cosh(y[i - 1, j])^2 + Δtmp = Δw[i - 1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2) end end - + return (Δy,) end @@ -349,14 +361,14 @@ Bijectors._link_chol_lkj(w::TrackedMatrix) = track(Bijectors._link_chol_lkj, w) w = data(w_tracked) K = LinearAlgebra.checksquare(w) - + z = similar(w) @inbounds z[1, 1] = 0 tmp_mat = similar(w) # cache for pullback. - @inbounds for j=2:K + @inbounds for j in 2:K z[1, j] = atanh(w[1, j]) tmp = sqrt(1 - w[1, j]^2) tmp_mat[1, j] = tmp @@ -374,22 +386,22 @@ Bijectors._link_chol_lkj(w::TrackedMatrix) = track(Bijectors._link_chol_lkj, w) Δw = similar(w) - @inbounds Δw[1,1] = zero(eltype(Δz)) + @inbounds Δw[1, 1] = zero(eltype(Δz)) - @inbounds for j=2:K + @inbounds for j in 2:K Δw[j, j] = 0 Δtmp = zero(eltype(Δz)) # Δtmp_mat[j-1,j] - for i in (j-1):-1:2 - p = w[i, j] / tmp_mat[i-1, j] + for i in (j - 1):-1:2 + p = w[i, j] / tmp_mat[i - 1, j] ftmp = sqrt(1 - p^2) d_ftmp_p = -p / ftmp - d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2 + d_p_tmp = -w[i, j] / tmp_mat[i - 1, j]^2 - Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p - Δw[i, j] = Δp / tmp_mat[i-1, j] + Δp = Δz[i, j] / (1 - p^2) + Δtmp * tmp_mat[i - 1, j] * d_ftmp_p + Δw[i, j] = Δp / tmp_mat[i - 1, j] Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp end - Δw[1, j] = Δz[1, j] / (1-w[1,j]^2) - Δtmp / sqrt(1 - w[1,j]^2) * w[1,j] + Δw[1, j] = Δz[1, j] / (1 - w[1, j]^2) - Δtmp / sqrt(1 - w[1, j]^2) * w[1, j] end return (Δw,) @@ -401,11 +413,13 @@ end function Bijectors.find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal} return track(Bijectors.find_alpha, wt_y, wt_u_hat, b) end -@grad function Bijectors.find_alpha(wt_y::TrackedReal, wt_u_hat::TrackedReal, b::TrackedReal) +@grad function Bijectors.find_alpha( + wt_y::TrackedReal, wt_u_hat::TrackedReal, b::TrackedReal +) α = Bijectors.find_alpha(data(wt_y), data(wt_u_hat), data(b)) ∂wt_y = inv(1 + wt_u_hat * sech(α + b)^2) - ∂wt_u_hat = - tanh(α + b) * ∂wt_y + ∂wt_u_hat = -tanh(α + b) * ∂wt_y ∂b = ∂wt_y - 1 find_alpha_pullback(Δ::Real) = (Δ * ∂wt_y, Δ * ∂wt_u_hat, Δ * ∂b) @@ -413,7 +427,9 @@ end end # `OrderedBijector` -Bijectors._transform_ordered(y::Union{TrackedVector,TrackedMatrix}) = track(Bijectors._transform_ordered, y) +function Bijectors._transform_ordered(y::Union{TrackedVector,TrackedMatrix}) + return track(Bijectors._transform_ordered, y) +end @grad function Bijectors._transform_ordered(y::AbstractVecOrMat) x, dx = ChainRulesCore.rrule(Bijectors._transform_ordered, data(y)) return x, (wrap_chainrules_output ∘ Base.tail ∘ dx) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index eedf4b3d..8533b58f 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -11,14 +11,14 @@ end @adjoint function eachcolmaphcat(f, x1, x2) function g(f, x1, x2) init = reshape(f(view(x1, :, 1), x2[1]), :, 1) - return reduce(hcat, [f(view(x1, :, i), x2[i]) for i in 2:size(x1, 2)]; init = init) + return reduce(hcat, [f(view(x1, :, i), x2[i]) for i in 2:size(x1, 2)]; init=init) end return pullback(g, f, x1, x2) end @adjoint function eachcolmaphcat(f, x) function g(f, x) init = reshape(f(view(x, :, 1)), :, 1) - return reduce(hcat, [f(view(x, :, i)) for i in 2:size(x, 2)]; init = init) + return reduce(hcat, [f(view(x, :, i)) for i in 2:size(x, 2)]; init=init) end return pullback(g, f, x) end @@ -59,7 +59,7 @@ end @adjoint function replace_diag(::typeof(log), X) f(i, j) = i == j ? log(X[i, j]) : X[i, j] out = f.(1:size(X, 1), (1:size(X, 2))') - out, ∇ -> begin + return out, ∇ -> begin g(i, j) = i == j ? ∇[i, j] / X[i, j] : ∇[i, j] (nothing, g.(1:size(X, 1), (1:size(X, 2))')) end @@ -67,28 +67,20 @@ end @adjoint function replace_diag(::typeof(exp), X) f(i, j) = ifelse(i == j, exp(X[i, j]), X[i, j]) out = f.(1:size(X, 1), (1:size(X, 2))') - out, ∇ -> begin + return out, ∇ -> begin g(i, j) = ifelse(i == j, ∇[i, j] * exp(X[i, j]), ∇[i, j]) (nothing, g.(1:size(X, 1), (1:size(X, 2))')) end end -@adjoint function pd_logpdf_with_trans( - d, - X::AbstractMatrix{<:Real}, - transform::Bool, -) +@adjoint function pd_logpdf_with_trans(d, X::AbstractMatrix{<:Real}, transform::Bool) return pullback(pd_logpdf_with_trans_zygote, d, X, transform) end -function pd_logpdf_with_trans_zygote( - d, - X::AbstractMatrix{<:Real}, - transform::Bool, -) +function pd_logpdf_with_trans_zygote(d, X::AbstractMatrix{<:Real}, transform::Bool) T = eltype(X) - Xcf = cholesky(X, check = false) + Xcf = cholesky(X; check=false) if !issuccess(Xcf) - Xcf = cholesky(X + max(eps(T), eps(T) * norm(X)) * I, check = true) + Xcf = cholesky(X + max(eps(T), eps(T) * norm(X)) * I; check=true) end lp = getlogp(d, Xcf, X) if transform && isfinite(lp) @@ -114,17 +106,21 @@ end end @adjoint function _simplex_bijector(X::AbstractMatrix, b::SimplexBijector) - return _simplex_bijector(X, b), Δ -> begin + return _simplex_bijector(X, b), + Δ -> begin maphcat(eachcol(X), eachcol(Δ)) do c1, c2 simplex_link_jacobian(c1)' * c2 - end, nothing + end, + nothing end end @adjoint function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector) - return _simplex_inv_bijector(Y, b), Δ -> begin + return _simplex_inv_bijector(Y, b), + Δ -> begin maphcat(eachcol(Y), eachcol(Δ)) do c1, c2 simplex_invlink_jacobian(c1)' * c2 - end, nothing + end, + nothing end end @@ -162,14 +158,15 @@ end return lower(A), Δ -> (lower(Δ),) end @adjoint function getpd(X::AbstractMatrix) - return LowerTriangular(X) * LowerTriangular(X)', Δ -> begin + return LowerTriangular(X) * LowerTriangular(X)', + Δ -> begin Xl = LowerTriangular(X) return (LowerTriangular(Δ' * Xl + Δ * Xl),) end end @adjoint function pd_link(X::AbstractMatrix{<:Real}) return pullback(X) do X - Y = cholesky(X; check = true).L + Y = cholesky(X; check=true).L return replace_diag(log, Y) end end @@ -181,20 +178,20 @@ end z_mat = similar(y) # cache for adjoint tmp_mat = similar(y) - + @inbounds for j in 1:K w[1, j] = 1 for i in 2:j - z = tanh(y[i-1, j]) - tmp = w[i-1, j] + z = tanh(y[i - 1, j]) + tmp = w[i - 1, j] z_mat[i, j] = z tmp_mat[i, j] = tmp - w[i-1, j] = z * tmp + w[i - 1, j] = z * tmp w[i, j] = tmp * sqrt(1 - z^2) end - for i in (j+1):K + for i in (j + 1):K w[i, j] = 0 end end @@ -205,14 +202,16 @@ end Δy = zero(y) @inbounds for j in 1:K - Δtmp = Δw[j,j] + Δtmp = Δw[j, j] for i in j:-1:2 - Δz = Δw[i-1, j] * tmp_mat[i, j] - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j] - Δy[i-1, j] = Δz / cosh(y[i-1, j])^2 - Δtmp = Δw[i-1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2) + Δz = + Δw[i - 1, j] * tmp_mat[i, j] - + Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j] + Δy[i - 1, j] = Δz / cosh(y[i - 1, j])^2 + Δtmp = Δw[i - 1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2) end end - + return (Δy,) end @@ -221,14 +220,14 @@ end @adjoint function _link_chol_lkj(w) K = LinearAlgebra.checksquare(w) - + z = similar(w) @inbounds z[1, 1] = 0 tmp_mat = similar(w) # cache for pullback. - @inbounds for j=2:K + @inbounds for j in 2:K z[1, j] = atanh(w[1, j]) tmp = sqrt(1 - w[1, j]^2) tmp_mat[1, j] = tmp @@ -246,27 +245,26 @@ end Δw = similar(w) - @inbounds Δw[1,1] = zero(eltype(Δz)) + @inbounds Δw[1, 1] = zero(eltype(Δz)) - @inbounds for j=2:K + @inbounds for j in 2:K Δw[j, j] = 0 Δtmp = zero(eltype(Δz)) # Δtmp_mat[j-1,j] - for i in (j-1):-1:2 - p = w[i, j] / tmp_mat[i-1, j] + for i in (j - 1):-1:2 + p = w[i, j] / tmp_mat[i - 1, j] ftmp = sqrt(1 - p^2) d_ftmp_p = -p / ftmp - d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2 + d_p_tmp = -w[i, j] / tmp_mat[i - 1, j]^2 - Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p - Δw[i, j] = Δp / tmp_mat[i-1, j] + Δp = Δz[i, j] / (1 - p^2) + Δtmp * tmp_mat[i - 1, j] * d_ftmp_p + Δw[i, j] = Δp / tmp_mat[i - 1, j] Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp end - Δw[1, j] = Δz[1, j] / (1-w[1,j]^2) - Δtmp / sqrt(1 - w[1,j]^2) * w[1,j] + Δw[1, j] = Δz[1, j] / (1 - w[1, j]^2) - Δtmp / sqrt(1 - w[1, j]^2) * w[1, j] end return (Δw,) end return z, pullback_link_chol_lkj - end diff --git a/src/interface.jl b/src/interface.jl index 2efcac98..91c9a961 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -3,7 +3,7 @@ import Base: ∘ import Random: AbstractRNG import Distributions: logpdf, rand, rand!, _rand!, _logpdf -const Elementwise{F} = Base.Fix1{<:Union{typeof(map),typeof(broadcast)}, F} +const Elementwise{F} = Base.Fix1{<:Union{typeof(map),typeof(broadcast)},F} """ elementwise(f) @@ -16,7 +16,9 @@ In the case where `f::ComposedFunction`, the result is elementwise(f) = Base.Fix1(broadcast, f) # TODO: This is makes dispatching quite a bit easier, but uncertain if this is really # the way to go. -elementwise(f::ComposedFunction) = ComposedFunction(elementwise(f.outer), elementwise(f.inner)) +function elementwise(f::ComposedFunction) + return ComposedFunction(elementwise(f.outer), elementwise(f.inner)) +end const Columnwise{F} = Base.Fix1{typeof(eachcolmaphcat),F} """ @@ -74,7 +76,9 @@ transform(f::F, x) where {F<:Function} = f(x) function transform(t::Transform, x) res = with_logabsdet_jacobian(t, x) if res isa ChangesOfVariables.NoLogAbsDetJacobian - error("`transform` not implemented for $(typeof(b)); implement `transform` and/or `with_logabsdet_jacobian`.") + error( + "`transform` not implemented for $(typeof(b)); implement `transform` and/or `with_logabsdet_jacobian`.", + ) end return first(res) @@ -98,7 +102,9 @@ Return `log(abs(det(J(b, x))))`, where `J(b, x)` is the jacobian of `b` at `x`. function logabsdetjac(b, x) res = with_logabsdet_jacobian(b, x) if res isa ChangesOfVariables.NoLogAbsDetJacobian - error("`logabsdetjac` not implemented for $(typeof(b)); implement `logabsdetjac` and/or `with_logabsdet_jacobian`.") + error( + "`logabsdetjac` not implemented for $(typeof(b)); implement `logabsdetjac` and/or `with_logabsdet_jacobian`.", + ) end return last(res) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index ffa30237..5f0a0d8d 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -1,11 +1,18 @@ # Transformed distributions -struct TransformedDistribution{D, B, V} <: Distribution{V, Continuous} where {D<:Distribution{V, Continuous}, B} +struct TransformedDistribution{D,B,V} <: + Distribution{V,Continuous} where {D<:Distribution{V,Continuous},B} dist::D transform::B - TransformedDistribution(d::UnivariateDistribution, b) = new{typeof(d), typeof(b), Univariate}(d, b) - TransformedDistribution(d::MultivariateDistribution, b) = new{typeof(d), typeof(b), Multivariate}(d, b) - TransformedDistribution(d::MatrixDistribution, b) = new{typeof(d), typeof(b), Matrixvariate}(d, b) + function TransformedDistribution(d::UnivariateDistribution, b) + return new{typeof(d),typeof(b),Univariate}(d, b) + end + function TransformedDistribution(d::MultivariateDistribution, b) + return new{typeof(d),typeof(b),Multivariate}(d, b) + end + function TransformedDistribution(d::MatrixDistribution, b) + return new{typeof(d),typeof(b),Matrixvariate}(d, b) + end end # fields may contain nested numerical parameters @@ -17,7 +24,6 @@ const MvTransformed = MultivariateTransformed const MatrixTransformed = TransformedDistribution{<:Distribution,<:Any,Matrixvariate} const Transformed = TransformedDistribution - """ transformed(d::Distribution) transformed(d::Distribution, b::Bijector) @@ -64,14 +70,14 @@ bijector(d::KSOneSided) = Logit(zero(eltype(d)), one(eltype(d))) bijector_bounded(d, a=minimum(d), b=maximum(d)) = Logit(a, b) bijector_lowerbounded(d, a=minimum(d)) = elementwise(log) ∘ Shift(-a) -bijector_upperbounded(d, b=maximum(d)) = elementwise(log) ∘ Shift(b) ∘ Scale(- one(typeof(b))) +function bijector_upperbounded(d, b=maximum(d)) + return elementwise(log) ∘ Shift(b) ∘ Scale(-one(typeof(b))) +end -const BoundedDistribution = Union{ - Arcsine, Biweight, Cosine, Epanechnikov, Beta, NoncentralBeta -} +const BoundedDistribution = Union{Arcsine,Biweight,Cosine,Epanechnikov,Beta,NoncentralBeta} bijector(d::BoundedDistribution) = bijector_bounded(d) -const LowerboundedDistribution = Union{Pareto, Levy} +const LowerboundedDistribution = Union{Pareto,Levy} bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d) bijector(d::PDMatDistribution) = PDBijector() @@ -111,7 +117,7 @@ function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) ϵ = _eps(T) x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, x)) + logjac + return logpdf(td.dist, mappedarray(x -> x + ϵ, x)) + logjac end function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) @@ -124,7 +130,7 @@ function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) ϵ = _eps(T) x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, x)) + logjac + return logpdf(td.dist, mappedarray(x -> x + ϵ, x)) + logjac end # TODO: should eventually drop using `logpdf_with_trans` and replace with @@ -145,20 +151,23 @@ rand(rng::AbstractRNG, td::MvTransformed) = td.transform(rand(rng, td.dist)) # TODO: implement more efficiently for flows function rand(rng::AbstractRNG, td::MvTransformed, num_samples::Int) samples = rand(rng, td.dist, num_samples) - res = reduce(hcat, map(axes(samples, 2)) do i - return td.transform(view(samples, :, i)) - end) + res = reduce( + hcat, + map(axes(samples, 2)) do i + return td.transform(view(samples, :, i)) + end, + ) return res end function _rand!(rng::AbstractRNG, td::MvTransformed, x::AbstractVector{<:Real}) rand!(rng, td.dist, x) - x .= td.transform(x) + return x .= td.transform(x) end function _rand!(rng::AbstractRNG, td::MatrixTransformed, x::DenseMatrix{<:Real}) rand!(rng, td.dist, x) - x .= td.transform(x) + return x .= td.transform(x) end # utility stuff @@ -173,4 +182,3 @@ function Base.minimum(td::UnivariateTransformed) min, max = td.transform.((Base.minimum(td.dist), Base.maximum(td.dist))) return max < min ? max : min end - diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index b0e4dc2e..ebb461f0 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -5,7 +5,13 @@ test_frule(Bijectors.find_alpha, x, y, z) test_rrule(Bijectors.find_alpha, x, y, z) - test_rrule(Bijectors.combine, Bijectors.PartitionMask(3, [1], [2]) ⊢ ChainRulesTestUtils.NoTangent(), [1.0], [2.0], [3.0]) + test_rrule( + Bijectors.combine, + Bijectors.PartitionMask(3, [1], [2]) ⊢ ChainRulesTestUtils.NoTangent(), + [1.0], + [2.0], + [3.0], + ) # ordered bijector b = Bijectors.OrderedBijector() diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 6bf8365f..dc7d8234 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -1,43 +1,45 @@ # Figure out which AD backend to test const AD = get(ENV, "AD", "All") -function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) +function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] if AD == "All" || AD == "Tracker" if :Tracker in broken - @test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol + @test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol = rtol atol = + atol else ∇tracker = Tracker.gradient(f, x)[1] - @test Tracker.data(∇tracker) ≈ finitediff rtol=rtol atol=atol + @test Tracker.data(∇tracker) ≈ finitediff rtol = rtol atol = atol @test Tracker.istracked(∇tracker) end end if AD == "All" || AD == "ForwardDiff" if :ForwardDiff in broken - @test_broken ForwardDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol + @test_broken ForwardDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol else - @test ForwardDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol + @test ForwardDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol end end if AD == "All" || AD == "Zygote" if :Zygote in broken - @test_broken Zygote.gradient(f, x)[1] ≈ finitediff rtol=rtol atol=atol + @test_broken Zygote.gradient(f, x)[1] ≈ finitediff rtol = rtol atol = atol else ∇zygote = Zygote.gradient(f, x)[1] - @test (all(finitediff .== 0) && ∇zygote === nothing) || isapprox(∇zygote, finitediff, rtol=rtol, atol=atol) + @test (all(finitediff .== 0) && ∇zygote === nothing) || + isapprox(∇zygote, finitediff; rtol=rtol, atol=atol) end end if AD == "All" || AD == "ReverseDiff" if :ReverseDiff in broken - @test_broken ReverseDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol + @test_broken ReverseDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol else - @test ReverseDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol + @test ReverseDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol end end - return + return nothing end diff --git a/test/bijectors/coupling.jl b/test/bijectors/coupling.jl index 5d6c6607..849fa6d1 100644 --- a/test/bijectors/coupling.jl +++ b/test/bijectors/coupling.jl @@ -1,12 +1,4 @@ -using Bijectors: - Coupling, - PartitionMask, - coupling, - couple, - partition, - combine, - Shift, - Scale +using Bijectors: Coupling, PartitionMask, coupling, couple, partition, combine, Shift, Scale @testset "Coupling" begin @testset "PartitionMask" begin @@ -15,9 +7,9 @@ using Bijectors: @test (m1.A_1 == m2.A_1) & (m1.A_2 == m2.A_2) & (m1.A_3 == m2.A_3) - x = [1., 2., 3.] + x = [1.0, 2.0, 3.0] x1, x2, x3 = partition(m1, x) - @test (x1 == [1.]) & (x2 == [2.]) & (x3 == [3.]) + @test (x1 == [1.0]) & (x2 == [2.0]) & (x3 == [3.0]) y = combine(m1, x1, x2, x3) @test y == x @@ -27,8 +19,8 @@ using Bijectors: m = PartitionMask(3, [1], [2]) cl1 = Coupling(x -> Shift(x[1]), m) - x = [1., 2., 3.] - @test cl1(x) == [3., 2., 3.] + x = [1.0, 2.0, 3.0] + @test cl1(x) == [3.0, 2.0, 3.0] cl2 = Coupling(θ -> Shift(θ[1]), m) @test cl2(x) == cl1(x) @@ -46,7 +38,7 @@ using Bijectors: # with_logabsdet_jacobian @test with_logabsdet_jacobian(cl1, x) == (cl1(x), logabsdetjac(cl1, x)) - @test with_logabsdet_jacobian(icl1, cl1(x)) == (x, - logabsdetjac(cl1, x)) + @test with_logabsdet_jacobian(icl1, cl1(x)) == (x, -logabsdetjac(cl1, x)) end @testset "Classic" begin @@ -54,12 +46,12 @@ using Bijectors: # With `Scale` cl = Coupling(x -> Scale(x[1]), m) - x = [-1., -2., -3.] - y = [2., -2., -3.] + x = [-1.0, -2.0, -3.0] + y = [2.0, -2.0, -3.0] test_bijector(cl, x; y=y, logjac=log(2)) - x = [1., 2., 3.] - y = [2., 2., 3.] + x = [1.0, 2.0, 3.0] + y = [2.0, 2.0, 3.0] test_bijector(cl, x; y=y, logjac=log(2)) end end diff --git a/test/bijectors/leaky_relu.jl b/test/bijectors/leaky_relu.jl index a51f33e5..f8cb1a08 100644 --- a/test/bijectors/leaky_relu.jl +++ b/test/bijectors/leaky_relu.jl @@ -18,11 +18,11 @@ using Bijectors: LeakyReLU b = LeakyReLU(Float32(b.α)) # < 0 - x = -1f0 + x = -1.0f0 test_bijector(b, x) # ≥ 0 - x = 1f0 + x = 1.0f0 test_bijector(b, x; test_not_identity=false, test_types=true) end diff --git a/test/bijectors/named_bijector.jl b/test/bijectors/named_bijector.jl index fbdce0ec..787a6b41 100644 --- a/test/bijectors/named_bijector.jl +++ b/test/bijectors/named_bijector.jl @@ -3,38 +3,38 @@ using Bijectors using Bijectors: Logit, AbstractNamedTransform, NamedTransform, NamedCoupling, Shift @testset "NamedTransform" begin - b = NamedTransform((a = elementwise(exp), b = elementwise(log))) - @test b((a = 0.0, b = exp(1.0))) == (a = 1.0, b = 1.0) + b = NamedTransform((a=elementwise(exp), b=elementwise(log))) + @test b((a=0.0, b=exp(1.0))) == (a=1.0, b=1.0) - with_logabsdet_jacobian(b, (a = 0.0, b = exp(1.0))) + with_logabsdet_jacobian(b, (a=0.0, b=exp(1.0))) end @testset "NamedCoupling" begin - nc = NamedCoupling(Val(:b), Val((:a, )), a -> Logit(zero(a), a)) - @inferred NamedCoupling(Val(:b), Val((:a, )), Shift) + nc = NamedCoupling(Val(:b), Val((:a,)), a -> Logit(zero(a), a)) + @inferred NamedCoupling(Val(:b), Val((:a,)), Shift) - nc = NamedCoupling(:b, (:a, ), a -> Logit(0., a)) # <= not type-inferrable but eh + nc = NamedCoupling(:b, (:a,), a -> Logit(0.0, a)) # <= not type-inferrable but eh @test Bijectors.target(nc) == :b - @test Bijectors.deps(nc) == (:a, ) + @test Bijectors.deps(nc) == (:a,) @inferred Bijectors.target(nc) @inferred Bijectors.deps(nc) - x = (a = 1.0, b = 0.5, c = 99999.) + x = (a=1.0, b=0.5, c=99999.0) @test Bijectors.coupling(nc)(x.a) isa Logit @test inverse(nc)(nc(x)) == x - @test logabsdetjac(nc, x) == logabsdetjac(Logit(0., 1.), x.b) + @test logabsdetjac(nc, x) == logabsdetjac(Logit(0.0, 1.0), x.b) @test logabsdetjac(inverse(nc), nc(x)) == -logabsdetjac(nc, x) - x = (a = 0.0, b = 2.0, c = 1.0) + x = (a=0.0, b=2.0, c=1.0) nc = NamedCoupling(:c, (:a, :b), (a, b) -> Logit(a, b)) @test nc(x).c == 0.0 @test inverse(nc)(nc(x)) == x - x = (a = 0.0, b = 2.0, c = 1.0) - nc = NamedCoupling(:c, (:b, ), b -> Shift(b)) + x = (a=0.0, b=2.0, c=1.0) + nc = NamedCoupling(:c, (:b,), b -> Shift(b)) @test nc(x).c == 3.0 @test inverse(nc)(nc(x)) == x end diff --git a/test/bijectors/permute.jl b/test/bijectors/permute.jl index 49e90a82..6ab0e968 100644 --- a/test/bijectors/permute.jl +++ b/test/bijectors/permute.jl @@ -8,10 +8,10 @@ using Bijectors: Permute # in the sense that the map is {1, 2} => {1} @test_throws ArgumentError Permute(2, 2 => 1) @test_throws ArgumentError Permute(2, [1, 2, 3] => [2, 1]) - + # Simplest case b1 = Permute([ - 0 1; + 0 1 1 0 ]) b2 = Permute([2, 1]) @@ -20,7 +20,7 @@ using Bijectors: Permute @test b1.A == b2.A == b3.A == b4.A - x = [1., 2.] + x = [1.0, 2.0] @test (inverse(b1) ∘ b1)(x) == x @test (inverse(b2) ∘ b2)(x) == x @test (inverse(b3) ∘ b3)(x) == x @@ -28,8 +28,8 @@ using Bijectors: Permute # Slightly more complex case; one entry is not permuted b1 = Permute([ - 0 1 0; - 1 0 0; + 0 1 0 + 1 0 0 0 0 1 ]) b2 = Permute([2, 1, 3]) @@ -37,8 +37,8 @@ using Bijectors: Permute b4 = Permute(3, [1, 2] => [2, 1]) @test b1.A == b2.A == b3.A == b4.A - - x = [1., 2., 3.] + + x = [1.0, 2.0, 3.0] @test (inverse(b1) ∘ b1)(x) == x @test (inverse(b2) ∘ b2)(x) == x @test (inverse(b3) ∘ b3)(x) == x diff --git a/test/bijectors/rational_quadratic_spline.jl b/test/bijectors/rational_quadratic_spline.jl index fe0ddc31..4b8be14b 100644 --- a/test/bijectors/rational_quadratic_spline.jl +++ b/test/bijectors/rational_quadratic_spline.jl @@ -4,7 +4,9 @@ using Bijectors: RationalQuadraticSpline @testset "RationalQuadraticSpline" begin # Monotonic spline on '[-B, B]' with `K` intermediate knots/"connection points". - d = 2; K = 3; B = 2; + d = 2 + K = 3 + B = 2 b_uv = RationalQuadraticSpline(randn(K), randn(K), randn(K - 1), B) b_mv = RationalQuadraticSpline(randn(d, K), randn(d, K), randn(d, K - 1), B) @@ -54,7 +56,7 @@ using Bijectors: RationalQuadraticSpline test_bijector(b, x) # Outside of domain - x = [-5., 5.] + x = [-5.0, 5.0] test_bijector(b, x; y=x, logjac=zero(eltype(x))) end end diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index dc1d3a55..b9ab1242 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -15,15 +15,16 @@ function test_bijector( changes_of_variables_test=true, inverse_functions_test=true, compare=isapprox, - kwargs... + kwargs..., ) # Ensure that everything is type-stable. ib = @inferred inverse(b) logjac_test = @inferred logabsdetjac(b, x) - res = @inferred with_logabsdet_jacobian(b, x) + res = @inferred with_logabsdet_jacobian(b, x) y_test = @inferred b(x) - ilogjac_test = !isnothing(y) ? @inferred(logabsdetjac(ib, y)) : @inferred(logabsdetjac(ib, y_test)) + ilogjac_test = + !isnothing(y) ? @inferred(logabsdetjac(ib, y)) : @inferred(logabsdetjac(ib, y_test)) ires = if !isnothing(y) @inferred(with_logabsdet_jacobian(inverse(b), y)) else @@ -34,18 +35,20 @@ function test_bijector( # For non-bijective transformations, these tests always fail since determinant of # the Jacobian is zero. Hence we allow the caller to disable them if necessary. if changes_of_variables_test - ChangesOfVariables.test_with_logabsdet_jacobian(b, x, getjacobian; compare=compare, kwargs...) ChangesOfVariables.test_with_logabsdet_jacobian( - ib, isnothing(y) ? y_test : y, getjacobian; - compare=compare, - kwargs... + b, x, getjacobian; compare=compare, kwargs... + ) + ChangesOfVariables.test_with_logabsdet_jacobian( + ib, isnothing(y) ? y_test : y, getjacobian; compare=compare, kwargs... ) end # InverseFunctions.jl if inverse_functions_test InverseFunctions.test_inverse(b, x; compare, kwargs...) - InverseFunctions.test_inverse(ib, isnothing(y) ? y_test : y; compare=compare, kwargs...) + InverseFunctions.test_inverse( + ib, isnothing(y) ? y_test : y; compare=compare, kwargs... + ) end # Always want the following to hold @@ -100,9 +103,9 @@ function test_functor(x, xs) @test _xs == xs end -function test_bijector_parameter_gradient(b::Bijectors.Transform, x, y = b(x)) +function test_bijector_parameter_gradient(b::Bijectors.Transform, x, y=b(x)) args, re = Functors.functor(b) - recon(k, param) = re(merge(args, NamedTuple{(k, )}((param, )))) + recon(k, param) = re(merge(args, NamedTuple{(k,)}((param,)))) # Compute the gradient wrt. one argument at the time. for (k, v) in pairs(args) diff --git a/test/interface.jl b/test/interface.jl index e1f8d0e4..7d989abf 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -7,7 +7,19 @@ using Tracker using DistributionsAD using Bijectors -using Bijectors: Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, RationalQuadraticSpline, LeakyReLU +using Bijectors: + Shift, + Scale, + Logit, + SimplexBijector, + PDBijector, + Permute, + PlanarLayer, + RadialLayer, + Stacked, + TruncatedBijector, + RationalQuadraticSpline, + LeakyReLU Random.seed!(123) @@ -19,7 +31,7 @@ contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs)) # Tests with scalar-valued distributions. uni_dists = [ Arcsine(2, 4), - Beta(2,2), + Beta(2, 2), BetaPrime(), Biweight(), Cauchy(), @@ -44,10 +56,10 @@ contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs)) Rayleigh(1.0), TDist(2), truncated(Normal(0, 1), -Inf, 2), - transformed(Beta(2,2)), + transformed(Beta(2, 2)), transformed(Exponential()), ] - + for dist in uni_dists @testset "$dist: dist" begin td = @inferred transformed(dist) @@ -68,7 +80,8 @@ contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs)) b = @inferred bijector(d) x = rand(d) y = @inferred b(x) - @test logpdf(d, inverse(b)(y)) + logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) + @test logpdf(d, inverse(b)(y)) + logabsdetjacinv(b, y) ≈ + logpdf_with_trans(d, x, true) @test logpdf(d, x) - logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) # verify against AD @@ -78,8 +91,9 @@ contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs)) y = b(x) # `ForwardDiff.derivative` can lead to some numerical inaccuracy, # so we use a slightly higher `atol` than default. - @test log(abs(ForwardDiff.derivative(b, x))) ≈ logabsdetjac(b, x) atol=1e-6 - @test log(abs(ForwardDiff.derivative(inverse(b), y))) ≈ logabsdetjac(inverse(b), y) atol=1e-6 + @test log(abs(ForwardDiff.derivative(b, x))) ≈ logabsdetjac(b, x) atol = 1e-6 + @test log(abs(ForwardDiff.derivative(inverse(b), y))) ≈ + logabsdetjac(inverse(b), y) atol = 1e-6 end end end @@ -91,7 +105,8 @@ end y = b(x) @test y ≈ link(d, x) @test inverse(b)(y) ≈ x - @test logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, false) - logpdf_with_trans(d, x, true) + @test logabsdetjac(b, x) ≈ + logpdf_with_trans(d, x, false) - logpdf_with_trans(d, x, true) d = truncated(Normal(), -Inf, 1) b = bijector(d) @@ -99,7 +114,8 @@ end y = b(x) @test y ≈ link(d, x) @test inverse(b)(y) ≈ x - @test logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, false) - logpdf_with_trans(d, x, true) + @test logabsdetjac(b, x) ≈ + logpdf_with_trans(d, x, false) - logpdf_with_trans(d, x, true) d = truncated(Normal(), 1, Inf) b = bijector(d) @@ -107,7 +123,8 @@ end y = b(x) @test y ≈ link(d, x) @test inverse(b)(y) ≈ x - @test logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, false) - logpdf_with_trans(d, x, true) + @test logabsdetjac(b, x) ≈ + logpdf_with_trans(d, x, false) - logpdf_with_trans(d, x, true) end @testset "Multivariate" begin @@ -117,7 +134,7 @@ end Dirichlet([eps(Float64), 1000 * one(Float64)]), MvNormal(randn(10), Diagonal(exp.(randn(10)))), MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10))))), - Dirichlet([1000 * one(Float64), eps(Float64)]), + Dirichlet([1000 * one(Float64), eps(Float64)]), Dirichlet([eps(Float64), 1000 * one(Float64)]), transformed(MvNormal(randn(10), Diagonal(exp.(randn(10))))), transformed(MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10)))))), @@ -144,11 +161,16 @@ end # which in turn will lead to differences between `ForwardDiff.jacobian` # and `logabsdetjac` due to how we handle the boundary values in `SimplexBijector`. # We therefore test the realizations _on_ the boundary rather if we're near the boundary. - x = any(rand(dist) .> 0.9999) ? [0.0, 1.0][sortperm(rand(dist))] : rand(dist) + x = if any(rand(dist) .> 0.9999) + [0.0, 1.0][sortperm(rand(dist))] + else + rand(dist) + end y = b(x) @test b(param(x)) isa TrackedArray @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) - @test log(abs(det(ForwardDiff.jacobian(inverse(b), y)))) ≈ logabsdetjac(inverse(b), y) + @test log(abs(det(ForwardDiff.jacobian(inverse(b), y)))) ≈ + logabsdetjac(inverse(b), y) else b = bijector(dist) x = rand(dist) @@ -156,8 +178,10 @@ end # `ForwardDiff.derivative` can lead to some numerical inaccuracy, # so we use a slightly higher `atol` than default. @test b(param(x)) isa TrackedArray - @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) atol=1e-6 - @test log(abs(det(ForwardDiff.jacobian(inverse(b), y)))) ≈ logabsdetjac(inverse(b), y) atol=1e-6 + @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) atol = + 1e-6 + @test log(abs(det(ForwardDiff.jacobian(inverse(b), y)))) ≈ + logabsdetjac(inverse(b), y) atol = 1e-6 end end end @@ -169,11 +193,11 @@ end S[1, 2] = S[2, 1] = 0.5 matrix_dists = [ - Wishart(v,S), - InverseWishart(v,S), - TuringWishart(v,S), - TuringInverseWishart(v,S), - LKJ(3, 1.), + Wishart(v, S), + InverseWishart(v, S), + TuringWishart(v, S), + TuringInverseWishart(v, S), + LKJ(3, 1.0), reshape(MvNormal(zeros(6), I), 2, 3), ] @@ -224,16 +248,16 @@ end @test sb1(param([x, x, y, y])) isa TrackedArray @test sb1([x, x, y, y]) ≈ res1[1] - @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0 atol=1e-6 - @test res1[2] ≈ 0 atol=1e-6 + @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0 atol = 1e-6 + @test res1[2] ≈ 0 atol = 1e-6 sb2 = Stacked([b, b, inverse(b), inverse(b)]) # <= Array res2 = with_logabsdet_jacobian(sb2, [x, x, y, y]) @test sb2(param([x, x, y, y])) isa TrackedArray @test sb2([x, x, y, y]) ≈ res2[1] - @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 atol=1e-12 - @test res2[2] ≈ 0.0 atol=1e-12 + @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 atol = 1e-12 + @test res2[2] ≈ 0.0 atol = 1e-12 # value-test x = ones(3) @@ -242,10 +266,10 @@ end @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] @test res[1] == [exp(x[1]), log(x[2]), x[3] + 5.0] - @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:3]) + @test logabsdetjac(sb, x) == + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:3]) @test res[2] == logabsdetjac(sb, x) - # TODO: change when we have dimensionality in the type sb = @inferred Stacked((elementwise(exp), SimplexBijector()), (1:1, 2:3)) x = ones(3) ./ 3.0 @@ -253,7 +277,8 @@ end @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] - @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) + @test logabsdetjac(sb, x) == + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:2]) @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 @@ -266,7 +291,8 @@ end @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] - @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) + @test logabsdetjac(sb, x) == + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:2]) @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 @@ -280,7 +306,8 @@ end @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] - @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) + @test logabsdetjac(sb, x) == + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:2]) @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 @@ -293,13 +320,13 @@ end @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] - @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) + @test logabsdetjac(sb, x) == + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:2]) @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 @test_throws AssertionError sb(x) - @testset "Stacked: ADVI with MvNormal" begin # MvNormal test dists = [ @@ -313,14 +340,14 @@ end InverseGamma(), Cauchy(), Gamma(), - MvNormal(zeros(2), I) + MvNormal(zeros(2), I), ] ranges = [] idx = 1 - for i = 1:length(dists) + for i in 1:length(dists) d = dists[i] - push!(ranges, idx:idx + length(d) - 1) + push!(ranges, idx:(idx + length(d) - 1)) idx += length(d) end ranges = tuple(ranges...) @@ -337,7 +364,7 @@ end @test sb isa Stacked td = transformed(d, sb) # => MultivariateTransformed <: Distribution{Multivariate, Continuous} - @test td isa Distribution{Multivariate, Continuous} + @test td isa Distribution{Multivariate,Continuous} # check that wrong ranges fails sb = Stacked(ibs) @@ -360,8 +387,8 @@ end # verification of computation x = rand(d) y = sb(x) - y_ = vcat([ibs[i](x[ranges[i]]) for i = 1:length(dists)]...) - x_ = vcat([bs[i](y[ranges[i]]) for i = 1:length(dists)]...) + y_ = vcat([ibs[i](x[ranges[i]]) for i in 1:length(dists)]...) + x_ = vcat([bs[i](y[ranges[i]]) for i in 1:length(dists)]...) @test x ≈ x_ @test y ≈ y_ @@ -371,8 +398,8 @@ end # Ensure `Stacked` works for a single bijector d = (MvNormal(zeros(2), I),) - sb = Stacked(bijector.(d), (1:2, )) - x = [.5, 1.] + sb = Stacked(bijector.(d), (1:2,)) + x = [0.5, 1.0] @test sb(x) == x @test logabsdetjac(sb, x) == 0 @test with_logabsdet_jacobian(sb, x) == (x, zero(eltype(x))) @@ -433,8 +460,8 @@ end elementwise(log), Scale(2.0), Scale(3.0), - Scale(rand(2,2)), - Scale(rand(2,2)), + Scale(rand(2, 2)), + Scale(rand(2, 2)), Shift(2.0), Shift(3.0), Shift(rand(2)), @@ -474,7 +501,7 @@ end end @testset "test_inverse and test_with_logabsdet_jacobian" begin - b = Bijectors.Scale{Float64,}(4.2) + b = Bijectors.Scale{Float64}(4.2) x = 0.3 InverseFunctions.test_inverse(b, x) diff --git a/test/norm_flows.jl b/test/norm_flows.jl index d00f632c..74fe3142 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -11,14 +11,15 @@ seed!(1) @test inverse(inverse(bn)) == bn @test inverse(bn)(bn(x)) ≈ x @test (inverse(bn) ∘ bn)(x) ≈ x - @test_throws ErrorException with_logabsdet_jacobian(bn, randn(10,2)) - @test logabsdetjac(inverse(bn), bn(x)) ≈ - logabsdetjac(bn, x) + @test_throws ErrorException with_logabsdet_jacobian(bn, randn(10, 2)) + @test logabsdetjac(inverse(bn), bn(x)) ≈ -logabsdetjac(bn, x) y, ladj = with_logabsdet_jacobian(bn, x) @test log(abs(det(ForwardDiff.jacobian(bn, x)))) ≈ sum(ladj) - @test log(abs(det(ForwardDiff.jacobian(inverse(bn), y)))) ≈ sum(logabsdetjac(inverse(bn), y)) + @test log(abs(det(ForwardDiff.jacobian(inverse(bn), y)))) ≈ + sum(logabsdetjac(inverse(bn), y)) - test_functor(bn, (b = bn.b, logs = bn.logs)) + test_functor(bn, (b=bn.b, logs=bn.logs)) end @testset "PlanarLayer" begin @@ -40,8 +41,8 @@ end z = ones(10, 100) @test inverse(flow)(flow(z)) ≈ z - test_functor(flow, (w = w, u = u, b = b)) - test_functor(inverse(flow), (orig = flow,)) + test_functor(flow, (w=w, u=u, b=b)) + test_functor(inverse(flow), (orig=flow,)) @testset "find_alpha" begin for wt_y in (-20.3, -3, -3//2, 0.0, 5, 29//4, 12.3) @@ -55,7 +56,8 @@ end # check if α is an approximate solution to the considered equation # have to set atol if wt_y is zero (otherwise only equality is checked) - @test wt_y ≈ α + wt_u_hat * tanh(α + b) atol=iszero(wt_y) ? 1e-14 : 0.0 + @test wt_y ≈ α + wt_u_hat * tanh(α + b) atol = + iszero(wt_y) ? 1e-14 : 0.0 end end end @@ -77,8 +79,8 @@ end our_method = sum(with_logabsdet_jacobian(flow, z)[2]) @test our_method ≈ forward_diff - @test inverse(flow)(flow(z)) ≈ z rtol=0.2 - @test (inverse(flow) ∘ flow)(z) ≈ z rtol=0.2 + @test inverse(flow)(flow(z)) ≈ z rtol = 0.2 + @test (inverse(flow) ∘ flow)(z) ≈ z rtol = 0.2 end α_ = 1.0 @@ -88,8 +90,8 @@ end flow = RadialLayer(α_, β, z_0) @test inverse(flow)(flow(z)) ≈ z - test_functor(flow, (α_ = α_, β = β, z_0 = z_0)) - test_functor(inverse(flow), (orig = flow,)) + test_functor(flow, (α_=α_, β=β, z_0=z_0)) + test_functor(inverse(flow), (orig=flow,)) end @testset "Flows" begin @@ -106,7 +108,7 @@ end lp = logpdf(d, x) - res[2] @test res[1] ≈ y - @test logpdf(flow, y) ≈ lp rtol=0.1 + @test logpdf(flow, y) ≈ lp rtol = 0.1 # flow with unconstrained-to-constrained d1 = Beta() diff --git a/test/runtests.jl b/test/runtests.jl index b48c656d..ee2401eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,8 +13,17 @@ using Zygote using Random, LinearAlgebra, Test -using Bijectors: Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, - PlanarLayer, RadialLayer, Stacked, TruncatedBijector +using Bijectors: + Shift, + Scale, + Logit, + SimplexBijector, + PDBijector, + Permute, + PlanarLayer, + RadialLayer, + Stacked, + TruncatedBijector using ChangesOfVariables: ChangesOfVariables using InverseFunctions: InverseFunctions @@ -43,4 +52,3 @@ if GROUP == "All" || GROUP == "AD" include("ad/chainrules.jl") include("ad/flows.jl") end - diff --git a/test/transform.jl b/test/transform.jl index 7be147d3..15cd9178 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -24,9 +24,10 @@ function single_sample_tests(dist, jacobian) # Check that the implementation of the logpdf agrees with the AD version. x = rand(dist) if dist isa SimplexDistribution - logpdf_ad = logpdf(dist, x .+ ϵ) - _logabsdet(jacobian(x->link(dist, x, false), x)) + logpdf_ad = + logpdf(dist, x .+ ϵ) - _logabsdet(jacobian(x -> link(dist, x, false), x)) else - logpdf_ad = logpdf(dist, x) - _logabsdet(jacobian(x->link(dist, x), x)) + logpdf_ad = logpdf(dist, x) - _logabsdet(jacobian(x -> link(dist, x), x)) end @test logpdf_ad ≈ logpdf_with_trans(dist, x, true) end @@ -38,7 +39,7 @@ function single_sample_tests(dist) # Check that invlink is inverse of link. x = rand(dist) - @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol=1e-9 + @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol = 1e-9 # Check that link is inverse of invlink. Hopefully this just holds given the above... y = @inferred(link(dist, x)) @@ -49,128 +50,139 @@ function single_sample_tests(dist) # 1.0 # julia> logistic(logit(0.9999999999999998)) # 0.9999999999999998 - @test @inferred(link(dist, invlink(dist, copy(y)))) ≈ y atol=0.5 + @test @inferred(link(dist, invlink(dist, copy(y)))) ≈ y atol = 0.5 else - @test @inferred(link(dist, invlink(dist, copy(y)))) ≈ y atol=1e-9 + @test @inferred(link(dist, invlink(dist, copy(y)))) ≈ y atol = 1e-9 end if dist isa SimplexDistribution # This should probably be exact. @test logpdf(dist, x .+ ϵ) == logpdf_with_trans(dist, x, false) # Check that invlink maps back to the apppropriate constrained domain. - @test all(isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(x, 0)) .+ ϵ for _ in 1:100])) + @test all( + isfinite, + logpdf.(Ref(dist), [invlink(dist, _rand_real(x, 0)) .+ ϵ for _ in 1:100]), + ) else # This should probably be exact. @test logpdf(dist, x) == logpdf_with_trans(dist, x, false) - @test all(isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(x)) for _ in 1:100])) + @test all( + isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(x)) for _ in 1:100]) + ) end end # Scalar tests @testset "scalar" begin -let - # Tests with scalar-valued distributions. - uni_dists = [ - Arcsine(2, 4), - Beta(2,2), - BetaPrime(), - Biweight(), - Cauchy(), - Chi(3), - Chisq(2), - Cosine(), - Epanechnikov(), - Erlang(), - Exponential(), - FDist(1, 1), - Frechet(), - Gamma(), - InverseGamma(), - InverseGaussian(), - # Kolmogorov(), - Laplace(), - Levy(), - Logistic(), - LogNormal(1.0, 2.5), - Normal(0.1, 2.5), - Pareto(), - Rayleigh(1.0), - TDist(2), - truncated(Normal(0, 1), -Inf, 2), - ] - for dist in uni_dists - - single_sample_tests(dist, ForwardDiff.derivative) + let + # Tests with scalar-valued distributions. + uni_dists = [ + Arcsine(2, 4), + Beta(2, 2), + BetaPrime(), + Biweight(), + Cauchy(), + Chi(3), + Chisq(2), + Cosine(), + Epanechnikov(), + Erlang(), + Exponential(), + FDist(1, 1), + Frechet(), + Gamma(), + InverseGamma(), + InverseGaussian(), + # Kolmogorov(), + Laplace(), + Levy(), + Logistic(), + LogNormal(1.0, 2.5), + Normal(0.1, 2.5), + Pareto(), + Rayleigh(1.0), + TDist(2), + truncated(Normal(0, 1), -Inf, 2), + ] + for dist in uni_dists + single_sample_tests(dist, ForwardDiff.derivative) + end end end -end # Tests with vector-valued distributions. @testset "vector" begin -let ϵ = eps(Float64) - vector_dists = [ - Dirichlet(2, 3), - Dirichlet([1000 * one(Float64), eps(Float64)]), - Dirichlet([eps(Float64), 1000 * one(Float64)]), - MvNormal(randn(10), Diagonal(exp.(randn(10)))), - MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10))))), - Dirichlet([1000 * one(Float64), eps(Float64)]), - Dirichlet([eps(Float64), 1000 * one(Float64)]), - ] - for dist in vector_dists - - if dist isa Dirichlet - single_sample_tests(dist) - - # This should fail at the minute. Not sure what the correct way to test this is. - - # Workaround for intermittent test failures, result of `logpdf_with_trans(dist, x, true)` - # is incorrect for `x == [0.9999999999999998, 0.0]`: - x = if params(dist) == params(Dirichlet([1000 * one(Float64), eps(Float64)])) - [1.0, 0.0] + let ϵ = eps(Float64) + vector_dists = [ + Dirichlet(2, 3), + Dirichlet([1000 * one(Float64), eps(Float64)]), + Dirichlet([eps(Float64), 1000 * one(Float64)]), + MvNormal(randn(10), Diagonal(exp.(randn(10)))), + MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10))))), + Dirichlet([1000 * one(Float64), eps(Float64)]), + Dirichlet([eps(Float64), 1000 * one(Float64)]), + ] + for dist in vector_dists + if dist isa Dirichlet + single_sample_tests(dist) + + # This should fail at the minute. Not sure what the correct way to test this is. + + # Workaround for intermittent test failures, result of `logpdf_with_trans(dist, x, true)` + # is incorrect for `x == [0.9999999999999998, 0.0]`: + x = + if params(dist) == + params(Dirichlet([1000 * one(Float64), eps(Float64)])) + [1.0, 0.0] + else + rand(dist) + end + + logpdf_turing = logpdf_with_trans(dist, x, true) + J = ForwardDiff.jacobian(x -> link(dist, x, Val(false)), x) + @test logpdf(dist, x .+ ϵ) - _logabsdet(J) ≈ logpdf_turing + + # Issue #12 + stepsize = 1e10 + dim = length(dist) + x = [ + logpdf_with_trans( + dist, + invlink(dist, link(dist, rand(dist)) .+ randn(dim) .* stepsize), + true, + ) for _ in 1:1_000 + ] + @test !any(isinf, x) && !any(isnan, x) else - rand(dist) + single_sample_tests(dist, ForwardDiff.jacobian) end - - logpdf_turing = logpdf_with_trans(dist, x, true) - J = ForwardDiff.jacobian(x->link(dist, x, Val(false)), x) - @test logpdf(dist, x .+ ϵ) - _logabsdet(J) ≈ logpdf_turing - - # Issue #12 - stepsize = 1e10 - dim = length(dist) - x = [logpdf_with_trans(dist, invlink(dist, link(dist, rand(dist)) .+ randn(dim) .* stepsize), true) for _ in 1:1_000] - @test !any(isinf, x) && !any(isnan, x) - else - single_sample_tests(dist, ForwardDiff.jacobian) end end end -end # Tests with matrix-valued distributions. @testset "matrix" begin -let - matrix_dists = [ - Wishart(7, [1 0.5; 0.5 1]), - InverseWishart(2, [1 0.5; 0.5 1]), - ] - for dist in matrix_dists - - single_sample_tests(dist) + let + matrix_dists = [Wishart(7, [1 0.5; 0.5 1]), InverseWishart(2, [1 0.5; 0.5 1])] + for dist in matrix_dists + single_sample_tests(dist) - x = rand(dist); x = x + x' + 2I - lowerinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[1] >= I[2]] - upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] >= I[1]] - logpdf_turing = logpdf_with_trans(dist, x, true) - J = ForwardDiff.jacobian(x->link(dist, x), x) - J = J[lowerinds, upperinds] - @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing + x = rand(dist) + x = x + x' + 2I + lowerinds = [ + LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[1] >= I[2] + ] + upperinds = [ + LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] >= I[1] + ] + logpdf_turing = logpdf_with_trans(dist, x, true) + J = ForwardDiff.jacobian(x -> link(dist, x), x) + J = J[lowerinds, upperinds] + @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing + end end end -end @testset "correlation matrix" begin - dist = LKJ(2, 1) single_sample_tests(dist) @@ -178,10 +190,12 @@ end x = rand(dist) x = x + x' + 2I d = 1 ./ sqrt.(diag(x)) - x = d .* x .* d' + x = d .* x .* d' - upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1]] - J = ForwardDiff.jacobian(x->link(dist, x), x) + upperinds = [ + LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1] + ] + J = ForwardDiff.jacobian(x -> link(dist, x), x) J = J[upperinds, upperinds] logpdf_turing = logpdf_with_trans(dist, x, true) @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing @@ -198,24 +212,26 @@ end # -3.006450206744678 # julia> logpdf_with_trans(Dirichlet([1., 1., 1.]), [-1., -2., -3.], true, true) # -3.006450206744678 -d = Dirichlet([1., 1., 1.]) -r = [-1000., -1000., 0.0] -r2 = [-1., -2., 0.0] +d = Dirichlet([1.0, 1.0, 1.0]) +r = [-1000.0, -1000.0, 0.0] +r2 = [-1.0, -2.0, 0.0] # test vector invlink dist = Dirichlet(ones(5)) -x = [[-2.72689, -2.92751, 1.63114, -1.62054, 0.0] [-1.24249, 2.58902, -3.73043, -3.53685, 0.0]] -@test all(sum(Bijectors.invlink(dist, x), dims = 1) .== 1) +x = [[-2.72689, -2.92751, 1.63114, -1.62054, 0.0] [ + -1.24249, 2.58902, -3.73043, -3.53685, 0.0 +]] +@test all(sum(Bijectors.invlink(dist, x); dims=1) .== 1) # test link #link(d, r) # test invlink -@test invlink(d, r) ≈ [0., 0., 1.] atol=1e-9 +@test invlink(d, r) ≈ [0.0, 0.0, 1.0] atol = 1e-9 # test logpdf_with_trans #@test logpdf_with_trans(d, invlink(d, r), true) -1999.30685281944 1e-9 ≈ # atol=NaN -@test logpdf_with_trans(d, invlink(d, r2), true) ≈ -3.760398892580863 atol=1e-9 +@test logpdf_with_trans(d, invlink(d, r2), true) ≈ -3.760398892580863 atol = 1e-9 macro aeq(x, y) return quote @@ -237,11 +253,20 @@ end g1 = y -> invlink(dist, y, Val(true)) g2 = y -> invlink(dist, y, Val(false)) - @test @aeq ForwardDiff.jacobian(f1, x) @inferred(Bijectors.simplex_link_jacobian(x, Val(true))) - @test @aeq ForwardDiff.jacobian(f2, x) @inferred(Bijectors.simplex_link_jacobian(x, Val(false))) - @test @aeq ForwardDiff.jacobian(g1, y) @inferred(Bijectors.simplex_invlink_jacobian(y, Val(true))) - @test @aeq ForwardDiff.jacobian(g2, y) @inferred(Bijectors.simplex_invlink_jacobian(y, Val(false))) - @test @aeq Bijectors.simplex_link_jacobian(x, Val(false)) * Bijectors.simplex_invlink_jacobian(y, Val(false)) I + @test @aeq ForwardDiff.jacobian(f1, x) @inferred( + Bijectors.simplex_link_jacobian(x, Val(true)) + ) + @test @aeq ForwardDiff.jacobian(f2, x) @inferred( + Bijectors.simplex_link_jacobian(x, Val(false)) + ) + @test @aeq ForwardDiff.jacobian(g1, y) @inferred( + Bijectors.simplex_invlink_jacobian(y, Val(true)) + ) + @test @aeq ForwardDiff.jacobian(g2, y) @inferred( + Bijectors.simplex_invlink_jacobian(y, Val(false)) + ) + @test @aeq Bijectors.simplex_link_jacobian(x, Val(false)) * + Bijectors.simplex_invlink_jacobian(y, Val(false)) I end for i in 1:4 test_link_and_invlink()