diff --git a/Project.toml b/Project.toml index cce9b5a4..793933d4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.8" +version = "0.6.9" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/common.jl b/src/common.jl index d30ffa83..6a5ec5b1 100644 --- a/src/common.jl +++ b/src/common.jl @@ -48,13 +48,13 @@ end # Tracker's implementation of ldiv isn't good. We'll use Zygote's instead. zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B -function adapt_randn(rng, x::AbstractArray, dims...) +function adapt_randn(rng::AbstractRNG, x::AbstractArray, dims...) adapt(typeof(x), randn(rng, eltype(x), dims...)) end # TODO: should be replaced by @non_differentiable when # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/212 is fixed -function ChainRules.rrule(::typeof(adapt_randn), rng, x, dims...) +function ChainRules.rrule(::typeof(adapt_randn), rng::AbstractRNG, x::AbstractArray, dims...) function adapt_randn_pullback(ΔQ) return (NO_FIELDS, Zero(), Zero(), map(_ -> Zero(), dims)...) end diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 7b86ed6a..3efbbf71 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -1,3 +1,7 @@ +function adapt_randn(rng::AbstractRNG, x::AbstractArray{<:ForwardDiff.Dual}, dims...) + adapt(typeof(x), randn(rng, ForwardDiff.valtype(eltype(x)), dims...)) +end + ## Binomial ## function binomlogpdf(n::Int, p::ForwardDiff.Dual{T}, x::Int) where {T} diff --git a/src/reversediff.jl b/src/reversediff.jl index e65b591c..6ced110b 100644 --- a/src/reversediff.jl +++ b/src/reversediff.jl @@ -18,13 +18,15 @@ using ..DistributionsAD: DistributionsAD import SpecialFunctions, NaNMath -import ..DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf +import ..DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf, adapt_randn import Base.Broadcast: materialize import StatsFuns: logsumexp const TrackedVecOrMat{V,D} = Union{TrackedVector{V,D},TrackedMatrix{V,D}} const RDBroadcasted{F, T} = Broadcasted{<:Any, <:Any, F, T} +import Random + import Distributions: logpdf, _logpdf, loglikelihood, @@ -49,6 +51,8 @@ using ..DistributionsAD: TuringPoissonBinomial, include("reversediffx.jl") +adapt_randn(rng::Random.AbstractRNG, x::TrackedArray, dims...) = adapt_randn(rng, value(x), dims...) + function PoissonBinomial(p::TrackedArray{<:Real}; check_args=true) return TuringPoissonBinomial(p; check_args = check_args) end diff --git a/src/tracker.jl b/src/tracker.jl index b338d4c3..93f7fcc4 100644 --- a/src/tracker.jl +++ b/src/tracker.jl @@ -202,6 +202,9 @@ for f in (:+, :-, :*, :/, :\, :dot), (T1, T2) in [ end end +## `adapt_randn` + +adapt_randn(rng::AbstractRNG, x::TrackedArray, dims...) = adapt_randn(rng, data(x), dims...) ## Uniform ## diff --git a/test/others.jl b/test/others.jl index c7ade069..07f01fb0 100644 --- a/test/others.jl +++ b/test/others.jl @@ -261,4 +261,38 @@ d = TuringScalMvNormal(m, sigmas[1]) @test params(d) == (m, sigmas[1]) end + + @testset "adapt_randn" begin + rng = MersenneTwister() + + xs = Any[(rng, T, n) -> rand(rng, T, n)] + if AD == "All" || AD == "ForwardDiff" + push!(xs, (rng, T, n) -> [ForwardDiff.Dual(rand(rng, T)) for _ in 1:n]) + end + if AD == "All" || AD == "Tracker" + push!(xs, (rng, T, n) -> Tracker.TrackedArray(rand(rng, T, n))) + end + if AD == "All" || AD == "ReverseDiff" + push!(xs, (rng, T, n) -> begin + v = rand(rng, T, n) + d = rand(Int, n) + tp = ReverseDiff.InstructionTape() + ReverseDiff.TrackedArray(v, d, tp) + end) + end + + for T in (Float32, Float64) + for f in xs + x = f(rng, T, 50) + + Random.seed!(rng, 100) + y = DistributionsAD.adapt_randn(rng, x, 10, 30) + @test y isa Matrix{T} + @test size(y) == (10, 30) + + Random.seed!(rng, 100) + @test y == randn(rng, T, 10, 30) + end + end + end end