-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix maporbroadcast for ReverseDiff #142
Conversation
Codecov Report
@@ Coverage Diff @@
## master TuringLang/Turing.jl#142 +/- ##
==========================================
- Coverage 54.98% 54.79% -0.20%
==========================================
Files 27 27
Lines 1726 1732 +6
==========================================
Hits 949 949
- Misses 777 783 +6
Continue to review full report at Codecov.
|
We should add some tests here, it seems the existing ones did not cover the problem. In what way is |
No but we get the problem in the Turing issue which is an array of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some comments. IMO the only thing needed seems to be a special implementation of maporbroadcast
for ReverseDiff, I don't think it's good to change the default implementation.
In general, it seems the issue with map
is a more general problem of ReverseDiff (and not of Bijectors) and ideally should be fixed there. But I'm fine with a workaround for ReverseDiff for now.
# Broadcasting here breaks Tracker for some reason | ||
maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still think it would make sense to keep this definition (and only change it for ReverseDiff). If the arrays are of the same dimension, it should be simpler to not invoke the broadcast machinery (for which adjoints require quite complex implementations e.g. in Tracker and Zygote that even make use of ForwardDiff).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tracker's map also reduces to an array of TrackedReals even when the input is a TrackedArray. So broadcasting by default is needed for 2 out of 3 AD backends. So it makes sense as the default. Only when we have mixed TrackedArrays and arrays of TrackedReals that Tracker fails hence the need for the map workaround. ReverseDiff broadcasting doesn't fail in this case.
Co-authored-by: David Widmann <[email protected]>
maporbroadcast(f, x::Union{TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...) | ||
maporbroadcast(f, x1::TrackedT, x2::AbstractArray) = map(f, x1, x2) | ||
maporbroadcast(f, x1::AbstractArray, x2::TrackedT) = map(f, x1, x2) | ||
maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::AbstractArray) = map(f, x1, x2, x3) | ||
maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3) | ||
maporbroadcast(f, x1::AbstractArray, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3) | ||
maporbroadcast(f, x1::TrackedT, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3) | ||
maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::TrackedT) = map(f, x1, x2, x3) | ||
maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not seem like a robust solution, and needs tests. I still think the easiest solution would be to just change the ReverseDiff part in this PR - since this is what you want to fix and what you added a test for.
dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0)) | ||
x = ReverseDiff.track(rand(dist)) | ||
@test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this test be simplified to just check the output of maporbroadcast
? Or at least a test with only maporbroadcast
be added?
Co-authored-by: David Widmann <[email protected]>
@mohamed82008 can you take another look, and see whether we can address the comments then merge this PR? |
I'm closing this PR since |
This PR fixes TuringLang/DistributionsAD.jl#217 by making sure broadcasting is used instead of mapping for ReverseDiff. Broadcasting mixed TrackedArrays and arrays of TrackedReals breaks Tracker but ReverseDiff can handle it.