Skip to content

Commit

Permalink
Merge pull request #58 from TuringLang/mt/reversediffx
Browse files Browse the repository at this point in the history
Custom gradient macro and broadcasting rework for ReverseDiff
  • Loading branch information
mohamed82008 authored Apr 1, 2020
2 parents 8431408 + 87b546d commit 14534fe
Show file tree
Hide file tree
Showing 7 changed files with 775 additions and 230 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -29,11 +31,13 @@ Distributions = "0.22, 0.23"
FillArrays = "0.8"
FiniteDifferences = "0.9"
ForwardDiff = "0.10.6"
MacroTools = "0.5"
NaNMath = "0.3"
PDMats = "0.9"
ReverseDiff = "1.1"
SpecialFunctions = "0.8, 0.9, 0.10"
StatsBase = "0.32, 0.33"
StaticArrays = "0.12"
StatsFuns = "0.8, 0.9"
Tracker = "0.2.5"
Zygote = "0.4.10"
Expand Down
5 changes: 4 additions & 1 deletion src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ import Distributions: MvNormal,
poissonbinomial_pdf_fft,
logpdf,
quantile,
PoissonBinomial
PoissonBinomial,
Binomial,
BetaBinomial,
Erlang

export TuringScalMvNormal,
TuringDiagMvNormal,
Expand Down
221 changes: 4 additions & 217 deletions src/reversediff.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
const RTR = ReverseDiff.TrackedReal
const RTV = ReverseDiff.TrackedVector
const RTM = ReverseDiff.TrackedMatrix
const RTA = ReverseDiff.TrackedArray
using ReverseDiff: SpecialInstruction
import NaNMath
using ForwardDiff: Dual
import SpecialFunctions: logbeta
import Distributions: Gamma
include("reversediffx.jl")

Base.:*(A::Adjoint{<:Real, <:RTV{<:Real}}, B::AbstractVector{<:Real}) = dot(A, B)
Base.:*(A::Adjoint{<:Real, <:RTV{<:Real}}, B::RTV{<:Real}) = dot(A, B)
Base.:*(A::AbstractVector{<:Real}, B::Adjoint{<:Real, <:RTV{<:Real}}) = dot(A, B)
Base.:*(A::RTV{<:Real}, B::Adjoint{<:Real, <:RTV{<:Real}}) = dot(A, B)
import Distributions: Gamma
using .ReverseDiffX
using .ReverseDiffX: RTR, RTV, RTM

Gamma::RTR, θ::Real; check_args=true) = pgamma(α, θ, check_args = check_args)
Gamma::Real, θ::RTR; check_args=true) = pgamma(α, θ, check_args = check_args)
Expand Down Expand Up @@ -181,39 +172,6 @@ end
# zero mean,, constant variance
MvLogNormal(d::Int, σ::RTR) = TuringMvLogNormal(TuringMvNormal(d, σ))

function LinearAlgebra.cholesky(A::RTM; check=true)
factors, info = turing_chol(A, check)
return Cholesky{eltype(factors), typeof(factors)}(factors, 'U', info)
end

function turing_chol(x::ReverseDiff.TrackedArray{V,D}, check) where {V,D}
tp = ReverseDiff.tape(x)
x_value = ReverseDiff.value(x)
check_value = ReverseDiff.value(check)
C, back = pullback(_turing_chol, x_value, check_value)
out = ReverseDiff.track(C.factors, D, tp)
ReverseDiff.record!(tp, SpecialInstruction, turing_chol, (x, check), out, (back, issuccess(C)))
return out, C.info
end

@noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(turing_chol)})
output = instruction.output
instruction.cache[2] || throw(PosDefException(C.info))
input = instruction.input
input_deriv = ReverseDiff.deriv(input[1])
P = instruction.cache[1]
input_deriv .+= P((factors = ReverseDiff.deriv(output),))[1]
ReverseDiff.unseed!(output)
return nothing
end

@noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(turing_chol)})
output, input = instruction.output, instruction.input
C = cholesky(ReverseDiff.value(input[1]), check = ReverseDiff.value(input[2]))
ReverseDiff.value!(output, C.factors)
return nothing
end

Distributions.Dirichlet(alpha::RTV) = TuringDirichlet(alpha)
Distributions.Dirichlet(d::Integer, alpha::RTR) = TuringDirichlet(d, alpha)

Expand Down Expand Up @@ -244,174 +202,3 @@ end
function Distributions.logpdf(d::InverseWishart, X::AbstractArray{<:RTM})
return logpdf(TuringInverseWishart(d), X)
end

# Modified from Tracker.jl

Base.vcat(xs::RTM...) = _vcat(xs...)
Base.vcat(xs::RTV...) = _vcat(xs...)
function _vcat(xs::Union{RTV{<:Any, D}, RTM{<:Any, D}}...) where {D}
tp = ReverseDiff.tape(xs...)
xs_value = ReverseDiff.value.(xs)
out_value = vcat(xs_value...)
function back(Δ)
start = 0
Δs = [begin
x = map(_ -> :, size(xsi))
i = isempty(x) ? x : Base.tail(x)
d = Δ[start+1:start+size(xsi,1), i...]
start += size(xsi, 1)
d
end for xsi in xs]
return (Δs...,)
end
out = ReverseDiff.track(out_value, D, tp)
ReverseDiff.record!(tp, SpecialInstruction, vcat, xs, out, (back,))
return out
end

@noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(vcat)})
output = instruction.output
input = instruction.input
input_derivs = ReverseDiff.deriv.(input)
P = instruction.cache[1]
jtvs = P(ReverseDiff.deriv(output))
for i in 1:length(input_derivs)
input_derivs[i] .+= jtvs[i]
end
ReverseDiff.unseed!(output)
return nothing
end

@noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(vcat)})
output, input = instruction.output, instruction.input
out_value = vcat(ReverseDiff.value.(input)...)
ReverseDiff.value!(output, out_value)
return nothing
end

Base.hcat(xs::RTM...) = _hcat(xs...)
Base.hcat(xs::RTV...) = _hcat(xs...)
function _hcat(xs::Union{RTV{<:Any, D}, RTM{<:Any, D}}...) where {D}
tp = ReverseDiff.tape(xs...)
xs_value = ReverseDiff.value.(xs)
out_value = hcat(xs_value...)
function back(Δ)
start = 0
Δs = [begin
d = if ndims(xsi) == 1
Δ[:, start+1]
else
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
Δ[:, start+1:start+size(xsi,2), i...]
end
start += size(xsi, 2)
d
end for xsi in xs]
return (Δs...,)
end
out = ReverseDiff.track(out_value, D, tp)
ReverseDiff.record!(tp, SpecialInstruction, hcat, xs, out, (back,))
return out
end

@noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(hcat)})
output = instruction.output
input = instruction.input
input_derivs = ReverseDiff.deriv.(input)
P = instruction.cache[1]
jtvs = P(ReverseDiff.deriv(output))
for i in 1:length(input_derivs)
input_derivs[i] .+= jtvs[i]
end
ReverseDiff.unseed!(output)
return nothing
end

@noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(hcat)})
output, input = instruction.output, instruction.input
out_value = hcat(ReverseDiff.value.(input)...)
ReverseDiff.value!(output, out_value)
return nothing
end

Base.cat(Xs::RTA...; dims) = _cat(dims, Xs...)
Base.cat(Xs::RTV...; dims) = _cat(dims, Xs...)
function _cat(dims, Xs::Union{RTV{<:Any, D}, RTM{<:Any, D}}...) where {D}
tp = ReverseDiff.tape(dims, Xs...)
Xs_value = ReverseDiff.value.(Xs)
out_value = cat(Xs_value...; dims = dims)
function back(Δ)
start = ntuple(i -> 0, Val(ndims(Δ)))
Δs = [begin
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ)))
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ)))
d = reshape(Δ[xs_in_Δ...],size(xs))
start = start .+ till_xs
d
end for xs in Xs]
return (Δs...,)
end
out = ReverseDiff.track(out_value, D, tp)
ReverseDiff.record!(tp, SpecialInstruction, cat, (dims, Xs...), out, (back,))
return out
end

@noinline function ReverseDiff.special_reverse_exec!(instruction::SpecialInstruction{typeof(cat)})
output = instruction.output
input = instruction.input
input_derivs = ReverseDiff.deriv.(Base.tail(input))
P = instruction.cache[1]
jtvs = P(ReverseDiff.deriv(output))
for i in 1:length(jtvs)
input_derivs[i] .+= jtvs[i]
end
ReverseDiff.unseed!(output)
return nothing
end

@noinline function ReverseDiff.special_forward_exec!(instruction::SpecialInstruction{typeof(cat)})
output, input = instruction.output, instruction.input
dims = ReverseDiff.value(input[1])
Xs = ReverseDiff.value.(Base.tail(input))
out_value = cat(Xs..., dims = dims)
ReverseDiff.value!(output, out_value)
return nothing
end

###########

# Broadcasting

using ReverseDiff: ForwardOptimize
using Base.Broadcast: Broadcasted
import Base.Broadcast: materialize
const RDBroadcasted{F, T} = Broadcasted{<:Any, <:Any, F, T}

_materialize(f, args) = broadcast(ForwardOptimize(f), args...)

for (M, f, arity) in ReverseDiff.DiffRules.diffrules()
if arity == 1
@eval @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA}}) = _materialize(bc.f, bc.args)
elseif arity == 2
@eval begin
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA, RTA}}) = _materialize(bc.f, bc.args)
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA, RTR}}) = _materialize(bc.f, bc.args)
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTR, RTA}}) = _materialize(bc.f, bc.args)
end
for A in ReverseDiff.ARRAY_TYPES
@eval begin
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A, RTA}}) = _materialize(bc.f, bc.args)
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA, $A}}) = _materialize(bc.f, bc.args)
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$A, RTR}}) = _materialize(bc.f, bc.args)
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTR, $A}}) = _materialize(bc.f, bc.args)
end
end
for R in ReverseDiff.REAL_TYPES
@eval begin
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{$R, RTA}}) = _materialize(bc.f, bc.args)
@inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{RTA, $R}}) = _materialize(bc.f, bc.args)
end
end
end
end
Loading

0 comments on commit 14534fe

Please sign in to comment.