From 47214f806790a6a281d1ec608059924e5c592192 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 8 Nov 2021 09:53:36 +0100 Subject: [PATCH] Remove adjoint for `fill` and fix tests (#203) * Remove adjoint for `fill` and fix Zygote tests * Bump version * Fix some more problems * Extension of #203: Fix deprecations in test (#204) * Fix deprecations * Improve CI (AD): cancel builds and no coverage * Improve CI (Others): cancel builds and no coverage * Change parameters to avoid issues with `xlogy` * Tracker does not like Diagonal(Fill(...)) * Unify CI * Fix tests * Update test structure and separate AD better * Fix tests * Relax type constraint * Simplify Zygote tests and use CR * Improve test design * Fix typo * Fix typo * Replace `unpack` with `_to_vec` * Fix tests (a bit) * Fix another test problem * Fix `_to_vec` * Fix handling of broken Zygote tests * Workarounds for `rand_tangent` * Improvements and fixes for Julia 1.3 * Remove Zygote test hack --- .github/workflows/{AD.yml => CI.yml} | 39 ++- .github/workflows/Others.yml | 34 -- Project.toml | 2 +- README.md | 3 +- src/flatten.jl | 7 - src/zygote.jl | 8 - test/ad/distributions.jl | 165 +++++---- test/ad/others.jl | 93 +++++ test/ad/utils.jl | 489 ++++++++++++++++++++------- test/others.jl | 154 +-------- test/runtests.jl | 83 +---- test/test_utils.jl | 23 ++ 12 files changed, 617 insertions(+), 483 deletions(-) rename .github/workflows/{AD.yml => CI.yml} (50%) delete mode 100644 .github/workflows/Others.yml create mode 100644 test/ad/others.jl create mode 100644 test/test_utils.jl diff --git a/.github/workflows/AD.yml b/.github/workflows/CI.yml similarity index 50% rename from .github/workflows/AD.yml rename to .github/workflows/CI.yml index 2012de7d..cc8b7e07 100644 --- a/.github/workflows/AD.yml +++ b/.github/workflows/CI.yml @@ -1,4 +1,4 @@ -name: AD tests +name: CI on: push: @@ -6,10 +6,15 @@ on: - master pull_request: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: test: runs-on: ${{ matrix.os }} - continue-on-error: ${{ matrix.version == 'nightly' }} strategy: matrix: version: @@ -19,7 +24,8 @@ jobs: - ubuntu-latest arch: - x64 - AD: + group: + - Others - ForwardDiff - Tracker - ReverseDiff @@ -28,27 +34,42 @@ jobs: - version: '1' os: macOS-latest arch: x64 - AD: ForwardDiff + group: Others + - version: '1' + os: macOS-latest + arch: x64 + group: ForwardDiff - version: '1' os: macOS-latest arch: x64 - AD: Tracker + group: Tracker - version: '1' os: macOS-latest arch: x64 - AD: ReverseDiff + group: ReverseDiff - version: '1' os: macOS-latest arch: x64 - AD: Zygote + group: Zygote steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} + - uses: actions/cache@v1 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@latest - uses: julia-actions/julia-runtest@latest + with: + coverage: false env: - GROUP: AD - AD: ${{ matrix.AD }} + GROUP: ${{ matrix.group }} diff --git a/.github/workflows/Others.yml b/.github/workflows/Others.yml deleted file mode 100644 index dd8e346c..00000000 --- a/.github/workflows/Others.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: Other tests - -on: - push: - branches: - - master - pull_request: - -jobs: - test: - runs-on: ${{ matrix.os }} - strategy: - matrix: - version: - - '1.3' - - '1' - os: - - ubuntu-latest - arch: - - x64 - include: - - version: '1' - os: macOS-latest - arch: x64 - steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 - with: - version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - - uses: julia-actions/julia-buildpkg@latest - - uses: julia-actions/julia-runtest@latest - env: - GROUP: Others diff --git a/Project.toml b/Project.toml index 99aa375e..fdef280f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.31" +version = "0.6.32" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/README.md b/README.md index 9d133fea..d10cae49 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@ # DistributionsAD.jl -[![AD tests](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/AD.yml/badge.svg?branch=master)](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/AD.yml?query=branch%3Amaster) -[![Other tests](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/Others.yml/badge.svg?branch=master)](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/Others.yml?query=branch%3Amaster) +[![CI](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/CI.yml?query=branch%3Amaster) This package defines the necessary functions to enable automatic differentiation (AD) of the `logpdf` function from [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) using the packages [Tracker.jl](https://github.com/FluxML/Tracker.jl), [Zygote.jl](https://github.com/FluxML/Zygote.jl), [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) and [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl). The goal of this package is to make the output of `logpdf` differentiable wrt all continuous parameters of a distribution as well as the random variable in the case of continuous distributions. diff --git a/src/flatten.jl b/src/flatten.jl index 37e13e42..f386a4b9 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -27,7 +27,6 @@ const flattened_dists = [ Bernoulli, Poisson, Skellam, Arcsine, - Beta, BetaPrime, Biweight, Cauchy, @@ -40,11 +39,9 @@ const flattened_dists = [ Bernoulli, Exponential, FDist, Frechet, - Gamma, GeneralizedExtremeValue, GeneralizedPareto, Gumbel, - #InverseGamma, InverseGaussian, Kolmogorov, Laplace, @@ -54,8 +51,6 @@ const flattened_dists = [ Bernoulli, LogitNormal, LogNormal, Normal, - #NormalCanon, - #NormalInverseGaussian, Pareto, PGeneralizedGaussian, Rayleigh, @@ -64,8 +59,6 @@ const flattened_dists = [ Bernoulli, TriangularDist, Triweight, TuringUniform, - #Truncated, - #VonMises, ] for T in flattened_dists @eval toflatten(::$T) = true diff --git a/src/zygote.jl b/src/zygote.jl index 236878ff..5500e8e3 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -1,11 +1,3 @@ -# Zygote fill has issues with non-numbers -ZygoteRules.@adjoint function fill(x::T, dims...) where {T} - return ZygoteRules.pullback(x, dims...) do x, dims... - return reshape([x for i in 1:prod(dims)], dims) - end -end - - ## Uniform ## ZygoteRules.@adjoint function Distributions.Uniform(args...) diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 3894a17a..08ddb926 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -26,6 +26,13 @@ to_positive(x) = exp.(x) to_positive(x::AbstractArray{<:AbstractArray}) = to_positive.(x) + # The following definition should not be needed + # It seems there is a bug in the default `rand_tangent` that causes a + # StackOverflowError though + function ChainRulesTestUtils.rand_tangent(::Random.AbstractRNG, ::typeof(to_positive)) + return NoTangent() + end + # Tests that have a `broken` field can be executed but, according to FiniteDifferences, # fail to produce the correct result. These tests can be checked with `@test_broken`. univariate_distributions = DistSpec[ @@ -68,13 +75,13 @@ DistSpec(Arcsine, (1.0,), 0.5), DistSpec(Arcsine, (0.0, 2.0), 0.5), - DistSpec(Beta, (), 0.5), - DistSpec(Beta, (1.0,), 0.5), - DistSpec(Beta, (1.0, 2.0), 0.5), + DistSpec(Beta, (), 0.4), + DistSpec(Beta, (1.5,), 0.4), + DistSpec(Beta, (1.5, 2.0), 0.4), - DistSpec(BetaPrime, (), 0.5), - DistSpec(BetaPrime, (1.5,), 0.5), - DistSpec(BetaPrime, (1.5, 2.0), 0.5), + DistSpec(BetaPrime, (), 0.4), + DistSpec(BetaPrime, (1.5,), 0.4), + DistSpec(BetaPrime, (1.5, 2.0), 0.4), DistSpec(Biweight, (), 0.5), DistSpec(Biweight, (1.0,), 0.5), @@ -104,9 +111,9 @@ DistSpec(Frechet, (1.0,), 0.5), DistSpec(Frechet, (1.0, 2.0), 0.5), - DistSpec(Gamma, (), 0.5), - DistSpec(Gamma, (1.0,), 0.5), - DistSpec(Gamma, (1.0, 2.0), 0.5), + DistSpec(Gamma, (), 0.4), + DistSpec(Gamma, (1.5,), 0.4), + DistSpec(Gamma, (1.5, 2.0), 0.4), DistSpec(GeneralizedExtremeValue, (1.0, 1.0, 1.0), 0.5), @@ -230,52 +237,45 @@ # Vector x DistSpec((m, A) -> MvNormal(m, to_posdef(A)), (a, A), b), - DistSpec(MvNormal, (a, b), c), DistSpec((m, s) -> MvNormal(m, to_posdef_diagonal(s)), (a, b), c), - DistSpec(MvNormal, (a, alpha), b), DistSpec((m, s) -> MvNormal(m, s^2 * I), (a, alpha), b), DistSpec(A -> MvNormal(to_posdef(A)), (A,), a), - DistSpec(MvNormal, (a,), b), DistSpec(s -> MvNormal(to_posdef_diagonal(s)), (a,), b), - DistSpec(s -> MvNormal(dim, s), (alpha,), a), + DistSpec(s -> MvNormal(zeros(dim), s^2 * I), (alpha,), a), DistSpec((m, A) -> TuringMvNormal(m, to_posdef(A)), (a, A), b), - DistSpec(TuringMvNormal, (a, b), c), DistSpec((m, s) -> TuringMvNormal(m, to_posdef_diagonal(s)), (a, b), c), - DistSpec(TuringMvNormal, (a, alpha), b), DistSpec((m, s) -> TuringMvNormal(m, s^2 * I), (a, alpha), b), DistSpec(A -> TuringMvNormal(to_posdef(A)), (A,), a), - DistSpec(TuringMvNormal, (a,), b), DistSpec(s -> TuringMvNormal(to_posdef_diagonal(s)), (a,), b), - DistSpec(s -> TuringMvNormal(dim, s), (alpha,), a), + DistSpec(s -> TuringMvNormal(zeros(dim), s^2 * I), (alpha,), a), DistSpec((m, A) -> MvLogNormal(m, to_posdef(A)), (a, A), b, to_positive), - DistSpec(MvLogNormal, (a, b), c, to_positive), DistSpec((m, s) -> MvLogNormal(m, to_posdef_diagonal(s)), (a, b), c, to_positive), - DistSpec(MvLogNormal, (a, alpha), b, to_positive), + DistSpec((m, s) -> MvLogNormal(m, s^2 * I), (a, alpha), b, to_positive), DistSpec(A -> MvLogNormal(to_posdef(A)), (A,), a, to_positive), - DistSpec(MvLogNormal, (a,), b, to_positive), DistSpec(s -> MvLogNormal(to_posdef_diagonal(s)), (a,), b, to_positive), - DistSpec(s -> MvLogNormal(dim, s), (alpha,), a, to_positive), + DistSpec(s -> MvLogNormal(zeros(dim), s^2 * I), (alpha,), a, to_positive), DistSpec(alpha -> Dirichlet(to_positive(alpha)), (a,), b, to_simplex), # Matrix case - DistSpec(MvNormal, (a, b), A), + DistSpec((m, A) -> MvNormal(m, to_posdef(A)), (a, A), B), DistSpec((m, s) -> MvNormal(m, to_posdef_diagonal(s)), (a, b), A), - DistSpec(MvNormal, (a, alpha), A), DistSpec((m, s) -> MvNormal(m, s^2 * I), (a, alpha), A), - DistSpec(MvNormal, (a,), A), - DistSpec(s -> MvNormal(to_posdef_diagonal(s)), (a,), A), - DistSpec(s -> MvNormal(dim, s), (alpha,), A), - DistSpec((m, A) -> MvNormal(m, to_posdef(A)), (a, A), B), DistSpec(A -> MvNormal(to_posdef(A)), (A,), B), - DistSpec(MvLogNormal, (a, b), A, to_positive), - DistSpec((m, s) -> MvLogNormal(m, to_posdef_diagonal(s)), (a, b), A, to_positive), - DistSpec(MvLogNormal, (a, alpha), A, to_positive), - DistSpec(MvLogNormal, (a,), A, to_positive), - DistSpec(s -> MvLogNormal(to_posdef_diagonal(s)), (a,), A, to_positive), - DistSpec(s -> MvLogNormal(dim, s), (alpha,), A, to_positive), + DistSpec(s -> MvNormal(to_posdef_diagonal(s)), (a,), A), + DistSpec(s -> MvNormal(zeros(dim), s^2 * I), (alpha,), A), + DistSpec((m, A) -> TuringMvNormal(m, to_posdef(A)), (a, A), B), + DistSpec((m, s) -> TuringMvNormal(m, to_posdef_diagonal(s)), (a, b), A), + DistSpec((m, s) -> TuringMvNormal(m, s^2 * I), (a, alpha), A), + DistSpec(A -> TuringMvNormal(to_posdef(A)), (A,), B), + DistSpec(s -> TuringMvNormal(to_posdef_diagonal(s)), (a,), A), + DistSpec(s -> TuringMvNormal(zeros(dim), s^2 * I), (alpha,), A), DistSpec((m, A) -> MvLogNormal(m, to_posdef(A)), (a, A), B, to_positive), + DistSpec((m, s) -> MvLogNormal(m, to_posdef_diagonal(s)), (a, b), A, to_positive), + DistSpec((m, s) -> MvLogNormal(m, s^2 * I), (a, alpha), A, to_positive), DistSpec(A -> MvLogNormal(to_posdef(A)), (A,), B, to_positive), + DistSpec(s -> MvLogNormal(to_posdef_diagonal(s)), (a,), A, to_positive), + DistSpec(s -> MvLogNormal(zeros(dim), s^2 * I), (alpha,), A, to_positive), DistSpec(alpha -> Dirichlet(to_positive(alpha)), (a,), A, to_simplex), ] @@ -284,24 +284,28 @@ broken_multivariate_distributions = DistSpec[ # Dispatch error DistSpec((m, A) -> MvNormalCanon(m, to_posdef(A)), (a, A), b), - DistSpec(MvNormalCanon, (a, b), c), - DistSpec(MvNormalCanon, (a, alpha), b), + DistSpec((m, p) -> MvNormalCanon(m, to_posdef_diagonal(p)), (a, b), c), + DistSpec((m, p) -> MvNormalCanon(m, p^2 * I), (a, alpha), b), DistSpec(A -> MvNormalCanon(to_posdef(A)), (A,), a), - DistSpec(MvNormalCanon, (a,), b), - DistSpec(s -> MvNormalCanon(dim, s), (alpha,), a), + DistSpec(p -> MvNormalCanon(to_posdef_diagonal(p)), (a,), b), + DistSpec(p -> MvNormalCanon(zeros(dim), p^2 * I), (alpha,), a), DistSpec((m, A) -> MvNormalCanon(m, to_posdef(A)), (a, A), B), - DistSpec(MvNormalCanon, (a, b), A), - DistSpec(MvNormalCanon, (a, alpha), A), + DistSpec((m, p) -> MvNormalCanon(m, to_posdef_diagonal(p)), (a, b), A), + DistSpec((m, p) -> MvNormalCanon(m, p^2 * I), (a, alpha), A), DistSpec(A -> MvNormalCanon(to_posdef(A)), (A,), B), - DistSpec(MvNormalCanon, (a,), A), - DistSpec(s -> MvNormalCanon(dim, s), (alpha,), A), + DistSpec(p -> MvNormalCanon(to_posdef_diagonal(p)), (a,), A), + DistSpec(p -> MvNormalCanon(zeros(dim), p^2 * I), (alpha,), A), ] # Tests that have a `broken` field can be executed but, according to FiniteDifferences, # fail to produce the correct result. These tests can be checked with `@test_broken`. matrixvariate_distributions = DistSpec[ # Matrix x - DistSpec((n1, n2) -> MatrixBeta(dim, n1, n2), (3.0, 3.0), A, to_beta_mat), + # We should use + # DistSpec((n1, n2) -> MatrixBeta(dim, n1, n2), (3.0, 3.0), A, to_beta_mat), + # but the default implementation of `rand_tangent` causes a StackOverflowError + # Thus we use the following workaround + DistSpec((n1, n2) -> MatrixBeta(3, n1, n2), (3.0, 3.0), A, to_beta_mat), DistSpec(() -> MatrixNormal(dim, dim), (), A, to_posdef, broken=(:Zygote,)), DistSpec((df, A) -> Wishart(df, to_posdef(A)), (3.0, A), B, to_posdef), DistSpec((df, A) -> InverseWishart(df, to_posdef(A)), (3.0, A), B, to_posdef), @@ -309,8 +313,17 @@ DistSpec((df, A) -> TuringInverseWishart(df, to_posdef(A)), (3.0, A), B, to_posdef), # Vector of matrices x + # Also here we should use + # DistSpec( + # (n1, n2) -> MatrixBeta(dim, n1, n2), + # (3.0, 3.0), + # [A, B], + # x -> map(to_beta_mat, x), + #), + # but the default implementation of `rand_tangent` causes a StackOverflowError + # Thus we use the following workaround DistSpec( - (n1, n2) -> MatrixBeta(dim, n1, n2), + (n1, n2) -> MatrixBeta(3, n1, n2), (3.0, 3.0), [A, B], x -> map(to_beta_mat, x), @@ -371,6 +384,7 @@ println("\nTesting: Univariate distributions\n") for d in univariate_distributions + @info "Testing: $(nameof(dist_type(d)))" test_ad(d) end end @@ -379,6 +393,7 @@ println("\nTesting: Multivariate distributions\n") for d in multivariate_distributions + @info "Testing: $(nameof(dist_type(d)))" test_ad(d) end @@ -388,41 +403,43 @@ d.x isa Number || continue # Broken distributions - d.f(d.θ...) isa Union{VonMises,TriangularDist} && continue + D = dist_type(d) + D <: Union{VonMises,TriangularDist} && continue # Skellam only fails in these tests with ReverseDiff # Ref: https://github.com/TuringLang/DistributionsAD.jl/issues/126 # PoissonBinomial fails with Zygote # Matrix case does not work with Skellam: # https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493 - filldist_broken = if d.f(d.θ...) isa Skellam + filldist_broken = if D <: Skellam ((d.broken..., :Zygote, :ReverseDiff), (d.broken..., :Zygote, :ReverseDiff)) - elseif d.f(d.θ...) isa PoissonBinomial + elseif D <: PoissonBinomial ((d.broken..., :Zygote), (d.broken..., :Zygote)) - elseif d.f(d.θ...) isa Chernoff + elseif D <: Chernoff # Zygote is not broken with `filldist` ((), ()) else (d.broken, d.broken) end - arraydist_broken = if d.f(d.θ...) isa PoissonBinomial + arraydist_broken = if D <: PoissonBinomial ((d.broken..., :Zygote), (d.broken..., :Zygote)) else (d.broken, d.broken) end # Create `filldist` distribution - f_filldist = (θ...,) -> filldist(d.f(θ...), n) + f = d.f + f_filldist = (θ...,) -> filldist(f(θ...), n) d_filldist = f_filldist(d.θ...) # Create `arraydist` distribution - f_arraydist = (θ...,) -> arraydist([d.f(θ...) for _ in 1:n]) + f_arraydist = (θ...,) -> arraydist([f(θ...) for _ in 1:n]) d_arraydist = f_arraydist(d.θ...) for (i, sz) in enumerate(((n,), (n, 2))) # Matrix case doesn't work for continuous distributions for some reason # now but not too important (?!) - if length(sz) == 2 && Distributions.value_support(typeof(d)) === Continuous + if length(sz) == 2 && D <: ContinuousDistribution continue end @@ -430,9 +447,9 @@ x = fill(d.x, sz) # Test AD + @info "Testing: filldist($(nameof(D)), $sz)" test_ad( DistSpec( - Symbol(:filldist, " (", d.name, ", $sz)"), f_filldist, d.θ, x, @@ -440,9 +457,10 @@ broken=filldist_broken[i], ) ) + + @info "Testing: arraydist($(nameof(D)), $sz)" test_ad( DistSpec( - Symbol(:arraydist, " (", d.name, ", $sz)"), f_arraydist, d.θ, x, @@ -458,6 +476,7 @@ println("\nTesting: Matrixvariate distributions\n") for d in matrixvariate_distributions + @info "Testing: $(nameof(dist_type(d)))" test_ad(d) end @@ -465,27 +484,30 @@ n = (2, 2) # always use 2 x 2 distributions for d in univariate_distributions d.x isa Number || continue - Distributions.value_support(typeof(d)) === Discrete && continue + D = dist_type(d) + D <: DiscreteDistribution && continue # Broken distributions - d.f(d.θ...) isa Union{VonMises,TriangularDist} && continue + D <: Union{VonMises,TriangularDist} && continue # Create `filldist` distribution - f_filldist = (θ...,) -> filldist(d.f(θ...), n...) + f = d.f + f_filldist = (θ...,) -> filldist(f(θ...), n...) # Create `arraydist` distribution - f_arraydist = (θ...,) -> arraydist(fill(d.f(θ...), n...)) + # Zygote's fill definition does not like non-numbers, so we use a workaround + f_arraydist = (θ...,) -> arraydist(reshape([f(θ...) for _ in 1:prod(n)], n)) # Matrix `x` x_mat = fill(d.x, n) # Zygote is not broken with `filldist` + Chernoff - filldist_broken = d.f(d.θ...) isa Chernoff ? () : d.broken + filldist_broken = D <: Chernoff ? () : d.broken # Test AD + @info "Testing: filldist($(nameof(D)), $n)" test_ad( DistSpec( - Symbol(:filldist, " (", d.name, ", $n)"), f_filldist, d.θ, x_mat, @@ -493,9 +515,9 @@ broken=filldist_broken, ) ) + @info "Testing: arraydist($(nameof(D)), $n)" test_ad( DistSpec( - Symbol(:arraydist, " (", d.name, ", $n)"), f_arraydist, d.θ, x_mat, @@ -508,9 +530,9 @@ x_vec_of_mat = [fill(d.x, n) for _ in 1:2] # Test AD + @info "Testing: filldist($(nameof(D)), $n, 2)" test_ad( DistSpec( - Symbol(:filldist, " (", d.name, ", $n, 2)"), f_filldist, d.θ, x_vec_of_mat, @@ -518,9 +540,9 @@ broken=filldist_broken, ) ) + @info "Testing: arraydist($(nameof(D)), $n, 2)" test_ad( DistSpec( - Symbol(:arraydist, " (", d.name, ", $n, 2)"), f_arraydist, d.θ, x_vec_of_mat, @@ -530,15 +552,15 @@ ) end - # test `filldist` and `arraydist` distributions of multivariate distributions n = 2 # always use two distributions for d in multivariate_distributions d.x isa AbstractVector || continue - Distributions.value_support(typeof(d)) === Discrete && continue + D = dist_type(d) + D <: DiscreteDistribution && continue # Tests are failing for matrix covariance vectorized MvNormal - if d.f(d.θ...) isa Union{ + if D <: Union{ MvNormal,MvLogNormal, DistributionsAD.TuringDenseMvNormal, DistributionsAD.TuringDiagMvNormal, @@ -549,18 +571,19 @@ end # Create `filldist` distribution - f_filldist = (θ...,) -> filldist(d.f(θ...), n) + f = d.f + f_filldist = (θ...,) -> filldist(f(θ...), n) # Create `arraydist` distribution - f_arraydist = (θ...,) -> arraydist(fill(d.f(θ...), n)) + f_arraydist = (θ...,) -> arraydist([f(θ...) for _ in 1:n]) # Matrix `x` x_mat = repeat(d.x, 1, n) # Test AD + @info "Testing: filldist($(nameof(D)), $n)" test_ad( DistSpec( - Symbol(:filldist, " (", d.name, ", $n)"), f_filldist, d.θ, x_mat, @@ -568,9 +591,9 @@ broken=d.broken, ) ) + @info "Testing: arraydist($(nameof(D)), $n)" test_ad( DistSpec( - Symbol(:arraydist, " (", d.name, ", $n)"), f_arraydist, d.θ, x_mat, @@ -583,9 +606,9 @@ x_vec_of_mat = [repeat(d.x, 1, n) for _ in 1:2] # Test AD + @info "Testing: filldist($(nameof(D)), $n, 2)" test_ad( DistSpec( - Symbol(:filldist, " (", d.name, ", $n, 2)"), f_filldist, d.θ, x_vec_of_mat, @@ -593,9 +616,9 @@ broken=d.broken, ) ) + @info "Testing: arraydist($(nameof(D)), $n, 2)" test_ad( DistSpec( - Symbol(:arraydist, " (", d.name, ", $n, 2)"), f_arraydist, d.θ, x_vec_of_mat, diff --git a/test/ad/others.jl b/test/ad/others.jl new file mode 100644 index 00000000..a0b74317 --- /dev/null +++ b/test/ad/others.jl @@ -0,0 +1,93 @@ +@testset "AD: Others" begin + if GROUP == "All" || GROUP == "Tracker" + @testset "TuringUniform" begin + @test logpdf(TuringUniform(), param(0.5)) == 0 + end + + @testset "Semicircle" begin + @test Tracker.data(logpdf(Semicircle(1.0), param(0.5))) == logpdf(Semicircle(1.0), 0.5) + end + end + + @testset "logsumexp" begin + x = rand(3) + test_reverse_mode_ad(logsumexp, randn(), x; rtol=1e-8, atol=1e-6) + end + + @testset "zygote_ldiv" begin + A = to_posdef(rand(3, 3)) + B = to_posdef(rand(3, 3)) + + test_reverse_mode_ad(randn(3, 3), A, B) do A, B + return DistributionsAD.zygote_ldiv(A, B) + end + end + + @testset "logdet" begin + N = 7 + B = randn(N, N) + + test_reverse_mode_ad(randn(), B; rtol=1e-8, atol=1e-6) do B + return logdet(cholesky(to_posdef(B))) + end + test_reverse_mode_ad(randn(), B; rtol=1e-8, atol=1e-6) do B + return logdet(cholesky(Symmetric(to_posdef(B)))) + end + end + + @testset "fill" begin + if GROUP == "All" || GROUP == "Tracker" + @test fill(param(1.0), 3) isa TrackedArray + end + + test_reverse_mode_ad(x->fill(x, 7), randn(7), randn()) + test_reverse_mode_ad(x->fill(x, 7, 11), randn(7, 11), randn()) + test_reverse_mode_ad(x->fill(x, 7, 11, 13), rand(7, 11, 13), randn()) + end + + @testset "Tracker, Zygote and ReverseDiff + MvNormal" begin + N = 7 + m = rand(N) + B = randn(N, N) + x = rand(TuringDenseMvNormal(m, to_posdef(B))) + + test_reverse_mode_ad(randn(), m, B, x) do m, B, x + return logpdf(MvNormal(m, to_posdef(B)), x) + end + test_reverse_mode_ad(randn(), m, B, x) do m, B, x + return logpdf(TuringMvNormal(m, to_posdef(B)), x) + end + test_reverse_mode_ad(randn(), m, B, x) do m, B, x + return logpdf(TuringMvNormal(m, Symmetric(to_posdef(B))), x) + end + end + + @testset "adapt_randn" begin + rng = MersenneTwister() + n = 50 + dims = (10, 30) + for T in (Float32, Float64) + if GROUP == "All" || GROUP == "ForwardDiff" + let + x = [ForwardDiff.Dual(rand(rng, T)) for _ in 1:n] + test_adapt_randn(rng, x, T, dims...) + end + end + if GROUP == "All" || GROUP == "Tracker" + let + x = Tracker.TrackedArray(rand(rng, T, 50)) + test_adapt_randn(rng, x, T, dims...) + end + end + if GROUP == "All" || GROUP == "ReverseDiff" + let + v = rand(rng, T, n) + d = rand(Int, n) + tp = ReverseDiff.InstructionTape() + x = ReverseDiff.TrackedArray(v, d, tp) + test_adapt_randn(rng, x, T, dims...) + end + end + end + end +end \ No newline at end of file diff --git a/test/ad/utils.jl b/test/ad/utils.jl index 9525523a..ecf4a61a 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -1,6 +1,139 @@ +using ChainRulesCore +using ChainRulesTestUtils +using FiniteDifferences + +const FDM = FiniteDifferences + +# Load AD backends +if GROUP == "All" || GROUP == "ForwardDiff" + @eval using ForwardDiff +end +if GROUP == "All" || GROUP == "Zygote" + @eval using Zygote +end +if GROUP == "All" || GROUP == "ReverseDiff" + @eval using ReverseDiff +end +if GROUP == "All" || GROUP == "Tracker" + @eval using Tracker +end + +function test_reverse_mode_ad(f, ȳ, x...; rtol=1e-6, atol=1e-6) + # Perform a regular forwards-pass. + y = f(x...) + + # Use finite differencing to compute reverse-mode sensitivities. + x̄s_fdm = FDM.j′vp(central_fdm(5, 1), f, ȳ, x...) + + if GROUP == "All" || GROUP == "Zygote" + # Use Zygote to compute reverse-mode sensitivities. + y_zygote, back_zygote = Zygote.pullback(f, x...) + x̄s_zygote = back_zygote(ȳ) + + # Check that Zygpte forwards-pass produces the correct answer. + @test y ≈ y_zygote atol=atol rtol=rtol + + # Check that Zygote reverse-mode sensitivities are correct. + @test all(zip(x̄s_zygote, x̄s_fdm)) do (x̄_zygote, x̄_fdm) + return isapprox(x̄_zygote, x̄_fdm; atol=atol, rtol=rtol) + end + end + + if GROUP == "All" || GROUP == "ReverseDiff" + test_rd = length(x) == 1 && y isa Number + if test_rd + # Use ReverseDiff to compute reverse-mode sensitivities. + if x[1] isa Array + x̄s_rd = similar(x[1]) + tp = ReverseDiff.GradientTape(x -> f(x), x[1]) + ReverseDiff.gradient!(x̄s_rd, tp, x[1]) + x̄s_rd .*= ȳ + y_rd = ReverseDiff.value(tp.output) + @assert y_rd isa Number + else + x̄s_rd = [x[1]] + tp = ReverseDiff.GradientTape(x -> f(x[1]), [x[1]]) + ReverseDiff.gradient!(x̄s_rd, tp, [x[1]]) + y_rd = ReverseDiff.value(tp.output)[1] + x̄s_rd = x̄s_rd[1] * ȳ + @assert y_rd isa Number + end + + # Check that ReverseDiff forwards-pass produces the correct answer. + @test y ≈ y_rd atol=atol rtol=rtol + + # Check that ReverseDiff reverse-mode sensitivities are correct. + @test x̄s_rd ≈ x̄s_fdm[1] atol=atol rtol=rtol + end + end + + if GROUP == "All" || GROUP == "Tracker" + # Use Tracker to compute reverse-mode sensitivities. + y_tracker, back_tracker = Tracker.forward(f, x...) + x̄s_tracker = back_tracker(ȳ) + + # Check that Tracker forwards-pass produces the correct answer. + @test y ≈ Tracker.data(y_tracker) atol=atol rtol=rtol + + # Check that Tracker reverse-mode sensitivities are correct. + @test all(zip(x̄s_tracker, x̄s_fdm)) do (x̄_tracker, x̄_fdm) + return isapprox(Tracker.data(x̄_tracker), x̄_fdm; atol=atol, rtol=rtol) + end + end +end + +# Define pullback for `to_simplex` +function to_simplex_pullback(ȳ::AbstractArray, y::AbstractArray) + x̄ = ȳ .* y + x̄ .= x̄ .- y .* sum(x̄; dims=1) + return x̄ +end +function ChainRulesCore.rrule(::typeof(to_simplex), x::AbstractArray{<:Real}) + y = to_simplex(x) + pullback(ȳ) = (NoTangent(), to_simplex_pullback(ȳ, y)) + return y, pullback +end + +# Define adjoints for ReverseDiff +if GROUP == "All" || GROUP == "ReverseDiff" + @eval begin + function to_simplex(x::AbstractArray{<:ReverseDiff.TrackedReal}) + return ReverseDiff.track(to_simplex, x) + end + ReverseDiff.@grad function to_simplex(x) + _x = ReverseDiff.value(x) + y = to_simplex(_x) + pullback(ȳ) = (to_simplex_pullback(ȳ, y),) + return y, pullback + end + end +end + +# Define adjoints for Tracker +if GROUP == "All" || GROUP == "Tracker" + @eval begin + to_posdef(A::Tracker.TrackedMatrix) = Tracker.track(to_posdef, A) + Tracker.@grad function to_posdef(A::Tracker.TrackedMatrix) + data_A = Tracker.data(A) + S = data_A * data_A' + I + function pullback(∇) + return ((∇ + ∇') * data_A,) + end + return S, pullback + end + + to_simplex(x::Tracker.TrackedArray) = Tracker.track(to_simplex, x) + Tracker.@grad function to_simplex(x::Tracker.TrackedArray) + data_x = Tracker.data(x) + y = to_simplex(data_x) + pullback(ȳ) = (to_simplex_pullback(ȳ, y),) + return y, pullback + end + end +end + # Struct of distribution, corresponding parameters, and a sample. -struct DistSpec{VF<:VariateForm,VS<:ValueSupport,F,T,X,G,B<:Tuple} - name::Symbol +struct DistSpec{D,F,T,X,G,B<:Tuple} f::F "Distribution parameters." θ::T @@ -10,151 +143,214 @@ struct DistSpec{VF<:VariateForm,VS<:ValueSupport,F,T,X,G,B<:Tuple} xtrans::G "Broken backends" broken::B + + function DistSpec{D}(f::F, θ, x, xtrans::T, broken) where {D,F,T} + return new{D,F,typeof(θ),typeof(x),T,typeof(broken)}(f, θ, x, xtrans, broken) + end end -function DistSpec(f, θ, x, xtrans=nothing; broken=()) - name = f isa Distribution ? nameof(typeof(f)) : nameof(typeof(f(θ...))) - return DistSpec(name, f, θ, x, xtrans; broken=broken) +function DistSpec(d::Distribution, θ, x, xtrans=nothing; broken=()) + return DistSpec{typeof(d)}(d, θ, x, xtrans, broken) end -function DistSpec(name::Symbol, f, θ, x, xtrans=nothing; broken=()) - F = f isa Distribution ? typeof(f) : typeof(f(θ...)) - VF = Distributions.variate_form(F) - VS = Distributions.value_support(F) - return DistSpec{VF,VS,typeof(f),typeof(θ),typeof(x),typeof(xtrans),typeof(broken)}( - name, f, θ, x, xtrans, broken, - ) +function DistSpec(f::F, θ, x, xtrans=nothing; broken=()) where {F} + D = typeof(f(θ...)) + return DistSpec{D}(f, θ, x, xtrans, broken) end -Distributions.variate_form(::Type{<:DistSpec{VF}}) where VF = VF -Distributions.value_support(::Type{<:DistSpec{VF,VS}}) where {VF,VS} = VS +dist_type(::DistSpec{D}) where {D} = D -# Auxiliary method for vectorizing parameters and samples -vectorize(v::Number) = [v] -vectorize(v::Diagonal) = v.diag -vectorize(v::AbstractVector{<:AbstractMatrix}) = mapreduce(vectorize, vcat, v) -vectorize(v) = vec(v) +# Auxiliary methods for vectorizing parameters and samples and unflattening them +# similar to `FDM.to_vec` +# However, some implementations in FDM don't work with overload AD such as Tracker, +# ForwardDiff, and ReverseDiff +# Therefore we add a `_to_vec` function -""" - unpack(x, inds, original...) +function _to_vec(x::Real) + function Real_from_vec(v) + length(v) == 1 || error("vector has incorrect number of elements") + return first(v) + end + return [x], Real_from_vec +end -Return a tuple of unpacked parameters and samples in vector `x`. +function _to_vec(x::AbstractArray{<:Real}) + sz = size(x) + Array_from_vec(v) = reshape(v, sz) + return vec(x), Array_from_vec +end -Here `original` are the original full set of parameters and samples, and -`inds` contains the indices of the original parameters and samples for which -a possibly different value is given in `x`. If no value is provided in `x`, -the original value of the parameter is returned. The values are returned -in the same order as the original parameters. -""" -function unpack(x, inds, original...) - offset = 0 - newvals = ntuple(length(original)) do i - if i in inds - v, offset = unpack_offset(x, offset, original[i]) - else - v = original[i] +function _to_vec(x::Union{Tuple,AbstractVector{<:AbstractArray}}) + x_vecs_and_backs = map(_to_vec, x) + x_vecs, x_backs = map(first, x_vecs_and_backs), map(last, x_vecs_and_backs) + lengths = map(length, x_vecs) + sz = typeof(lengths)(cumsum(collect(lengths))) + function Tuple_or_Array_of_Array_from_vec(v) + map(x_backs, lengths, sz) do x_back, l, s + return x_back(v[(s - l + 1):s]) end - return v end - offset == length(x) || throw(ArgumentError()) + return reduce(vcat, x_vecs), Tuple_or_Array_of_Array_from_vec +end - return newvals +# Functor that fixes non-differentiable location `x` for discrete distributions +struct FixedLocation{X} + x::X end +(f::FixedLocation)(args...) = f.x, args -# Auxiliary methods for unpacking numbers and arrays -function unpack_offset(x, offset, original::Number) - newoffset = offset + 1 - val = x[newoffset] - return val, newoffset +# Functor that transforms differentiable location `x` for continuous distributions +# from unconstrained to constrained space +struct TransformedLocation{F} + trans::F end -function unpack_offset(x, offset, original::AbstractArray) - newoffset = offset + length(original) - val = reshape(x[(offset + 1):newoffset], size(original)) - return val, newoffset +(f::TransformedLocation)(x, args...) = f.trans(x), args +(f::TransformedLocation{Nothing})(x, args...) = x, args + +# Convenience function that returns the correct functor for +# discrete and continuous distributions +make_unpack_x_θ(_, x, ::Type{<:DiscreteDistribution}) = FixedLocation(x) +make_unpack_x_θ(trans, _, ::Type{<:ContinuousDistribution}) = TransformedLocation(trans) + +# "Unignore" arguments, i.e., add default arguments if they were ignored +struct Unignore{A} + args::A + ignores::BitVector +end + +function Unignore(args, ignores::BitVector) + n = length(args) + @assert length(ignores) == n + return Unignore{typeof(args)}(args, ignores) end -function unpack_offset(x, offset, original::AbstractArray{<:AbstractArray}) - newoffset = offset - val = map(original) do orig - out, newoffset = unpack_offset(x, newoffset, orig) - return out + +function (f::Unignore)(x...) + j = Ref(0) + newx = map(f.args, f.ignores) do argsi, ignoresi + return if ignoresi + argsi + else + x[(j[] += 1)] + end end - return val, newoffset + + @assert length(x) == j[] || error("wrong number of arguments") + + return newx +end + +# we define the following two functions to be able to tell Zygote that it should not +# compute derivatives for the fields of the functors `unpack_x_θ` +""" + loglikelihood_parameterized(unpack_x_θ, dist, args...) + +Compute the log-likelihood of distribution `dist(θ...)` for `x` where +`x, θ = unpack_x_θ(args...)` are extracted from the arguments `args` with `unpack_x_θ`. + +Internally, computations are performed with `loglikelihood`. + +See also: [`sum_logpdf_parameterized`](@ref) +""" +function loglikelihood_parameterized(unpack_x_θ, f, args...) + x, θ = ignore_derivatives(unpack_x_θ)(args...) + return loglikelihood(f(θ...), x) +end + +""" + sum_logpdf_parameterized(unpack_x_θ, dist, args...) + +Compute the log-likelihood of distribution `dist(θ...)` for `x` where +`x, θ = unpack_x_θ(args...)` are extracted from the arguments `args` with `unpack_x_θ`. + +Internally, the log pdf of individual data points is computed with `logpdf` which are then +summed up. + +See also: [`loglikelihood_parameterized`](@ref) +""" +function sum_logpdf_parameterized(unpack_x_θ, f, args...) + x, θ = ignore_derivatives(unpack_x_θ)(args...) + # we use `_logpdf` to be able to handle univariate distributions correctly (see below) + return sum(_logpdf(f(θ...), x)) end -# Run AD tests of a -function test_ad(dist::DistSpec; kwargs...) - @info "Testing: $(dist.name)" +# Function that computes arrays of `logpdf` values +# `logpdf` does not handle arrays of samples for univariate distributions +_logpdf(d::Distribution, x) = logpdf(d, x) +_logpdf(d::UnivariateDistribution, x::AbstractArray) = logpdf.((d,), x) + +# Run AD tests +function test_ad(dist::DistSpec{D}; kwargs...) where {D} f = dist.f θ = dist.θ x = dist.x - g = dist.xtrans broken = dist.broken - # Create functions with all possible arguments - f_loglik_allargs = let f=f, g=g - function (x, θ...) - dist = f(θ...) - xtilde = g === nothing ? x : g(x) - return loglikelihood(dist, xtilde) - end - end - f_logpdf_allargs = let f=f, g=g - function (x, θ...) - dist = f(θ...) - xtilde = g === nothing ? x : g(x) - if dist isa UnivariateDistribution && xtilde isa AbstractArray - return sum(logpdf.(dist, xtilde)) - else - return sum(logpdf(dist, xtilde)) - end + # combine all arguments + # point `x` is not differentiable if the distribution is discrete + args = D <: ContinuousDistribution ? (x, θ...) : θ + + # Create function that splits arguments and transforms location x if needed + unpack_x_θ = make_unpack_x_θ(dist.xtrans, x, D) + + # short cut: since Zygote does not use special number types with + # different dispatches etc., it is suffiient to just test derivatives of + # all differentiable arguments at once + if GROUP === "All" || GROUP === "Zygote" + # is Zygote broken? + zygote_broken = :Zygote in broken + + if zygote_broken + testset_zygote_broken(dist, unpack_x_θ, args...; kwargs...) + else + testset_zygote(dist, unpack_x_θ, args...; kwargs...) end end - # For all combinations of distribution parameters `θ` - for inds in powerset(2:(length(θ) + 1)) - # Test only distribution parameters - if !isempty(inds) - xtest = mapreduce(vcat, inds) do i - vectorize(θ[i - 1]) - end - f_loglik_test = let xorig=x, θorig=θ, inds=inds - x -> f_loglik_allargs(unpack(x, inds, xorig, θorig...)...) - end - f_logpdf_test = let xorig=x, θorig=θ, inds=inds - x -> f_logpdf_allargs(unpack(x, inds, xorig, θorig...)...) - end + # Early exit + GROUP !== "Zygote" || return - @test f_loglik_test(xtest) ≈ f_logpdf_test(xtest) + # Define functions for computing the log-likelihood that ignore some arguments + # (i.e., set them to their default values) + # This is used to check if we can differentiate with respect to a subset of arguments + # with ForwardDiff, Tracker, and ReverseDiff + n = length(args) + ignores = falses(n) + unignore = Unignore(args, ignores) + function loglikelihood_test(x...) + return sum_logpdf_parameterized(unpack_x_θ, f, unignore(x...)...) + end + sum_logpdf_test(x...) = sum_logpdf_parameterized(unpack_x_θ, f, unignore(x...)...) + + # Quick sanity check + @test loglikelihood_test(args...) ≈ sum_logpdf_test(args...) - test_ad(f_loglik_test, xtest, broken; kwargs...) - test_ad(f_logpdf_test, xtest, broken; kwargs...) + # For all combinations of arguments + for inds in powerset(1:n, 1, n) + # Update boolean vector of ignored arguments + fill!(ignores, true) + for i in inds + @inbounds ignores[i] = false end - # Test derivative with respect to location `x` as well - # if the distribution is continuous - if Distributions.value_support(typeof(dist)) === Continuous - xtest = isempty(inds) ? vectorize(x) : vcat(vectorize(x), xtest) - push!(inds, 1) - f_loglik_test = let xorig=x, θorig=θ, inds=inds - x -> f_loglik_allargs(unpack(x, inds, xorig, θorig...)...) - end - f_logpdf_test = let xorig=x, θorig=θ, inds=inds - x -> f_logpdf_allargs(unpack(x, inds, xorig, θorig...)...) - end + # Vectorize to-be-differentiated arguments for ForwardDiff, Tracker, and ReverseDiff + args_vec, args_unflatten = _to_vec(args[inds]) + loglik_test(x) = loglikelihood_test(args_unflatten(x)...) + logpdf_test(x) = sum_logpdf_test(args_unflatten(x)...) - @test f_loglik_test(xtest) ≈ f_logpdf_test(xtest) + @test loglik_test(args_vec) ≈ logpdf_test(args_vec) - test_ad(f_loglik_test, xtest, broken; kwargs...) - test_ad(f_logpdf_test, xtest, broken; kwargs...) - end + test_ad(loglik_test, args_vec, broken; kwargs...) + test_ad(logpdf_test, args_vec, broken; kwargs...) end + + return end function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) finitediff = FDM.grad(central_fdm(5, 1), f, x)[1] - if AD == "All" || AD == "Tracker" + if GROUP == "All" || GROUP == "Tracker" if :Tracker in broken @test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol=rtol atol=atol else @@ -162,7 +358,7 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) end end - if AD == "All" || AD == "ForwardDiff" + if GROUP == "All" || GROUP == "ForwardDiff" if :ForwardDiff in broken @test_broken ForwardDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol else @@ -170,19 +366,7 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) end end - if AD == "All" || AD == "Zygote" - if :Zygote in broken - @test_broken zygote_isapprox( - Zygote.gradient(f, x)[1], finitediff; rtol=rtol, atol=atol, - ) - else - @test zygote_isapprox( - Zygote.gradient(f, x)[1], finitediff; rtol=rtol, atol=atol, - ) - end - end - - if AD == "All" || AD == "ReverseDiff" + if GROUP == "All" || GROUP == "ReverseDiff" if :ReverseDiff in broken @test_broken ReverseDiff.gradient(f, x) ≈ finitediff rtol=rtol atol=atol else @@ -193,8 +377,65 @@ function test_ad(f, x, broken = (); rtol = 1e-6, atol = 1e-6) return end -# Handle Zygote's `nothing` -zygote_isapprox(x, expected; kwargs...) = isapprox(x, expected; kwargs...) -function zygote_isapprox(::Nothing, expected; kwargs...) - return isapprox(zero(expected), expected; kwargs...) +function testset_zygote(distspec, unpack_x_θ, args...; kwargs...) + f = distspec.f + θ = distspec.θ + x = distspec.x + + @testset "Zygote: $(f(θ...)) at x=$x" begin + @test loglikelihood_parameterized(unpack_x_θ, f, args...) ≈ + sum_logpdf_parameterized(unpack_x_θ, f, args...) + + for l in (loglikelihood_parameterized, sum_logpdf_parameterized) + # Zygote has type inference problems so we don't check it + test_rrule( + Zygote.ZygoteRuleConfig(), l, unpack_x_θ, f, args...; + rrule_f=rrule_via_ad, check_inferred=false, kwargs... + ) + end + end +end + +function testset_zygote_broken(args...; kwargs...) + # don't show test errors - tests are known to be broken :) + testset = suppress_stdout() do + testset_zygote(args...; kwargs...) + end + + # change errors and fails to broken results, and count number of errors and fails + efs = errors_to_broken!(testset) + + # ensure that passing tests are not marked as broken + if iszero(efs) + error("Zygote tests of $(f(θ...)) at x=$x passed unexpectedly, please mark not as broken") + end + + return testset +end + +# `redirect_stdout(f, devnull)` is only available in Julia >= 1.6 +function suppress_stdout(f) + @static if VERSION < v"1.6" + open((@static Sys.iswindows() ? "NUL" : "/dev/null"), "w") do devnull + redirect_stdout(f, devnull) + end + else + redirect_stdout(f, devnull) + end +end + +# change test errors and failures to broken results +function errors_to_broken!(ts::Test.DefaultTestSet) + results = ts.results + efs = 0 + for i in eachindex(results) + @inbounds t = results[i] + if t isa Test.DefaultTestSet + efs += errors_to_broken!(t) + elseif t isa Union{Test.Fail, Test.Error} + efs += 1 + results[i] = Test.Broken(t.test_type, t.orig_expr) + end + end + return efs end diff --git a/test/others.jl b/test/others.jl index 330b4060..10bd450c 100644 --- a/test/others.jl +++ b/test/others.jl @@ -72,21 +72,23 @@ if TD <: TuringDenseMvNormal C = Matrix{Float64}(I, 3, 3) d1 = TuringMvLogNormal(TuringMvNormal(m, C)) + d2 = MvLogNormal(MvNormal(m, C)) elseif TD <: TuringDiagMvNormal C = ones(3) d1 = TuringMvLogNormal(TuringMvNormal(m, C)) + d2 = MvLogNormal(MvNormal(m, Diagonal(C .^ 2))) else C = 1.0 d1 = TuringMvLogNormal(TuringMvNormal(m, C)) + d2 = MvLogNormal(MvNormal(m, C^2 * I)) end - d2 = MvLogNormal(MvNormal(m, C)) @test length(d1) == length(d2) x1 = rand(d1) x2 = rand(d1, 3) - @test isapprox(logpdf(d1, x1), logpdf(d2, x1), rtol = 1e-6) - @test isapprox(logpdf(d1, x2), logpdf(d2, x2), rtol = 1e-6) + @test logpdf(d1, x1) ≈ logpdf(d2, x1) rtol=1e-6 + @test logpdf(d1, x2) ≈ logpdf(d2, x2) rtol=1e-6 x2[:, 1] .= -1 @test isinf(logpdf(d1, x2)[1]) @@ -96,15 +98,6 @@ @testset "TuringUniform" begin @test logpdf(TuringUniform(), 0.5) == 0 - if AD == "All" || AD == "Tracker" - @test logpdf(TuringUniform(), param(0.5)) == 0 - end - end - - if AD == "All" || AD == "Tracker" - @testset "Semicircle" begin - @test Tracker.data(logpdf(Semicircle(1.0), param(0.5))) == logpdf(Semicircle(1.0), 0.5) - end end @testset "TuringPoissonBinomial" begin @@ -126,116 +119,10 @@ @test info == C.info end - function test_reverse_mode_ad( f, ȳ, x...; rtol=1e-6, atol=1e-6) - # Perform a regular forwards-pass. - y = f(x...) - - # Use finite differencing to compute reverse-mode sensitivities. - x̄s_fdm = FDM.j′vp(central_fdm(5, 1), f, ȳ, x...) - - if AD == "All" || AD == "Zygote" - # Use Zygote to compute reverse-mode sensitivities. - y_zygote, back_zygote = Zygote.pullback(f, x...) - x̄s_zygote = back_zygote(ȳ) - - # Check that Zygpte forwards-pass produces the correct answer. - @test isapprox(y, y_zygote, atol=atol, rtol=rtol) - - # Check that Zygote reverse-mode sensitivities are correct. - @test all(zip(x̄s_zygote, x̄s_fdm)) do (x̄_zygote, x̄_fdm) - isapprox(x̄_zygote, x̄_fdm; atol=atol, rtol=rtol) - end - end - - if AD == "All" || AD == "ReverseDiff" - test_rd = length(x) == 1 && y isa Number - if test_rd - # Use ReverseDiff to compute reverse-mode sensitivities. - if x[1] isa Array - x̄s_rd = similar(x[1]) - tp = ReverseDiff.GradientTape(x -> f(x), x[1]) - ReverseDiff.gradient!(x̄s_rd, tp, x[1]) - x̄s_rd .*= ȳ - y_rd = ReverseDiff.value(tp.output) - @assert y_rd isa Number - else - x̄s_rd = [x[1]] - tp = ReverseDiff.GradientTape(x -> f(x[1]), [x[1]]) - ReverseDiff.gradient!(x̄s_rd, tp, [x[1]]) - y_rd = ReverseDiff.value(tp.output)[1] - x̄s_rd = x̄s_rd[1] * ȳ - @assert y_rd isa Number - end - - # Check that ReverseDiff forwards-pass produces the correct answer. - @test isapprox(y, y_rd, atol=atol, rtol=rtol) - - # Check that ReverseDiff reverse-mode sensitivities are correct. - @test isapprox(x̄s_rd, x̄s_fdm[1]; atol=atol, rtol=rtol) - end - end - - if AD == "All" || AD == "Tracker" - # Use Tracker to compute reverse-mode sensitivities. - y_tracker, back_tracker = Tracker.forward(f, x...) - x̄s_tracker = back_tracker(ȳ) - - # Check that Tracker forwards-pass produces the correct answer. - @test isapprox(y, Tracker.data(y_tracker), atol=atol, rtol=rtol) - - # Check that Tracker reverse-mode sensitivities are correct. - @test all(zip(x̄s_tracker, x̄s_fdm)) do (x̄_tracker, x̄_fdm) - isapprox(Tracker.data(x̄_tracker), x̄_fdm; atol=atol, rtol=rtol) - end - end - end - _to_cov(B) = B + B' + 2 * size(B, 1) * Matrix(I, size(B)...) - - @testset "logsumexp" begin - x, y = rand(3), rand() - test_reverse_mode_ad(logsumexp, y, x; rtol=1e-8, atol=1e-6) - end - @testset "zygote_ldiv" begin - A = rand(3, 3)'; A = A + A' + 3I; + A = to_posdef(rand(3, 3)) B = copy(A) - Ȳ = rand(3, 3) @test DistributionsAD.zygote_ldiv(A, B) == A \ B - test_reverse_mode_ad((A,B)->DistributionsAD.zygote_ldiv(A,B), Ȳ, A, B) - end - - @testset "logdet" begin - rng, N = MersenneTwister(123456), 7 - y, B = randn(rng), randn(rng, N, N) - test_reverse_mode_ad(B->logdet(cholesky(_to_cov(B))), y, B; rtol=1e-8, atol=1e-6) - test_reverse_mode_ad(B->logdet(cholesky(Symmetric(_to_cov(B)))), y, B; rtol=1e-8, atol=1e-6) - end - - @testset "fill" begin - if AD == "All" || AD == "Tracker" - @test fill(param(1.0), 3) isa TrackedArray - end - rng = MersenneTwister(123456) - test_reverse_mode_ad(x->fill(x, 7), randn(rng, 7), randn(rng)) - test_reverse_mode_ad(x->fill(x, 7, 11), randn(rng, 7, 11), randn(rng)) - test_reverse_mode_ad(x->fill(x, 7, 11, 13), rand(rng, 7, 11, 13), randn(rng)) - end - @testset "Tracker, Zygote and ReverseDiff + MvNormal" begin - rng, N = MersenneTwister(123456), 11 - B = randn(rng, N, N) - m, A = randn(rng, N), B' * B + I - - # Generate from the TuringDenseMvNormal - d = TuringDenseMvNormal(m, A) - x = rand(d) - - # Check that the logpdf agrees with MvNormal. - d_ref = MvNormal(m, PDMat(A)) - @test logpdf(d, x) ≈ logpdf(d_ref, x) - - test_reverse_mode_ad((m, B, x)->logpdf(MvNormal(m, _to_cov(B)), x), randn(rng), m, B, x) - test_reverse_mode_ad((m, B, x)->logpdf(TuringMvNormal(m, _to_cov(B)), x), randn(rng), m, B, x) - test_reverse_mode_ad((m, B, x)->logpdf(TuringMvNormal(m, Symmetric(_to_cov(B))), x), randn(rng), m, B, x) end @testset "Entropy" begin @@ -269,35 +156,8 @@ @testset "adapt_randn" begin rng = MersenneTwister() - - xs = Any[(rng, T, n) -> rand(rng, T, n)] - if AD == "All" || AD == "ForwardDiff" - push!(xs, (rng, T, n) -> [ForwardDiff.Dual(rand(rng, T)) for _ in 1:n]) - end - if AD == "All" || AD == "Tracker" - push!(xs, (rng, T, n) -> Tracker.TrackedArray(rand(rng, T, n))) - end - if AD == "All" || AD == "ReverseDiff" - push!(xs, (rng, T, n) -> begin - v = rand(rng, T, n) - d = rand(Int, n) - tp = ReverseDiff.InstructionTape() - ReverseDiff.TrackedArray(v, d, tp) - end) - end - for T in (Float32, Float64) - for f in xs - x = f(rng, T, 50) - - Random.seed!(rng, 100) - y = DistributionsAD.adapt_randn(rng, x, 10, 30) - @test y isa Matrix{T} - @test size(y) == (10, 30) - - Random.seed!(rng, 100) - @test y == randn(rng, T, 10, 30) - end + test_adapt_randn(rng, rand(rng, T, 50), T, 10, 30) end end diff --git a/test/runtests.jl b/test/runtests.jl index e67bd4c2..d03627e0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,10 +1,7 @@ using DistributionsAD -using ChainRulesCore -using ChainRulesTestUtils using Combinatorics using Distributions -using FiniteDifferences using PDMats using Random, LinearAlgebra, Test @@ -17,91 +14,17 @@ using StatsFuns: StatsFuns, logsumexp, logistic Random.seed!(1) # Set seed that all testsets should reset to. -const FDM = FiniteDifferences const GROUP = get(ENV, "GROUP", "All") -# Figure out which AD backend to test -const AD = get(ENV, "AD", "All") -if AD == "All" || AD == "ForwardDiff" - @eval using ForwardDiff -end -if AD == "All" || AD == "Zygote" - @eval using Zygote -end -if AD == "All" || AD == "ReverseDiff" - @eval using ReverseDiff -end -if AD == "All" || AD == "Tracker" - @eval using Tracker -end +include("test_utils.jl") if GROUP == "All" || GROUP == "Others" include("others.jl") end -if GROUP == "All" || GROUP == "AD" - # Create positive definite matrix - to_posdef(A::AbstractMatrix) = A * A' + I - to_posdef_diagonal(a::AbstractVector) = Diagonal(a.^2 .+ 1) - - # Create vectors in probability simplex. - function to_simplex(x::AbstractArray) - max = maximum(x; dims=1) - y = exp.(x .- max) - y ./= sum(y; dims=1) - return y - end - to_simplex(x::AbstractArray{<:AbstractArray}) = to_simplex.(x) - function to_simplex_pullback(ȳ::AbstractArray, y::AbstractArray) - x̄ = ȳ .* y - x̄ .= x̄ .- y .* sum(x̄; dims=1) - return x̄ - end - function ChainRulesCore.rrule(::typeof(to_simplex), x::AbstractArray{<:Real}) - y = to_simplex(x) - pullback(ȳ) = (NoTangent(), to_simplex_pullback(ȳ, y)) - return y, pullback - end - - if AD == "All" || AD == "ReverseDiff" - @eval begin - # Define adjoint for ReverseDiff - function to_simplex(x::AbstractArray{<:ReverseDiff.TrackedReal}) - return ReverseDiff.track(to_simplex, x) - end - ReverseDiff.@grad function to_simplex(x) - _x = ReverseDiff.value(x) - y = to_simplex(_x) - pullback(ȳ) = (to_simplex_pullback(ȳ, y),) - return y, pullback - end - end - end - - if AD == "All" || AD == "Tracker" - @eval begin - # Define adjoints for Tracker - to_posdef(A::Tracker.TrackedMatrix) = Tracker.track(to_posdef, A) - Tracker.@grad function to_posdef(A::Tracker.TrackedMatrix) - data_A = Tracker.data(A) - S = data_A * data_A' + I - function pullback(∇) - return ((∇ + ∇') * data_A,) - end - return S, pullback - end - - to_simplex(x::Tracker.TrackedArray) = Tracker.track(to_simplex, x) - Tracker.@grad function to_simplex(x::Tracker.TrackedArray) - data_x = Tracker.data(x) - y = to_simplex(data_x) - pullback(ȳ) = (to_simplex_pullback(ȳ, y),) - return y, pullback - end - end - end - +if GROUP == "All" || GROUP in ("ForwardDiff", "Zygote", "ReverseDiff", "Tracker") include("ad/utils.jl") + include("ad/others.jl") include("ad/chainrules.jl") include("ad/distributions.jl") end diff --git a/test/test_utils.jl b/test/test_utils.jl new file mode 100644 index 00000000..73a2aa0f --- /dev/null +++ b/test/test_utils.jl @@ -0,0 +1,23 @@ +# Create positive definite matrix +to_posdef(A::AbstractMatrix) = A * A' + I +to_posdef_diagonal(a::AbstractVector) = Diagonal(a .^ 2 .+ 1) + +# Create vectors in probability simplex. +function to_simplex(x::AbstractArray) + max = maximum(x; dims = 1) + y = exp.(x .- max) + y ./= sum(y; dims = 1) + return y +end +to_simplex(x::AbstractArray{<:AbstractArray}) = to_simplex.(x) + +# Utility for testing `adapt_randn` +function test_adapt_randn(rng, x::AbstractVector, ::Type{T}, dims::Int...) where {T} + Random.seed!(rng, 100) + y = DistributionsAD.adapt_randn(rng, x, dims...) + @test y isa Array{T} + @test size(y) == dims + + Random.seed!(rng, 100) + @test y == randn(rng, T, dims...) +end