From b11399d871723fffad3758461f5f7874ad57a510 Mon Sep 17 00:00:00 2001 From: Evgeny Tankhilevich Date: Fri, 14 Feb 2020 12:01:06 +0000 Subject: [PATCH 1/3] adjoint for Base.reverse --- .gitignore | 1 + src/lib/array.jl | 6 ++++++ test/gradcheck.jl | 5 +++++ 3 files changed, 12 insertions(+) diff --git a/.gitignore b/.gitignore index 9e6791bdc..e34d18a9f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ *.jl.*.cov *.jl.mem docs/build +.tags* diff --git a/src/lib/array.jl b/src/lib/array.jl index e7fcf8a5e..41cdbdb9c 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -63,6 +63,12 @@ end circshift(A, shifts), Δ -> (circshift(Δ, map(-, shifts)), nothing) end +@adjoint function reverse(x::AbstractArray, args...; kwargs...) + _reverse(t) = reverse(t, args...; kwargs...) + _nothings(t) = map(_->nothing, keys(t)) + _reverse(x), Δ->(_reverse(Δ), _nothings(args)..., _nothings(kwargs)...) +end + @adjoint permutedims(xs) = permutedims(xs), Δ -> (permutedims(Δ),) @adjoint permutedims(xs::AbstractVector) = permutedims(xs), Δ -> (vec(permutedims(Δ)),) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 2bf4e5610..e4ad93301 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -129,6 +129,11 @@ end @test gradtest(x -> meanpool(x, pdims), x) end +@test gradtest(x -> reverse(x), rand(17)) +@test gradtest(x -> reverse(x, 8), rand(17)) +@test gradtest(x -> reverse(x, 8, 13), rand(17)) +@test gradtest(x -> reverse(x, dims=2), rand(17, 42)) + @test gradtest(x -> permutedims(x), rand(2)) @test gradtest(x -> permutedims(x), rand(2,3)) @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) From 0a92f385b93e289493e771ebf0fe835cbfe5f5a8 Mon Sep 17 00:00:00 2001 From: Evgeny Tankhilevich Date: Sat, 15 Feb 2020 17:26:15 +0000 Subject: [PATCH 2/3] revert gitignore --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index e34d18a9f..9e6791bdc 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,3 @@ *.jl.*.cov *.jl.mem docs/build -.tags* From 8f7da4db98eeeacd451fcde0d30f45875928b634 Mon Sep 17 00:00:00 2001 From: Evgeny Tankhilevich Date: Mon, 17 Feb 2020 15:03:11 +0000 Subject: [PATCH 3/3] remove kwargs grads --- src/lib/array.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 41cdbdb9c..08bee6bf2 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -65,8 +65,7 @@ end @adjoint function reverse(x::AbstractArray, args...; kwargs...) _reverse(t) = reverse(t, args...; kwargs...) - _nothings(t) = map(_->nothing, keys(t)) - _reverse(x), Δ->(_reverse(Δ), _nothings(args)..., _nothings(kwargs)...) + _reverse(x), Δ->(_reverse(Δ), map(_->nothing, args)...) end @adjoint permutedims(xs) = permutedims(xs), Δ -> (permutedims(Δ),)