Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fmapreduce #35

Open
mcabbott opened this issue Feb 1, 2022 · 0 comments
Open

fmapreduce #35

mcabbott opened this issue Feb 1, 2022 · 0 comments

Comments

@mcabbott
Copy link
Member

mcabbott commented Feb 1, 2022

This package probably wants a way to write mapreduce, to replace e.g. sum(norm(p) for p in params(m)) in Flux. This seems like the minimal attempt, but it's not Zygote-friendly. Can this be fixed, and is there a better way?

julia> using Functors, Zygote

julia> const INIT = Base._InitialValue();

julia> function fmapreduce(f, op, x; init = INIT, walk = (f, x) -> foreach(f, Functors.children(x)), kw...)
         fmap(x; walk, kw...) do y
           init = init===INIT ? f(y) : op(init, f(y))
         end
         init===INIT ? Base.mapreduce_empty(f, op) : init
       end
fmapreduce (generic function with 1 method)

julia> m = ([1,2], (x=[3,4], y=5), 6);

julia> fmapreduce(sum, +, m)
21

julia> gradient(fmapreduce, sum, +, m)
(nothing, nothing, nothing)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant