From 88affba5e78dec632f983b5d171bf72c741d64e2 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 10 Oct 2020 11:06:49 +1100 Subject: [PATCH 1/5] fix maporbroadcast for ReverseDiff --- src/Bijectors.jl | 2 -- src/compat/reversediff.jl | 4 +++- src/compat/tracker.jl | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index e435ebf7..6991d446 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -243,8 +243,6 @@ end include("interface.jl") -# Broadcasting here breaks Tracker for some reason -maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) # optional dependencies diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 61d61b03..9413cc35 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -1,7 +1,7 @@ module ReverseDiffCompat using ..ReverseDiff: ReverseDiff, @grad, value, track, TrackedReal, TrackedVector, - TrackedMatrix + TrackedMatrix, TrackedArray using Requires, LinearAlgebra using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian, @@ -46,6 +46,8 @@ function Base.maximum(d::LocationScale{<:TrackedReal}) end end +maporbroadcast(f, x::Union{AbstractArray, TrackedArray, AbstractArray{<:TrackedReal}}...) = f.(x...) + logabsdetjac(b::Log{1}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) @grad function logabsdetjac(b::Log{1}, x::AbstractVector) return -sum(log, value(x)), Δ -> (nothing, -Δ ./ value(x)) diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index cb8d6549..789d4f4c 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -12,6 +12,8 @@ using .Tracker: Tracker, using Compat: eachcol using LinearAlgebra +# Broadcasting here breaks Tracker for some reason +maporbroadcast(f, x::Union{AbstractArray, TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...) maporbroadcast(f, x::TrackedArray...) = f.(x...) function maporbroadcast( f, From c238ce60b0bed1735a82a9adf5c30a7253607205 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 11 Oct 2020 02:44:11 +1100 Subject: [PATCH 2/5] add test --- test/ad/distributions.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 7f017fe0..264ea727 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -550,4 +550,10 @@ ) end end + + @testset "Turing issue 1385" begin + dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0)) + x = ReverseDiff.track(rand(dist)) + @test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray + end end From b53c8a45d2f7c67935a3d5ef72cbc37e9dfd1284 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 11 Oct 2020 05:10:58 +1100 Subject: [PATCH 3/5] fix Tracker + ReverseDiff case --- src/compat/reversediff.jl | 2 -- src/compat/tracker.jl | 20 +++++++++++--------- test/ad/distributions.jl | 6 ------ test/runtests.jl | 7 +++++++ 4 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 9413cc35..d9fcb587 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -46,8 +46,6 @@ function Base.maximum(d::LocationScale{<:TrackedReal}) end end -maporbroadcast(f, x::Union{AbstractArray, TrackedArray, AbstractArray{<:TrackedReal}}...) = f.(x...) - logabsdetjac(b::Log{1}, x::Union{TrackedVector, TrackedMatrix}) = track(logabsdetjac, b, x) @grad function logabsdetjac(b::Log{1}, x::AbstractVector) return -sum(log, value(x)), Δ -> (nothing, -Δ ./ value(x)) diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index 789d4f4c..f0dea430 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -12,16 +12,18 @@ using .Tracker: Tracker, using Compat: eachcol using LinearAlgebra -# Broadcasting here breaks Tracker for some reason -maporbroadcast(f, x::Union{AbstractArray, TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...) +# Broadcasting here breaks Tracker +const TrackedT = Union{TrackedArray, AbstractArray{<:TrackedReal}} maporbroadcast(f, x::TrackedArray...) = f.(x...) -function maporbroadcast( - f, - x1::TrackedArray{T, N}, - x::AbstractArray{<:TrackedReal}..., -) where {T, N} - return f.(convert(Array{TrackedReal{T}, N}, x1), x...) -end +maporbroadcast(f, x::Union{TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...) +maporbroadcast(f, x1::TrackedT, x2::AbstractArray) = map(f, x1, x2) +maporbroadcast(f, x1::AbstractArray, x2::TrackedT) = map(f, x1, x2) +maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::AbstractArray) = map(f, x1, x2, x3) +maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3) +maporbroadcast(f, x1::AbstractArray, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3) +maporbroadcast(f, x1::TrackedT, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3) +maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::TrackedT) = map(f, x1, x2, x3) +maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3) _eps(::Type{<:TrackedReal{T}}) where {T} = _eps(T) function Base.minimum(d::LocationScale{<:TrackedReal}) diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 264ea727..7f017fe0 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -550,10 +550,4 @@ ) end end - - @testset "Turing issue 1385" begin - dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0)) - x = ReverseDiff.track(rand(dist)) - @test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray - end end diff --git a/test/runtests.jl b/test/runtests.jl index 47f36d95..2c17c6f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,5 +35,12 @@ end if !is_TRAVIS && (GROUP == "All" || GROUP == "AD") include("ad/distributions.jl") + if AD == "ReverseDiff" + @testset "Turing issue 1385" begin + dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0)) + x = ReverseDiff.track(rand(dist)) + @test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray + end + end end From 8be3f14dc974e04797486759f95ff602375b34f1 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sun, 11 Oct 2020 13:30:11 +1100 Subject: [PATCH 4/5] Update src/compat/reversediff.jl Co-authored-by: David Widmann --- src/compat/reversediff.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index d9fcb587..61d61b03 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -1,7 +1,7 @@ module ReverseDiffCompat using ..ReverseDiff: ReverseDiff, @grad, value, track, TrackedReal, TrackedVector, - TrackedMatrix, TrackedArray + TrackedMatrix using Requires, LinearAlgebra using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian, From 4bfdc8133f5f6eb8b61a57ad927038a2310286ad Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Thu, 28 Jan 2021 16:56:49 +0000 Subject: [PATCH 5/5] Update test/runtests.jl Co-authored-by: David Widmann --- test/runtests.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 2c17c6f0..3e4a364a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,7 +35,7 @@ end if !is_TRAVIS && (GROUP == "All" || GROUP == "AD") include("ad/distributions.jl") - if AD == "ReverseDiff" + if AD == "All" || AD == "ReverseDiff" @testset "Turing issue 1385" begin dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0)) x = ReverseDiff.track(rand(dist)) @@ -43,4 +43,3 @@ if !is_TRAVIS && (GROUP == "All" || GROUP == "AD") end end end -