Skip to content

Commit

Permalink
Fix adapt_randn for tracked arrays (#115)
Browse files Browse the repository at this point in the history
* Fix `adapt_randn` for tracked arrays

* Bump version
  • Loading branch information
devmotion authored Sep 14, 2020
1 parent a892d75 commit 09ea3e7
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DistributionsAD"
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.8"
version = "0.6.9"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
4 changes: 2 additions & 2 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ end
# Tracker's implementation of ldiv isn't good. We'll use Zygote's instead.
zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B

function adapt_randn(rng, x::AbstractArray, dims...)
function adapt_randn(rng::AbstractRNG, x::AbstractArray, dims...)
adapt(typeof(x), randn(rng, eltype(x), dims...))
end

# TODO: should be replaced by @non_differentiable when
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/212 is fixed
function ChainRules.rrule(::typeof(adapt_randn), rng, x, dims...)
function ChainRules.rrule(::typeof(adapt_randn), rng::AbstractRNG, x::AbstractArray, dims...)
function adapt_randn_pullback(ΔQ)
return (NO_FIELDS, Zero(), Zero(), map(_ -> Zero(), dims)...)
end
Expand Down
4 changes: 4 additions & 0 deletions src/forwarddiff.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
function adapt_randn(rng::AbstractRNG, x::AbstractArray{<:ForwardDiff.Dual}, dims...)
adapt(typeof(x), randn(rng, ForwardDiff.valtype(eltype(x)), dims...))
end

## Binomial ##

function binomlogpdf(n::Int, p::ForwardDiff.Dual{T}, x::Int) where {T}
Expand Down
6 changes: 5 additions & 1 deletion src/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ using ..DistributionsAD: DistributionsAD


import SpecialFunctions, NaNMath
import ..DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf
import ..DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf, adapt_randn
import Base.Broadcast: materialize
import StatsFuns: logsumexp

const TrackedVecOrMat{V,D} = Union{TrackedVector{V,D},TrackedMatrix{V,D}}
const RDBroadcasted{F, T} = Broadcasted{<:Any, <:Any, F, T}

import Random

import Distributions: logpdf,
_logpdf,
loglikelihood,
Expand All @@ -49,6 +51,8 @@ using ..DistributionsAD: TuringPoissonBinomial,

include("reversediffx.jl")

adapt_randn(rng::Random.AbstractRNG, x::TrackedArray, dims...) = adapt_randn(rng, value(x), dims...)

function PoissonBinomial(p::TrackedArray{<:Real}; check_args=true)
return TuringPoissonBinomial(p; check_args = check_args)
end
Expand Down
3 changes: 3 additions & 0 deletions src/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ for f in (:+, :-, :*, :/, :\, :dot), (T1, T2) in [
end
end

## `adapt_randn`

adapt_randn(rng::AbstractRNG, x::TrackedArray, dims...) = adapt_randn(rng, data(x), dims...)

## Uniform ##

Expand Down
34 changes: 34 additions & 0 deletions test/others.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,38 @@
d = TuringScalMvNormal(m, sigmas[1])
@test params(d) == (m, sigmas[1])
end

@testset "adapt_randn" begin
rng = MersenneTwister()

xs = Any[(rng, T, n) -> rand(rng, T, n)]
if AD == "All" || AD == "ForwardDiff"
push!(xs, (rng, T, n) -> [ForwardDiff.Dual(rand(rng, T)) for _ in 1:n])
end
if AD == "All" || AD == "Tracker"
push!(xs, (rng, T, n) -> Tracker.TrackedArray(rand(rng, T, n)))
end
if AD == "All" || AD == "ReverseDiff"
push!(xs, (rng, T, n) -> begin
v = rand(rng, T, n)
d = rand(Int, n)
tp = ReverseDiff.InstructionTape()
ReverseDiff.TrackedArray(v, d, tp)
end)
end

for T in (Float32, Float64)
for f in xs
x = f(rng, T, 50)

Random.seed!(rng, 100)
y = DistributionsAD.adapt_randn(rng, x, 10, 30)
@test y isa Matrix{T}
@test size(y) == (10, 30)

Random.seed!(rng, 100)
@test y == randn(rng, T, 10, 30)
end
end
end
end

2 comments on commit 09ea3e7

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/21350

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.9 -m "<description of version>" 09ea3e78b64cc77d5ed9f8b90f0048c40915036c
git push origin v0.6.9

Please sign in to comment.