Skip to content

Commit

Permalink
Merge pull request #63 from TuringLang/mt/refactor
Browse files Browse the repository at this point in the history
Refactor
  • Loading branch information
mohamed82008 authored Apr 13, 2020
2 parents 5aa6c87 + 4170be4 commit 82d5507
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 387 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -34,7 +34,6 @@ ForwardDiff = "0.10.6"
MacroTools = "0.5"
NaNMath = "0.3"
PDMats = "0.9"
ReverseDiff = "1.1"
SpecialFunctions = "0.8, 0.9, 0.10"
StatsBase = "0.32, 0.33"
StaticArrays = "0.12"
Expand All @@ -46,7 +45,8 @@ julia = "1"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "FiniteDifferences"]
test = ["FiniteDifferences", "Test", "ReverseDiff"]
8 changes: 5 additions & 3 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using PDMats,
Combinatorics,
SpecialFunctions,
StatsFuns,
Compat
Compat,
Requires

using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
TrackedVecOrMat, track, @grad, data
Expand All @@ -21,7 +22,6 @@ using Distributions: AbstractMvLogNormal,
using DiffRules, SpecialFunctions, FillArrays
using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
using Base.Iterators: drop
using ReverseDiff

import StatsFuns: logsumexp,
binomlogpdf,
Expand Down Expand Up @@ -55,6 +55,8 @@ include("matrixvariate.jl")
include("flatten.jl")
include("arraydist.jl")
include("filldist.jl")
include("reversediff.jl")
@init @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("reversediff.jl")
end

end
53 changes: 30 additions & 23 deletions src/arraydist.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,37 @@
# Utils

function maporbroadcast(f, dists::AbstractArray, x::AbstractArray)
# Broadcasting here breaks Tracker for some reason
return sum(map(f, dists, x))
end
function maporbroadcast(f, dists::AbstractVector, x::AbstractMatrix)
return map(x -> maporbroadcast(f, dists, x), eachcol(x))
end
@init @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin
function maporbroadcast(f, dists::LazyArrays.BroadcastArray, x::AbstractArray)
return sum(copy(f.(dists, x)))
end
function maporbroadcast(f, dists::LazyArrays.BroadcastVector, x::AbstractMatrix)
return vec(sum(copy(f.(dists, x)), dims = 1))
end
lazyarray(f, x...) = LazyArrays.LazyArray(Base.broadcasted(f, x...))
export lazyarray
end

# Univariate

const VectorOfUnivariate = Distributions.Product

function arraydist(dists::AbstractVector{<:UnivariateDistribution})
return product_distribution(dists)
end
function arraydist(dists::AbstractVector{<:Normal})
m = mapvcat(mean, dists)
s = mapvcat(std, dists)
return TuringMvNormal(m, s)
return Product(dists)
end

function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
return sum(map((d, x) -> logpdf(d, x), dist.v, x))
return maporbroadcast(logpdf, dist.v, x)
end
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
# eachcol breaks Zygote, so we need an adjoint
return mapvcat(dist.v, eachcol(x)) do dist, c
sum(map(c) do x
logpdf(dist, x)
end)
end
return maporbroadcast(logpdf, dist.v, x)
end
@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
# Any other more efficient implementation breaks Zygote
Expand All @@ -40,17 +51,13 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
return MatrixOfUnivariate(dists)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
# Broadcasting here breaks Tracker for some reason
# A Zygote adjoint is defined for mapvcat to use broadcasting
return sum(map(dist.dists, x) do dist, x
logpdf(dist, x)
end)
return maporbroadcast(logpdf, dist.dists, x)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return mapvcat(x -> logpdf(dist, x), x)
return map(x -> logpdf(dist, x), x)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}})
return mapvcat(x -> logpdf(dist, x), x)
return map(x -> logpdf(dist, x), x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
return rand.(Ref(rng), dist.dists)
Expand All @@ -72,16 +79,16 @@ function arraydist(dists::AbstractVector{<:MultivariateDistribution})
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
# eachcol breaks Zygote, so we define an adjoint
return sum(logpdf.(dist.dists, eachcol(x)))
return sum(map(logpdf, dist.dists, eachcol(x)))
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return mapvcat(x -> logpdf(dist, x), x)
return map(x -> logpdf(dist, x), x)
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}})
return mapvcat(x -> logpdf(dist, x), x)
return map(x -> logpdf(dist, x), x)
end
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
f(dist, x) = sum(mapvcat(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
f(dist, x) = sum(map(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
return pullback(f, dist, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
Expand Down
112 changes: 65 additions & 47 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,62 @@
## Generic ##

_istracked(x) = false
_istracked(x::TrackedArray) = false
_istracked(x::AbstractArray{<:TrackedReal}) = true
function mapvcat(f, args...)
out = map(f, args...)
if _istracked(out)
init = vcat(out[1])
return reshape(reduce(vcat, drop(out, 1); init = init), size(out))
else
return out
Tracker.dual(x::Bool, p) = x
Base.prevfloat(r::TrackedReal) = track(prevfloat, r)
@grad function prevfloat(r::Real)
prevfloat(data(r)), Δ -> Δ
end
Base.nextfloat(r::TrackedReal) = track(nextfloat, r)
@grad function nextfloat(r::Real)
nextfloat(data(r)), Δ -> Δ
end

for f = [:hcat, :vcat]
for c = [
[:TrackedReal],
[:AbstractVecOrMat, :TrackedReal],
[:TrackedVecOrMat, :TrackedReal],
]
cnames = map(_ -> gensym(), c)
@eval begin
function Base.$f(
$([:($x::$c) for (x, c) in zip(cnames, c)]...),
x::Union{TrackedArray,TrackedReal},
xs::Union{AbstractArray,Number}...,
)
return track($f, $(cnames...), x, xs...)
end
end
end
@eval begin
@grad function $f(x::Real)
$f(data(x)), (Δ) -> (Δ[1],)
end
@grad function $f(x1::Real, x2::Real)
$f(data(x1), data(x2)), (Δ) -> (Δ[1], Δ[2])
end
@grad function $f(x1::AbstractVector, x2::Real)
$f(data(x1), data(x2)), (Δ) -> (Δ[1:length(x1)], Δ[length(x1)+1])
end
end
end
@adjoint function mapvcat(f, args...)
g(f, args...) = map(f, args...)
return pullback(g, f, args...)

function Base.copy(
A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}},
) where {T <: Real}
return track(copy, A)
end
@grad function Base.copy(
A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}},
) where {T <: Real}
return copy(data(A)), ∇ -> (copy(∇),)
end

Base.:*(A::TrackedMatrix, B::AbstractTriangular) = track(*, A, B)
Base.:*(A::AbstractTriangular{T}, B::TrackedVector) where {T} = track(*, A, B)
Base.:*(A::AbstractTriangular{T}, B::TrackedMatrix) where {T} = track(*, A, B)
Base.:*(A::Adjoint{T, <:AbstractTriangular{T}}, B::TrackedMatrix) where {T} = track(*, A, B)
Base.:*(A::Adjoint{T, <:AbstractTriangular{T}}, B::TrackedVector) where {T} = track(*, A, B)

function Base.fill(
value::TrackedReal,
dims::Vararg{Union{Integer, AbstractUnitRange}},
Expand All @@ -32,10 +72,11 @@ end

## StatsFuns ##

logsumexp(x::TrackedArray) = track(logsumexp, x)
@grad function logsumexp(x::TrackedArray)
lse = logsumexp(data(x))
return lse, Δ ->.* exp.(x .- lse),)
logsumexp(x::TrackedArray; dims=:) = _logsumexp(x, dims)
_logsumexp(x::TrackedArray, dims=:) = track(_logsumexp, x, dims)
@grad function _logsumexp(x::TrackedArray, dims)
lse = logsumexp(data(x), dims = dims)
return lse, Δ ->.* exp.(x .- lse), nothing)
end

## Linear algebra ##
Expand Down Expand Up @@ -67,17 +108,6 @@ LinearAlgebra.UpperTriangular(A::TrackedMatrix) = upper(A)
upper(A::TrackedMatrix) = track(upper, A)
@grad upper(A) = upper(Tracker.data(A)), ∇ -> (upper(∇),)

function Base.copy(
A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}},
) where {T <: Real}
return track(copy, A)
end
@grad function Base.copy(
A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}},
) where {T <: Real}
return copy(data(A)), ∇ -> (copy(∇),)
end

function LinearAlgebra.cholesky(A::TrackedMatrix; check=true)
factors_info = turing_chol(A, check)
factors = factors_info[1]
Expand Down Expand Up @@ -136,25 +166,6 @@ end
return logabsgamma(x), Δ -> (digamma(x) * Δ[1],)
end

# Some Tracker fixes

for i = 0:2, c = Tracker.combinations([:AbstractArray, :TrackedArray, :TrackedReal, :Number], i), f = [:hcat, :vcat]
if :TrackedReal in c
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
track($f, $(cnames...), x, xs...)
end
end
@grad function vcat(x::Real)
vcat(data(x)), (Δ) -> (Δ[1],)
end
@grad function vcat(x1::Real, x2::Real)
vcat(data(x1), data(x2)), (Δ) -> (Δ[1], Δ[2])
end
@grad function vcat(x1::AbstractVector, x2::Real)
vcat(data(x1), data(x2)), (Δ) -> (Δ[1:length(x1)], Δ[length(x1)+1])
end

# Zygote fill has issues with non-numbers

@adjoint function fill(x::T, dims...) where {T}
Expand All @@ -163,3 +174,10 @@ end
end
pullback(zfill, x, dims...)
end

# isprobvec

function Distributions.isprobvec(p::TrackedArray{<:Real})
pdata = Tracker.data(p)
all(x -> x zero(x), pdata) && isapprox(sum(pdata), one(eltype(pdata)), atol = 1e-6)
end
4 changes: 2 additions & 2 deletions src/filldist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function _flat_logpdf(dist, x)
f, args = flatten(dist)
return sum(f.(args..., x))
else
return sum(mapvcat(x) do x
return sum(map(x) do x
logpdf(dist, x)
end)
end
Expand All @@ -60,7 +60,7 @@ function _flat_logpdf_mat(dist, x)
f, args = flatten(dist)
return vec(sum(f.(args..., x), dims = 1))
else
temp = mapvcat(x -> logpdf(dist, x), x)
temp = map(x -> logpdf(dist, x), x)
return vec(sum(temp, dims = 1))
end
end
Expand Down
10 changes: 5 additions & 5 deletions src/matrixvariate.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## MatrixBeta

function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:TrackedMatrix{<:Real}})
return mapvcat(x -> logpdf(d, x), X)
return map(x -> logpdf(d, x), X)
end
@adjoint function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:Matrix{<:Real}})
f(d, X) = map(x -> logpdf(d, x), X)
Expand Down Expand Up @@ -112,10 +112,10 @@ function Distributions.logpdf(d::TuringWishart, X::AbstractMatrix{<:Real})
return 0.5 * ((df - (p + 1)) * logdet(Xcf) - tr(d.chol \ X)) - d.c0
end
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
return mapvcat(x -> logpdf(d, x), X)
return map(x -> logpdf(d, x), X)
end
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:Matrix{<:Real}})
return mapvcat(x -> logpdf(d, x), X)
return map(x -> logpdf(d, x), X)
end

#### Sampling
Expand Down Expand Up @@ -233,10 +233,10 @@ function Distributions.logpdf(d::TuringInverseWishart, X::AbstractMatrix{<:Real}
-0.5 * ((df + p + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) - d.c0
end
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
return mapvcat(x -> logpdf(d, x), X)
return map(x -> logpdf(d, x), X)
end
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:Matrix{<:Real}})
return mapvcat(x -> logpdf(d, x), X)
return map(x -> logpdf(d, x), X)
end

#### Sampling
Expand Down
Loading

0 comments on commit 82d5507

Please sign in to comment.