Skip to content

Commit

Permalink
Remove adjoint for fill and fix tests (#203)
Browse files Browse the repository at this point in the history
* Remove adjoint for `fill` and fix Zygote tests

* Bump version

* Fix some more problems

* Extension of #203: Fix deprecations in test (#204)

* Fix deprecations

* Improve CI (AD): cancel builds and no coverage

* Improve CI (Others): cancel builds and no coverage

* Change parameters to avoid issues with `xlogy`

* Tracker does not like Diagonal(Fill(...))

* Unify CI

* Fix tests

* Update test structure and separate AD better

* Fix tests

* Relax type constraint

* Simplify Zygote tests and use CR

* Improve test design

* Fix typo

* Fix typo

* Replace `unpack` with `_to_vec`

* Fix tests (a bit)

* Fix another test problem

* Fix `_to_vec`

* Fix handling of broken Zygote tests

* Workarounds for `rand_tangent`

* Improvements and fixes for Julia 1.3

* Remove Zygote test hack
  • Loading branch information
devmotion authored Nov 8, 2021
1 parent af2ea73 commit 47214f8
Show file tree
Hide file tree
Showing 12 changed files with 617 additions and 483 deletions.
39 changes: 30 additions & 9 deletions .github/workflows/AD.yml → .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
name: AD tests
name: CI

on:
push:
branches:
- master
pull_request:

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
test:
runs-on: ${{ matrix.os }}
continue-on-error: ${{ matrix.version == 'nightly' }}
strategy:
matrix:
version:
Expand All @@ -19,7 +24,8 @@ jobs:
- ubuntu-latest
arch:
- x64
AD:
group:
- Others
- ForwardDiff
- Tracker
- ReverseDiff
Expand All @@ -28,27 +34,42 @@ jobs:
- version: '1'
os: macOS-latest
arch: x64
AD: ForwardDiff
group: Others
- version: '1'
os: macOS-latest
arch: x64
group: ForwardDiff
- version: '1'
os: macOS-latest
arch: x64
AD: Tracker
group: Tracker
- version: '1'
os: macOS-latest
arch: x64
AD: ReverseDiff
group: ReverseDiff
- version: '1'
os: macOS-latest
arch: x64
AD: Zygote
group: Zygote
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: actions/cache@v1
env:
cache-name: cache-artifacts
with:
path: ~/.julia/artifacts
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
restore-keys: |
${{ runner.os }}-test-${{ env.cache-name }}-
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
with:
coverage: false
env:
GROUP: AD
AD: ${{ matrix.AD }}
GROUP: ${{ matrix.group }}
34 changes: 0 additions & 34 deletions .github/workflows/Others.yml

This file was deleted.

2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DistributionsAD"
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.31"
version = "0.6.32"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# DistributionsAD.jl

[![AD tests](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/AD.yml/badge.svg?branch=master)](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/AD.yml?query=branch%3Amaster)
[![Other tests](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/Others.yml/badge.svg?branch=master)](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/Others.yml?query=branch%3Amaster)
[![CI](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/CI.yml/badge.svg?branch=master)](https://github.com/TuringLang/DistributionsAD.jl/actions/workflows/CI.yml?query=branch%3Amaster)

This package defines the necessary functions to enable automatic differentiation (AD) of the `logpdf` function from [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) using the packages [Tracker.jl](https://github.com/FluxML/Tracker.jl), [Zygote.jl](https://github.com/FluxML/Zygote.jl), [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) and [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl). The goal of this package is to make the output of `logpdf` differentiable wrt all continuous parameters of a distribution as well as the random variable in the case of continuous distributions.

Expand Down
7 changes: 0 additions & 7 deletions src/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ const flattened_dists = [ Bernoulli,
Poisson,
Skellam,
Arcsine,
Beta,
BetaPrime,
Biweight,
Cauchy,
Expand All @@ -40,11 +39,9 @@ const flattened_dists = [ Bernoulli,
Exponential,
FDist,
Frechet,
Gamma,
GeneralizedExtremeValue,
GeneralizedPareto,
Gumbel,
#InverseGamma,
InverseGaussian,
Kolmogorov,
Laplace,
Expand All @@ -54,8 +51,6 @@ const flattened_dists = [ Bernoulli,
LogitNormal,
LogNormal,
Normal,
#NormalCanon,
#NormalInverseGaussian,
Pareto,
PGeneralizedGaussian,
Rayleigh,
Expand All @@ -64,8 +59,6 @@ const flattened_dists = [ Bernoulli,
TriangularDist,
Triweight,
TuringUniform,
#Truncated,
#VonMises,
]
for T in flattened_dists
@eval toflatten(::$T) = true
Expand Down
8 changes: 0 additions & 8 deletions src/zygote.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
# Zygote fill has issues with non-numbers
ZygoteRules.@adjoint function fill(x::T, dims...) where {T}
return ZygoteRules.pullback(x, dims...) do x, dims...
return reshape([x for i in 1:prod(dims)], dims)
end
end


## Uniform ##

ZygoteRules.@adjoint function Distributions.Uniform(args...)
Expand Down
Loading

2 comments on commit 47214f8

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/48384

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.32 -m "<description of version>" 47214f806790a6a281d1ec608059924e5c592192
git push origin v0.6.32

Please sign in to comment.