-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix maporbroadcast for ReverseDiff #142
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,14 +12,18 @@ using .Tracker: Tracker, | |
using Compat: eachcol | ||
using LinearAlgebra | ||
|
||
# Broadcasting here breaks Tracker | ||
const TrackedT = Union{TrackedArray, AbstractArray{<:TrackedReal}} | ||
maporbroadcast(f, x::TrackedArray...) = f.(x...) | ||
mohamed82008 marked this conversation as resolved.
Show resolved
Hide resolved
mohamed82008 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function maporbroadcast( | ||
f, | ||
x1::TrackedArray{T, N}, | ||
x::AbstractArray{<:TrackedReal}..., | ||
) where {T, N} | ||
return f.(convert(Array{TrackedReal{T}, N}, x1), x...) | ||
end | ||
maporbroadcast(f, x::Union{TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...) | ||
maporbroadcast(f, x1::TrackedT, x2::AbstractArray) = map(f, x1, x2) | ||
maporbroadcast(f, x1::AbstractArray, x2::TrackedT) = map(f, x1, x2) | ||
maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::AbstractArray) = map(f, x1, x2, x3) | ||
maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3) | ||
maporbroadcast(f, x1::AbstractArray, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3) | ||
maporbroadcast(f, x1::TrackedT, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3) | ||
maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::TrackedT) = map(f, x1, x2, x3) | ||
maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3) | ||
Comment on lines
+18
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This does not seem like a robust solution, and needs tests. I still think the easiest solution would be to just change the ReverseDiff part in this PR - since this is what you want to fix and what you added a test for. |
||
|
||
_eps(::Type{<:TrackedReal{T}}) where {T} = _eps(T) | ||
function Base.minimum(d::LocationScale{<:TrackedReal}) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,5 +35,12 @@ end | |
|
||
if !is_TRAVIS && (GROUP == "All" || GROUP == "AD") | ||
include("ad/distributions.jl") | ||
if AD == "ReverseDiff" | ||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@testset "Turing issue 1385" begin | ||
dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0)) | ||
x = ReverseDiff.track(rand(dist)) | ||
@test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray | ||
Comment on lines
+40
to
+42
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this test be simplified to just check the output of |
||
end | ||
end | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think it would make sense to keep this definition (and only change it for ReverseDiff). If the arrays are of the same dimension, it should be simpler to not invoke the broadcast machinery (for which adjoints require quite complex implementations e.g. in Tracker and Zygote that even make use of ForwardDiff).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tracker's map also reduces to an array of TrackedReals even when the input is a TrackedArray. So broadcasting by default is needed for 2 out of 3 AD backends. So it makes sense as the default. Only when we have mixed TrackedArrays and arrays of TrackedReals that Tracker fails hence the need for the map workaround. ReverseDiff broadcasting doesn't fail in this case.