|
85 | 85 | DistSpec(Poisson, (0.5,), 1),
|
86 | 86 | DistSpec(Poisson, (0.5,), [1, 1]),
|
87 | 87 |
|
88 |
| - DistSpec(Skellam, (1.0, 2.0), -2; broken=(:Zygote,)), |
89 |
| - DistSpec(Skellam, (1.0, 2.0), [-2, -2]; broken=(:Zygote,)), |
| 88 | + DistSpec(Skellam, (1.0, 2.0), -2), |
| 89 | + DistSpec(Skellam, (1.0, 2.0), [-2, -2]), |
90 | 90 |
|
91 | 91 | DistSpec(PoissonBinomial, ([0.5, 0.5],), 0),
|
92 | 92 |
|
|
193 | 193 |
|
194 | 194 | DistSpec(NormalCanon, (1.0, 2.0), 0.5),
|
195 | 195 |
|
196 |
| - DistSpec(NormalInverseGaussian, (1.0, 2.0, 1.0, 1.0), 0.5; broken=(:Zygote,)), |
| 196 | + DistSpec(NormalInverseGaussian, (1.0, 2.0, 1.0, 1.0), 0.5), |
197 | 197 |
|
| 198 | + DistSpec(Pareto, (), 1.5), |
198 | 199 | DistSpec(Pareto, (1.0,), 1.5),
|
199 | 200 | DistSpec(Pareto, (1.0, 1.0), 1.5),
|
200 | 201 |
|
|
245 | 246 | DistSpec(VonMises, (1.0,), 1.0),
|
246 | 247 | DistSpec(VonMises, (1, 1), 1),
|
247 | 248 |
|
248 |
| - # Only some Zygote tests are broken and therefore this can not be checked |
249 |
| - DistSpec(Pareto, (), 1.5; broken=(:Zygote,)), |
250 |
| - |
251 | 249 | # Some tests are broken on some Julia versions, therefore it can't be checked reliably
|
252 |
| - DistSpec(PoissonBinomial, ([0.5, 0.5],), [0, 0]; broken=(:Zygote,)), |
| 250 | + DistSpec(PoissonBinomial, ([0.5, 0.5],), [0, 0]; broken=(:Zygote,)), |
253 | 251 | ]
|
254 | 252 |
|
255 | 253 | # Tests that have a `broken` field can be executed but, according to FiniteDifferences,
|
|
405 | 403 | B,
|
406 | 404 | to_posdef,
|
407 | 405 | ),
|
408 |
| - DistSpec((eta) -> LKJ(10, eta), (1.), A_big, to_corr) |
| 406 | + DistSpec(eta -> LKJ(10, eta), (1.,), A_big, to_corr) |
409 | 407 | # AD for parameters of LKJ requires more DistributionsAD supports
|
410 | 408 | ]
|
411 | 409 |
|
|
435 | 433 | # Skellam only fails in these tests with ReverseDiff
|
436 | 434 | # Ref: https://github.com/TuringLang/DistributionsAD.jl/issues/126
|
437 | 435 | # PoissonBinomial fails with Zygote
|
| 436 | + # Matrix case does not work with Skellam: |
| 437 | + # https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493 |
438 | 438 | filldist_broken = if d.f(d.θ...) isa Skellam
|
439 |
| - (d.broken..., :ReverseDiff) |
| 439 | + ((d.broken..., :ReverseDiff), (d.broken..., :Zygote, :ReverseDiff)) |
440 | 440 | elseif d.f(d.θ...) isa PoissonBinomial
|
441 |
| - (d.broken..., :Zygote) |
| 441 | + ((d.broken..., :Zygote), (d.broken..., :Zygote)) |
442 | 442 | else
|
443 |
| - d.broken |
| 443 | + (d.broken, d.broken) |
444 | 444 | end
|
445 |
| - arraydist_broken = if d.f(d.θ...) isa PoissonBinomial |
446 |
| - (d.broken..., :Zygote) |
| 445 | + arraydist_broken = if d.f(d.θ...) isa Skellam |
| 446 | + (d.broken, (d.broken..., :Zygote)) |
| 447 | + elseif d.f(d.θ...) isa PoissonBinomial |
| 448 | + ((d.broken..., :Zygote), (d.broken..., :Zygote)) |
447 | 449 | else
|
448 |
| - d.broken |
| 450 | + (d.broken, d.broken) |
449 | 451 | end
|
450 | 452 |
|
451 | 453 | # Create `filldist` distribution
|
|
456 | 458 | f_arraydist = (θ...,) -> arraydist([d.f(θ...) for _ in 1:n])
|
457 | 459 | d_arraydist = f_arraydist(d.θ...)
|
458 | 460 |
|
459 |
| - for sz in ((n,), (n, 2)) |
| 461 | + for (i, sz) in enumerate(((n,), (n, 2))) |
460 | 462 | # Matrix case doesn't work for continuous distributions for some reason
|
461 | 463 | # now but not too important (?!)
|
462 | 464 | if length(sz) == 2 && Distributions.value_support(typeof(d)) === Continuous
|
|
474 | 476 | d.θ,
|
475 | 477 | x,
|
476 | 478 | d.xtrans;
|
477 |
| - broken=filldist_broken, |
| 479 | + broken=filldist_broken[i], |
478 | 480 | )
|
479 | 481 | )
|
480 | 482 | test_ad(
|
|
484 | 486 | d.θ,
|
485 | 487 | x,
|
486 | 488 | d.xtrans;
|
487 |
| - broken=arraydist_broken, |
| 489 | + broken=arraydist_broken[i], |
488 | 490 | )
|
489 | 491 | )
|
490 | 492 | end
|
|
0 commit comments