Skip to content

Commit

Permalink
Implement scalar rules for Zygote with ChainRules (#103)
Browse files Browse the repository at this point in the history
* Implement scalar rules for Zygote with ChainRules

* Fix adjoints of `uniformlogpdf`

* Fix typo

* Update src/chainrules.jl

Co-authored-by: Seth Axen <[email protected]>

* Remove one more `@thunk`

* Add tests

* Add missing import

* Introduce more randomness in tests

* Use `@scalar_rule` for `uniformlogpdf`

* Set seed

Co-authored-by: Seth Axen <[email protected]>
  • Loading branch information
devmotion and sethaxen authored Aug 23, 2020
1 parent f914b48 commit 01ad761
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 115 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.6.4"

[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -22,6 +23,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ChainRules = "0.7"
ChainRulesCore = "0.9.5"
Compat = "3.6"
DiffRules = "0.1, 1.0"
Distributions = "0.23.3"
Expand All @@ -39,6 +41,7 @@ ZygoteRules = "0.2"
julia = "1.3"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand All @@ -47,4 +50,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Combinatorics", "FiniteDifferences", "Test", "ReverseDiff", "Zygote", "Tracker"]
test = ["ChainRulesTestUtils", "Combinatorics", "FiniteDifferences", "Test", "ReverseDiff", "Zygote", "Tracker"]
2 changes: 2 additions & 0 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using PDMats,
Requires,
ZygoteRules,
ChainRules, # needed for `ChainRules.chol_blocked_rev`
ChainRulesCore,
FillArrays

using SpecialFunctions: logabsgamma, digamma
Expand Down Expand Up @@ -53,6 +54,7 @@ include("flatten.jl")
include("arraydist.jl")
include("filldist.jl")

include("chainrules.jl")
include("zygote.jl")

@init begin
Expand Down
85 changes: 85 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
## Uniform ##

@scalar_rule(
uniformlogpdf(a, b, x),
@setup(
insupport = a <= x <= b,
diff = b - a,
c = insupport ? inv(diff) : inv(one(diff)),
z = insupport ? zero(x) : oftype(x, NaN),
),
(c, -c, z),
)

## Beta ##

@scalar_rule(
betalogpdf::Real, β::Real, x::Number),
@setup(di = digamma+ β)),
(
@thunk(log(x) - digamma(α) + di),
@thunk(log(1 - x) - digamma(β) + di),
@thunk((α - 1)/x + (1 - β)/(1 - x)),
),
)

## Gamma ##

@scalar_rule(
gammalogpdf(k::Real, θ::Real, x::Number),
(
@thunk(-digamma(k) - log(θ) + log(x)),
@thunk(-k/θ + x/θ^2),
@thunk((k - 1)/x - 1/θ),
),
)

## Chisq ##

@scalar_rule(
chisqlogpdf(k::Real, x::Number),
@setup(ko2 = k / 2),
(@thunk((-logtwo - digamma(ko2) + log(x)) / 2), @thunk((ko2 - 1)/x - one(ko2) / 2)),
)

## FDist ##

@scalar_rule(
fdistlogpdf(v1::Real, v2::Real, x::Number),
@setup(
temp1 = v1 * x + v2,
temp2 = log(temp1),
vsum = v1 + v2,
temp3 = vsum / temp1,
temp4 = digamma(vsum / 2),
),
(
@thunk((log(v1 * x) + 1 - temp2 - x * temp3 - digamma(v1 / 2) + temp4) / 2),
@thunk((log(v2) + 1 - temp2 - temp3 - digamma(v2 / 2) + temp4) / 2),
@thunk(v1 / 2 * (1 / x - temp3) - 1 / x),
),
)

## TDist ##

@scalar_rule(
tdistlogpdf(v::Real, x::Number),
(
@thunk((digamma((v + 1) / 2) - 1 / v - digamma(v / 2) - log(1 + x^2 / v) + x^2 * (v + 1) / v^2 / (1 + x^2 / v)) / 2),
@thunk(-x * (v + 1) / (v + x^2)),
)
)

## Binomial ##

@scalar_rule(
binomlogpdf(n::Int, p::Real, x::Int),
(DoesNotExist(), x / p - (n - x) / (1 - p), DoesNotExist()),
)

## Poisson ##

@scalar_rule(
poislogpdf(v::Real, x::Int),
(x / v - 1, DoesNotExist()),
)
24 changes: 16 additions & 8 deletions src/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,24 @@ uniformlogpdf(a::Real, b::Real, x::TrackedReal) = track(uniformlogpdf, a, b, x)
uniformlogpdf(a::TrackedReal, b::TrackedReal, x::Real) = track(uniformlogpdf, a, b, x)
uniformlogpdf(a::TrackedReal, b::TrackedReal, x::TrackedReal) = track(uniformlogpdf, a, b, x)
@grad function uniformlogpdf(a, b, x)
# compute log pdf
diff = data(b) - data(a)
T = typeof(diff)
if a <= data(x) <= b && a < b
l = -log(diff)
da = 1/diff^2
return l, Δ -> (da * Δ, -da * Δ, zero(T) * Δ)
else
n = T(NaN)
return n, Δ -> (n, n, n)
insupport = a <= data(x) <= b
lp = insupport ? -log(diff) : log(zero(diff))

function pullback(Δ)
z = zero(x) * Δ
if insupport
c = Δ / diff
return c, -c, z
else
c = Δ / one(diff)
cNaN = oftype(c, NaN)
return cNaN, cNaN, oftype(z, NaN)
end
end

return lp, pullback
end


Expand Down
7 changes: 3 additions & 4 deletions src/univariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ Base.minimum(d::TuringUniform) = d.a
Base.maximum(d::TuringUniform) = d.b

function uniformlogpdf(a, b, x)
c = -log(b - a)
diff = b - a
if a <= x <= b
return c
return -log(diff)
else
return oftype(c, -Inf)
return log(zero(diff))
end
end


if VERSION < v"1.2"
Base.inv(::Irrational{:π}) = 1/π
end
Expand Down
98 changes: 0 additions & 98 deletions src/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,108 +8,10 @@ end

## Uniform ##

ZygoteRules.@adjoint function uniformlogpdf(a, b, x)
diff = b - a
T = typeof(diff)
if a <= x <= b && a < b
l = -log(diff)
da = 1/diff^2
return l, Δ -> (da * Δ, -da * Δ, zero(T) * Δ)
else
n = T(NaN)
return n, Δ -> (n, n, n)
end
end

ZygoteRules.@adjoint function Distributions.Uniform(args...)
return ZygoteRules.pullback(TuringUniform, args...)
end


## Beta ##

function _betalogpdfgrad(α, β, x)
di = digamma+ β)
= log(x) - digamma(α) + di
= log(1 - x) - digamma(β) + di
dx =- 1)/x + (1 - β)/(1 - x)
return (dα, dβ, dx)
end
ZygoteRules.@adjoint function betalogpdf::Real, β::Real, x::Number)
return betalogpdf(α, β, x), Δ ->.* _betalogpdfgrad(α, β, x))
end


## Gamma ##

function _gammalogpdfgrad(k, θ, x)
dk = -digamma(k) - log(θ) + log(x)
= -k/θ + x/θ^2
dx = (k - 1)/x - 1/θ
return (dk, dθ, dx)
end
ZygoteRules.@adjoint function gammalogpdf(k::Real, θ::Real, x::Number)
return gammalogpdf(k, θ, x), Δ ->.* _gammalogpdfgrad(k, θ, x))
end


## Chisq ##

function _chisqlogpdfgrad(k, x)
hk = k/2
d = digamma(hk)
dk = (-log(oftype(hk, 2)) - d + log(x))/2
dx = (hk - 1)/x - one(hk)/2
return (dk, dx)
end
ZygoteRules.@adjoint function chisqlogpdf(k::Real, x::Number)
return chisqlogpdf(k, x), Δ ->.* _chisqlogpdfgrad(k, x))
end

## FDist ##

function _fdistlogpdfgrad(v1, v2, x)
temp1 = v1 * x + v2
temp2 = log(temp1)
vsum = v1 + v2
temp3 = vsum / temp1
temp4 = digamma(vsum / 2)
dv1 = (log(v1 * x) + 1 - temp2 - x * temp3 - digamma(v1 / 2) + temp4) / 2
dv2 = (log(v2) + 1 - temp2 - temp3 - digamma(v2 / 2) + temp4) / 2
dx = v1 / 2 * (1 / x - temp3) - 1 / x
return (dv1, dv2, dx)
end
ZygoteRules.@adjoint function fdistlogpdf(v1::Real, v2::Real, x::Number)
return fdistlogpdf(v1, v2, x), Δ ->.* _fdistlogpdfgrad(v1, v2, x))
end

## TDist ##

function _tdistlogpdfgrad(v, x)
dv = (digamma((v + 1) / 2) - 1 / v - digamma(v / 2) - log(1 + x^2 / v) + x^2 * (v + 1) / v^2 / (1 + x^2 / v)) / 2
dx = -x * (v + 1) / (v + x^2)
return (dv, dx)
end
ZygoteRules.@adjoint function tdistlogpdf(v::Real, x::Number)
return tdistlogpdf(v, x), Δ ->.* _tdistlogpdfgrad(v, x))
end


## Binomial ##

ZygoteRules.@adjoint function binomlogpdf(n::Int, p::Real, x::Int)
return binomlogpdf(n, p, x),
Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing)
end

## Poisson ##

ZygoteRules.@adjoint function poislogpdf(v::Real, x::Int)
return poislogpdf(v, x),
Δ->* (x/v - 1), nothing)
end


## PoissonBinomial ##

# Zygote loads ForwardDiff, so this dummy adjoint should never be needed.
Expand Down
51 changes: 51 additions & 0 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@

@testset "chainrules" begin
x, Δx, x̄ = randn(3)
y, Δy, ȳ = randn(3)
z, Δz, z̄ = randn(3)
Δu = randn()

= x + exp(y) + exp(z)
= x + exp(y)
frule_test(DistributionsAD.uniformlogpdf, (x, Δx), (ỹ, Δy), (z̃, Δz))
rrule_test(DistributionsAD.uniformlogpdf, Δu, (x, x̄), (ỹ, ȳ), (z̃, z̄))

= exp(x)
= exp(y)
= logistic(z)
frule_test(DistributionsAD.betalogpdf, (x̃, Δx), (ỹ, Δy), (z̃, Δz))
rrule_test(DistributionsAD.betalogpdf, Δu, (x̃, x̄), (ỹ, ȳ), (z̃, z̄))

= exp(x)
= exp(y)
= exp(z)
frule_test(DistributionsAD.gammalogpdf, (x̃, Δx), (ỹ, Δy), (z̃, Δz))
rrule_test(DistributionsAD.gammalogpdf, Δu, (x̃, x̄), (ỹ, ȳ), (z̃, z̄))

= exp(x)
= exp(y)
= exp(z)
frule_test(DistributionsAD.chisqlogpdf, (x̃, Δx), (ỹ, Δy))
rrule_test(DistributionsAD.chisqlogpdf, Δu, (x̃, x̄), (ỹ, ȳ))

= exp(x)
= exp(y)
= exp(z)
frule_test(DistributionsAD.fdistlogpdf, (x̃, Δx), (ỹ, Δy), (z̃, Δz))
rrule_test(DistributionsAD.fdistlogpdf, Δu, (x̃, x̄), (ỹ, ȳ), (z̃, z̄))

= exp(x)
frule_test(DistributionsAD.tdistlogpdf, (x̃, Δx), (y, Δy))
rrule_test(DistributionsAD.tdistlogpdf, Δu, (x̃, x̄), (y, ȳ))

= rand(1:100)
= logistic(y)
= rand(1:x̃)
frule_test(DistributionsAD.binomlogpdf, (x̃, nothing), (ỹ, Δy), (z̃, nothing))
rrule_test(DistributionsAD.binomlogpdf, Δu, (x̃, nothing), (ỹ, ȳ), (z̃, nothing))

= exp(x)
= rand(1:100)
frule_test(DistributionsAD.poislogpdf, (x̃, Δx), (ỹ, nothing))
rrule_test(DistributionsAD.poislogpdf, Δu, (x̃, x̄), (ỹ, nothing))
end
8 changes: 5 additions & 3 deletions test/ad/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
B = rand(dim, dim)
C = rand(dim, dim)

# Create a random number
# Create random numbers
alpha = rand()
beta = rand()
gamma = rand()

# Create matrix `X` such that `X` and `I - X` are positive definite if `A ≠ 0`.
function to_beta_mat(A)
Expand Down Expand Up @@ -198,10 +200,10 @@
),

DistSpec(Uniform, (), 0.5),
DistSpec(Uniform, (0.0, 1.0), 0.5),
DistSpec(Uniform, (alpha, alpha + beta), alpha + beta * gamma),

DistSpec(TuringUniform, (), 0.5),
DistSpec(TuringUniform, (0.0, 1.0), 0.5),
DistSpec(TuringUniform, (alpha, alpha + beta), alpha + beta * gamma),

DistSpec(VonMises, (), 1.0),

Expand Down
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using DistributionsAD

using ChainRulesTestUtils
using Combinatorics
using Distributions
using FiniteDifferences
Expand Down Expand Up @@ -27,7 +28,9 @@ using Distributions: meanlogdet
using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal,
TuringPoissonBinomial
using StatsBase: entropy
using StatsFuns: binomlogpdf, logsumexp
using StatsFuns: binomlogpdf, logsumexp, logistic

Random.seed!(1) # Set seed that all testsets should reset to.

const FDM = FiniteDifferences
const GROUP = get(ENV, "GROUP", "All")
Expand Down Expand Up @@ -55,5 +58,6 @@ end

if GROUP == "All" || GROUP == "AD"
include("ad/utils.jl")
include("ad/chainrules.jl")
include("ad/distributions.jl")
end

0 comments on commit 01ad761

Please sign in to comment.