Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrap Dict Tangents for ChainRules #1288

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

oxinabox
Copy link
Member

Can someone link me to the issues about this?

This path isn't well explored in ChainRules, but it is defined.
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/fbb4936204cb1d857c2dd41eac4bd7bf497771b2/src/tangent_types/tangent.jl#L56-L58

This is the code Zygote uses for accumulating,

function accum(a::AbstractDict, b::AbstractDict)
@assert a === b
return a
end

which actually looks like it is different to what ChainRulesCore will do?
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/fbb4936204cb1d857c2dd41eac4bd7bf497771b2/src/tangent_types/tangent.jl#L334

TODO:

  • tests

@mzgubic
Copy link
Collaborator

mzgubic commented Aug 17, 2022

I don't know how accum of the dicts works, but it seems to be done somewhere else.

julia> gradient(d -> d["one"]*2 + d["one"]*3, Dict("one"=>1, "two"=>2))
(a, b) = (Dict{Any, Any}("one" => 5.0), Dict{Any, Any}("one" => 5.0))
(Dict{Any, Any}("one" => 5.0),)

where I've added @show a, b in that accum method. So it looks like it's been accumulated somewhere earlier already?

@jgreener64
Copy link
Contributor

This is the issue from the Slack thread: JuliaDiff/ChainRules.jl#662.

Note that testing that issue with this PR requires the Molly master branch.

@ToucheSir
Copy link
Member

So it looks like it's been accumulated somewhere earlier already?

That's correct. Accumulation for Dicts (and mutable structs) is pretty weird in Zygote because of the need to support setindex!/setfield! on them. In this case, the accumulation is happening in

grad[k] = accum(get(grad, k, nothing), Δ)
, which looks up the actual tangent value from a (context-global) cache.

@oxinabox
Copy link
Member Author

We could probably add that special case of === do nothing to ChainRulesCore.
Its certainly not a valid thing that should normally happen AFAIK -- accumulating a gradient against itself.
I feel like it is a bug that Zygote doesn't then return nothing in the place it has already done the accumulation but that's a different issue

@ToucheSir
Copy link
Member

Ideally CR wouldn't have to make any changes here, will have to look into the PR + original issue in more depth though.

I feel like it is a bug that Zygote doesn't then return nothing in the place it has already done the accumulation but that's a different issue

It should be doing this. If there's any rule where it's not then I think that should be called a bug.

@ToucheSir
Copy link
Member

The z2d conversion looks good. All that's left other than tests is the other end in wrap_chainrules_output, I take it?

@oxinabox
Copy link
Member Author

It should be doing this. If there's any rule where it's not then I think that should be called a bug.

Here it doesn't seem to

grad[k] = accum(get(grad, k, nothing), Δ)
return (grad, nothing)

@oxinabox
Copy link
Member Author

The z2d conversion looks good. All that's left other than tests is the other end in wrap_chainrules_output, I take it?

Oh true, I forgot that.

@ToucheSir
Copy link
Member

Ah, you are right and I was mixing up @adjoint and rrule return conventions again. Making those pullbacks return nothing would be an interesting experiment, though AFAICT they don't cause any issues right now?

@ToucheSir
Copy link
Member

ToucheSir commented Sep 6, 2022

After testing locally with the line in #1288 (comment) returning nothing for the Dict itself,

Zygote.jl/test/features.jl

Lines 585 to 586 in 4183226

d = Dict(:x=>1.0, :y=>3.0);
@test gradient(d -> Dict(:x => d[:x])[:x], d) == (Dict(:x => 1),)
fails:

Pairs: Test Failed at ~/.julia/dev/Zygote/test/features.jl:586
  Expression: gradient((d->begin
                (Dict(:x => d[:x]))[:x]
            end), d) == (Dict(:x => 1),)
   Evaluated: (nothing,) == (Dict(:x => 1),)

Were there more tests taking differentiating wrt a Dict argument, there would likely be more failures. But, quelle surprise, Zygote has just the one.

Note that this is just for Dicts. Doing the same for (literal_)getindex on mutable types breaks a whole lot more tests, but that can be a discussion for another day.

@ToucheSir
Copy link
Member

For reference, this is the test that fails when one removes the accum overload linked up top:
https://github.com/FluxML/Zygote.jl/blob/v0.6.47/test/interface.jl#L246-L254

Again, the only reason there's a single failure is because there is only one test for nested diff with implicit params. Is there perhaps a way for us to constrain this accum call so that it's only valid in that circumstance?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants