Skip to content

Commit

Permalink
Merge pull request #37 from TuringLang/mt/perf_fixes
Browse files Browse the repository at this point in the history
Minor performance and bug fixes (lessons learnt from TuringExamples)
  • Loading branch information
mohamed82008 authored Feb 27, 2020
2 parents 0f29efb + a0d96e0 commit 2f9d942
Show file tree
Hide file tree
Showing 13 changed files with 370 additions and 243 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/ForwardDiff_Tracker.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: ForwardDiff and Tracker tests

on:
push:
branches:
- master
pull_request:
types: [opened, synchronize, reopened]

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: [1.0.5, 1.2.0, 1.3]
julia-arch: [x64, x86]
os: [ubuntu-latest, macOS-latest]
exclude:
- os: macOS-latest
julia-arch: x86

steps:
- uses: actions/[email protected]
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.julia-version }}
- uses: julia-actions/julia-runtest@master
env:
STAGE: ForwardDiff_Tracker
4 changes: 3 additions & 1 deletion .github/workflows/CI.yml → .github/workflows/Others.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: CI
name: Other tests

on:
push:
Expand All @@ -25,3 +25,5 @@ jobs:
with:
version: ${{ matrix.julia-version }}
- uses: julia-actions/julia-runtest@master
env:
STAGE: Others
29 changes: 29 additions & 0 deletions .github/workflows/Zygote.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: Zygote tests

on:
push:
branches:
- master
pull_request:
types: [opened, synchronize, reopened]

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: [1.0.5, 1.2.0, 1.3]
julia-arch: [x64, x86]
os: [ubuntu-latest, macOS-latest]
exclude:
- os: macOS-latest
julia-arch: x86

steps:
- uses: actions/[email protected]
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.julia-version }}
- uses: julia-actions/julia-runtest@master
env:
STAGE: Zygote
42 changes: 22 additions & 20 deletions src/arraydist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@

const VectorOfUnivariate = Distributions.Product

function arraydist(dists::AbstractVector{<:Normal{T}}) where {T}
means = mean.(dists)
vars = var.(dists)
return MvNormal(means, vars)
end
function arraydist(dists::AbstractVector{<:Normal{<:TrackedReal}})
means = vcatmapreduce(mean, dists)
vars = vcatmapreduce(var, dists)
return MvNormal(means, vars)
end
function arraydist(dists::AbstractVector{<:UnivariateDistribution})
return product_distribution(dists)
end
function arraydist(dists::AbstractVector{<:Normal})
m = mapvcat(mean, dists)
v = mapvcat(var, dists)
return MvNormal(m, v)
end

function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
return sum(vcatmapreduce(logpdf, dist.v, x))
return sum(map((d, x) -> logpdf(d, x), dist.v, x))
end
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
# eachcol breaks Zygote, so we need an adjoint
return vcatmapreduce((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x))
return mapvcat(dist.v, eachcol(x)) do dist, c
sum(map(c) do x
logpdf(dist, x)
end)
end
end
@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
# Any other more efficient implementation breaks Zygote
Expand All @@ -41,14 +41,16 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
# Broadcasting here breaks Tracker for some reason
# A Zygote adjoint is defined for vcatmapreduce to use broadcasting
return sum(vcatmapreduce(logpdf, dist.dists, x))
# A Zygote adjoint is defined for mapvcat to use broadcasting
return sum(map(dist.dists, x) do dist, x
logpdf(dist, x)
end)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return vcatmapreduce(x -> logpdf(dist, x), x)
return mapvcat(x -> logpdf(dist, x), x)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}})
return vcatmapreduce(x -> logpdf(dist, x), x)
return mapvcat(x -> logpdf(dist, x), x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
return rand.(Ref(rng), dist.dists)
Expand All @@ -70,16 +72,16 @@ function arraydist(dists::AbstractVector{<:MultivariateDistribution})
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
# eachcol breaks Zygote, so we define an adjoint
return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x)))
return sum(logpdf.(dist.dists, eachcol(x)))
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(dist, x), x), size(x))
return mapvcat(x -> logpdf(dist, x), x)
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(dist, x), x), size(x))
return mapvcat(x -> logpdf(dist, x), x)
end
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
f(dist, x) = sum(mapvcat(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
return pullback(f, dist, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
Expand Down
19 changes: 12 additions & 7 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
## Generic ##

function vcatmapreduce(f, args...)
init = vcat(f(first.(args)...,))
zipped_args = zip(args...,)
return mapreduce(vcat, drop(zipped_args, 1); init = init) do zarg
f(zarg...,)
_istracked(x) = false
_istracked(x::TrackedArray) = false
_istracked(x::AbstractArray{<:TrackedReal}) = true
function mapvcat(f, args...)
out = map(f, args...)
if _istracked(out)
init = vcat(out[1])
return reshape(reduce(vcat, drop(out, 1); init = init), size(out))
else
return out
end
end
@adjoint function vcatmapreduce(f, args...)
g(f, args...) = f.(args...)
@adjoint function mapvcat(f, args...)
g(f, args...) = map(f, args...)
return pullback(g, f, args...)
end

Expand Down
20 changes: 7 additions & 13 deletions src/filldist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,20 @@ end
function _flat_logpdf(dist, x)
if toflatten(dist)
f, args = flatten(dist)
if any(Tracker.istracked, args)
return sum(f.(args..., x))
else
return sum(logpdf.(dist, x))
end
return sum(f.(args..., x))
else
return sum(vcatmapreduce(x -> logpdf(dist, x), x))
return sum(mapvcat(x) do x
logpdf(dist, x)
end)
end
end
function _flat_logpdf_mat(dist, x)
if toflatten(dist)
f, args = flatten(dist)
if any(Tracker.istracked, args)
return vec(sum(f.(args..., x), dims = 1))
else
return vec(sum(logpdf.(dist, x), dims = 1))
end
return vec(sum(f.(args..., x), dims = 1))
else
temp = vcatmapreduce(x -> logpdf(dist, x), x)
return vec(sum(reshape(temp, size(x)), dims = 1))
temp = mapvcat(x -> logpdf(dist, x), x)
return vec(sum(temp, dims = 1))
end
end

Expand Down
3 changes: 2 additions & 1 deletion src/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ const flattened_dists = [ Bernoulli,
FDist,
Frechet,
Gamma,
#GeneralizedExtremeValue,
GeneralizedExtremeValue,
GeneralizedPareto,
Gumbel,
#InverseGamma,
Expand All @@ -63,6 +63,7 @@ const flattened_dists = [ Bernoulli,
TDist,
TriangularDist,
Triweight,
TuringUniform,
#Truncated,
#VonMises,
]
Expand Down
10 changes: 5 additions & 5 deletions src/matrixvariate.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## MatrixBeta

function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:TrackedMatrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
return mapvcat(x -> logpdf(d, x), X)
end
@adjoint function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:Matrix{<:Real}})
f(d, X) = map(x -> logpdf(d, x), X)
Expand Down Expand Up @@ -112,10 +112,10 @@ function Distributions.logpdf(d::TuringWishart, X::AbstractMatrix{<:Real})
return 0.5 * ((df - (p + 1)) * logdet(Xcf) - tr(d.chol \ X)) - d.c0
end
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
return mapvcat(x -> logpdf(d, x), X)
end
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:Matrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
return mapvcat(x -> logpdf(d, x), X)
end

#### Sampling
Expand Down Expand Up @@ -233,10 +233,10 @@ function Distributions.logpdf(d::TuringInverseWishart, X::AbstractMatrix{<:Real}
-0.5 * ((df + p + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) - d.c0
end
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
return mapvcat(x -> logpdf(d, x), X)
end
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:Matrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
return mapvcat(x -> logpdf(d, x), X)
end

#### Sampling
Expand Down
20 changes: 20 additions & 0 deletions src/univariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ function TuringUniform(a::Real, b::Real)
return TuringUniform{T}(T(a), T(b))
end
Distributions.logpdf(d::TuringUniform, x::Real) = uniformlogpdf(d.a, d.b, x)
Base.minimum(d::TuringUniform) = d.a
Base.maximum(d::TuringUniform) = d.b

Distributions.Uniform(a::TrackedReal, b::Real) = TuringUniform{TrackedReal}(a, b)
Distributions.Uniform(a::Real, b::TrackedReal) = TuringUniform{TrackedReal}(a, b)
Expand Down Expand Up @@ -348,3 +350,21 @@ function Base.convert(
DiscreteNonParametric{T,P,Ts,Ps}(support(d), probs(d), check_args=false)
end

# Fix SubArray support
function Distributions.DiscreteNonParametric{T,P,Ts,Ps}(
vs::Ts,
ps::Ps;
check_args=true,
) where {T<:Real, P<:Real, Ts<:AbstractVector{T}, Ps<:SubArray{P, 1}}
cps = ps[:]
return DiscreteNonParametric{T,P,Ts,typeof(cps)}(vs, cps; check_args = check_args)
end

function Distributions.DiscreteNonParametric{T,P,Ts,Ps}(
vs::Ts,
ps::Ps;
check_args=true,
) where {T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:TrackedArray{P, 1, <:SubArray{P, 1}}}
cps = ps[:]
return DiscreteNonParametric{T,P,Ts,typeof(cps)}(vs, cps; check_args = check_args)
end
Loading

0 comments on commit 2f9d942

Please sign in to comment.