Skip to content

Commit 60dc43c

Browse files
authored
Update to ChainRulesCore 0.10 (#182)
* Update to ChainRulesCore 0.10 * Apply test fixes from DistributionsAD * Bump version * Update Project.toml
1 parent d56748d commit 60dc43c

File tree

3 files changed

+23
-21
lines changed

3 files changed

+23
-21
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.9.2"
3+
version = "0.9.3"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
@@ -21,7 +21,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
2121

2222
[compat]
2323
ArgCheck = "1, 2"
24-
ChainRulesCore = "0.9"
24+
ChainRulesCore = "0.9, 0.10"
2525
Compat = "3"
2626
Distributions = "0.23.3, 0.24, 0.25"
2727
Functors = "0.1, 0.2"

test/Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1414
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1515

1616
[compat]
17-
ChainRulesTestUtils = "0.6.3"
17+
ChainRulesTestUtils = "0.6.3, 0.7"
1818
Combinatorics = "1.0.2"
1919
DistributionsAD = "0.6.3"
2020
FiniteDifferences = "0.11, 0.12"
2121
ForwardDiff = "0.10.12"
2222
Functors = "0.1, 0.2"
23-
NNlib = "0.7"
23+
NNlib = "0.7.18"
2424
ReverseDiff = "1.4.2"
2525
Tracker = "0.2.11"
2626
Zygote = "0.5.4, 0.6"

test/ad/distributions.jl

+19-17
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@
8585
DistSpec(Poisson, (0.5,), 1),
8686
DistSpec(Poisson, (0.5,), [1, 1]),
8787

88-
DistSpec(Skellam, (1.0, 2.0), -2; broken=(:Zygote,)),
89-
DistSpec(Skellam, (1.0, 2.0), [-2, -2]; broken=(:Zygote,)),
88+
DistSpec(Skellam, (1.0, 2.0), -2),
89+
DistSpec(Skellam, (1.0, 2.0), [-2, -2]),
9090

9191
DistSpec(PoissonBinomial, ([0.5, 0.5],), 0),
9292

@@ -193,8 +193,9 @@
193193

194194
DistSpec(NormalCanon, (1.0, 2.0), 0.5),
195195

196-
DistSpec(NormalInverseGaussian, (1.0, 2.0, 1.0, 1.0), 0.5; broken=(:Zygote,)),
196+
DistSpec(NormalInverseGaussian, (1.0, 2.0, 1.0, 1.0), 0.5),
197197

198+
DistSpec(Pareto, (), 1.5),
198199
DistSpec(Pareto, (1.0,), 1.5),
199200
DistSpec(Pareto, (1.0, 1.0), 1.5),
200201

@@ -245,11 +246,8 @@
245246
DistSpec(VonMises, (1.0,), 1.0),
246247
DistSpec(VonMises, (1, 1), 1),
247248

248-
# Only some Zygote tests are broken and therefore this can not be checked
249-
DistSpec(Pareto, (), 1.5; broken=(:Zygote,)),
250-
251249
# Some tests are broken on some Julia versions, therefore it can't be checked reliably
252-
DistSpec(PoissonBinomial, ([0.5, 0.5],), [0, 0]; broken=(:Zygote,)),
250+
DistSpec(PoissonBinomial, ([0.5, 0.5],), [0, 0]; broken=(:Zygote,)),
253251
]
254252

255253
# Tests that have a `broken` field can be executed but, according to FiniteDifferences,
@@ -405,7 +403,7 @@
405403
B,
406404
to_posdef,
407405
),
408-
DistSpec((eta) -> LKJ(10, eta), (1.), A_big, to_corr)
406+
DistSpec(eta -> LKJ(10, eta), (1.,), A_big, to_corr)
409407
# AD for parameters of LKJ requires more DistributionsAD supports
410408
]
411409

@@ -435,17 +433,21 @@
435433
# Skellam only fails in these tests with ReverseDiff
436434
# Ref: https://github.com/TuringLang/DistributionsAD.jl/issues/126
437435
# PoissonBinomial fails with Zygote
436+
# Matrix case does not work with Skellam:
437+
# https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493
438438
filldist_broken = if d.f(d.θ...) isa Skellam
439-
(d.broken..., :ReverseDiff)
439+
((d.broken..., :ReverseDiff), (d.broken..., :Zygote, :ReverseDiff))
440440
elseif d.f(d.θ...) isa PoissonBinomial
441-
(d.broken..., :Zygote)
441+
((d.broken..., :Zygote), (d.broken..., :Zygote))
442442
else
443-
d.broken
443+
(d.broken, d.broken)
444444
end
445-
arraydist_broken = if d.f(d.θ...) isa PoissonBinomial
446-
(d.broken..., :Zygote)
445+
arraydist_broken = if d.f(d.θ...) isa Skellam
446+
(d.broken, (d.broken..., :Zygote))
447+
elseif d.f(d.θ...) isa PoissonBinomial
448+
((d.broken..., :Zygote), (d.broken..., :Zygote))
447449
else
448-
d.broken
450+
(d.broken, d.broken)
449451
end
450452

451453
# Create `filldist` distribution
@@ -456,7 +458,7 @@
456458
f_arraydist =...,) -> arraydist([d.f...) for _ in 1:n])
457459
d_arraydist = f_arraydist(d.θ...)
458460

459-
for sz in ((n,), (n, 2))
461+
for (i, sz) in enumerate(((n,), (n, 2)))
460462
# Matrix case doesn't work for continuous distributions for some reason
461463
# now but not too important (?!)
462464
if length(sz) == 2 && Distributions.value_support(typeof(d)) === Continuous
@@ -474,7 +476,7 @@
474476
d.θ,
475477
x,
476478
d.xtrans;
477-
broken=filldist_broken,
479+
broken=filldist_broken[i],
478480
)
479481
)
480482
test_ad(
@@ -484,7 +486,7 @@
484486
d.θ,
485487
x,
486488
d.xtrans;
487-
broken=arraydist_broken,
489+
broken=arraydist_broken[i],
488490
)
489491
)
490492
end

0 commit comments

Comments
 (0)