|
1 | 1 |
|
| 2 | +AD_distributionsad = if VERSION >= v"1.10" |
| 3 | + Dict( |
| 4 | + :ForwarDiff => AutoForwardDiff(), |
| 5 | + #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment |
| 6 | + :Zygote => AutoZygote(), |
| 7 | + :Enzyme => AutoEnzyme(), |
| 8 | + ) |
| 9 | +else |
| 10 | + Dict( |
| 11 | + :ForwarDiff => AutoForwardDiff(), |
| 12 | + #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment |
| 13 | + :Zygote => AutoZygote(), |
| 14 | + ) |
| 15 | +end |
| 16 | + |
2 | 17 | @testset "inference RepGradELBO DistributionsAD" begin
|
3 | 18 | @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
|
4 | 19 | [Float64, Float32],
|
|
9 | 24 | :RepGradELBOStickingTheLanding =>
|
10 | 25 | RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy()),
|
11 | 26 | ),
|
12 |
| - (adbackname, adtype) in Dict( |
13 |
| - :ForwarDiff => AutoForwardDiff(), |
14 |
| - #:ReverseDiff => AutoReverseDiff(), |
15 |
| - :Zygote => AutoZygote(), |
16 |
| - #:Enzyme => AutoEnzyme(), |
17 |
| - ) |
| 27 | + (adbackname, adtype) in AD_distributionsad |
18 | 28 |
|
19 | 29 | seed = (0x38bef07cf9cc549d)
|
20 | 30 | rng = StableRNG(seed)
|
|
31 | 41 | # where ρ = 1 - ημ, μ is the strong convexity constant.
|
32 | 42 | contraction_rate = 1 - η * strong_convexity
|
33 | 43 |
|
34 |
| - μ0 = Zeros(realtype, n_dims) |
35 |
| - L0 = Diagonal(Ones(realtype, n_dims)) |
| 44 | + μ0 = zeros(realtype, n_dims) |
| 45 | + L0 = Diagonal(ones(realtype, n_dims)) |
36 | 46 | q0 = TuringDiagMvNormal(μ0, diag(L0))
|
37 | 47 |
|
38 | 48 | @testset "convergence" begin
|
|
0 commit comments