-
-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add trainables_nt
#175
base: master
Are you sure you want to change the base?
Add trainables_nt
#175
Conversation
Keeping differentiability aside, is |
Exactly. And we need a nested namedtuple-only return in order to be compatible with ComponentArrays. |
What about replacing |
Wait there are two big differences from fmapstructure / Flux.state:
ComponentArrays has no notion of shared parameters. That's a large part of what makes everything touching Functors tricky. (In fact the replacement of a vector with a NamedTuple opens the door to weirdness here, before you get to ComponentArrays, as you replace a mutable thing with an immutable one. Probably not in a way that matters for Flux models.) Example with this: julia> sh = [1f0, 2f0];
julia> ps, re = trainables_nt((sh, sh, [3,4.]))
((_1 = Float32[1.0, 2.0], _2 = Float32[1.0, 2.0], _3 = [3.0, 4.0]), Optimisers.RestructureFromNT{Tuple{Vector{Float32}, Vector{Float32}, Vector{Float64}}}((Float32[1.0, 2.0], Float32[1.0, 2.0], [3.0, 4.0])))
julia> ps._1 === ps._2
true
julia> v = ComponentVector(ps);
julia> getfield(v, :data) |> println
[1.0, 2.0, 1.0, 2.0, 3.0, 4.0]
julia> v[3] = 99;
julia> re(v) # sharing is broken
([1.0, 2.0], [99.0, 2.0], [3.0, 4.0]) And unrelated to sharing: julia> re(v)[1] |> eltype # accidental promotion is back
Float64
julia> re(v)[1] # no copy on reconstruction, but will view(::CuArray) work everywhere?
2-element view(::Vector{Float64}, 1:2) with eltype Float64:
1.0
2.0 cf destructure: julia> v2, re2 = destructure((sh, sh, [3,4.]))
([1.0, 2.0, 3.0, 4.0], Restructure(Tuple, ..., 4))
julia> v2[2] = 999;
julia> re2(v2)
(Float32[1.0, 999.0], Float32[1.0, 999.0], [3.0, 4.0]) When last I looked, ComponentArrays it also made more whole copies in the gradient. More broadly, what's this for? Why do we care about ComponentArrays? |
I would like to have something in the |
I need help with the using Zygote, Optimisers, ComponentArrays, Test
m = (collect(1:3.0), collect(4:6.0))
ps, re = trainables_nt(m)
Zygote.refresh()
gps = gradient(x -> re(x)[1][2], ps)[1]
@test gps == (_1 = [0.0, 1.0, 0.0], _2 = nothing). # ok
v = ComponentVector(ps)
gv = gradient(x -> re(x)[1][2], v)[1] # this is `nothing`!!!! The relevant rule is function ChainRulesCore.rrule(::typeof(restructure_from_nt), x, ps)
model = restructure_from_nt(x, ps)
proj_ps = ProjectTo(ps)
function restructure_from_nt_back(Δmodel_raw)
Δmodel = unthunk(Δmodel_raw)
walk = RestructureFromNamedTupleBackWalk()
function exclude(x)
@show "exclude" x isnumeric(x)
# i += 1
# return i > 1
return isnumeric(x)
end
Δps = fmap(ps, Δmodel; exclude, walk, cache=nothing) do p, Δ
@show "fmap" Δ p
return Δ
end
Δpst = Tangent{typeof(Δps)}(; Δps...)
@show "rrule" Δmodel x ps Δps Δpst #here Δp = (_1 = [0.0, 1.0, 0.0], _2 = ChainRulesCore.ZeroTangent())
@show typeof(Δmodel) typeof(ps) typeof(Δps)
return (NoTangent(), NoTangent(), Δps)
# return (NoTangent(), NoTangent(), proj_ps(Δpst))
end
return model, restructure_from_nt_back
end
struct RestructureFromNamedTupleBackWalk <: AbstractWalk end
function (::RestructureFromNamedTupleBackWalk)(recurse, ps, Δmodel)
@show 1 typeof(Δmodel) typeof(ps)
Δm = make_named_tuple(Δmodel)
@show 2 typeof(Δm) ps Δm
Δm === nothing && return nothing
Δm === ZeroTangent() && return ZeroTangent()
y = mapvalue(recurse, ps, Δm)
@show 3 typeof(Δmodel) typeof(Δm) typeof(y)
return y
end Why do I get |
This is a proposal for an alternative to
destructure
which doesn't completely flatten the parameters but returns a nested named tuple. The associated reconstructor can be be used onComponentArray
s as well.