Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add total(f, model) to replace implicit sum(f, Flux.params(model)) #57

Closed
wants to merge 1 commit into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Feb 14, 2022

This proposes to add some kind of differentiable sum(f, trainable(x)) which walks the model. I'm not certain this is the right thing yet.

Right now this gets all trainable parameters. But perhaps a variant which takes a type total(f, Union{Dense, Conv}, model) might be a better explicit-parameters replacement for modules? Xref FluxML/Flux.jl#1863 (comment)

Closes FluxML/Functors.jl#35 , probably.

Edit: since I couldn't find this, big Flux issue about explicit parameters is FluxML/Flux.jl#1986 and snippet with a quick way to write total here: FluxML/Flux.jl#2040 (comment)

src/Optimisers.jl Outdated Show resolved Hide resolved
Comment on lines -110 to 132
if p isa ProjectTo # e.g. Array, NamedTuple
p(y)
else # p === identity for unknown structs
# if p isa ProjectTo # e.g. Array, NamedTuple
# p(y) # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538
if x isa Union{Number, AbstractArray} # these don't use Tangent
ProjectTo(x)(unthunk(y))
else
Tangent{typeof(x), typeof(y)}(y)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is either a bug in earlier _Tangent_biwalk, or in ChainRulesCore.

Comment on lines +176 to +204
function total(f, x)
values = []
fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z)))
sum(values)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While Any[] doesn't seem great, this ends up about the same speed as my other idea:

const INIT = Base._InitialValue()

function total2(f, x; init = INIT)
  fmap(x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) do y
    val = f(y)
    init = init===INIT ? val : (init+val)
  end
  init
end

julia> @btime total(norm, $model)  # Resnet from the docs
  min 23.863 ms, mean 23.995 ms (1541 allocations, 130.06 KiB)
730.5533f0

julia> @btime total2(norm, $model)
  min 23.834 ms, mean 23.982 ms (1538 allocations, 128.17 KiB)
730.5533f0

julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7));

julia> @btime total(norm, $m)
  min 1.750 μs, mean 1.846 μs (16 allocations, 752 bytes)
10.0

julia> @btime total2(norm, $m)
  min 1.675 μs, mean 1.769 μs (15 allocations, 640 bytes)
10.0

@rejuvyesh
Copy link

Should this be more general to allow computing the norm of the gradients as well?

@mcabbott mcabbott added the enhancement New feature or request label Jun 7, 2022
julia> total(norm, m)
10.0

julia> total(length, m) == length(destructure(m)[1])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would solve FluxML/Flux.jl#2043 (as long as trainable parameters are what you want).

Or total(Base.summarysize, m) for bytes, total(_ -> 1, m) to count arrays.

@ToucheSir
Copy link
Member

One idea for making this differentiable is to take the non-diff part (caching) out of fmap and instead creating a memoized callback wrapper type. Then fmap itself is trivially differentiable and the only remaining tricky bit is how to make that wrapper Nth times differentiable.

@mcabbott
Copy link
Member Author

Not sure I follow. If fmap were made differentiable, how would you write total(norm, model) using it?

@ToucheSir
Copy link
Member

fmap itself is not sufficient for this. My thought was that you could use it as the map part of a mapreduce, where cached leaves are replaced with some neutral element. Then the sum reduction requires no caching and can be made fully type stable. The idea behind this is to avoid the Any[] in https://github.com/FluxML/Optimisers.jl/pull/57/files#r806313306. Things won't be type stable under AD unfortunately, but I can see a possible path for making it so with changes to Functors internals.

@mcabbott
Copy link
Member Author

OK. I guess this PR's take is that since essentially nothing else about Functors.jl is type-stable, everything takes a few μs, there's not much point pushing hard here.

Decomposing into fmap then reduce ... the prune keyword lets you drop duplicates, but I've forgotten why fmapstructure is different here:

julia> twice = [1,2.0];

julia> m = (twice, [3,4.0], twice, sin, 99);

julia> fmap(x -> Some(length(x)), m, prune=nothing, exclude=(x->x isa Array))
(Some(2), Some(2), nothing, sin, 99)  # need Some to know to ignore 99

julia> fmapstructure(x -> length(x), m, prune=0, exclude=(x->x isa Array))
(2, 2, 0, (), ())  # easier

@ToucheSir
Copy link
Member

fmap is the only intrinsically type unstable part of Functors. Part of that is due to the IdDict cache, but another is the fragility of a chain of mutually recursive functions under type inference that we've accrued over time in the name of backwards compatibility. FluxML/Functors.jl#32 was in part exploring if type stable structural traversals in Functors are possible, and indeed some are (e.g. structural sum).

