-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrap Dict Tangents for ChainRules #1288
base: master
Are you sure you want to change the base?
Conversation
I don't know how julia> gradient(d -> d["one"]*2 + d["one"]*3, Dict("one"=>1, "two"=>2))
(a, b) = (Dict{Any, Any}("one" => 5.0), Dict{Any, Any}("one" => 5.0))
(Dict{Any, Any}("one" => 5.0),) where I've added |
This is the issue from the Slack thread: JuliaDiff/ChainRules.jl#662. Note that testing that issue with this PR requires the Molly master branch. |
That's correct. Accumulation for Dicts (and mutable structs) is pretty weird in Zygote because of the need to support Line 36 in 5c80f55
|
We could probably add that special case of |
Ideally CR wouldn't have to make any changes here, will have to look into the PR + original issue in more depth though.
It should be doing this. If there's any rule where it's not then I think that should be called a bug. |
The z2d conversion looks good. All that's left other than tests is the other end in |
Here it doesn't seem to Lines 36 to 37 in 5c80f55
|
Oh true, I forgot that. |
Ah, you are right and I was mixing up |
After testing locally with the line in #1288 (comment) returning Lines 585 to 586 in 4183226
Were there more tests taking differentiating wrt a Note that this is just for Dicts. Doing the same for |
For reference, this is the test that fails when one removes the accum overload linked up top: Again, the only reason there's a single failure is because there is only one test for nested diff with implicit params. Is there perhaps a way for us to constrain this accum call so that it's only valid in that circumstance? |
Can someone link me to the issues about this?
This path isn't well explored in ChainRules, but it is defined.
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/fbb4936204cb1d857c2dd41eac4bd7bf497771b2/src/tangent_types/tangent.jl#L56-L58
This is the code Zygote uses for accumulating,
Zygote.jl/src/lib/base.jl
Lines 26 to 29 in c335d6d
which actually looks like it is different to what ChainRulesCore will do?
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/fbb4936204cb1d857c2dd41eac4bd7bf497771b2/src/tangent_types/tangent.jl#L334
TODO: