Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix maporbroadcast for ReverseDiff #142

Closed
wants to merge 5 commits into from
Closed

Conversation

mohamed82008
Copy link
Member

This PR fixes TuringLang/DistributionsAD.jl#217 by making sure broadcasting is used instead of mapping for ReverseDiff. Broadcasting mixed TrackedArrays and arrays of TrackedReals breaks Tracker but ReverseDiff can handle it.

@codecov
Copy link

codecov bot commented Oct 10, 2020

Codecov Report

Merging #142 (8be3f14) into master (a70914e) will decrease coverage by 0.19%.
The diff coverage is 0.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     TuringLang/Turing.jl#142      +/-   ##
==========================================
- Coverage   54.98%   54.79%   -0.20%     
==========================================
  Files          27       27              
  Lines        1726     1732       +6     
==========================================
  Hits          949      949              
- Misses        777      783       +6     
Impacted Files Coverage Δ
src/Bijectors.jl 69.73% <ø> (+0.90%) ⬆️
src/compat/tracker.jl 40.90% <0.00%> (-1.22%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a70914e...4bfdc81. Read the comment docs.

@devmotion
Copy link
Member

We should add some tests here, it seems the existing ones did not cover the problem. In what way is maporbroadcast actually covered by any tests? Do tests fail if e.g. map is used for all arrays?

@mohamed82008
Copy link
Member Author

Do tests fail if e.g. map is used for all arrays?

No but we get the problem in the Turing issue which is an array of TrackedReals when applying a 1-dim TruncatedBijector. maporbroadcast is tested but only for correctness not the output type. I can add that issue's case as a test case.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comments. IMO the only thing needed seems to be a special implementation of maporbroadcast for ReverseDiff, I don't think it's good to change the default implementation.

In general, it seems the issue with map is a more general problem of ReverseDiff (and not of Bijectors) and ideally should be fixed there. But I'm fine with a workaround for ReverseDiff for now.

src/compat/reversediff.jl Outdated Show resolved Hide resolved
src/compat/reversediff.jl Outdated Show resolved Hide resolved
src/compat/tracker.jl Show resolved Hide resolved
test/ad/distributions.jl Outdated Show resolved Hide resolved
Comment on lines -246 to -247
# Broadcasting here breaks Tracker for some reason
maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think it would make sense to keep this definition (and only change it for ReverseDiff). If the arrays are of the same dimension, it should be simpler to not invoke the broadcast machinery (for which adjoints require quite complex implementations e.g. in Tracker and Zygote that even make use of ForwardDiff).

Copy link
Member Author

@mohamed82008 mohamed82008 Oct 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracker's map also reduces to an array of TrackedReals even when the input is a TrackedArray. So broadcasting by default is needed for 2 out of 3 AD backends. So it makes sense as the default. Only when we have mixed TrackedArrays and arrays of TrackedReals that Tracker fails hence the need for the map workaround. ReverseDiff broadcasting doesn't fail in this case.

Comment on lines +18 to +26
maporbroadcast(f, x::Union{TrackedArray, AbstractArray{<:TrackedReal}}...) = map(f, x...)
maporbroadcast(f, x1::TrackedT, x2::AbstractArray) = map(f, x1, x2)
maporbroadcast(f, x1::AbstractArray, x2::TrackedT) = map(f, x1, x2)
maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::AbstractArray) = map(f, x1, x2, x3)
maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3)
maporbroadcast(f, x1::AbstractArray, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3)
maporbroadcast(f, x1::TrackedT, x2::TrackedT, x3::AbstractArray) = map(f, x1, x2, x3)
maporbroadcast(f, x1::AbstractArray, x2::TrackedT, x3::TrackedT) = map(f, x1, x2, x3)
maporbroadcast(f, x1::TrackedT, x2::AbstractArray, x3::TrackedT) = map(f, x1, x2, x3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem like a robust solution, and needs tests. I still think the easiest solution would be to just change the ReverseDiff part in this PR - since this is what you want to fix and what you added a test for.

test/runtests.jl Outdated Show resolved Hide resolved
Comment on lines +40 to +42
dist = arraydist(truncated.(Laplace.(0, [1, 2]), -10.0, 70.0))
x = ReverseDiff.track(rand(dist))
@test typeof(bijector(dist)(x)) <: ReverseDiff.TrackedArray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this test be simplified to just check the output of maporbroadcast? Or at least a test with only maporbroadcast be added?

Co-authored-by: David Widmann <[email protected]>
@yebai
Copy link
Member

yebai commented May 25, 2021

@mohamed82008 can you take another look, and see whether we can address the comments then merge this PR?

@yebai
Copy link
Member

yebai commented Feb 3, 2022

I'm closing this PR since arraydist and filldist will be replaced by Product in the future.

@yebai yebai closed this Feb 3, 2022
@yebai yebai deleted the mt/fix_turing_1385 branch February 3, 2022 17:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

arraydist is showing unintended behavior.
3 participants