For this PR, my main ask is that it not cut off any paths which could bring us better type stability in the future. That doesn't seem to be the case, but I don't understand the (un)thunking well enough to say for sure. Minor comments would be the inclusion of an init kwarg and request for bikeshedding the name.

z, total_back
end

function _total_hobbit(config::RuleConfig, f, x)
Copy link
Member

@ToucheSir ToucheSir Aug 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A brief comment about what this and _total_grad do would help. "hobbit" in particular is alien terminology for anyone who hasn't read a couple of specific issues on the ChainRules repo 😛. Is there something more concise than _total_value_and_inner_pullbacks?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I rebased this but realised I have no memory of how it worked. Will revise or re-write.

@ToucheSir
Copy link
Member

Now that people are starting to use explicit params, we've seen a few instances where it would be nice to have a easy method for adding regularization terms. I believe this function should be easier to implement in a post-FluxML/Functors.jl#43 world too.

@mcabbott
Copy link
Member Author

#143 and #57 (comment) hint that the signature here should probably allow for total(f, model, grads) to mean roughly sum(f(x,dx) for ...). That opens the thorny question of what happens to dx when the same x appears twice, are they added? Does total(f, model, grads) need to be differentiable?

@ToucheSir
Copy link
Member

ToucheSir commented Apr 20, 2023

That opens the thorny question of what happens to dx when the same x appears twice, are they added?

Makes sense, though thinking about this stuff is always rather mind-bending.

Does total(f, model, grads) need to be differentiable?

I guess it would be future-proofing for models with nested differentiation? We could always kick this can down the road until someone needs this for Flux models.

@darsnack
Copy link
Member

That opens the thorny question of what happens to dx when the same x appears twice, are they added?

In the case of total, I think yes. But in the general case, it should apply the reduction operator supplied?

@darsnack
Copy link
Member

Also, I don't think adding total(f, model, grads) is worth holding this PR up.

@mcabbott
Copy link
Member Author

But in the general case, it should apply the reduction operator supplied?

Thinking more, I think do think they are always added.

If I do something like mapreduce(norm, max, ([1,2], [3 4 5])), then the reduction operator only makes sense on the output of the mapped function -- here max on two scalars. It cannot be applied to arrays.

In some fmapreduce(f, op, model, grads), the accumulation of distinct dx in grads is part of AD, it should always happen before f is called, using + (or e.g. Zygote.accum) not op.

@ablaom
Copy link

ablaom commented Sep 6, 2023

@mcabbott

I'm trying to adapt the total method proposed here for a different use case and wondered if the following was expected behaviour:

chain = Chain(Dense(3=>5), Dense(5=>1, relu))
f(A) = sum(abs2, A)

# old way of doing things:
sum(f.(Flux.params(chain)))
# 7.247263f0

# new way:
total(f, chain)
# ERROR: MethodError: no method matching isnumeric(::Chain{Tuple{Dense{typeof(identity), Mat# rix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}})

@mcabbott
Copy link
Member Author

mcabbott commented Sep 6, 2023

No, it's not... I suspect you have Base.isnumeric, there's an unfortunate name clash? Or perhaps this has just rotted, sorry. I mean to revisit but have been busy.

@ablaom
Copy link

ablaom commented Sep 6, 2023

Okay, nevermind. I was using Base.isnumeric instead of the local one.

@ablaom
Copy link

ablaom commented Sep 6, 2023

Thanks for the lightning reply ❤️

@darsnack
Copy link
Member

darsnack commented Feb 8, 2024

Bump on merging this? We still get regularization questions frequently.

@mcabbott
Copy link
Member Author

mcabbott commented Feb 8, 2024

Now rebased, but the tests tell me I need to remember what on earth _Tangent_biwalk does.

@mcabbott mcabbott changed the title RFC: add total Add total(f, model) to replace implicit sum(f, Flux.params(model)) Feb 8, 2024
@CarloLucibello
Copy link
Member

CarloLucibello commented Mar 30, 2024

I think that instead of introducing total, for more generality we should have a trainables(model) returning an iterable (also broadcastable and differentiable) over the training parameters. This can be used in expressions like sum(f, trainables(model)) and much more.

It could be implemented on top of Functors.fleaves

@CarloLucibello
Copy link
Member

This could be closed since we now have sum(f, trainables(model)).

@ToucheSir
Copy link
Member

Are we still running into recompilation issues using sum(f, trainables(model))? If not I'm fine with not pursuing this if @mcabbott is as well. Otherwise we may need to dig into whether the compilation issue is solvable or a fundamental flaw with Zygote.

@CarloLucibello
Copy link
Member

I don't think we have recompilation issues with sum(f, trainables(model)). Closing as it is better not increase the api surface when we have a neat way to do things.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

fmapreduce
6 participants