diff --git a/.github/workflows/ForwardDiff_Tracker.yml b/.github/workflows/ForwardDiff_Tracker.yml new file mode 100644 index 00000000..6db29645 --- /dev/null +++ b/.github/workflows/ForwardDiff_Tracker.yml @@ -0,0 +1,29 @@ +name: ForwardDiff and Tracker tests + +on: + push: + branches: + - master + pull_request: + types: [opened, synchronize, reopened] + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + julia-version: [1.0.5, 1.2.0, 1.3] + julia-arch: [x64, x86] + os: [ubuntu-latest, macOS-latest] + exclude: + - os: macOS-latest + julia-arch: x86 + + steps: + - uses: actions/checkout@v1.0.0 + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + - uses: julia-actions/julia-runtest@master + env: + STAGE: ForwardDiff_Tracker \ No newline at end of file diff --git a/.github/workflows/CI.yml b/.github/workflows/Others.yml similarity index 91% rename from .github/workflows/CI.yml rename to .github/workflows/Others.yml index 4dc16fb8..8f1ce02b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/Others.yml @@ -1,4 +1,4 @@ -name: CI +name: Other tests on: push: @@ -25,3 +25,5 @@ jobs: with: version: ${{ matrix.julia-version }} - uses: julia-actions/julia-runtest@master + env: + STAGE: Others \ No newline at end of file diff --git a/.github/workflows/Zygote.yml b/.github/workflows/Zygote.yml new file mode 100644 index 00000000..95092af3 --- /dev/null +++ b/.github/workflows/Zygote.yml @@ -0,0 +1,29 @@ +name: Zygote tests + +on: + push: + branches: + - master + pull_request: + types: [opened, synchronize, reopened] + +jobs: + test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + julia-version: [1.0.5, 1.2.0, 1.3] + julia-arch: [x64, x86] + os: [ubuntu-latest, macOS-latest] + exclude: + - os: macOS-latest + julia-arch: x86 + + steps: + - uses: actions/checkout@v1.0.0 + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + - uses: julia-actions/julia-runtest@master + env: + STAGE: Zygote \ No newline at end of file diff --git a/src/arraydist.jl b/src/arraydist.jl index fbe38b8d..36ebf080 100644 --- a/src/arraydist.jl +++ b/src/arraydist.jl @@ -2,25 +2,25 @@ const VectorOfUnivariate = Distributions.Product -function arraydist(dists::AbstractVector{<:Normal{T}}) where {T} - means = mean.(dists) - vars = var.(dists) - return MvNormal(means, vars) -end -function arraydist(dists::AbstractVector{<:Normal{<:TrackedReal}}) - means = vcatmapreduce(mean, dists) - vars = vcatmapreduce(var, dists) - return MvNormal(means, vars) -end function arraydist(dists::AbstractVector{<:UnivariateDistribution}) return product_distribution(dists) end +function arraydist(dists::AbstractVector{<:Normal}) + m = mapvcat(mean, dists) + v = mapvcat(var, dists) + return MvNormal(m, v) +end + function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real}) - return sum(vcatmapreduce(logpdf, dist.v, x)) + return sum(map((d, x) -> logpdf(d, x), dist.v, x)) end function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real}) # eachcol breaks Zygote, so we need an adjoint - return vcatmapreduce((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x)) + return mapvcat(dist.v, eachcol(x)) do dist, c + sum(map(c) do x + logpdf(dist, x) + end) + end end @adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real}) # Any other more efficient implementation breaks Zygote @@ -41,14 +41,16 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution}) end function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real}) # Broadcasting here breaks Tracker for some reason - # A Zygote adjoint is defined for vcatmapreduce to use broadcasting - return sum(vcatmapreduce(logpdf, dist.dists, x)) + # A Zygote adjoint is defined for mapvcat to use broadcasting + return sum(map(dist.dists, x) do dist, x + logpdf(dist, x) + end) end function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) - return vcatmapreduce(x -> logpdf(dist, x), x) + return mapvcat(x -> logpdf(dist, x), x) end function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}}) - return vcatmapreduce(x -> logpdf(dist, x), x) + return mapvcat(x -> logpdf(dist, x), x) end function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate) return rand.(Ref(rng), dist.dists) @@ -70,16 +72,16 @@ function arraydist(dists::AbstractVector{<:MultivariateDistribution}) end function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) # eachcol breaks Zygote, so we define an adjoint - return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x))) + return sum(logpdf.(dist.dists, eachcol(x))) end function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) - return reshape(vcatmapreduce(x -> logpdf(dist, x), x), size(x)) + return mapvcat(x -> logpdf(dist, x), x) end function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}}) - return reshape(vcatmapreduce(x -> logpdf(dist, x), x), size(x)) + return mapvcat(x -> logpdf(dist, x), x) end @adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) - f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2))) + f(dist, x) = sum(mapvcat(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2))) return pullback(f, dist, x) end function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) diff --git a/src/common.jl b/src/common.jl index 5edcdc11..de9556e9 100644 --- a/src/common.jl +++ b/src/common.jl @@ -1,14 +1,19 @@ ## Generic ## -function vcatmapreduce(f, args...) - init = vcat(f(first.(args)...,)) - zipped_args = zip(args...,) - return mapreduce(vcat, drop(zipped_args, 1); init = init) do zarg - f(zarg...,) +_istracked(x) = false +_istracked(x::TrackedArray) = false +_istracked(x::AbstractArray{<:TrackedReal}) = true +function mapvcat(f, args...) + out = map(f, args...) + if _istracked(out) + init = vcat(out[1]) + return reshape(reduce(vcat, drop(out, 1); init = init), size(out)) + else + return out end end -@adjoint function vcatmapreduce(f, args...) - g(f, args...) = f.(args...) +@adjoint function mapvcat(f, args...) + g(f, args...) = map(f, args...) return pullback(g, f, args...) end diff --git a/src/filldist.jl b/src/filldist.jl index a665fdeb..6a325372 100644 --- a/src/filldist.jl +++ b/src/filldist.jl @@ -48,26 +48,20 @@ end function _flat_logpdf(dist, x) if toflatten(dist) f, args = flatten(dist) - if any(Tracker.istracked, args) - return sum(f.(args..., x)) - else - return sum(logpdf.(dist, x)) - end + return sum(f.(args..., x)) else - return sum(vcatmapreduce(x -> logpdf(dist, x), x)) + return sum(mapvcat(x) do x + logpdf(dist, x) + end) end end function _flat_logpdf_mat(dist, x) if toflatten(dist) f, args = flatten(dist) - if any(Tracker.istracked, args) - return vec(sum(f.(args..., x), dims = 1)) - else - return vec(sum(logpdf.(dist, x), dims = 1)) - end + return vec(sum(f.(args..., x), dims = 1)) else - temp = vcatmapreduce(x -> logpdf(dist, x), x) - return vec(sum(reshape(temp, size(x)), dims = 1)) + temp = mapvcat(x -> logpdf(dist, x), x) + return vec(sum(temp, dims = 1)) end end diff --git a/src/flatten.jl b/src/flatten.jl index 21822fba..37e13e42 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -41,7 +41,7 @@ const flattened_dists = [ Bernoulli, FDist, Frechet, Gamma, - #GeneralizedExtremeValue, + GeneralizedExtremeValue, GeneralizedPareto, Gumbel, #InverseGamma, @@ -63,6 +63,7 @@ const flattened_dists = [ Bernoulli, TDist, TriangularDist, Triweight, + TuringUniform, #Truncated, #VonMises, ] diff --git a/src/matrixvariate.jl b/src/matrixvariate.jl index f723b860..79e63a41 100644 --- a/src/matrixvariate.jl +++ b/src/matrixvariate.jl @@ -1,7 +1,7 @@ ## MatrixBeta function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:TrackedMatrix{<:Real}}) - return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X)) + return mapvcat(x -> logpdf(d, x), X) end @adjoint function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:Matrix{<:Real}}) f(d, X) = map(x -> logpdf(d, x), X) @@ -112,10 +112,10 @@ function Distributions.logpdf(d::TuringWishart, X::AbstractMatrix{<:Real}) return 0.5 * ((df - (p + 1)) * logdet(Xcf) - tr(d.chol \ X)) - d.c0 end function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:AbstractMatrix{<:Real}}) - return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X)) + return mapvcat(x -> logpdf(d, x), X) end function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:Matrix{<:Real}}) - return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X)) + return mapvcat(x -> logpdf(d, x), X) end #### Sampling @@ -233,10 +233,10 @@ function Distributions.logpdf(d::TuringInverseWishart, X::AbstractMatrix{<:Real} -0.5 * ((df + p + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) - d.c0 end function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:AbstractMatrix{<:Real}}) - return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X)) + return mapvcat(x -> logpdf(d, x), X) end function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:Matrix{<:Real}}) - return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X)) + return mapvcat(x -> logpdf(d, x), X) end #### Sampling diff --git a/src/univariate.jl b/src/univariate.jl index 27938d11..c67461c8 100644 --- a/src/univariate.jl +++ b/src/univariate.jl @@ -13,6 +13,8 @@ function TuringUniform(a::Real, b::Real) return TuringUniform{T}(T(a), T(b)) end Distributions.logpdf(d::TuringUniform, x::Real) = uniformlogpdf(d.a, d.b, x) +Base.minimum(d::TuringUniform) = d.a +Base.maximum(d::TuringUniform) = d.b Distributions.Uniform(a::TrackedReal, b::Real) = TuringUniform{TrackedReal}(a, b) Distributions.Uniform(a::Real, b::TrackedReal) = TuringUniform{TrackedReal}(a, b) @@ -348,3 +350,21 @@ function Base.convert( DiscreteNonParametric{T,P,Ts,Ps}(support(d), probs(d), check_args=false) end +# Fix SubArray support +function Distributions.DiscreteNonParametric{T,P,Ts,Ps}( + vs::Ts, + ps::Ps; + check_args=true, +) where {T<:Real, P<:Real, Ts<:AbstractVector{T}, Ps<:SubArray{P, 1}} + cps = ps[:] + return DiscreteNonParametric{T,P,Ts,typeof(cps)}(vs, cps; check_args = check_args) +end + +function Distributions.DiscreteNonParametric{T,P,Ts,Ps}( + vs::Ts, + ps::Ps; + check_args=true, +) where {T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:TrackedArray{P, 1, <:SubArray{P, 1}}} + cps = ps[:] + return DiscreteNonParametric{T,P,Ts,typeof(cps)}(vs, cps; check_args = check_args) +end diff --git a/test/distributions.jl b/test/distributions.jl index 08f7f4b8..50a072e2 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -456,8 +456,6 @@ matrix_cont_dists = [ #DistSpec(:TuringInverseWishart, (dim, cov_mat), fill(cov_mat, 2)), ] xmatrix_cont_dists = [ - matrix_cont_dists; - # Matrix x filter(!isnothing, filldist_spec.(uni_cont_dists; n = (2, 2))); filter(!isnothing, filldist_spec.(multi_cont_dists; disttype = :multi, n = 2)); @@ -469,6 +467,8 @@ xmatrix_cont_dists = [ filter(!isnothing, filldist_spec.(multi_cont_dists; disttype = :multi, n = 2, d = 2)); filter(!isnothing, arraydist_spec.(uni_cont_dists; n = (2, 2), d = 2)); filter(!isnothing, arraydist_spec.(multi_cont_dists; disttype = :multi, n = 2, d = 2)); + + matrix_cont_dists; ] broken_matrix_cont_dists = [ # Other @@ -478,65 +478,67 @@ broken_matrix_cont_dists = [ DistSpec(:MatrixFDist, (dim, dim, cov_mat), cov_mat), ] -test_head(s) = println("\n"*s*"\n") -separator() = println("\n"*"="^50) - -separator() -@testset "Univariate discrete distributions" begin - test_head("Testing: Univariate discrete distributions") - for d in uni_disc_dists - test_info(d.name) - for testf in get_all_functions(d, false) - test_ad(testf.f, testf.x) +if get_stage() != "Others" + test_head(s) = println("\n"*s*"\n") + separator() = println("\n"*"="^50) + + separator() + @testset "Univariate discrete distributions" begin + test_head("Testing: Univariate discrete distributions") + for d in uni_disc_dists + test_info(d.name) + for testf in get_all_functions(d, false) + test_ad(testf.f, testf.x) + end end end -end -separator() - -# Note: broadcasting logpdf with univariate distributions having tracked parameters breaks -# Tracker. Ref: https://github.com/FluxML/Tracker.jl/issues/65 -# filldist works around this so it is the recommended way for AD-friendly "broadcasting" -# of logpdf with univariate distributions -@testset "Univariate continuous distributions" begin - test_head("Testing: Univariate continuous distributions") - for d in uni_cont_dists - test_info(d.name) - for testf in get_all_functions(d, true) - test_ad(testf.f, testf.x) + separator() + + # Note: broadcasting logpdf with univariate distributions having tracked parameters breaks + # Tracker. Ref: https://github.com/FluxML/Tracker.jl/issues/65 + # filldist works around this so it is the recommended way for AD-friendly "broadcasting" + # of logpdf with univariate distributions + @testset "Univariate continuous distributions" begin + test_head("Testing: Univariate continuous distributions") + for d in uni_cont_dists + test_info(d.name) + for testf in get_all_functions(d, true) + test_ad(testf.f, testf.x) + end end end -end -separator() - -@testset "Multivariate discrete distributions" begin - test_head("Testing: Multivariate discrete distributions") - for d in xmulti_disc_dists - test_info(d.name) - for testf in get_all_functions(d, false) - test_ad(testf.f, testf.x) + separator() + + @testset "Multivariate discrete distributions" begin + test_head("Testing: Multivariate discrete distributions") + for d in xmulti_disc_dists + test_info(d.name) + for testf in get_all_functions(d, false) + test_ad(testf.f, testf.x) + end end end -end -separator() -@testset "Multivariate continuous distributions" begin - test_head("Testing: Multivariate continuous distributions") - for d in xmulti_cont_dists - test_info(d.name) - for testf in get_all_functions(d, true) - test_ad(testf.f, testf.x) + separator() + @testset "Multivariate continuous distributions" begin + test_head("Testing: Multivariate continuous distributions") + for d in xmulti_cont_dists + test_info(d.name) + for testf in get_all_functions(d, true) + test_ad(testf.f, testf.x) + end end end -end -separator() - -@testset "Matrix-variate continuous distributions" begin - test_head("Testing: Matrix-variate continuous distributions") - for d in xmatrix_cont_dists - test_info(d.name) - for testf in get_all_functions(d, true) - test_ad(testf.f, testf.x) + separator() + + @testset "Matrix-variate continuous distributions" begin + test_head("Testing: Matrix-variate continuous distributions") + for d in xmatrix_cont_dists + test_info(d.name) + for testf in get_all_functions(d, true) + test_ad(testf.f, testf.x) + end end end + separator() end -separator() diff --git a/test/others.jl b/test/others.jl index e43dd0b6..3b532586 100644 --- a/test/others.jl +++ b/test/others.jl @@ -1,146 +1,148 @@ using StatsBase: entropy -@testset "unsafe_cholesky" begin - A = rand(3, 3); A = A + A' + 3I - @test Matrix(DistributionsAD.unsafe_cholesky(A, true)) == Matrix(cholesky(A)) - @test !issuccess(DistributionsAD.unsafe_cholesky(rand(3,3), false)) - @test_throws PosDefException DistributionsAD.unsafe_cholesky(rand(3,3), true) -end - -@testset "TuringWishart" begin - dim = 3 - A = Matrix{Float64}(I, dim, dim) - dW1 = Wishart(dim + 4, A) - dW2 = TuringWishart(dim + 4, A) - - @testset "$F" for F in (size, rank, mean, meanlogdet, entropy, cov, var) - @test F(dW1) == F(dW2) +if get_stage() in ("Others", "all") + @testset "unsafe_cholesky" begin + A = rand(3, 3); A = A + A' + 3I + @test Matrix(DistributionsAD.unsafe_cholesky(A, true)) == Matrix(cholesky(A)) + @test !issuccess(DistributionsAD.unsafe_cholesky(rand(3,3), false)) + @test_throws PosDefException DistributionsAD.unsafe_cholesky(rand(3,3), true) end - @test Matrix(mode(dW1)) == mode(dW2) - xw = rand(dW2) - @test insupport(dW1, xw) - @test insupport(dW2, xw) - @test logpdf(dW1, xw) == logpdf(dW2, xw) -end - -@testset "TuringInverseWishart" begin - dim = 3 - A = Matrix{Float64}(I, dim, dim) - dIW1 = InverseWishart(dim + 4, A) - dIW2 = TuringInverseWishart(dim + 4, A) - - @testset "$F" for F in (size, rank, mean, cov, var) - @test F(dIW1) == F(dIW2) - end - @test Matrix(mode(dIW1)) == mode(dIW2) - xiw = rand(dIW2) - @test insupport(dIW1, xiw) - @test insupport(dIW2, xiw) - @test logpdf(dIW1, xiw) == logpdf(dIW2, xiw) -end - -@testset "TuringMvNormal" begin - @testset "$TD" for TD in [TuringDenseMvNormal, TuringDiagMvNormal, TuringScalMvNormal] - m = rand(3) - if TD <: TuringDenseMvNormal - C = Matrix{Float64}(I, 3, 3) - d1 = TuringMvNormal(m, C) - elseif TD <: TuringDiagMvNormal - C = ones(3) - d1 = TuringMvNormal(m, C) - else - C = 1.0 - d1 = TuringMvNormal(m, C) + + @testset "TuringWishart" begin + dim = 3 + A = Matrix{Float64}(I, dim, dim) + dW1 = Wishart(dim + 4, A) + dW2 = TuringWishart(dim + 4, A) + mean = Distributions.mean + @testset "$F" for F in (size, rank, mean, meanlogdet, entropy, cov, var) + @test F(dW1) == F(dW2) end - d2 = MvNormal(m, C) + @test Matrix(mode(dW1)) == mode(dW2) + xw = rand(dW2) + @test insupport(dW1, xw) + @test insupport(dW2, xw) + @test logpdf(dW1, xw) == logpdf(dW2, xw) + end - @testset "$F" for F in (length, size) - @test F(d1) == F(d2) + @testset "TuringInverseWishart" begin + dim = 3 + A = Matrix{Float64}(I, dim, dim) + dIW1 = InverseWishart(dim + 4, A) + dIW2 = TuringInverseWishart(dim + 4, A) + mean = Distributions.mean + @testset "$F" for F in (size, rank, mean, cov, var) + @test F(dIW1) == F(dIW2) end + @test Matrix(mode(dIW1)) == mode(dIW2) + xiw = rand(dIW2) + @test insupport(dIW1, xiw) + @test insupport(dIW2, xiw) + @test logpdf(dIW1, xiw) == logpdf(dIW2, xiw) + end - x1 = rand(d1) - x2 = rand(d1, 3) - @test isapprox(logpdf(d1, x1), logpdf(d2, x1), rtol = 1e-6) - @test isapprox(logpdf(d1, x2), logpdf(d2, x2), rtol = 1e-6) + @testset "TuringMvNormal" begin + @testset "$TD" for TD in [TuringDenseMvNormal, TuringDiagMvNormal, TuringScalMvNormal] + m = rand(3) + if TD <: TuringDenseMvNormal + C = Matrix{Float64}(I, 3, 3) + d1 = TuringMvNormal(m, C) + elseif TD <: TuringDiagMvNormal + C = ones(3) + d1 = TuringMvNormal(m, C) + else + C = 1.0 + d1 = TuringMvNormal(m, C) + end + d2 = MvNormal(m, C) + + @testset "$F" for F in (length, size) + @test F(d1) == F(d2) + end + + x1 = rand(d1) + x2 = rand(d1, 3) + @test isapprox(logpdf(d1, x1), logpdf(d2, x1), rtol = 1e-6) + @test isapprox(logpdf(d1, x2), logpdf(d2, x2), rtol = 1e-6) + end end -end - -@testset "TuringMvLogNormal" begin - @testset "$TD" for TD in [TuringDenseMvNormal, TuringDiagMvNormal, TuringScalMvNormal] - m = rand(3) - if TD <: TuringDenseMvNormal - C = Matrix{Float64}(I, 3, 3) - d1 = TuringMvLogNormal(TuringMvNormal(m, C)) - elseif TD <: TuringDiagMvNormal - C = ones(3) - d1 = TuringMvLogNormal(TuringMvNormal(m, C)) - else - C = 1.0 - d1 = TuringMvLogNormal(TuringMvNormal(m, C)) + + @testset "TuringMvLogNormal" begin + @testset "$TD" for TD in [TuringDenseMvNormal, TuringDiagMvNormal, TuringScalMvNormal] + m = rand(3) + if TD <: TuringDenseMvNormal + C = Matrix{Float64}(I, 3, 3) + d1 = TuringMvLogNormal(TuringMvNormal(m, C)) + elseif TD <: TuringDiagMvNormal + C = ones(3) + d1 = TuringMvLogNormal(TuringMvNormal(m, C)) + else + C = 1.0 + d1 = TuringMvLogNormal(TuringMvNormal(m, C)) + end + d2 = MvLogNormal(MvNormal(m, C)) + + @test length(d1) == length(d2) + + x1 = rand(d1) + x2 = rand(d1, 3) + @test isapprox(logpdf(d1, x1), logpdf(d2, x1), rtol = 1e-6) + @test isapprox(logpdf(d1, x2), logpdf(d2, x2), rtol = 1e-6) + + x2[:, 1] .= -1 + @test isinf(logpdf(d1, x2)[1]) + @test isinf(logpdf(d2, x2)[1]) end - d2 = MvLogNormal(MvNormal(m, C)) + end - @test length(d1) == length(d2) + @testset "TuringUniform" begin + @test logpdf(TuringUniform(), param(0.5)) == 0 + end + + @testset "Semicircle" begin + @test Tracker.data(logpdf(Semicircle(1.0), param(0.5))) == logpdf(Semicircle(1.0), 0.5) + end + + @testset "TuringPoissonBinomial" begin + d1 = TuringPoissonBinomial([0.5, 0.5]) + d2 = PoissonBinomial([0.5, 0.5]) + @test quantile(d1, 0.5) == quantile(d2, 0.5) + @test minimum(d1) == minimum(d2) + end + + @testset "Inverse of pi" begin + @test 1/pi == inv(pi) + end + + @testset "Others" begin + @test fill(param(1.0), 3) isa TrackedArray + x = rand(3) + @test isapprox(Tracker.data(Tracker.gradient(logsumexp, x)[1]), + ForwardDiff.gradient(logsumexp, x), atol = 1e-5) + A = rand(3, 3)'; A = A + A' + 3I; + C = cholesky(A; check = true) + factors, info = DistributionsAD.turing_chol(A, true) + @test factors == C.factors + @test info == C.info + B = copy(A) + @test DistributionsAD.zygote_ldiv(A, B) == A \ B + end + + @testset "Entropy" begin + sigmas = exp.(randn(10)) + d1 = TuringDiagMvNormal(zeros(10), sigmas) + d2 = MvNormal(zeros(10), sigmas) + + @test entropy(d1) == entropy(d2) + end - x1 = rand(d1) - x2 = rand(d1, 3) - @test isapprox(logpdf(d1, x1), logpdf(d2, x1), rtol = 1e-6) - @test isapprox(logpdf(d1, x2), logpdf(d2, x2), rtol = 1e-6) + @testset "Params" begin + m = rand(10) + sigmas = randexp(10) + + d = TuringDiagMvNormal(m, sigmas) + @test params(d) == (m, sigmas) - x2[:, 1] .= -1 - @test isinf(logpdf(d1, x2)[1]) - @test isinf(logpdf(d2, x2)[1]) + d = TuringScalMvNormal(m, sigmas[1]) + @test params(d) == (m, sigmas[1]) end -end - -@testset "TuringUniform" begin - @test logpdf(TuringUniform(), param(0.5)) == 0 -end - -@testset "Semicircle" begin - @test Tracker.data(logpdf(Semicircle(1.0), param(0.5))) == logpdf(Semicircle(1.0), 0.5) -end - -@testset "TuringPoissonBinomial" begin - d1 = TuringPoissonBinomial([0.5, 0.5]) - d2 = PoissonBinomial([0.5, 0.5]) - @test quantile(d1, 0.5) == quantile(d2, 0.5) - @test minimum(d1) == minimum(d2) -end - -@testset "Inverse of pi" begin - @test 1/pi == inv(pi) -end - -@testset "Others" begin - @test fill(param(1.0), 3) isa TrackedArray - x = rand(3) - @test isapprox(Tracker.data(Tracker.gradient(logsumexp, x)[1]), - ForwardDiff.gradient(logsumexp, x), atol = 1e-5) - A = rand(3, 3)'; A = A + A' + 3I; - C = cholesky(A; check = true) - factors, info = DistributionsAD.turing_chol(A, true) - @test factors == C.factors - @test info == C.info - B = copy(A) - @test DistributionsAD.zygote_ldiv(A, B) == A \ B -end - -@testset "Entropy" begin - sigmas = exp.(randn(10)) - d1 = TuringDiagMvNormal(zeros(10), sigmas) - d2 = MvNormal(zeros(10), sigmas) - - @test entropy(d1) == entropy(d2) -end - -@testset "Params" begin - m = rand(10) - sigmas = randexp(10) - - d = TuringDiagMvNormal(m, sigmas) - @test params(d) == (m, sigmas) - - d = TuringScalMvNormal(m, sigmas[1]) - @test params(d) == (m, sigmas[1]) end \ No newline at end of file diff --git a/test/staging.jl b/test/staging.jl new file mode 100644 index 00000000..e69de29b diff --git a/test/test_utils.jl b/test/test_utils.jl index fc9225bb..d7ac0759 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -134,20 +134,61 @@ function get_all_functions(dist::DistSpec, continuous=false) return fs end +# Taken from Turing.jl +function get_stage() + if get(ENV, "TRAVIS", "") == "true" || get(ENV, "GITHUB_ACTIONS", "") == "true" + if "STAGE" in keys(ENV) + return ENV["STAGE"] + else + return "all" + end + end + + return "all" +end + function test_ad(f, at = 0.5; rtol = 1e-8, atol = 1e-8) - isarr = isa(at, AbstractArray) - reverse_tracker = Tracker.data(Tracker.gradient(f, at)[1]) - reverse_zygote = Zygote.gradient(f, at)[1] - if isarr - forward = ForwardDiff.gradient(f, at) - @test isapprox(reverse_tracker, forward, rtol=rtol, atol=atol) - @test isapprox(reverse_zygote, forward, rtol=rtol, atol=atol) + stg = get_stage() + if stg == "all" + isarr = isa(at, AbstractArray) + reverse_tracker = Tracker.data(Tracker.gradient(f, at)[1]) + reverse_zygote = Zygote.gradient(f, at)[1] + if isarr + forward = ForwardDiff.gradient(f, at) + @test isapprox(reverse_tracker, forward, rtol=rtol, atol=atol) + @test isapprox(reverse_zygote, forward, rtol=rtol, atol=atol) + else + forward = ForwardDiff.derivative(f, at) + finite_diff = central_fdm(5,1)(f, at) + @test isapprox(reverse_tracker, forward, rtol=rtol, atol=atol) + @test isapprox(reverse_tracker, finite_diff, rtol=rtol, atol=atol) + @test isapprox(reverse_zygote, finite_diff, rtol=rtol, atol=atol) + end + elseif stg == "ForwardDiff_Tracker" + isarr = isa(at, AbstractArray) + reverse_tracker = Tracker.data(Tracker.gradient(f, at)[1]) + if isarr + forward = ForwardDiff.gradient(f, at) + @test isapprox(reverse_tracker, forward, rtol=rtol, atol=atol) + else + forward = ForwardDiff.derivative(f, at) + finite_diff = central_fdm(5,1)(f, at) + @test isapprox(reverse_tracker, forward, rtol=rtol, atol=atol) + @test isapprox(reverse_tracker, finite_diff, rtol=rtol, atol=atol) + end + elseif stg == "Zygote" + isarr = isa(at, AbstractArray) + reverse_zygote = Zygote.gradient(f, at)[1] + if isarr + forward = ForwardDiff.gradient(f, at) + @test isapprox(reverse_zygote, forward, rtol=rtol, atol=atol) + else + forward = ForwardDiff.derivative(f, at) + finite_diff = central_fdm(5,1)(f, at) + @test isapprox(reverse_zygote, finite_diff, rtol=rtol, atol=atol) + end else - forward = ForwardDiff.derivative(f, at) - finite_diff = central_fdm(5,1)(f, at) - @test isapprox(reverse_tracker, forward, rtol=rtol, atol=atol) - @test isapprox(reverse_tracker, finite_diff, rtol=rtol, atol=atol) - @test isapprox(reverse_zygote, finite_diff, rtol=rtol, atol=atol) + throw("Unsupported test stage.") end end