Skip to content

Commit

Permalink
Add support for nested derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Sep 27, 2023
1 parent 87a3dba commit 61326ef
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
12 changes: 6 additions & 6 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N}
mapreduce(*, +, value(a), value(b))
end

function rrule(::Type{TaylorScalar{T, N}}, v::NTuple{N, T}) where {N, T <: Number}
function rrule(::Type{TaylorScalar{T, N}}, v::NTuple{N, T}) where {N, T}
taylor_scalar_pullback(t̄) = NoTangent(), value(t̄)
return TaylorScalar(v), taylor_scalar_pullback
end
Expand All @@ -22,7 +22,7 @@ function rrule(::typeof(value), t::TaylorScalar{T, N}) where {N, T}
end

function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
i::Integer) where {N, T <: Number}
i::Integer) where {N, T}
function extract_derivative_pullback(d̄)
NoTangent(), TaylorScalar{T, N}(ntuple(j -> j === i ?: zero(T), Val(N))),
NoTangent()
Expand All @@ -31,7 +31,7 @@ function rrule(::typeof(extract_derivative), t::TaylorScalar{T, N},
end

function rrule(::typeof(*), A::AbstractMatrix{S},
t::AbstractVector{TaylorScalar{T, N}}) where {N, S <: Number, T}
t::AbstractVector{TaylorScalar{T, N}}) where {N, S, T}
project_A = ProjectTo(A)
function gemv_pullback(x̄)
= reinterpret(reshape, T, x̄)
Expand All @@ -41,17 +41,17 @@ function rrule(::typeof(*), A::AbstractMatrix{S},
return A * t, gemv_pullback
end

@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T <: Number}
@adjoint function +(t::Vector{TaylorScalar{T, N}}, v::Vector{T}) where {N, T}
project_v = ProjectTo(v)
t + v, x̄ -> (x̄, project_v(x̄))
end

@adjoint function +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T <: Number}
@adjoint function +(v::Vector{T}, t::Vector{TaylorScalar{T, N}}) where {N, T}
project_v = ProjectTo(v)
v + t, x̄ -> (project_v(x̄), x̄)
end

(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx)
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T} = primal(dx)

# Not-a-number patches

Expand Down
8 changes: 4 additions & 4 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,26 @@ end
# Added to help Zygote infer types
make_taylor(t0::T, t1::S, ::Val{N}) where {T, S, N} = TaylorScalar{T, N}(t0, T(t1))

@inline function derivative(f, x::T, ::Val{N}) where {T <: Number, N}
@inline function derivative(f, x::T, ::Val{N}) where {T <: TN, N}
t = TaylorScalar{T, N}(x, one(x))
return extract_derivative(f(t), N)
end

@inline function derivative(f, x::AbstractVector{T}, l::AbstractVector{S},
vN::Val{N}) where {T <: Number, S <: Number, N}
vN::Val{N}) where {T <: TN, S <: TN, N}
t = map((t0, t1) -> make_taylor(t0, t1, vN), x, l)
# equivalent to map(TaylorScalar{T, N}, x, l)
return extract_derivative(f(t), N)
end

# shorthand notations for matrices

@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: Number, N}
@inline function derivative(f, x::AbstractMatrix{T}, vN::Val{N}) where {T <: TN, N}
size(x)[1] != 1 && @warn "x is not a row vector."
mapcols(u -> derivative(f, u[1], vN), x)
end

@inline function derivative(f, x::AbstractMatrix{T}, l::AbstractVector{S},
vN::Val{N}) where {T <: Number, S <: Number, N}
vN::Val{N}) where {T <: TN, S <: TN, N}
mapcols(u -> derivative(f, u, l, vN), x)
end
2 changes: 1 addition & 1 deletion src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct TaylorScalar{T, N}
value::NTuple{N, T}
end

TaylorOrNumber = Union{TaylorScalar, Number}
TN = Union{TaylorScalar, Number}

@inline TaylorScalar(xs::Vararg{T, N}) where {T, N} = TaylorScalar(xs)

Expand Down

0 comments on commit 61326ef

Please sign in to comment.