diff --git a/src/Bijectors.jl b/src/Bijectors.jl index e435ebf7..6991d446 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -243,8 +243,6 @@ end include("interface.jl") -# Broadcasting here breaks Tracker for some reason -maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) # optional dependencies diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index cb8d6549..f0dea430 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -12,14 +12,18 @@ using .Tracker: Tracker, using Compat: eachcol using LinearAlgebra +# Broadcasting here breaks Tracker +const TrackedT = Union{TrackedArray, AbstractArray{<:TrackedReal}} maporbroadcast(f, x::TrackedArray...) = f.(x...) -function maporbroadcast( - f, - x1::TrackedArray{T, N}, - x::AbstractArray{<:TrackedReal}..., -) where {T, N} - return f.(convert(Array{TrackedReal{T}, N}, x1), x...) -end +maporbroadcast(f, x::Union{TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...) +maporbroadcast(f, x1::TrackedT, x2::AbstractArray) = map(f, x1, x2) +maporbroadcast(f, x1::AbstractArray, x2::TrackedT) = map(f, x1, x2) +maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::AbstractArray) = map(f, x1, x2, x3) +maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3) +maporbroadcast(f, x1::AbstractArray, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3) +maporbroadcast(f, x1::TrackedT, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3) +maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::TrackedT) = map(f, x1, x2, x3) +maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3) _eps(::Type{<:TrackedReal{T}}) where {T} = _eps(T) function Base.minimum(d::LocationScale{<:TrackedReal}) diff --git a/test/runtests.jl b/test/runtests.jl index 47f36d95..3e4a364a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,5 +35,11 @@ end if !is_TRAVIS && (GROUP == "All" || GROUP == "AD") include("ad/distributions.jl") + if AD == "All" || AD == "ReverseDiff" + @testset "Turing issue 1385" begin + dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0)) + x = ReverseDiff.track(rand(dist)) + @test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray + end + end end -