From f074b8da7d7f2c1a6a0c0f98350e31e81cf9004e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 31 Mar 2021 02:20:13 +0200 Subject: [PATCH] Remove workaround for view with ReverseDiff (#159) --- Project.toml | 2 +- src/reversediff.jl | 17 ----------------- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index 4cba67ab..3704083d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.20" +version = "0.6.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/reversediff.jl b/src/reversediff.jl index 6e853d3b..21c6285d 100644 --- a/src/reversediff.jl +++ b/src/reversediff.jl @@ -56,17 +56,6 @@ include("reversediffx.jl") adapt_randn(rng::Random.AbstractRNG, x::TrackedArray, dims...) = adapt_randn(rng, value(x), dims...) -# without this definition tests of `VectorOfMultivariate` with `Dirichlet` fail -# upstream bug caused by `view` + `track`: https://github.com/JuliaDiff/ReverseDiff.jl/pull/164 -function _logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:TrackedReal}) - return sum(i -> _logpdf(dist.dists[i], x[:, i]), axes(x, 2)) -end - -# fix method ambiguity -function _logpdf(dist::FillVectorOfMultivariate, x::AbstractMatrix{<:TrackedReal}) - return loglikelihood(dist.dists.value, x) -end - function PoissonBinomial(p::TrackedArray{<:Real}; check_args=true) return TuringPoissonBinomial(p; check_args = check_args) end @@ -269,12 +258,6 @@ function loglikelihood(d::Dirichlet, x::AbstractMatrix{<:TrackedReal}) return loglikelihood(TuringDirichlet(d), x) end -# default definition of `loglikelihood` yields gradients of zero?! -# upstream bug caused by `view` + `track`: https://github.com/JuliaDiff/ReverseDiff.jl/pull/164 -function loglikelihood(d::TuringDirichlet, x::AbstractMatrix{<:TrackedReal}) - return sum(i -> logpdf(d, x[:, i]), axes(x, 2)) -end - for func_header in [ :(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractVector)), :(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractVector)),