From 6b3b4811556852457c9345512b113ab5c1c4e796 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Fri, 27 Mar 2020 10:28:37 +1100 Subject: [PATCH 01/10] fix broadcasting, vcat, hcat, cat, fill --- Project.toml | 2 + src/reversediff.jl | 221 +----------------------- src/reversediffx.jl | 412 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 418 insertions(+), 217 deletions(-) create mode 100644 src/reversediffx.jl diff --git a/Project.toml b/Project.toml index e77df306..bd844a03 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -29,6 +30,7 @@ Distributions = "0.22, 0.23" FillArrays = "0.8" FiniteDifferences = "0.9" ForwardDiff = "0.10.6" +MacroTools = "0.5" NaNMath = "0.3" PDMats = "0.9" ReverseDiff = "1.1" diff --git a/src/reversediff.jl b/src/reversediff.jl index f9bfb83f..b949b05f 100644 --- a/src/reversediff.jl +++ b/src/reversediff.jl @@ -1,17 +1,8 @@ -const RTR = ReverseDiff.TrackedReal -const RTV = ReverseDiff.TrackedVector -const RTM = ReverseDiff.TrackedMatrix -const RTA = ReverseDiff.TrackedArray -using ReverseDiff: SpecialInstruction -import NaNMath -using ForwardDiff: Dual -import SpecialFunctions: logbeta -import Distributions: Gamma +include("reversediffx.jl") -Base.:*(A::Adjoint{<:Real, <:RTV{<:Real}}, B::AbstractVector{<:Real}) = dot(A, B) -Base.:*(A::Adjoint{<:Real, <:RTV{<:Real}}, B::RTV{<:Real}) = dot(A, B) -Base.:*(A::AbstractVector{<:Real}, B::Adjoint{<:Real, <:RTV{<:Real}}) = dot(A, B) -Base.:*(A::RTV{<:Real}, B::Adjoint{<:Real, <:RTV{<:Real}}) = dot(A, B) +import Distributions: Gamma +using .ReverseDiffX +using .ReverseDiffX: RTR, RTV, RTM Gamma(α::RTR, θ::Real; check_args=true) = pgamma(α, θ, check_args = check_args) Gamma(α::Real, θ::RTR; check_args=true) = pgamma(α, θ, check_args = check_args) @@ -181,39 +172,6 @@ end # zero mean,, constant variance MvLogNormal(d::Int, σ::RTR) = TuringMvLogNormal(TuringMvNormal(d, σ)) -function LinearAlgebra.cholesky(A::RTM; check=true) - factors, info = turing_chol(A, check) - return Cholesky{eltype(factors), typeof(factors)}(factors, 'U', info) -end - -function turing_chol(x::ReverseDiff.TrackedArray{V,D}, check) where {V,D} - tp = ReverseDiff.tape(x) - x_value = ReverseDiff.value(x) - check_value = ReverseDiff.value(check) - C, back = pullback(_turing_chol, x_value, check_value) - out = ReverseDiff.track(C.factors, D, tp) - ReverseDiff.record!(tp, SpecialInstruction, turing_chol, (x, check), out, (back, issuccess(C))) - return out, C.info -end - -@noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(turing_chol)}) - output = instruction.output - instruction.cache[2] || throw(PosDefException(C.info)) - input = instruction.input - input_deriv = ReverseDiff.deriv(input[1]) - P = instruction.cache[1] - input_deriv .+= P((factors = ReverseDiff.deriv(output),))[1] - ReverseDiff.unseed!(output) - return nothing -end - -@noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(turing_chol)}) - output, input = instruction.output, instruction.input - C = cholesky(ReverseDiff.value(input[1]), check = ReverseDiff.value(input[2])) - ReverseDiff.value!(output, C.factors) - return nothing -end - Distributions.Dirichlet(alpha::RTV) = TuringDirichlet(alpha) Distributions.Dirichlet(d::Integer, alpha::RTR) = TuringDirichlet(d, alpha) @@ -244,174 +202,3 @@ end function Distributions.logpdf(d::InverseWishart, X::AbstractArray{<:RTM}) return logpdf(TuringInverseWishart(d), X) end - -# Modified from Tracker.jl - -Base.vcat(xs::RTM...) = _vcat(xs...) -Base.vcat(xs::RTV...) = _vcat(xs...) -function _vcat(xs::Union{RTV{<:Any, D}, RTM{<:Any, D}}...) where {D} - tp = ReverseDiff.tape(xs...) - xs_value = ReverseDiff.value.(xs) - out_value = vcat(xs_value...) - function back(Δ) - start = 0 - Δs = [begin - x = map(_ -> :, size(xsi)) - i = isempty(x) ? x : Base.tail(x) - d = Δ[start+1:start+size(xsi,1), i...] - start += size(xsi, 1) - d - end for xsi in xs] - return (Δs...,) - end - out = ReverseDiff.track(out_value, D, tp) - ReverseDiff.record!(tp, SpecialInstruction, vcat, xs, out, (back,)) - return out -end - -@noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(vcat)}) - output = instruction.output - input = instruction.input - input_derivs = ReverseDiff.deriv.(input) - P = instruction.cache[1] - jtvs = P(ReverseDiff.deriv(output)) - for i in 1:length(input_derivs) - input_derivs[i] .+= jtvs[i] - end - ReverseDiff.unseed!(output) - return nothing -end - -@noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(vcat)}) - output, input = instruction.output, instruction.input - out_value = vcat(ReverseDiff.value.(input)...) - ReverseDiff.value!(output, out_value) - return nothing -end - -Base.hcat(xs::RTM...) = _hcat(xs...) -Base.hcat(xs::RTV...) = _hcat(xs...) -function _hcat(xs::Union{RTV{<:Any, D}, RTM{<:Any, D}}...) where {D} - tp = ReverseDiff.tape(xs...) - xs_value = ReverseDiff.value.(xs) - out_value = hcat(xs_value...) - function back(Δ) - start = 0 - Δs = [begin - d = if ndims(xsi) == 1 - Δ[:, start+1] - else - i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail - Δ[:, start+1:start+size(xsi,2), i...] - end - start += size(xsi, 2) - d - end for xsi in xs] - return (Δs...,) - end - out = ReverseDiff.track(out_value, D, tp) - ReverseDiff.record!(tp, SpecialInstruction, hcat, xs, out, (back,)) - return out -end - -@noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(hcat)}) - output = instruction.output - input = instruction.input - input_derivs = ReverseDiff.deriv.(input) - P = instruction.cache[1] - jtvs = P(ReverseDiff.deriv(output)) - for i in 1:length(input_derivs) - input_derivs[i] .+= jtvs[i] - end - ReverseDiff.unseed!(output) - return nothing -end - -@noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(hcat)}) - output, input = instruction.output, instruction.input - out_value = hcat(ReverseDiff.value.(input)...) - ReverseDiff.value!(output, out_value) - return nothing -end - -Base.cat(Xs::RTA...; dims) = _cat(dims, Xs...) -Base.cat(Xs::RTV...; dims) = _cat(dims, Xs...) -function _cat(dims, Xs::Union{RTV{<:Any, D}, RTM{<:Any, D}}...) where {D} - tp = ReverseDiff.tape(dims, Xs...) - Xs_value = ReverseDiff.value.(Xs) - out_value = cat(Xs_value...; dims = dims) - function back(Δ) - start = ntuple(i -> 0, Val(ndims(Δ))) - Δs = [begin - dim_xs = 1:ndims(xs) - till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ))) - xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ))) - d = reshape(Δ[xs_in_Δ...],size(xs)) - start = start .+ till_xs - d - end for xs in Xs] - return (Δs...,) - end - out = ReverseDiff.track(out_value, D, tp) - ReverseDiff.record!(tp, SpecialInstruction, cat, (dims, Xs...), out, (back,)) - return out -end - -@noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(cat)}) - output = instruction.output - input = instruction.input - input_derivs = ReverseDiff.deriv.(Base.tail(input)) - P = instruction.cache[1] - jtvs = P(ReverseDiff.deriv(output)) - for i in 1:length(jtvs) - input_derivs[i] .+= jtvs[i] - end - ReverseDiff.unseed!(output) - return nothing -end - -@noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(cat)}) - output, input = instruction.output, instruction.input - dims = ReverseDiff.value(input[1]) - Xs = ReverseDiff.value.(Base.tail(input)) - out_value = cat(Xs..., dims = dims) - ReverseDiff.value!(output, out_value) - return nothing -end - -########### - -# Broadcasting - -using ReverseDiff: ForwardOptimize -using Base.Broadcast: Broadcasted -import Base.Broadcast: materialize -const RDBroadcasted{F, T} = Broadcasted{<:Any, <:Any, F, T} - -_materialize(f, args) = broadcast(ForwardOptimize(f), args...) - -for (M, f, arity) in ReverseDiff.DiffRules.diffrules() - if arity == 1 - @eval @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA}}) = _materialize(bc.f, bc.args) - elseif arity == 2 - @eval begin - @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA, RTA}}) = _materialize(bc.f, bc.args) - @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA, RTR}}) = _materialize(bc.f, bc.args) - @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTR, RTA}}) = _materialize(bc.f, bc.args) - end - for A in ReverseDiff.ARRAY_TYPES - @eval begin - @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A, RTA}}) = _materialize(bc.f, bc.args) - @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA, $A}}) = _materialize(bc.f, bc.args) - @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A, RTR}}) = _materialize(bc.f, bc.args) - @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTR, $A}}) = _materialize(bc.f, bc.args) - end - end - for R in ReverseDiff.REAL_TYPES - @eval begin - @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$R, RTA}}) = _materialize(bc.f, bc.args) - @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA, $R}}) = _materialize(bc.f, bc.args) - end - end - end -end diff --git a/src/reversediffx.jl b/src/reversediffx.jl new file mode 100644 index 00000000..12325c62 --- /dev/null +++ b/src/reversediffx.jl @@ -0,0 +1,412 @@ +module ReverseDiffX + +# A lot of this module is adapted from Tracker.jl. +# ReverseDiff.jl is not actively developed but it would be nice to move the code in this +# module to ReverseDiff at some point. + +export NotTracked + +using MacroTools, LinearAlgebra +using ForwardDiff: Dual +import SpecialFunctions, NaNMath, Zygote +using ..ReverseDiff +const RTR = ReverseDiff.TrackedReal +const RTV = ReverseDiff.TrackedVector +const RTM = ReverseDiff.TrackedMatrix +const RTA = ReverseDiff.TrackedArray +using ..ReverseDiff: SpecialInstruction +using ..DistributionsAD: DistributionsAD, _turing_chol +import ..DistributionsAD: turing_chol +using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted + +""" + f(x) = dot(x, x) + f(x::ReverseDiff.TrackedVector) = ReverseDiff.track(f, x) + ReverseDiff.@grad function f(x) + xv = ReverseDiff.value(x) + return dot(xv, xv), ∇ -> (∇ * 2 * xv,) + end +The `@grad` macro provides a way for the users to define custom adjoints for single-output functions wrt to their input numbers or arrays. +""" +macro grad(expr) + if @capture(expr, + (f_(xs__) where {T__} = body_) | + (f_(xs__) = body_) | + (function f_(xs__) body_ end) | + (function f_(xs__) where {T__} body_ end) + ) + closure = gensym(:f) + tp = gensym(:tp) + output_value = gensym(:output_value) + output = gensym(:output) + back = gensym(:back) + args = gensym(:args) + xsv = getargs_expr(xs) + T = T == nothing ? [] : T + return quote + function ReverseDiff.track(::typeof($f), $(xs...)) where {$(T...),} + $args = $xsv + $closure = ($(xs...),) -> $body + $tp = ReverseDiff.tape($args...) + $output_value, $back = $closure($args...) + $output = ReverseDiff.track($output_value, $tp) + ReverseDiff.record!( + $tp, + ReverseDiff.SpecialInstruction, + $f, + $args, + $output, + ($back, $closure), + ) + return $output + end + + @static if !hasmethod( + ReverseDiff.special_reverse_exec!, + Tuple{ReverseDiff.SpecialInstruction{typeof($f)}}, + ) + @noinline function ReverseDiff.special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f)}) + output = instruction.output + input = instruction.input + back = instruction.cache[1] + input_derivs = back(ReverseDiff.deriv(output)) + @assert input_derivs isa Tuple + ReverseDiff.add_to_deriv!.(input, input_derivs) + ReverseDiff.unseed!(output) + return nothing + end + end + + @static if !hasmethod( + ReverseDiff.special_forward_exec!, + Tuple{ReverseDiff.SpecialInstruction{typeof($f)}}, + ) + @noinline function ReverseDiff.special_forward_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f)}) + output, input = instruction.output, instruction.input + pullback = instruction.cache[2] + out_value = pullback(input...)[1] + ReverseDiff.value!(output, out_value) + return nothing + end + end + end |> esc + else + throw("Invalid `ReverseDiff` custom gradient definition.") + end +end +add_to_deriv!(d1, d2) = nothing +function add_to_deriv!(d1::Union{RTR, RTA}, d2) + d = ReverseDiff.deriv(d1) + d .+= d2 +end +function getargs_expr(args_with_types) + expr = Expr(:tuple) + for at in args_with_types + x, tosplat = remove_tp(at) + if tosplat + push!(expr.args, :($x...)) + else + push!(expr.args, x) + end + end + return expr +end +function remove_tp(t) + if @capture(t, X_::T_...) + return X, true + elseif @capture(t, X_::T_) + return X, false + elseif @capture(t, ::typeof(T_)...) + return T, true + elseif @capture(t, ::typeof(T_)) + return T, false + elseif @capture(t, X_...) + return X, true + else + return t, false + end +end + +_fill(v::Real, dims::Vararg{Union{Integer, AbstractUnitRange}}) = fill(v[], dims...) +Base.fill(v::RTR, dims::Vararg{Union{Integer, AbstractUnitRange}}) = _fill(Ref(v), dims...) +function _fill( + value::Base.RefValue{<:RTR}, + dims::Vararg{Union{Integer, AbstractUnitRange}}, +) + return ReverseDiff.track(_fill, value, dims...) +end +@grad function _fill(value::Base.RefValue{<:Real}, dims...) + return fill(ReverseDiff.value(value[]), dims...), function(Δ) + size(Δ) ≢ dims && error("Dimension mismatch") + return (sum(Δ), map(_->nothing, dims)...) + end +end + +Base.:*(A::Adjoint{<:Real, <:RTV{<:Real}}, B::AbstractVector{<:Real}) = dot(A, B) +Base.:*(A::Adjoint{<:Real, <:RTV{<:Real}}, B::RTV{<:Real}) = dot(A, B) +Base.:*(A::AbstractVector{<:Real}, B::Adjoint{<:Real, <:RTV{<:Real}}) = dot(A, B) +Base.:*(A::RTV{<:Real}, B::Adjoint{<:Real, <:RTV{<:Real}}) = dot(A, B) + +function LinearAlgebra.cholesky(A::RTM; check=true) + factors, info = turing_chol(A, check) + return Cholesky{eltype(factors), typeof(factors)}(factors, 'U', info) +end + +function turing_chol(x::ReverseDiff.TrackedArray{V,D}, check) where {V,D} + tp = ReverseDiff.tape(x) + x_value = ReverseDiff.value(x) + check_value = ReverseDiff.value(check) + C, back = Zygote.pullback(_turing_chol, x_value, check_value) + out = ReverseDiff.track(C.factors, D, tp) + ReverseDiff.record!(tp, SpecialInstruction, turing_chol, (x, check), out, (back, issuccess(C))) + return out, C.info +end + +@noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(turing_chol)}) + output = instruction.output + instruction.cache[2] || throw(PosDefException(C.info)) + input = instruction.input + input_deriv = ReverseDiff.deriv(input[1]) + P = instruction.cache[1] + input_deriv .+= P((factors = ReverseDiff.deriv(output),))[1] + ReverseDiff.unseed!(output) + return nothing +end + +@noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(turing_chol)}) + output, input = instruction.output, instruction.input + C = cholesky(ReverseDiff.value(input[1]), check = ReverseDiff.value(input[2])) + ReverseDiff.value!(output, C.factors) + return nothing +end + +# Modified from Tracker.jl + +Base.vcat(xs::RTM...) = ReverseDiff.track(vcat, xs...) +Base.vcat(xs::RTV...) = ReverseDiff.track(vcat, xs...) +@grad function vcat(xs::Union{RTV, RTM}...) + xs_value = ReverseDiff.value.(xs) + out_value = vcat(xs_value...) + function back(Δ) + start = 0 + Δs = [begin + x = map(_ -> :, size(xsi)) + i = isempty(x) ? x : Base.tail(x) + d = Δ[start+1:start+size(xsi,1), i...] + start += size(xsi, 1) + d + end for xsi in xs] + return (Δs...,) + end + return out_value, back +end + +Base.hcat(xs::RTM...) = ReverseDiff.track(hcat, xs...) +Base.hcat(xs::RTV...) = ReverseDiff.track(hcat, xs...) +@grad function hcat(xs::Union{RTV, RTM}...) + xs_value = ReverseDiff.value.(xs) + out_value = hcat(xs_value...) + function back(Δ) + start = 0 + Δs = [begin + d = if ndims(xsi) == 1 + Δ[:, start+1] + else + i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail + Δ[:, start+1:start+size(xsi,2), i...] + end + start += size(xsi, 2) + d + end for xsi in xs] + return (Δs...,) + end + return out_value, back +end + +Base.cat(Xs::RTA...; dims) = _cat(dims, Xs...) +Base.cat(Xs::RTV...; dims) = _cat(dims, Xs...) +function _cat(dims, Xs::Union{RTV{<:Any, D}, RTM{<:Any, D}}...) where {D} + tp = ReverseDiff.tape(dims, Xs...) + Xs_value = ReverseDiff.value.(Xs) + out_value = cat(Xs_value...; dims = dims) + function back(Δ) + start = ntuple(i -> 0, Val(ndims(Δ))) + Δs = [begin + dim_xs = 1:ndims(xs) + till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ))) + xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ))) + d = reshape(Δ[xs_in_Δ...],size(xs)) + start = start .+ till_xs + d + end for xs in Xs] + return (Δs...,) + end + out = ReverseDiff.track(out_value, D, tp) + ReverseDiff.record!(tp, SpecialInstruction, cat, (dims, Xs...), out, (back,)) + return out +end + +@noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(cat)}) + output = instruction.output + input = instruction.input + input_derivs = ReverseDiff.deriv.(Base.tail(input)) + P = instruction.cache[1] + jtvs = P(ReverseDiff.deriv(output)) + for i in 1:length(jtvs) + input_derivs[i] .+= jtvs[i] + end + ReverseDiff.unseed!(output) + return nothing +end + +@noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(cat)}) + output, input = instruction.output, instruction.input + dims = ReverseDiff.value(input[1]) + Xs = ReverseDiff.value.(Base.tail(input)) + out_value = cat(Xs..., dims = dims) + ReverseDiff.value!(output, out_value) + return nothing +end + +########### + +# Broadcasting + +using ForwardDiff: Dual, partials + +trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) + +unbroadcast(x::AbstractArray, Δ) = + size(x) == size(Δ) ? Δ : + length(x) == length(Δ) ? trim(x, Δ) : + trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) + +unbroadcast(x::Number, Δ) = sum(Δ) +unbroadcast(x::Base.RefValue, _) = nothing + +dual(x, p) = x +dual(x::Real, p) = Dual(x, p) + +function partial(f, Δ, i, args::Vararg{Any,N}) where {N} + dargs = ntuple(j -> dual(args[j], i==j), Val(N)) + return Δ * f(dargs...).partials[1] +end + +isclosure(::Any) = false +@generated isclosure(::F) where {F <: Function} = :($(fieldcount(F) > 0)) +hasclosure(b) = isclosure(b) +hasclosure(b::Broadcasted) = isclosure(b.f) || any(hasclosure, b.args) + +""" + NotTracked(f::Function) + +A callable struct that can be used to wrap around closures declaring that they are not closures of tracked variables. This enables the broadcasting of such functions producing a `TrackedArray` instead of an `Array{<:TrackedReal}`. +""" +struct NotTracked{F <: Function} <: Function + f::F +end +(f::NotTracked)(args...; kwargs...) = f.f(args...; kwargs...) + +@inline maybetrackedclosure(f) = false +@inline maybetrackedclosure(f::NotTracked) = false +@inline maybetrackedclosure(f::Function) = isclosure(f) +@inline mayhavetrackedclosure(b) = false +@inline mayhavetrackedclosure(b::Broadcasted) = maybetrackedclosure(b.f) || + any(mayhavetrackedclosure, b.args) + +@inline function ∇broadcast(untracked_bc, fallback_style, axes, f::F, args::Vararg{<:Any,N}) where {F, N} + y = Base.materialize(untracked_bc) + tp = ReverseDiff.tape(f, args...) + eltype(y) <: Real || return copy(Broadcasted{fallback_style, typeof(axes), typeof(f), typeof(args)}(f, args, axes)) + eltype(y) == Bool && return y + function back(Δ) + Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N)) + dxs = map(unbroadcast, args, Δargs) + return dxs + end + out = ReverseDiff.track(y, tp) + _args = map(args) do a + a isa Number && return Ref(a) + return a + end + ReverseDiff.record!(tp, ReverseDiff.SpecialInstruction, ∇broadcast, _args, out, (back, untracked_bc)) + return out +end +@noinline function ReverseDiff.special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(∇broadcast)}) + output = instruction.output + input = instruction.input + back = instruction.cache[1] + input_derivs = back(ReverseDiff.deriv(output)) + @assert input_derivs isa Tuple + ReverseDiff.add_to_deriv!.(input, input_derivs) + ReverseDiff.unseed!(output) + return nothing +end +@noinline function ReverseDiff.special_forward_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(∇broadcast)}) + output, input = instruction.output, instruction.input + bc = instruction.cache[2] + out_value = Base.materialize(bc) + ReverseDiff.value!(output, out_value) + return nothing +end + +struct TrackedStyle <: BroadcastStyle end + +Broadcast.BroadcastStyle(::Type{<:Union{RTA, RTR}}) = TrackedStyle() +Broadcast.BroadcastStyle(::TrackedStyle, b::BroadcastStyle) = TrackedStyle() + +# We have to re-build the original broadcast struct to get the appropriate array +# style. We need this primarily to support CuArrays' broadcasting fixes. +broadcast_rebuild(xs) = ReverseDiff.value(xs) +function broadcast_rebuild(bc::Broadcasted) + broadcasted(bc.f, broadcast_rebuild.(bc.args)...) +end +preprocess(x) = x + +getstyle(::Broadcasted{Style}) where {Style} = Style +function Base.copy(bc::Broadcasted{TrackedStyle}) + bc1 = Broadcast.flatten(bc) + untracked_bc = broadcast_rebuild(bc) + bc2 = Broadcast.flatten(untracked_bc) + style = getstyle(bc2) + axes = bc1.axes + f, args = bc2.f, bc1.args + T = Core.Compiler.return_type(f, Tuple{eltype.(args)...}) + maybereal = T <: Real || T >: Real + if hasclosure(bc) && mayhavetrackedclosure(bc) || !maybereal + return copy(Broadcasted{style, typeof(axes), typeof(f), typeof(args)}(f, args, axes)) + else + return ∇broadcast(untracked_bc, style, axes, f, args...) + end +end + +# https://github.com/FluxML/Flux.jl/issues/353 +if VERSION < v"1.1.0-DEV.548" + @eval Base.Broadcast begin + function flatten(bc::Broadcasted{Style}) where {Style} + isflat(bc) && return bc + args = cat_nested(bc) + let makeargs = make_makeargs(bc), f = bc.f + newf = @inline function(args::Vararg{Any,N}) where N + f(makeargs(args...)...) + end + return Broadcasted{Style}(newf, args, bc.axes) + end + end + @inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}}) + bc = t[1] + let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f + let makeargs = make_makeargs(makeargs, bc.args) + headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args) + return @inline function(args::Vararg{Any,N}) where N + args1 = makeargs(args...) + a, b = headargs(args1...), tailargs(args1...) + (f(a...), b...) + end + end + end + end + end +end + +end From 5b0b6a7e89208c6eb464218100c0e49f806bfd04 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Fri, 27 Mar 2020 21:11:49 +1100 Subject: [PATCH 02/10] make ReverseDiff broadcasting fast --- Project.toml | 2 + src/reversediffx.jl | 238 ++++++++++++++++++++++++++------------------ 2 files changed, 143 insertions(+), 97 deletions(-) diff --git a/Project.toml b/Project.toml index bd844a03..31cd93f9 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -36,6 +37,7 @@ PDMats = "0.9" ReverseDiff = "1.1" SpecialFunctions = "0.8, 0.9, 0.10" StatsBase = "0.32, 0.33" +StaticArrays = "0.12" StatsFuns = "0.8, 0.9" Tracker = "0.2.5" Zygote = "0.4.10" diff --git a/src/reversediffx.jl b/src/reversediffx.jl index 12325c62..8e2d2d2a 100644 --- a/src/reversediffx.jl +++ b/src/reversediffx.jl @@ -7,14 +7,13 @@ module ReverseDiffX export NotTracked using MacroTools, LinearAlgebra -using ForwardDiff: Dual import SpecialFunctions, NaNMath, Zygote using ..ReverseDiff const RTR = ReverseDiff.TrackedReal const RTV = ReverseDiff.TrackedVector const RTM = ReverseDiff.TrackedMatrix const RTA = ReverseDiff.TrackedArray -using ..ReverseDiff: SpecialInstruction +using ..ReverseDiff: SpecialInstruction, value, value!, deriv, track, record!, tape, unseed! using ..DistributionsAD: DistributionsAD, _turing_chol import ..DistributionsAD: turing_chol using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted @@ -96,7 +95,7 @@ macro grad(expr) end add_to_deriv!(d1, d2) = nothing function add_to_deriv!(d1::Union{RTR, RTA}, d2) - d = ReverseDiff.deriv(d1) + d = deriv(d1) d .+= d2 end function getargs_expr(args_with_types) @@ -133,10 +132,10 @@ function _fill( value::Base.RefValue{<:RTR}, dims::Vararg{Union{Integer, AbstractUnitRange}}, ) - return ReverseDiff.track(_fill, value, dims...) + return track(_fill, value, dims...) end -@grad function _fill(value::Base.RefValue{<:Real}, dims...) - return fill(ReverseDiff.value(value[]), dims...), function(Δ) +@grad function _fill(v::Base.RefValue{<:Real}, dims...) + return fill(value(v[]), dims...), function(Δ) size(Δ) ≢ dims && error("Dimension mismatch") return (sum(Δ), map(_->nothing, dims)...) end @@ -152,13 +151,13 @@ function LinearAlgebra.cholesky(A::RTM; check=true) return Cholesky{eltype(factors), typeof(factors)}(factors, 'U', info) end -function turing_chol(x::ReverseDiff.TrackedArray{V,D}, check) where {V,D} - tp = ReverseDiff.tape(x) - x_value = ReverseDiff.value(x) - check_value = ReverseDiff.value(check) +function turing_chol(x::RTA{V,D}, check) where {V,D} + tp = tape(x) + x_value = value(x) + check_value = value(check) C, back = Zygote.pullback(_turing_chol, x_value, check_value) - out = ReverseDiff.track(C.factors, D, tp) - ReverseDiff.record!(tp, SpecialInstruction, turing_chol, (x, check), out, (back, issuccess(C))) + out = track(C.factors, D, tp) + record!(tp, SpecialInstruction, turing_chol, (x, check), out, (back, issuccess(C))) return out, C.info end @@ -166,26 +165,26 @@ end output = instruction.output instruction.cache[2] || throw(PosDefException(C.info)) input = instruction.input - input_deriv = ReverseDiff.deriv(input[1]) + input_deriv = deriv(input[1]) P = instruction.cache[1] - input_deriv .+= P((factors = ReverseDiff.deriv(output),))[1] - ReverseDiff.unseed!(output) + input_deriv .+= P((factors = deriv(output),))[1] + unseed!(output) return nothing end @noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(turing_chol)}) output, input = instruction.output, instruction.input - C = cholesky(ReverseDiff.value(input[1]), check = ReverseDiff.value(input[2])) - ReverseDiff.value!(output, C.factors) + C = cholesky(value(input[1]), check = value(input[2])) + value!(output, C.factors) return nothing end # Modified from Tracker.jl -Base.vcat(xs::RTM...) = ReverseDiff.track(vcat, xs...) -Base.vcat(xs::RTV...) = ReverseDiff.track(vcat, xs...) +Base.vcat(xs::RTM...) = track(vcat, xs...) +Base.vcat(xs::RTV...) = track(vcat, xs...) @grad function vcat(xs::Union{RTV, RTM}...) - xs_value = ReverseDiff.value.(xs) + xs_value = value.(xs) out_value = vcat(xs_value...) function back(Δ) start = 0 @@ -201,10 +200,10 @@ Base.vcat(xs::RTV...) = ReverseDiff.track(vcat, xs...) return out_value, back end -Base.hcat(xs::RTM...) = ReverseDiff.track(hcat, xs...) -Base.hcat(xs::RTV...) = ReverseDiff.track(hcat, xs...) +Base.hcat(xs::RTM...) = track(hcat, xs...) +Base.hcat(xs::RTV...) = track(hcat, xs...) @grad function hcat(xs::Union{RTV, RTM}...) - xs_value = ReverseDiff.value.(xs) + xs_value = value.(xs) out_value = hcat(xs_value...) function back(Δ) start = 0 @@ -226,8 +225,8 @@ end Base.cat(Xs::RTA...; dims) = _cat(dims, Xs...) Base.cat(Xs::RTV...; dims) = _cat(dims, Xs...) function _cat(dims, Xs::Union{RTV{<:Any, D}, RTM{<:Any, D}}...) where {D} - tp = ReverseDiff.tape(dims, Xs...) - Xs_value = ReverseDiff.value.(Xs) + tp = tape(dims, Xs...) + Xs_value = value.(Xs) out_value = cat(Xs_value...; dims = dims) function back(Δ) start = ntuple(i -> 0, Val(ndims(Δ))) @@ -241,56 +240,39 @@ function _cat(dims, Xs::Union{RTV{<:Any, D}, RTM{<:Any, D}}...) where {D} end for xs in Xs] return (Δs...,) end - out = ReverseDiff.track(out_value, D, tp) - ReverseDiff.record!(tp, SpecialInstruction, cat, (dims, Xs...), out, (back,)) + out = track(out_value, D, tp) + record!(tp, SpecialInstruction, cat, (dims, Xs...), out, (back,)) return out end @noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(cat)}) output = instruction.output input = instruction.input - input_derivs = ReverseDiff.deriv.(Base.tail(input)) + input_derivs = deriv.(Base.tail(input)) P = instruction.cache[1] - jtvs = P(ReverseDiff.deriv(output)) + jtvs = P(deriv(output)) for i in 1:length(jtvs) input_derivs[i] .+= jtvs[i] end - ReverseDiff.unseed!(output) + unseed!(output) return nothing end @noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(cat)}) output, input = instruction.output, instruction.input - dims = ReverseDiff.value(input[1]) - Xs = ReverseDiff.value.(Base.tail(input)) + dims = input[1] + Xs = value.(Base.tail(input)) out_value = cat(Xs..., dims = dims) - ReverseDiff.value!(output, out_value) + value!(output, out_value) return nothing end -########### +################ +# Broadcasting # +################ -# Broadcasting - -using ForwardDiff: Dual, partials - -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) - -unbroadcast(x::AbstractArray, Δ) = - size(x) == size(Δ) ? Δ : - length(x) == length(Δ) ? trim(x, Δ) : - trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) - -unbroadcast(x::Number, Δ) = sum(Δ) -unbroadcast(x::Base.RefValue, _) = nothing - -dual(x, p) = x -dual(x::Real, p) = Dual(x, p) - -function partial(f, Δ, i, args::Vararg{Any,N}) where {N} - dargs = ntuple(j -> dual(args[j], i==j), Val(N)) - return Δ * f(dargs...).partials[1] -end +using StaticArrays +using ForwardDiff isclosure(::Any) = false @generated isclosure(::F) where {F <: Function} = :($(fieldcount(F) > 0)) @@ -314,42 +296,6 @@ end @inline mayhavetrackedclosure(b::Broadcasted) = maybetrackedclosure(b.f) || any(mayhavetrackedclosure, b.args) -@inline function ∇broadcast(untracked_bc, fallback_style, axes, f::F, args::Vararg{<:Any,N}) where {F, N} - y = Base.materialize(untracked_bc) - tp = ReverseDiff.tape(f, args...) - eltype(y) <: Real || return copy(Broadcasted{fallback_style, typeof(axes), typeof(f), typeof(args)}(f, args, axes)) - eltype(y) == Bool && return y - function back(Δ) - Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N)) - dxs = map(unbroadcast, args, Δargs) - return dxs - end - out = ReverseDiff.track(y, tp) - _args = map(args) do a - a isa Number && return Ref(a) - return a - end - ReverseDiff.record!(tp, ReverseDiff.SpecialInstruction, ∇broadcast, _args, out, (back, untracked_bc)) - return out -end -@noinline function ReverseDiff.special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(∇broadcast)}) - output = instruction.output - input = instruction.input - back = instruction.cache[1] - input_derivs = back(ReverseDiff.deriv(output)) - @assert input_derivs isa Tuple - ReverseDiff.add_to_deriv!.(input, input_derivs) - ReverseDiff.unseed!(output) - return nothing -end -@noinline function ReverseDiff.special_forward_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(∇broadcast)}) - output, input = instruction.output, instruction.input - bc = instruction.cache[2] - out_value = Base.materialize(bc) - ReverseDiff.value!(output, out_value) - return nothing -end - struct TrackedStyle <: BroadcastStyle end Broadcast.BroadcastStyle(::Type{<:Union{RTA, RTR}}) = TrackedStyle() @@ -357,7 +303,7 @@ Broadcast.BroadcastStyle(::TrackedStyle, b::BroadcastStyle) = TrackedStyle() # We have to re-build the original broadcast struct to get the appropriate array # style. We need this primarily to support CuArrays' broadcasting fixes. -broadcast_rebuild(xs) = ReverseDiff.value(xs) +broadcast_rebuild(xs) = value(xs) function broadcast_rebuild(bc::Broadcasted) broadcasted(bc.f, broadcast_rebuild.(bc.args)...) end @@ -368,15 +314,14 @@ function Base.copy(bc::Broadcasted{TrackedStyle}) bc1 = Broadcast.flatten(bc) untracked_bc = broadcast_rebuild(bc) bc2 = Broadcast.flatten(untracked_bc) - style = getstyle(bc2) - axes = bc1.axes f, args = bc2.f, bc1.args T = Core.Compiler.return_type(f, Tuple{eltype.(args)...}) - maybereal = T <: Real || T >: Real - if hasclosure(bc) && mayhavetrackedclosure(bc) || !maybereal + isreal = (T <: Real) && (T !== Union{}) + if hasclosure(bc) && mayhavetrackedclosure(bc) || !isreal + style, axes = getstyle(bc2), bc1.axes return copy(Broadcasted{style, typeof(axes), typeof(f), typeof(args)}(f, args, axes)) else - return ∇broadcast(untracked_bc, style, axes, f, args...) + return ∇broadcast(f, args...) end end @@ -409,4 +354,103 @@ if VERSION < v"1.1.0-DEV.548" end end +getouttype(::RTR{<:Any, D}) where {D} = D +getouttype(::RTA{<:Any, D}) where {D} = D +getouttype(::Any) = Union{} + +deref(x) = x +deref(x::Base.RefValue) = x[] + +@generated function splatcall(f, x::SVector{N}, utargs::T, ::Val{tinds}) where {N, T <: Tuple, tinds} + args = [] + ti = 1 + uti = 1 + for i in 1:(N + length(T.types)) + if i in tinds + push!(args, :(deref(x[$ti]))) + ti += 1 + else + push!(args, :(deref(utargs[$uti]))) + uti += 1 + end + end + return quote + $(Expr(:meta, :inline)) + $(Expr(:call, :f, args...)) + end +end + +@generated function splitargs(args::T) where {T <: Tuple} + N = length(T.types) + RealOrArray = Union{Real, AbstractArray} + inds = [i for i in 1:N if T.types[i] <: RealOrArray] + indsval = :(Val{$(Expr(:tuple, [:($i) for i in inds]...))}()) + maybetracked = Expr(:tuple, [:(args[$i]) for i in inds]...) + untracked = Expr(:tuple, [:(args[$i]) for i in 1:N if !(i in inds)]...) + return :($indsval, $maybetracked, $untracked) +end +@inline function ∇broadcast(f::F, args::Vararg{<:Any}) where {F} + inds, targs, untracked = trackedargs(args) + N = length(targs) + D = promote_type(getouttype.(targs)...) + result = DiffResults.GradientResult(zero(SVector{N, D})) + function df(x...) + return ForwardDiff.gradient!( + result, + s -> splatcall(f, s, untracked, inds), + SVector(x), + ) + end + results = broadcast(df, value.(targs)...) + tp = tape(targs...) + out = track(DiffResults.value.(results), D, tp) + cache = (results, df, ReverseDiff.index_bound.(targs, (out,))) + record!(tp, SpecialInstruction, ∇broadcast, targs, out, cache) + return out +end +@noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(∇broadcast)}) + input = instruction.input + output = instruction.output + output_deriv = deriv(output) + results, _, bounds = instruction.cache + N = length(input) + if N == 1 || all(isequal(size(input[1])), size.(Base.tail(input))) + add_to_deriv!(input, output_deriv, results) + else + add_to_deriv!(input, output_deriv, results, bounds) + end + unseed!(output) + return nothing +end + +@generated function add_to_deriv!(xs::T, o, r) where {T <: Tuple} + N = length(T.types) + return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, $i)) for i in 1:N]...) +end +_add_to_deriv!(_, _, _, _) = nothing +function _add_to_deriv!(x::Union{RTR, RTA}, out_deriv, results, i) + return ReverseDiff.istracked(x) && ReverseDiff.diffresult_increment_deriv!(x, out_deriv, results, i) +end + +@generated function add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple} + N = length(T.types) + return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, $i, bounds[$i])) for i in 1:N]...) +end +_add_to_deriv!(_, _, _, _, _) = nothing +function _add_to_deriv!(x::Union{RTR, RTA}, out_deriv, results, i, bound) + return ReverseDiff.istracked(x) && ReverseDiff.diffresult_increment_deriv!(x, out_deriv, results, i, bound) +end + +@noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(∇broadcast)}) + input, output = instruction.input, instruction.output + results, df, _ = instruction.cache + ReverseDiff.pull_value!.(input) + broadcast!(df, results, value.(input)...) + output_value = value(output) + for i in eachindex(output_value) + output_value[i] = DiffResults.value(results[i]) + end + return nothing +end + end From 511767062bed532f8598463d44823b05a34e91ff Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 28 Mar 2020 19:47:12 +1100 Subject: [PATCH 03/10] many performance fixes --- src/reversediffx.jl | 225 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 176 insertions(+), 49 deletions(-) diff --git a/src/reversediffx.jl b/src/reversediffx.jl index 8e2d2d2a..c60296e7 100644 --- a/src/reversediffx.jl +++ b/src/reversediffx.jl @@ -1,22 +1,27 @@ module ReverseDiffX -# A lot of this module is adapted from Tracker.jl. +# A lot of this module is adapted from Tracker.jl and ReverseDiff.jl # ReverseDiff.jl is not actively developed but it would be nice to move the code in this # module to ReverseDiff at some point. export NotTracked -using MacroTools, LinearAlgebra -import SpecialFunctions, NaNMath, Zygote -using ..ReverseDiff +using MacroTools, LinearAlgebra, ..ReverseDiff +import SpecialFunctions, NaNMath, Zygote, StaticArrays + +using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted +using ForwardDiff: ForwardDiff, Dual +using ..ReverseDiff: SpecialInstruction, value, value!, deriv, track, record!, tape, unseed! +using ..DistributionsAD: DistributionsAD, _turing_chol + +import ..DistributionsAD: turing_chol +import Base.Broadcast: materialize + const RTR = ReverseDiff.TrackedReal const RTV = ReverseDiff.TrackedVector const RTM = ReverseDiff.TrackedMatrix const RTA = ReverseDiff.TrackedArray -using ..ReverseDiff: SpecialInstruction, value, value!, deriv, track, record!, tape, unseed! -using ..DistributionsAD: DistributionsAD, _turing_chol -import ..DistributionsAD: turing_chol -using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted +const RDBroadcasted{F, T} = Broadcasted{<:Any, <:Any, F, T} """ f(x) = dot(x, x) @@ -70,7 +75,7 @@ macro grad(expr) back = instruction.cache[1] input_derivs = back(ReverseDiff.deriv(output)) @assert input_derivs isa Tuple - ReverseDiff.add_to_deriv!.(input, input_derivs) + DistributionsAD.ReverseDiffX.add_to_deriv!.(input, input_derivs) ReverseDiff.unseed!(output) return nothing end @@ -82,6 +87,7 @@ macro grad(expr) ) @noinline function ReverseDiff.special_forward_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f)}) output, input = instruction.output, instruction.input + ReverseDiff.pull_value!.(input) pullback = instruction.cache[2] out_value = pullback(input...)[1] ReverseDiff.value!(output, out_value) @@ -95,8 +101,7 @@ macro grad(expr) end add_to_deriv!(d1, d2) = nothing function add_to_deriv!(d1::Union{RTR, RTA}, d2) - d = deriv(d1) - d .+= d2 + ReverseDiff.increment_deriv!(d1, d2) end function getargs_expr(args_with_types) expr = Expr(:tuple) @@ -271,30 +276,27 @@ end # Broadcasting # ################ -using StaticArrays -using ForwardDiff - -isclosure(::Any) = false -@generated isclosure(::F) where {F <: Function} = :($(fieldcount(F) > 0)) -hasclosure(b) = isclosure(b) -hasclosure(b::Broadcasted) = isclosure(b.f) || any(hasclosure, b.args) - """ NotTracked(f::Function) -A callable struct that can be used to wrap around closures declaring that they are not closures of tracked variables. This enables the broadcasting of such functions producing a `TrackedArray` instead of an `Array{<:TrackedReal}`. +A struct that can be used to wrap around closures, structs and arrays of structs declaring that they do not contain tracked variables. This enables a more efficient broadcasting of such functions and structs when doing automatic differentiation with `ReverseDiff` producing a `TrackedArray` instead of an `Array{<:TrackedReal}`. """ -struct NotTracked{F <: Function} <: Function +struct NotTracked{F} <: Function f::F end -(f::NotTracked)(args...; kwargs...) = f.f(args...; kwargs...) +(f::NotTracked{<:Union{Function, Type}})(args...; kwargs...) = f.f(args...; kwargs...) + +istypeorclosure(::F) where {F} = _istypeorclosure(F) +istypeorclosure(::AbstractArray{F}) where {F} = _istypeorclosure(F) +istypeorclosure(::Base.RefValue{F}) where {F} = _istypeorclosure(F) +istypeorclosure(::AbstractArray{<:Real}) = false +istypeorclosure(::Real) = false +@generated _istypeorclosure(::Type{F}) where {F} = :($(fieldcount(F) > 0)) -@inline maybetrackedclosure(f) = false -@inline maybetrackedclosure(f::NotTracked) = false -@inline maybetrackedclosure(f::Function) = isclosure(f) -@inline mayhavetrackedclosure(b) = false -@inline mayhavetrackedclosure(b::Broadcasted) = maybetrackedclosure(b.f) || - any(mayhavetrackedclosure, b.args) +@inline mayhavetracked(b) = istypeorclosure(b) +@inline mayhavetracked(b::NotTracked) = false +@inline mayhavetracked(b::Base.RefValue{<:NotTracked}) = false +@inline mayhavetracked(b::Broadcasted) = mayhavetracked(b.f) || any(mayhavetracked, b.args) struct TrackedStyle <: BroadcastStyle end @@ -307,21 +309,58 @@ broadcast_rebuild(xs) = value(xs) function broadcast_rebuild(bc::Broadcasted) broadcasted(bc.f, broadcast_rebuild.(bc.args)...) end -preprocess(x) = x getstyle(::Broadcasted{Style}) where {Style} = Style -function Base.copy(bc::Broadcasted{TrackedStyle}) - bc1 = Broadcast.flatten(bc) - untracked_bc = broadcast_rebuild(bc) - bc2 = Broadcast.flatten(untracked_bc) - f, args = bc2.f, bc1.args - T = Core.Compiler.return_type(f, Tuple{eltype.(args)...}) - isreal = (T <: Real) && (T !== Union{}) - if hasclosure(bc) && mayhavetrackedclosure(bc) || !isreal - style, axes = getstyle(bc2), bc1.axes - return copy(Broadcasted{style, typeof(axes), typeof(f), typeof(args)}(f, args, axes)) +remove_not_tracked(f) = f +remove_not_tracked(f::NotTracked) = f.f +remove_not_tracked(f::Base.RefValue{<:NotTracked}) = Ref(remove_not_tracked(f[])) +remove_not_tracked(f::Base.RefValue{<:NotTracked{<:AbstractArray}}) = remove_not_tracked(f[]) +function remove_not_tracked(b::Broadcasted{style}) where {style} + return Broadcasted{style}(remove_not_tracked(b.f), remove_not_tracked.(b.args), b.axes) +end + +@generated function onlyrealarrays(args::T) where {T <: Tuple} + o = all(map(x -> (x <: AbstractArray{<:Real} || !(x <: AbstractArray)), T.types)) + return :($o) +end +@generated function anyreals(args::T) where {T <: Tuple} + o = any(map(x -> x <: Real, T.types)) + return :($o) +end + +function get_implementation(bc, f, T, args) + outputisreal = (T <: AbstractArray{<:Real}) && (T !== Union{}) + # Any arg is a real number or an array of untracked non-reals, + # Output is real, and + # No tracked closure or arguments, except TrackedReal and TrackedArray. + if !mayhavetracked(bc) && outputisreal && (anyreals(args) || !onlyrealarrays(args)) + return Val(:tracker) + # No arg is a real number and array args must be arrays of reals, + # Output is real, and + # No tracked closure or arguments, except TrackedReal and TrackedArray. + elseif !mayhavetracked(bc) && outputisreal + return Val(:reversediff) + # Function or any arg is possibly a tracked non-real or array of tracked non-reals, + # Or output is not an array of reals else + return Val(:fallback) + end +end +function Base.copy(_bc::Broadcasted{TrackedStyle}) + bc = remove_not_tracked(_bc) + flattened_bc = Broadcast.flatten(bc) + untracked_bc = broadcast_rebuild(bc) + flattened_untracked_bc = Broadcast.flatten(untracked_bc) + T = Core.Compiler.return_type(copy, Tuple{typeof(untracked_bc)}) + f, args = flattened_untracked_bc.f, flattened_bc.args + implementation = get_implementation(_bc, f, T, args) + if implementation isa Val{:reversediff} return ∇broadcast(f, args...) + elseif implementation isa Val{:tracker} + return tracker_∇broadcast(f, args...) + else + style, axes = getstyle(flattened_untracked_bc), flattened_bc.axes + return copy(Broadcasted{style, typeof(axes), typeof(f), typeof(args)}(f, args, axes)) end end @@ -389,8 +428,11 @@ end untracked = Expr(:tuple, [:(args[$i]) for i in 1:N if !(i in inds)]...) return :($indsval, $maybetracked, $untracked) end + +## A generalization of the broadcasting approach in ReverseDiff for general functions + @inline function ∇broadcast(f::F, args::Vararg{<:Any}) where {F} - inds, targs, untracked = trackedargs(args) + inds, targs, untracked = splitargs(args) N = length(targs) D = promote_type(getouttype.(targs)...) result = DiffResults.GradientResult(zero(SVector{N, D})) @@ -403,7 +445,9 @@ end end results = broadcast(df, value.(targs)...) tp = tape(targs...) - out = track(DiffResults.value.(results), D, tp) + out_value = DiffResults.value.(results) + eltype(out_value) == Bool && return out_value + out = track(out_value, D, tp) cache = (results, df, ReverseDiff.index_bound.(targs, (out,))) record!(tp, SpecialInstruction, ∇broadcast, targs, out, cache) return out @@ -425,32 +469,115 @@ end @generated function add_to_deriv!(xs::T, o, r) where {T <: Tuple} N = length(T.types) - return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, $i)) for i in 1:N]...) + return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i))) for i in 1:N]...) end _add_to_deriv!(_, _, _, _) = nothing -function _add_to_deriv!(x::Union{RTR, RTA}, out_deriv, results, i) +function _add_to_deriv!(x::Union{RTR, RTA}, out_deriv, results, ::Val{i}) where {i} return ReverseDiff.istracked(x) && ReverseDiff.diffresult_increment_deriv!(x, out_deriv, results, i) end @generated function add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple} N = length(T.types) - return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, $i, bounds[$i])) for i in 1:N]...) + return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i), bounds[$i])) for i in 1:N]...) end _add_to_deriv!(_, _, _, _, _) = nothing -function _add_to_deriv!(x::Union{RTR, RTA}, out_deriv, results, i, bound) +function _add_to_deriv!(x::Union{RTR, RTA}, out_deriv, results, ::Val{i}, bound) where {i} return ReverseDiff.istracked(x) && ReverseDiff.diffresult_increment_deriv!(x, out_deriv, results, i, bound) end @noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(∇broadcast)}) input, output = instruction.input, instruction.output results, df, _ = instruction.cache - ReverseDiff.pull_value!.(input) broadcast!(df, results, value.(input)...) output_value = value(output) - for i in eachindex(output_value) - output_value[i] = DiffResults.value(results[i]) - end + output_value .= DiffResults.value.(results) return nothing end +## Tracker style broadcasting +## Good for broadcasting real numbers or arrays of non-tracked structs + +trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) + +unbroadcast(x::AbstractArray, Δ) = + size(x) == size(Δ) ? Δ : + length(x) == length(Δ) ? trim(x, Δ) : + trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) + +unbroadcast(x::Number, Δ) = sum(Δ) +unbroadcast(x::Base.RefValue, _) = nothing + +dual(x, p) = x +dual(x::Real, p) = Dual(x, p) + +function _deriv(f, G, ::Val{i}, args::Vararg{Any, N}) where {N, i} + dargs = ntuple(j -> dual(args[j], i==j), Val(N)) + return f(dargs...).partials[1] * G +end +@generated function _derivs(f, G, args::Vararg{Any, N}) where {N} + return Expr(:tuple, [:(_deriv.(f, G, Val($i), args...)) for i in 1:N]...) +end +@inline function tracker_∇broadcast(f, args::Vararg{Any, N}) where {N} + args_values = map(value, args) + out_value = broadcast(f, args_values...) + tp = tape(args...) + eltype(out_value) == Bool && return out_value + out = track(out_value, tp) + cache = (f,) + record!(tp, SpecialInstruction, tracker_∇broadcast, args, out, cache) + return out +end + +@noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(tracker_∇broadcast)}) + input, output = instruction.input, instruction.output + f = instruction.cache[1] + output_value = value(output) + broadcast!(f, output_value, value.(input)...) + return nothing +end + +@noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(tracker_∇broadcast)}) + input = instruction.input + output = instruction.output + f = instruction.cache[1] + output_deriv = deriv(output) + N = length(input) + Δargs = _derivs(f, output_deriv, value.(input)...) + dxs = map(unbroadcast, input, Δargs) + map(add_to_deriv!, input, dxs) + unseed!(output) + return nothing +end + +## Limited ReverseDiff broadcasting +## Efficient broadcasting for specific functions, e.g. +, - + +@inline _materialize(f, args) = broadcast(f, args...) + +for (M, f, arity) in ReverseDiff.DiffRules.diffrules() + if arity == 1 + @eval @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA}}) = _materialize(bc.f, bc.args) + elseif arity == 2 + @eval begin + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA, RTA}}) = _materialize(bc.f, bc.args) + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA, RTR}}) = _materialize(bc.f, bc.args) + @noinline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTR, RTA}}) = _materialize(bc.f, bc.args) + end + for A in ReverseDiff.ARRAY_TYPES + @eval begin + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A, RTA}}) = _materialize(bc.f, bc.args) + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA, $A}}) = _materialize(bc.f, bc.args) + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A, RTR}}) = _materialize(bc.f, bc.args) + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTR, $A}}) = _materialize(bc.f, bc.args) + end + end + for R in ReverseDiff.REAL_TYPES + @eval begin + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$R, RTA}}) = _materialize(bc.f, bc.args) + @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA, $R}}) = _materialize(bc.f, bc.args) + end + end + end +end + end From 59ec193bd41bfc65e54b05ec8076436cf103ada5 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 28 Mar 2020 20:19:08 +1100 Subject: [PATCH 04/10] fix import bug --- src/reversediffx.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/reversediffx.jl b/src/reversediffx.jl index c60296e7..3d9faf8a 100644 --- a/src/reversediffx.jl +++ b/src/reversediffx.jl @@ -6,14 +6,13 @@ module ReverseDiffX export NotTracked -using MacroTools, LinearAlgebra, ..ReverseDiff -import SpecialFunctions, NaNMath, Zygote, StaticArrays - +using MacroTools, LinearAlgebra, ..ReverseDiff, StaticArrays using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted using ForwardDiff: ForwardDiff, Dual using ..ReverseDiff: SpecialInstruction, value, value!, deriv, track, record!, tape, unseed! using ..DistributionsAD: DistributionsAD, _turing_chol +import SpecialFunctions, NaNMath, Zygote import ..DistributionsAD: turing_chol import Base.Broadcast: materialize From 4d3ffdc1a71ce3b2f7323d9993f876cfdfee6927 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 29 Mar 2020 11:46:55 +1100 Subject: [PATCH 05/10] test fixes --- src/DistributionsAD.jl | 5 ++++- src/reversediffx.jl | 19 ++++++++++--------- src/univariate.jl | 10 ++++++++++ test/test_utils.jl | 8 ++++---- 4 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 4bf5fc80..3193a166 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -33,7 +33,10 @@ import Distributions: MvNormal, poissonbinomial_pdf_fft, logpdf, quantile, - PoissonBinomial + PoissonBinomial, + Binomial, + BetaBinomial, + Erlang export TuringScalMvNormal, TuringDiagMvNormal, diff --git a/src/reversediffx.jl b/src/reversediffx.jl index 3d9faf8a..45b6c6f5 100644 --- a/src/reversediffx.jl +++ b/src/reversediffx.jl @@ -130,16 +130,16 @@ function remove_tp(t) end end -_fill(v::Real, dims::Vararg{Union{Integer, AbstractUnitRange}}) = fill(v[], dims...) -Base.fill(v::RTR, dims::Vararg{Union{Integer, AbstractUnitRange}}) = _fill(Ref(v), dims...) +_fill(v::Real, dims::Vararg{Union{Integer, AbstractUnitRange}}) = fill(v, dims...) +Base.fill(v::RTR, dims::Vararg{Union{Integer, AbstractUnitRange}}) = _fill(v, dims...) function _fill( - value::Base.RefValue{<:RTR}, + value::RTR, dims::Vararg{Union{Integer, AbstractUnitRange}}, ) return track(_fill, value, dims...) end -@grad function _fill(v::Base.RefValue{<:Real}, dims...) - return fill(value(v[]), dims...), function(Δ) +@grad function _fill(v::Real, dims...) + return fill(value(v), dims...), function(Δ) size(Δ) ≢ dims && error("Dimension mismatch") return (sum(Δ), map(_->nothing, dims)...) end @@ -289,6 +289,7 @@ istypeorclosure(::F) where {F} = _istypeorclosure(F) istypeorclosure(::AbstractArray{F}) where {F} = _istypeorclosure(F) istypeorclosure(::Base.RefValue{F}) where {F} = _istypeorclosure(F) istypeorclosure(::AbstractArray{<:Real}) = false +istypeorclosure(::AbstractArray{<:RTR}) = true istypeorclosure(::Real) = false @generated _istypeorclosure(::Type{F}) where {F} = :($(fieldcount(F) > 0)) @@ -329,17 +330,17 @@ end function get_implementation(bc, f, T, args) outputisreal = (T <: AbstractArray{<:Real}) && (T !== Union{}) - # Any arg is a real number or an array of untracked non-reals, + # Each arg is either a real number, an array of untraked reals, a tracked array of reals or an array of untracked non-reals, # Output is real, and # No tracked closure or arguments, except TrackedReal and TrackedArray. if !mayhavetracked(bc) && outputisreal && (anyreals(args) || !onlyrealarrays(args)) return Val(:tracker) - # No arg is a real number and array args must be arrays of reals, + # No arg is a real number and array args must be arrays of untracked reals or tracked arrays of reals, # Output is real, and # No tracked closure or arguments, except TrackedReal and TrackedArray. elseif !mayhavetracked(bc) && outputisreal return Val(:reversediff) - # Function or any arg is possibly a tracked non-real or array of tracked non-reals, + # Function or any arg is possibly a tracked non-real or an array of tracked reals/non-reals, # Or output is not an array of reals else return Val(:fallback) @@ -522,7 +523,7 @@ end tp = tape(args...) eltype(out_value) == Bool && return out_value out = track(out_value, tp) - cache = (f,) + cache = (f,) record!(tp, SpecialInstruction, tracker_∇broadcast, args, out, cache) return out end diff --git a/src/univariate.jl b/src/univariate.jl index af4b7f7d..0aa147a0 100644 --- a/src/univariate.jl +++ b/src/univariate.jl @@ -226,6 +226,16 @@ function nbinomlogpdf(r::ForwardDiff.Dual{T}, p::Real, k::Int) where {T} return FD(nbinomlogpdf(val_r, p, k), Δ_r) end +## Integer dual ## + +function BetaBinomial(n::ForwardDiff.Dual{<:Any, <:Integer}, α::Real, β::Real; check_args = true) + return BetaBinomial(ForwardDiff.value(n), α, β; check_args = check_args) +end +Binomial(n::ForwardDiff.Dual{<:Any, <:Integer}, p::Real) = Binomial(ForwardDiff.value(n), p) +function Erlang(α::ForwardDiff.Dual{<:Any, <:Integer}, θ::Real; check_args = true) + return Erlang(ForwardDiff.value(α), θ, check_args = check_args) +end + ## Poisson ## poislogpdf(v::TrackedReal, x::Int) = track(poislogpdf, v, x) diff --git a/test/test_utils.jl b/test/test_utils.jl index bd9fe824..94c0b637 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -160,14 +160,14 @@ function test_ad(f, at = 0.5; rtol = 1e-8, atol = 1e-8) forward = ForwardDiff.gradient(f, at) @test isapprox(reverse_tracker, forward, rtol=rtol, atol=atol) @test isapprox(reverse_zygote, forward, rtol=rtol, atol=atol) - @test isapprox(reverse_diff, reverse_tracker, rtol=rtol, atol=atol) + @test isapprox(reverse_diff, forward, rtol=rtol, atol=atol) else forward = ForwardDiff.derivative(f, at) finite_diff = central_fdm(5,1)(f, at) @test isapprox(reverse_tracker, forward, rtol=rtol, atol=atol) @test isapprox(reverse_tracker, finite_diff, rtol=rtol, atol=atol) @test isapprox(reverse_zygote, finite_diff, rtol=rtol, atol=atol) - @test isapprox(reverse_diff, reverse_tracker, rtol=rtol, atol=atol) + @test isapprox(reverse_diff, forward, rtol=rtol, atol=atol) end elseif stg == "ForwardDiff_Tracker" isarr = isa(at, AbstractArray) @@ -192,8 +192,8 @@ function test_ad(f, at = 0.5; rtol = 1e-8, atol = 1e-8) @test isapprox(reverse_zygote, forward, rtol=rtol, atol=atol) elseif stg == "ReverseDiff" reverse_diff = ReverseDiff.gradient(f, at) - reverse_tracker = Tracker.data(Tracker.gradient(f, at)[1]) - @test isapprox(reverse_diff, reverse_tracker, rtol=rtol, atol=atol) + forward = ForwardDiff.gradient(f, at) + @test isapprox(reverse_diff, forward, rtol=rtol, atol=atol) else throw("Unsupported test stage.") end From 0ad74f5c84aa4d0916b9ab6d51d780bb891a69c5 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 29 Mar 2020 12:27:28 +1100 Subject: [PATCH 06/10] apply David's and Philipp's comments --- src/reversediffx.jl | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/reversediffx.jl b/src/reversediffx.jl index 45b6c6f5..55892f8a 100644 --- a/src/reversediffx.jl +++ b/src/reversediffx.jl @@ -27,7 +27,7 @@ const RDBroadcasted{F, T} = Broadcasted{<:Any, <:Any, F, T} f(x::ReverseDiff.TrackedVector) = ReverseDiff.track(f, x) ReverseDiff.@grad function f(x) xv = ReverseDiff.value(x) - return dot(xv, xv), ∇ -> (∇ * 2 * xv,) + return dot(xv, xv), Δ -> (Δ * 2 * xv,) end The `@grad` macro provides a way for the users to define custom adjoints for single-output functions wrt to their input numbers or arrays. """ @@ -192,13 +192,13 @@ Base.vcat(xs::RTV...) = track(vcat, xs...) out_value = vcat(xs_value...) function back(Δ) start = 0 - Δs = [begin + Δs = map(xs) do xsi x = map(_ -> :, size(xsi)) i = isempty(x) ? x : Base.tail(x) d = Δ[start+1:start+size(xsi,1), i...] start += size(xsi, 1) d - end for xsi in xs] + end return (Δs...,) end return out_value, back @@ -211,7 +211,7 @@ Base.hcat(xs::RTV...) = track(hcat, xs...) out_value = hcat(xs_value...) function back(Δ) start = 0 - Δs = [begin + Δs = map(xs) do xsi d = if ndims(xsi) == 1 Δ[:, start+1] else @@ -220,7 +220,7 @@ Base.hcat(xs::RTV...) = track(hcat, xs...) end start += size(xsi, 2) d - end for xsi in xs] + end return (Δs...,) end return out_value, back @@ -234,14 +234,14 @@ function _cat(dims, Xs::Union{RTV{<:Any, D}, RTM{<:Any, D}}...) where {D} out_value = cat(Xs_value...; dims = dims) function back(Δ) start = ntuple(i -> 0, Val(ndims(Δ))) - Δs = [begin + Δs = map(Xs) do xs dim_xs = 1:ndims(xs) till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ))) xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ))) d = reshape(Δ[xs_in_Δ...],size(xs)) start = start .+ till_xs d - end for xs in Xs] + end return (Δs...,) end out = track(out_value, D, tp) @@ -293,10 +293,10 @@ istypeorclosure(::AbstractArray{<:RTR}) = true istypeorclosure(::Real) = false @generated _istypeorclosure(::Type{F}) where {F} = :($(fieldcount(F) > 0)) -@inline mayhavetracked(b) = istypeorclosure(b) -@inline mayhavetracked(b::NotTracked) = false -@inline mayhavetracked(b::Base.RefValue{<:NotTracked}) = false -@inline mayhavetracked(b::Broadcasted) = mayhavetracked(b.f) || any(mayhavetracked, b.args) +mayhavetracked(b) = istypeorclosure(b) +mayhavetracked(b::NotTracked) = false +mayhavetracked(b::Base.RefValue{<:NotTracked}) = false +mayhavetracked(b::Broadcasted) = mayhavetracked(b.f) || any(mayhavetracked, b.args) struct TrackedStyle <: BroadcastStyle end @@ -319,14 +319,14 @@ function remove_not_tracked(b::Broadcasted{style}) where {style} return Broadcasted{style}(remove_not_tracked(b.f), remove_not_tracked.(b.args), b.axes) end -@generated function onlyrealarrays(args::T) where {T <: Tuple} - o = all(map(x -> (x <: AbstractArray{<:Real} || !(x <: AbstractArray)), T.types)) - return :($o) -end -@generated function anyreals(args::T) where {T <: Tuple} - o = any(map(x -> x <: Real, T.types)) - return :($o) -end +onlyrealarrays(args::Tuple) = onlyrealarray(first(args)) && onlyrealarrays(Base.tail(args)) +onlyrealarrays(::Tuple{}) = true +onlyrealarray(::AbstractArray{<:Real}) = true +onlyrealarray(::AbstractArray) = false +onlyrealarray(::Any) = true + +anyreals(args::Tuple) = first(args) isa Real || anyreals(Base.tail(args)) +anyreals(args::Tuple{}) = false function get_implementation(bc, f, T, args) outputisreal = (T <: AbstractArray{<:Real}) && (T !== Union{}) From 5c3a97507fba31eba2202756c1b8af63bace3bfe Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 29 Mar 2020 15:43:30 +1100 Subject: [PATCH 07/10] avoid fallback impl for TrackedArray inputs --- src/reversediffx.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/reversediffx.jl b/src/reversediffx.jl index 55892f8a..ac941c5f 100644 --- a/src/reversediffx.jl +++ b/src/reversediffx.jl @@ -289,6 +289,7 @@ istypeorclosure(::F) where {F} = _istypeorclosure(F) istypeorclosure(::AbstractArray{F}) where {F} = _istypeorclosure(F) istypeorclosure(::Base.RefValue{F}) where {F} = _istypeorclosure(F) istypeorclosure(::AbstractArray{<:Real}) = false +istypeorclosure(::RTA) = false istypeorclosure(::AbstractArray{<:RTR}) = true istypeorclosure(::Real) = false @generated _istypeorclosure(::Type{F}) where {F} = :($(fieldcount(F) > 0)) From 48ad330c4a12862f833433ade7ad81e4b36f0f17 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Tue, 31 Mar 2020 02:52:03 +1100 Subject: [PATCH 08/10] fix perf of broadcast involving type constructors --- src/reversediffx.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/reversediffx.jl b/src/reversediffx.jl index ac941c5f..23df96be 100644 --- a/src/reversediffx.jl +++ b/src/reversediffx.jl @@ -295,6 +295,7 @@ istypeorclosure(::Real) = false @generated _istypeorclosure(::Type{F}) where {F} = :($(fieldcount(F) > 0)) mayhavetracked(b) = istypeorclosure(b) +mayhavetracked(b::Type) = false mayhavetracked(b::NotTracked) = false mayhavetracked(b::Base.RefValue{<:NotTracked}) = false mayhavetracked(b::Broadcasted) = mayhavetracked(b.f) || any(mayhavetracked, b.args) @@ -334,6 +335,7 @@ function get_implementation(bc, f, T, args) # Each arg is either a real number, an array of untraked reals, a tracked array of reals or an array of untracked non-reals, # Output is real, and # No tracked closure or arguments, except TrackedReal and TrackedArray. + @show mayhavetracked(bc) if !mayhavetracked(bc) && outputisreal && (anyreals(args) || !onlyrealarrays(args)) return Val(:tracker) # No arg is a real number and array args must be arrays of untracked reals or tracked arrays of reals, From d85fb98252798a7c7ead5ef8c363ca767c66cc0e Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Tue, 31 Mar 2020 03:00:17 +1100 Subject: [PATCH 09/10] remove a show --- src/reversediffx.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/reversediffx.jl b/src/reversediffx.jl index 23df96be..a2c9e7f4 100644 --- a/src/reversediffx.jl +++ b/src/reversediffx.jl @@ -335,7 +335,6 @@ function get_implementation(bc, f, T, args) # Each arg is either a real number, an array of untraked reals, a tracked array of reals or an array of untracked non-reals, # Output is real, and # No tracked closure or arguments, except TrackedReal and TrackedArray. - @show mayhavetracked(bc) if !mayhavetracked(bc) && outputisreal && (anyreals(args) || !onlyrealarrays(args)) return Val(:tracker) # No arg is a real number and array args must be arrays of untracked reals or tracked arrays of reals, From 87b546d8ce6abf8ff799a7fb934867ff7eaaacb4 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Wed, 1 Apr 2020 19:16:50 +1100 Subject: [PATCH 10/10] kwarg support in grad macro and tests --- src/reversediffx.jl | 149 ++++++++++++++++++++++++-------------------- test/others.jl | 143 ++++++++++++++++++++++++++++++++++++++++++ test/test_utils.jl | 12 ++-- 3 files changed, 230 insertions(+), 74 deletions(-) diff --git a/src/reversediffx.jl b/src/reversediffx.jl index a2c9e7f4..1d3048aa 100644 --- a/src/reversediffx.jl +++ b/src/reversediffx.jl @@ -32,71 +32,70 @@ const RDBroadcasted{F, T} = Broadcasted{<:Any, <:Any, F, T} The `@grad` macro provides a way for the users to define custom adjoints for single-output functions wrt to their input numbers or arrays. """ macro grad(expr) - if @capture(expr, - (f_(xs__) where {T__} = body_) | - (f_(xs__) = body_) | - (function f_(xs__) body_ end) | - (function f_(xs__) where {T__} body_ end) - ) - closure = gensym(:f) - tp = gensym(:tp) - output_value = gensym(:output_value) - output = gensym(:output) - back = gensym(:back) - args = gensym(:args) - xsv = getargs_expr(xs) - T = T == nothing ? [] : T - return quote - function ReverseDiff.track(::typeof($f), $(xs...)) where {$(T...),} - $args = $xsv - $closure = ($(xs...),) -> $body - $tp = ReverseDiff.tape($args...) - $output_value, $back = $closure($args...) - $output = ReverseDiff.track($output_value, $tp) - ReverseDiff.record!( - $tp, - ReverseDiff.SpecialInstruction, - $f, - $args, - $output, - ($back, $closure), - ) - return $output - end - - @static if !hasmethod( - ReverseDiff.special_reverse_exec!, - Tuple{ReverseDiff.SpecialInstruction{typeof($f)}}, + d = MacroTools.splitdef(expr) + f = d[:name] + closure = gensym(f) + d[:name] = closure + closure_ex = MacroTools.combinedef(d) + + tp = gensym(:tp) + output_value = gensym(:output_value) + output = gensym(:output) + back = gensym(:back) + args = gensym(:args) + kwargs = gensym(:kwargs) + args_ex = getargs_expr(d[:args]) + kwargs_ex = getkwargs_expr(d[:kwargs]) + return quote + function ReverseDiff.track(::typeof($f), $(d[:args]...); $(d[:kwargs]...)) where {$(d[:whereparams]...),} + $closure_ex + $args = $args_ex + $kwargs = $kwargs_ex + $tp = ReverseDiff.tape($args...) + $output_value, $back = $closure($args...; $kwargs...) + $output = ReverseDiff.track($output_value, $tp) + ReverseDiff.record!( + $tp, + ReverseDiff.SpecialInstruction, + $f, + $args, + $output, + ($back, $closure, $kwargs), ) - @noinline function ReverseDiff.special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f)}) - output = instruction.output - input = instruction.input - back = instruction.cache[1] - input_derivs = back(ReverseDiff.deriv(output)) - @assert input_derivs isa Tuple - DistributionsAD.ReverseDiffX.add_to_deriv!.(input, input_derivs) - ReverseDiff.unseed!(output) - return nothing - end + return $output + end + + if !hasmethod( + ReverseDiff.special_reverse_exec!, + Tuple{ReverseDiff.SpecialInstruction{typeof($f)}}, + ) + @noinline function ReverseDiff.special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f)}) + output = instruction.output + input = instruction.input + back = instruction.cache[1] + input_derivs = back(ReverseDiff.deriv(output)) + @assert input_derivs isa Tuple + DistributionsAD.ReverseDiffX.add_to_deriv!.(input, input_derivs) + ReverseDiff.unseed!(output) + return nothing end + end - @static if !hasmethod( - ReverseDiff.special_forward_exec!, - Tuple{ReverseDiff.SpecialInstruction{typeof($f)}}, - ) - @noinline function ReverseDiff.special_forward_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f)}) - output, input = instruction.output, instruction.input - ReverseDiff.pull_value!.(input) - pullback = instruction.cache[2] - out_value = pullback(input...)[1] - ReverseDiff.value!(output, out_value) - return nothing - end + if !hasmethod( + ReverseDiff.special_forward_exec!, + Tuple{ReverseDiff.SpecialInstruction{typeof($f)}}, + ) + @noinline function ReverseDiff.special_forward_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f)}) + output, input = instruction.output, instruction.input + ReverseDiff.pull_value!.(input) + pullback = instruction.cache[2] + kwargs = instruction.cache[3] + out_value = pullback(input...; kwargs...)[1] + ReverseDiff.value!(output, out_value) + return nothing end - end |> esc - else - throw("Invalid `ReverseDiff` custom gradient definition.") - end + end + end |> esc end add_to_deriv!(d1, d2) = nothing function add_to_deriv!(d1::Union{RTR, RTA}, d2) @@ -114,31 +113,49 @@ function getargs_expr(args_with_types) end return expr end +function getkwargs_expr(kwargs_with_types) + syms = [] + final = nothing + for at in kwargs_with_types + final isa Nothing || throw("Invalid kwargs.") + x, tosplat = remove_tp(at) + if tosplat + final = x + else + push!(syms, x) + end + end + expr = length(syms) == 0 ? :(NamedTuple()) : Expr(:tuple, [:($f = $f) for f in syms]...) + final = final == nothing ? :(NamedTuple()) : final + return :(Base.merge($expr, $final)) +end function remove_tp(t) if @capture(t, X_::T_...) return X, true elseif @capture(t, X_::T_) return X, false + elseif @capture(t, X_::T_ = V_) + return X, false elseif @capture(t, ::typeof(T_)...) return T, true elseif @capture(t, ::typeof(T_)) return T, false elseif @capture(t, X_...) return X, true + elseif @capture(t, X_ = V_) + return X, false else return t, false end end -_fill(v::Real, dims::Vararg{Union{Integer, AbstractUnitRange}}) = fill(v, dims...) -Base.fill(v::RTR, dims::Vararg{Union{Integer, AbstractUnitRange}}) = _fill(v, dims...) -function _fill( +function Base.fill( value::RTR, dims::Vararg{Union{Integer, AbstractUnitRange}}, ) - return track(_fill, value, dims...) + return track(fill, value, dims...) end -@grad function _fill(v::Real, dims...) +@grad function fill(v::Real, dims...) return fill(value(v), dims...), function(Δ) size(Δ) ≢ dims && error("Dimension mismatch") return (sum(Δ), map(_->nothing, dims)...) diff --git a/test/others.jl b/test/others.jl index 5faac2b6..00d72a5f 100644 --- a/test/others.jl +++ b/test/others.jl @@ -1,3 +1,8 @@ +const RTR = ReverseDiff.TrackedReal +const RTV = ReverseDiff.TrackedVector +const RTM = ReverseDiff.TrackedMatrix +const RTA = ReverseDiff.TrackedArray +using DistributionsAD.ReverseDiffX: @grad using StatsBase: entropy if get_stage() in ("Others", "all") @@ -142,4 +147,142 @@ if get_stage() in ("Others", "all") d = TuringScalMvNormal(m, sigmas[1]) @test params(d) == (m, sigmas[1]) end + + @testset "ReverseDiff @grad macro" begin + x = rand(3); + A = rand(3, 3); + A_x = [vec(A); x]; + global custom_grad_called + + f1(x) = dot(x, x) + f1(x::RTV) = ReverseDiff.track(f1, x) + @grad function f1(x::AbstractVector) + global custom_grad_called = true + xv = ReverseDiff.value(x) + dot(xv, xv), Δ -> (Δ * 2 * xv,) + end + + custom_grad_called = false + g1 = ReverseDiff.gradient(f1, x) + g2 = ReverseDiff.gradient(x -> dot(x, x), x) + @test g1 == g2 + @test custom_grad_called + + f2(A, x) = A * x + f2(A, x::RTV) = ReverseDiff.track(f2, A, x) + f2(A::RTM, x) = ReverseDiff.track(f2, A, x) + f2(A::RTM, x::RTV) = ReverseDiff.track(f2, A, x) + @grad function f2(A::AbstractMatrix, x::AbstractVector) + global custom_grad_called = true + Av = ReverseDiff.value(A) + xv = ReverseDiff.value(x) + Av * xv, Δ -> (Δ * xv', Av' * Δ) + end + + custom_grad_called = false + g1 = ReverseDiff.gradient(x -> sum(f2(A, x)), x) + g2 = ReverseDiff.gradient(x -> sum(A * x), x) + @test g1 == g2 + @test custom_grad_called + + custom_grad_called = false + g1 = ReverseDiff.gradient(A -> sum(f2(A, x)), A) + g2 = ReverseDiff.gradient(A -> sum(A * x), A) + @test g1 == g2 + @test custom_grad_called + + custom_grad_called = false + g1 = ReverseDiff.gradient(A_x -> sum(f2(reshape(A_x[1:9], 3, 3), A_x[10:end])), A_x) + g2 = ReverseDiff.gradient(A_x -> sum(reshape(A_x[1:9], 3, 3) * A_x[10:end]), A_x) + @test g1 == g2 + @test custom_grad_called + + f3(A; dims) = sum(A, dims = dims) + f3(A::RTM; dims) = ReverseDiff.track(f3, A; dims = dims) + @grad function f3(A::AbstractMatrix; dims) + global custom_grad_called = true + Av = ReverseDiff.value(A) + sum(Av, dims = dims), Δ -> (zero(Av) .+ Δ,) + end + custom_grad_called = false + g1 = ReverseDiff.gradient(A -> sum(f3(A, dims = 1)), A) + g2 = ReverseDiff.gradient(A -> sum(sum(A, dims = 1)), A) + @test g1 == g2 + @test custom_grad_called + + f4(::typeof(log), A; dims) = sum(log, A, dims = dims) + f4(::typeof(log), A::RTM; dims) = ReverseDiff.track(f4, log, A; dims = dims) + @grad function f4(::typeof(log), A::AbstractMatrix; dims) + global custom_grad_called = true + Av = ReverseDiff.value(A) + sum(log, Av, dims = dims), Δ -> (nothing, 1 ./ Av .* Δ) + end + custom_grad_called = false + g1 = ReverseDiff.gradient(A -> sum(f4(log, A, dims = 1)), A) + g2 = ReverseDiff.gradient(A -> sum(sum(log, A, dims = 1)), A) + @test g1 == g2 + @test custom_grad_called + + f5(x) = log(x) + f5(x::RTR) = ReverseDiff.track(f5, x) + @grad function f5(x::Real) + global custom_grad_called = true + xv = ReverseDiff.value(x) + log(xv), Δ -> (1 / xv * Δ,) + end + custom_grad_called = false + g1 = ReverseDiff.gradient(x -> f5(x[1]) * f5(x[2]) + exp(x[3]), x) + g2 = ReverseDiff.gradient(x -> log(x[1]) * log(x[2]) + exp(x[3]), x) + @test g1 == g2 + @test custom_grad_called + + f6(x) = sum(x) + f6(x::RTA{<:AbstractFloat}) = ReverseDiff.track(f6, x) + @grad function f6(x::RTA{T}) where {T <: AbstractFloat} + global custom_grad_called = true + xv = ReverseDiff.value(x) + sum(xv), Δ -> (one.(xv) .* Δ,) + end + + custom_grad_called = false + g1 = ReverseDiff.gradient(f6, x) + g2 = ReverseDiff.gradient(sum, x) + @test g1 == g2 + @test custom_grad_called + + x2 = round.(Int, x) + custom_grad_called = false + g1 = ReverseDiff.gradient(f6, x2) + g2 = ReverseDiff.gradient(sum, x2) + @test g1 == g2 + @test !custom_grad_called + f6(x::RTA) = ReverseDiff.track(f6, x) + @test_throws MethodError ReverseDiff.gradient(f6, x2) + + f7(x...) = +(x...) + f7(x::RTR{<:AbstractFloat}...) = ReverseDiff.track(f7, x...) + @grad function f7(x::RTR{T}...) where {T <: AbstractFloat} + global custom_grad_called = true + xv = ReverseDiff.value.(x) + +(xv...), Δ -> one.(xv) .* Δ + end + custom_grad_called = false + g1 = ReverseDiff.gradient(x -> f7(x...), x) + g2 = ReverseDiff.gradient(sum, x) + @test g1 == g2 + @test custom_grad_called + + f8(A; kwargs...) = sum(A, kwargs...) + f8(A::RTM; kwargs...) = ReverseDiff.track(f8, A; kwargs...) + @grad function f8(A::AbstractMatrix; kwargs...) + global custom_grad_called = true + Av = ReverseDiff.value(A) + sum(Av; kwargs...), Δ -> (zero(Av) .+ Δ,) + end + custom_grad_called = false + g1 = ReverseDiff.gradient(A -> sum(f8(A, dims = 1)), A) + g2 = ReverseDiff.gradient(A -> sum(sum(A, dims = 1)), A) + @test g1 == g2 + @test custom_grad_called + end end \ No newline at end of file diff --git a/test/test_utils.jl b/test/test_utils.jl index 94c0b637..fa999cb4 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -136,15 +136,11 @@ end # Taken from Turing.jl function get_stage() - if get(ENV, "TRAVIS", "") == "true" || get(ENV, "GITHUB_ACTIONS", "") == "true" - if "STAGE" in keys(ENV) - return ENV["STAGE"] - else - return "all" - end + if "STAGE" in keys(ENV) + return ENV["STAGE"] + else + return "all" end - - return "all" end const zygote_counter = Ref(0)