Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trainables_nt #175

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.3.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Debugger = "31a5f54b-26ea-5ae9-a837-f05ce5417438"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -18,9 +19,10 @@ Zygote = "0.6.40"
julia = "1.6"

[extras]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "StaticArrays", "Zygote"]
test = ["Test", "ComponentArrays", "StaticArrays", "Zygote"]
3 changes: 2 additions & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Functors: functor, fmap, fmap_with_path,
isleaf, @functor, fmapstructure, children, AbstractWalk
using LinearAlgebra


include("interface.jl")
export AbstractRule

Expand All @@ -16,7 +17,7 @@ include("destructure.jl")
export destructure

include("trainables.jl")
export trainables
export trainables, trainables_nt
export KeyPath, haskeypath, getkeypath # from Functors.jl

include("rules.jl")
Expand Down
208 changes: 206 additions & 2 deletions src/trainables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ end

function ∇trainables(x, Δ)
i = 0
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
return fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
return Δ[i+=1]
end
end
Expand Down Expand Up @@ -113,7 +113,7 @@ end

function ∇trainables_with_path(x, Δ)
i = 0
return fmapstructure(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
return fmap(x; exclude = isnumeric, walk = TrainableStructWalk()) do _
Δi = Δ[i+=1]
if isnothing(Δi)
return nothing
Expand All @@ -122,3 +122,207 @@ function ∇trainables_with_path(x, Δ)
end
end
end


### trainables_nt ######################

"""
trainables_nt(model) -> ps, re

Return a pair `(ps, re)` where `ps` is a nested named tuple with the same structure as
the trainable part of `model` and with leaves the trainable parameters.

Parameters are not copied, but the returned `ps` is a view into the original model.

The `re` is a function that reconstructs a model from the parameters,
i.e. `re(ps)` is the same as the origin `model` but with the trainable parameters replaced by `ps`.

# Examples

```jldoctest
julia> using Flux, Optimisers

julia> model = Chain(Dense(784, 32, relu), Dense(32, 10));

julia> ps, re = trainables_nt(model);

julia> ps.layers._1.weight === model.layers[1].weight
true
```

```jldoctest

julia> v = ComponentVector(ps)

julia> model2 = re(2 * v)
Chain(
Dense(784 => 32, relu), # 25_120 parameters
Dense(32 => 10), # 330 parameters
) # Total: 4 arrays, 25_450 parameters, 100.281 KiB.
```
"""
function trainables_nt(model)
ps = _trainables_nt(model)
re = RestructureFromNT(model)
return ps, re
end

function _trainables_nt(x)
walknt = TrainableNamedTupleWalk()
ps = fmap(identity, x; exclude=isnumeric, walk=walknt, cache=nothing)
return ps
end

function ChainRulesCore.rrule(::typeof(_trainables_nt), model)
ps = _trainables_nt(model)
function _trainables_nt_back(Δps)
walk = TrainableNamedTupleBackWalk()
Δmodel = fmap(model, Δps; exclude=isnumeric, walk, cache=nothing) do x, Δ
return Δ
end
return (NoTangent(), Δmodel)
end
return ps, _trainables_nt_back
end


struct TrainableNamedTupleWalk <: AbstractWalk end

function (::TrainableNamedTupleWalk)(recurse, x)
ch = trainable(x)
y = map(recurse, make_named_tuple(ch))
return y
end

struct TrainableNamedTupleBackWalk <: AbstractWalk end

function (::TrainableNamedTupleBackWalk)(recurse, model, Δps)
# @show 1 typeof(model) typeof(Δps)
ch = trainable(model)
Δ = unmake_named_tuple(ch, Δps)
# @show 2 typeof(ch) typeof(Δ)
Δ === nothing && return nothing
Δ === ZeroTangent() && return ZeroTangent()
y = mapvalue(recurse, ch, Δ)
# @show 3 typeof(model) typeof(ch) typeof(Δ) typeof(y)
return y
end


struct RestructureFromNT{T}
x::T
end

(re::RestructureFromNT)(ps) = restructure_from_nt(re.x, ps)

function restructure_from_nt(model, ps)
walk = RestructureFromNamedTupleWalk()
return fmap(model, ps; exclude=isnumeric, walk, cache=nothing) do x, p
return p
end
end

struct RestructureFromNamedTupleWalk <: AbstractWalk end

function (::RestructureFromNamedTupleWalk)(recurse, x, nt)
children, re = functor(x)
newchildren = map_commons(recurse, children, nt)
return re(newchildren)
end

function ChainRulesCore.rrule(::typeof(restructure_from_nt), x, ps)
model = restructure_from_nt(x, ps)
proj_ps = ProjectTo(ps)

function restructure_from_nt_back(Δmodel_raw)
Δmodel = unthunk(Δmodel_raw)
walk = RestructureFromNamedTupleBackWalk()
function exclude(x)
@show "exclude" x isnumeric(x)
# i += 1
# return i > 1
return isnumeric(x)
end
Δps = fmap(ps, Δmodel; exclude, walk, cache=nothing) do p, Δ
@show "fmap" Δ p

return Δ
end
@show "rrule" Δmodel x ps Δps
@show typeof(Δmodel) typeof(ps) typeof(Δps)
Δps = (_1=ones(3), _2=zeros(3))
Δpst = Tangent{typeof(Δps)}(; Δps...)
# pR
return (NoTangent(), NoTangent(), proj_ps(Δpst))
end
return model, restructure_from_nt_back
end

struct RestructureFromNamedTupleBackWalk <: AbstractWalk end

function (::RestructureFromNamedTupleBackWalk)(recurse, ps, Δmodel)
@show 1 typeof(Δmodel) typeof(ps)
Δm = make_named_tuple(Δmodel)
@show 2 typeof(Δm) ps Δm
# Δm isa Float64 && return Δm
# Δm isa Array && return Δm
# ps isa Float64 && return ps
# ps isa Array && return ps
# return nothing
Δm === nothing && return nothing
Δm === ZeroTangent() && return ZeroTangent()
y = mapvalue(recurse, ps, Δm)
@show 3 typeof(Δmodel) typeof(Δm) typeof(y)
return y
end

function map_commons(f, x::NamedTuple{xkeys}, y) where {xkeys}
ykeys = propertynames(y)
vals = map(k -> k in ykeys ? f(x[k], getproperty(y, k)) : x[k], xkeys)
return NamedTuple{xkeys}(vals)
end

function map_commons(f, x::Tuple, y)
ykeys = propertynames(y)
vals = ntuple(length(x)) do i
k = Symbol("_", i)
k in ykeys ? f(x[i], getproperty(y, k)) : x[i]
end
return vals
end

function map_commons(f, x::Vector, y)
ykeys = propertynames(y)
vals = map(1:length(x)) do i
k = Symbol("_", i)
k in ykeys ? f(x[i], getproperty(y, k)) : x[i]
end
return vals
end

make_named_tuple(x) = x
make_named_tuple(x::AbstractDict{Symbol}) = NamedTuple(x)
make_named_tuple(x::AbstractDict) = NamedTuple(Symbol("_", k) => v for (k, v) in pairs(x))
make_named_tuple(x::Tuple) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x)
make_named_tuple(x::Vector) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x)

make_named_tuple(x::Tangent{<:Any,<:NamedTuple}) = x
make_named_tuple(x::Tangent{<:Any,<:AbstractDict{Symbol}}) = NamedTuple(x)
make_named_tuple(x::Tangent{<:Any,<:AbstractDict}) = NamedTuple(Symbol("_", k) => v for (k, v) in pairs(x))
make_named_tuple(x::Tangent{<:Any,<:Tuple}) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x)
make_named_tuple(x::Tangent{<:Any,<:Vector}) = NamedTuple{ntuple(i -> Symbol("_",i), length(x))}(x)


unmake_named_tuple(x::NamedTuple, ps) = ps

function unmake_named_tuple(x::Tuple, ps)
return ntuple(length(x)) do i
ps[Symbol("_", i)]
end
end

function unmake_named_tuple(x::Vector, ps)
return map(1:length(x)) do i
ps[Symbol("_", i)]
end
end
8 changes: 8 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@

mapvalue(f, x...) = map(f, x...)
mapvalue(f, x::NamedTuple, ys::NamedTuple...) = map(f, x, ys...)
mapvalue(f, x, y::NamedTuple{ykeys}) where {ykeys} =
NamedTuple{ykeys}((f(getproperty(x ,k), yk) for (k, yk) in pairs(y))) # used in rrule for restructure_from_nt

mapvalue(f, x::Dict, ys...) = Dict(k => f(v, (get(y, k, nothing) for y in ys)...) for (k,v) in x)

# without theses, tuples are returned instead of NamedTuples
mapvalue(f, x::NamedTuple{Ks}, y::Tangent{<:Any,<:NamedTuple}) where {Ks} =
NamedTuple{Ks}((f(v, y[k]) for (k,v) in pairs(x)))

mapkey(f, x::NamedTuple{Ks}) where Ks = NamedTuple{Ks}(map(f, Ks))
mapkey(f, x::Dict) = Dict(k => f(k) for k in keys(x))
mapkey(f, x::Tuple) = ntuple(i -> f(i), length(x))
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using ChainRulesCore, Functors, StaticArrays, Zygote
using LinearAlgebra, Statistics, Test, Random
using Optimisers: @.., @lazy
using Base.Broadcast: broadcasted, instantiate, Broadcasted
using ComponentArrays: ComponentArrays, ComponentVector

Random.seed!(1)

Expand Down
Loading
Loading